eval_dataset.py 8.44 KB
"""Generate labeled evaluation samples from an existing lyric library."""

from __future__ import annotations

import csv
import random
import re
from dataclasses import dataclass
from pathlib import Path

from lyric_dedup.file_import import iter_lyric_files
from lyric_dedup.file_import import read_lyric_file
from lyric_dedup.file_import import record_from_file
from lyric_dedup.normalization import normalize_lyrics


@dataclass(frozen=True)
class GeneratedSample:
    sample_id: str
    file: str
    expected: str
    sample_type: str
    source: str
    title: str = ""
    artist: str = ""


def generate_eval_set(
    *,
    library_dir: Path,
    output_dir: Path,
    csv_path: Path,
    size: int = 100,
    positive_ratio: float = 0.6,
    seed: int = 20260602,
) -> dict[str, object]:
    rng = random.Random(seed)
    source_files = iter_lyric_files(library_dir)
    if not source_files:
        raise ValueError(f"{library_dir} 下没有 .lrc/.txt 歌词文件")

    output_dir.mkdir(parents=True, exist_ok=True)
    csv_path.parent.mkdir(parents=True, exist_ok=True)
    _clean_generated_output_dir(output_dir)

    positives = round(size * positive_ratio)
    negatives = size - positives
    samples: list[GeneratedSample] = []
    for index in range(positives):
        source = source_files[index % len(source_files)]
        samples.append(_positive_sample(index + 1, source, output_dir, csv_path.parent, rng))
    for index in range(negatives):
        left = source_files[index % len(source_files)]
        right = source_files[(index + 1) % len(source_files)]
        samples.append(_negative_sample(positives + index + 1, left, right, output_dir, csv_path.parent, rng))

    rng.shuffle(samples)
    with csv_path.open("w", encoding="utf-8", newline="") as file:
        writer = csv.DictWriter(file, fieldnames=["id", "file", "expected", "sample_type", "source", "title", "artist"])
        writer.writeheader()
        writer.writerows(
            {
                "id": sample.sample_id,
                "file": sample.file,
                "expected": sample.expected,
                "sample_type": sample.sample_type,
                "source": sample.source,
                "title": sample.title,
                "artist": sample.artist,
            }
            for sample in samples
        )

    return {
        "size": size,
        "positive": positives,
        "negative": negatives,
        "library_files": len(source_files),
        "lyrics_dir": str(output_dir),
        "csv": str(csv_path),
    }


def _positive_sample(index: int, source: Path, output_dir: Path, csv_base: Path, rng: random.Random) -> GeneratedSample:
    raw = read_lyric_file(source)
    source_record = record_from_file(source)
    variants = [
        ("exact_copy", raw),
        ("timestamped", _add_timestamps(_content_lines(raw))),
        ("punctuation_noise", _add_punctuation_noise(_content_lines(raw), rng)),
        ("with_platform_noise", _with_platform_noise(_content_lines(raw))),
        ("blank_line_noise", _add_blank_line_noise(_content_lines(raw))),
        ("lrc_with_platform_noise", _add_timestamps(_content_lines(_with_platform_noise(_content_lines(raw))))),
        ("translation_added", _translation_added(_content_lines(raw))),
    ]
    sample_type, text = variants[(index - 1) % len(variants)]
    name = f"pos_{index:03d}_{sample_type}.txt"
    path = output_dir / name
    path.write_text(text, encoding="utf-8")
    return GeneratedSample(
        sample_id=f"pos-{index:03d}",
        file=str(path.relative_to(csv_base)),
        expected="应去重",
        sample_type=sample_type,
        source=str(source),
        title=source_record.title or "",
        artist=source_record.artist or "",
    )


def _negative_sample(index: int, left: Path, right: Path, output_dir: Path, csv_base: Path, rng: random.Random) -> GeneratedSample:
    left_lines = _normalized_lines(left)
    right_lines = _normalized_lines(right)
    variants = [
        ("single_song_fragment", _single_song_fragment(left_lines)),
        ("short_shared_snippet", _short_shared_snippet(left_lines, rng)),
        ("mixed_fragments", _mixed_fragments(left_lines, right_lines, rng)),
        ("same_theme_synthetic", _same_theme_synthetic(index)),
        ("translation_only_like", _translation_only_like(left_lines)),
    ]
    sample_type, text = variants[(index - 1) % len(variants)]
    name = f"neg_{index:03d}_{sample_type}.txt"
    path = output_dir / name
    path.write_text(text, encoding="utf-8")
    return GeneratedSample(
        sample_id=f"neg-{index:03d}",
        file=str(path.relative_to(csv_base)),
        expected="不应去重",
        sample_type=sample_type,
        source=f"{left} | {right}",
    )


def _content_lines(text: str) -> list[str]:
    lines = [line.strip() for line in text.splitlines() if line.strip()]
    return lines or [text.strip()]


def _clean_generated_output_dir(output_dir: Path) -> None:
    for path in output_dir.iterdir():
        if path.is_file() and path.suffix.lower() in {".txt", ".lrc"}:
            path.unlink()


def _normalized_lines(path: Path) -> list[str]:
    normalized = normalize_lyrics(read_lyric_file(path))
    return list(normalized.primary_lines or normalized.unique_lines)


def _add_timestamps(lines: list[str]) -> str:
    return "\n".join(f"[00:{idx % 60:02d}.00]{line}" for idx, line in enumerate(lines, start=1))


def _add_punctuation_noise(lines: list[str], rng: random.Random) -> str:
    marks = ["!", "?", "...", ",", "。"]
    return "\n".join(f"{line}{rng.choice(marks)}" for line in lines)


def _with_platform_noise(lines: list[str]) -> str:
    return "\n".join(["歌词来自QQ音乐", "作词:测试", *lines, "未经著作权人许可 不得翻唱"])


def _add_blank_line_noise(lines: list[str]) -> str:
    result: list[str] = []
    for idx, line in enumerate(lines, start=1):
        result.append(line)
        if idx % 4 == 0:
            result.append("")
    return "\n".join(result)


def _translation_added(lines: list[str]) -> str:
    result: list[str] = []
    for idx, line in enumerate(lines, start=1):
        result.append(line)
        if _looks_foreign(line) and idx <= 24:
            result.append(_pseudo_translation(idx))
    return "\n".join(result)


def _single_song_fragment(lines: list[str]) -> str:
    if len(lines) <= 4:
        return "\n".join(lines[: max(1, len(lines) // 2)])
    fragment_len = max(2, min(8, len(lines) // 4))
    start = max(0, (len(lines) - fragment_len) // 2)
    return "\n".join(lines[start : start + fragment_len])


def _short_shared_snippet(lines: list[str], rng: random.Random) -> str:
    snippet = rng.sample(lines, k=min(2, len(lines))) if lines else []
    synthetic = [
        "清晨的风吹过新的街口",
        "我把昨天放进安静的口袋",
        *snippet,
        "故事从这里重新开始",
        "灯光落下我继续往前走",
    ]
    return "\n".join(synthetic)


def _mixed_fragments(left_lines: list[str], right_lines: list[str], rng: random.Random) -> str:
    left_pick = rng.sample(left_lines, k=min(2, len(left_lines))) if left_lines else []
    right_pick = rng.sample(right_lines, k=min(2, len(right_lines))) if right_lines else []
    filler = ["新的旋律慢慢靠近", "陌生的名字写在风里", "没有人停在原地"]
    return "\n".join([*left_pick, *filler, *right_pick])


def _same_theme_synthetic(index: int) -> str:
    themes = [
        "我在夜里想起远方的你",
        "城市灯火陪我走过雨季",
        "那些没说完的话留在风里",
        "明天醒来我们各自继续",
        f"这是第 {index} 个全新测试样本",
    ]
    return "\n".join(themes)


def _translation_only_like(lines: list[str]) -> str:
    foreign_count = sum(1 for line in lines if _looks_foreign(line))
    if foreign_count < 2:
        return _same_theme_synthetic(foreign_count + len(lines))
    return "\n".join(_pseudo_translation(idx) for idx in range(1, min(8, foreign_count) + 1))


def _pseudo_translation(index: int) -> str:
    translations = [
        "今晚我仍然想念你",
        "风会带走所有疲惫",
        "黑暗里也会有光",
        "别让昨天困住自己",
        "我们终会继续向前",
        "雨停以后天空会亮",
        "把遗憾留在旧时光",
        "你已经足够好了",
    ]
    return translations[(index - 1) % len(translations)]


def _looks_foreign(line: str) -> bool:
    latin = len(re.findall(r"[A-Za-z]", line))
    cjk = len(re.findall(r"[\u4e00-\u9fff]", line))
    return latin > 0 and cjk == 0