generate_composition_testset.py 12.4 KB
"""生成曲去重评估测试集。

从音频目录随机抽取若干首参照歌入库,对每首用 ffmpeg 生成多个变换版本,
覆盖曲去重测试样本类型.md 中第一类(数字信号变换)和第三类(困难正样本)的可合成部分。
负样本从未入库的 holdout 歌曲生成,以匹配最终接口 duplicate true/false 语义。

用法:
python scripts/generate_composition_testset.py \
    --audio-dir /Volumes/移动硬盘/composition_test \
    --negative-audio-dir /Volumes/移动硬盘/composition_drop \
    --out-dir composition_testset \
    --num-songs 100 \
    --num-negative-songs 100 \
    --negative-variants \
    --seed 123

输出:
    reference.csv   — 参照曲(原始文件),需提前入库
    queries.csv     — 查询曲,带 variant 和 expected 标注
"""

import argparse
import csv
import logging
import random
import subprocess
import sys
from pathlib import Path

try:
    from tqdm import tqdm
except ImportError:
    tqdm = None


def _tqdm(iterable, **kwargs):
    if tqdm is not None:
        return tqdm(iterable, **kwargs)
    total = kwargs.get("total", None) or (len(iterable) if hasattr(iterable, "__len__") else None)
    desc = kwargs.get("desc", "")
    class _Simple:
        def __init__(self):
            self._i = 0
        def __iter__(self):
            for item in iterable:
                self._i += 1
                if total:
                    print(f"\r{desc}: {self._i}/{total}", end="", flush=True)
                yield item
            if total:
                print()
    return _Simple()

logger = logging.getLogger(__name__)

# --------------------------------------------------------------------------
# 第一类:数字信号变换
# --------------------------------------------------------------------------
DSP_VARIANTS: list[tuple[str, str]] = [
    # Pitch Shift(±1、±2 半音)
    ("pitch_up1",   "asetrate=22050*1.0595,aresample=22050"),   # +1 半音
    ("pitch_up2",   "asetrate=22050*1.1225,aresample=22050"),   # +2 半音
    ("pitch_down1", "asetrate=22050*0.9439,aresample=22050"),   # -1 半音
    ("pitch_down2", "asetrate=22050*0.8909,aresample=22050"),   # -2 半音
    # Tempo Shift
    ("tempo_slow",  "atempo=0.90"),                              # 0.9x
    ("tempo_fast",  "atempo=1.10"),                              # 1.1x
    ("tempo_faster","atempo=1.20"),                              # 1.2x
    # EQ 变换
    ("lowpass",     "lowpass=f=4000"),                          # 低通
    ("highpass",    "highpass=f=800"),                          # 高通
    ("eq_mid",      "equalizer=f=1000:width_type=o:width=2:g=-6"),  # 中频衰减
    # 压缩编码往返(编码为 mp3 再解回 wav,模拟有损压缩引入的失真)
    ("codec_320k",  "acodec=libmp3lame,b:a=320k"),
    ("codec_128k",  "acodec=libmp3lame,b:a=128k"),
]

# --------------------------------------------------------------------------
# 第三类:困难正样本(可合成部分)
# --------------------------------------------------------------------------
HARD_POSITIVE_VARIANTS: list[tuple[str, str]] = [
    # 前奏删减:从 20% 处开始截取(模拟删前奏版本)
    ("trim_intro",    None),   # 特殊处理,用 -ss 参数
    # 只保留副歌:截取中间 40%(模拟短视频截段)
    ("chorus_only",   None),   # 特殊处理,用 -ss + -t 参数
    # Pitch + Tempo 叠加(模拟 Live 版同时有调整)
    ("pitch_up1_tempo_slow", "asetrate=22050*1.0595,aresample=22050,atempo=0.92"),
]

# 负样本变体只使用相对温和的处理,避免把负样本评估变成极端音质测试。
NEGATIVE_VARIANTS: list[tuple[str, str | None]] = [
    ("negative_lowpass", "lowpass=f=4000"),
    ("negative_codec_128k", "acodec=libmp3lame,b:a=128k"),
]


def _ffmpeg_variant(src: Path, dst: Path, af: str) -> bool:
    """普通 audio filter 变换。"""
    # 压缩编码往返需要两步:先编码为 mp3,再解回 wav
    if "acodec" in af:
        import tempfile
        with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp:
            tmp_mp3 = Path(tmp.name)
        ok1 = _run_ffmpeg([
            "ffmpeg", "-y", "-i", str(src),
            "-ar", "22050", "-ac", "1",
            "-codec:a", "libmp3lame", "-b:a", af.split("b:a=")[1],
            str(tmp_mp3),
        ])
        if not ok1:
            return False
        ok2 = _run_ffmpeg([
            "ffmpeg", "-y", "-i", str(tmp_mp3),
            "-ar", "22050", "-ac", "1",
            str(dst),
        ])
        tmp_mp3.unlink(missing_ok=True)
        return ok2

    cmd = [
        "ffmpeg", "-y", "-i", str(src),
        "-af", af,
        "-ar", "22050", "-ac", "1",
        str(dst),
    ]
    return _run_ffmpeg(cmd)


def _ffmpeg_trim(src: Path, dst: Path, start_ratio: float, duration_ratio: float) -> bool:
    """按相对位置截取片段。需要先探测时长。"""
    duration = _probe_duration(src)
    if duration is None:
        return False
    ss = duration * start_ratio
    t = duration * duration_ratio
    return _run_ffmpeg([
        "ffmpeg", "-y", "-i", str(src),
        "-ss", f"{ss:.3f}", "-t", f"{t:.3f}",
        "-ar", "22050", "-ac", "1",
        str(dst),
    ])


def _probe_duration(src: Path) -> float | None:
    result = subprocess.run(
        ["ffprobe", "-v", "error", "-show_entries", "format=duration",
         "-of", "default=noprint_wrappers=1:nokey=1", str(src)],
        capture_output=True, text=True,
    )
    try:
        return float(result.stdout.strip())
    except ValueError:
        return None


def _run_ffmpeg(cmd: list[str]) -> bool:
    result = subprocess.run(cmd, capture_output=True)
    return result.returncode == 0


def _song_id(path: Path) -> str:
    return path.stem.split("_")[0]


def _discover_wavs(audio_dir: Path) -> list[Path]:
    return [
        f for f in sorted(audio_dir.rglob("*.wav"))
        if f.is_file() and not f.name.startswith("._")
    ]


def main() -> None:
    logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")

    parser = argparse.ArgumentParser()
    parser.add_argument("--audio-dir", required=True)
    parser.add_argument(
        "--negative-audio-dir",
        default="/Volumes/移动硬盘/lyric_audio_type11",
        help="负样本来源目录;会排除 --audio-dir 中已存在的 song_id",
    )
    parser.add_argument("--out-dir", required=True)
    parser.add_argument("--num-songs", type=int, default=20, help="抽取歌曲数量")
    parser.add_argument("--num-negative-songs", type=int, default=20, help="抽取未入库负样本歌曲数量")
    parser.add_argument(
        "--negative-variants",
        action="store_true",
        help="为负样本额外生成 codec/lowpass 变体",
    )
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    audio_dir = Path(args.audio_dir)
    negative_audio_dir = Path(args.negative_audio_dir)
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    variants_dir = out_dir / "variants"
    variants_dir.mkdir(exist_ok=True)

    all_wavs = _discover_wavs(audio_dir)
    negative_wavs = _discover_wavs(negative_audio_dir)

    if len(negative_wavs) < args.num_negative_songs:
        logger.error(
            "负样本目录下只有 %d 个 wav,少于 --num-negative-songs = %d",
            len(negative_wavs),
            args.num_negative_songs,
        )
        sys.exit(1)

    # 从参照目录中排除负样本目录已有的 song_id,避免参照曲与负样本重叠
    negative_song_ids = {_song_id(wav) for wav in negative_wavs}
    all_wavs = [wav for wav in all_wavs if _song_id(wav) not in negative_song_ids]

    if len(all_wavs) < args.num_songs:
        logger.error(
            "参照目录排除负样本 song_id 后只有 %d 个 wav,少于 --num-songs = %d",
            len(all_wavs),
            args.num_songs,
        )
        sys.exit(1)

    random.seed(args.seed)
    selected = random.sample(all_wavs, args.num_songs)
    negative_selected = random.sample(negative_wavs, args.num_negative_songs)
    logger.info(
        "已抽取 %d 首参照歌,%d 首未入库负样本歌(负样本来源: %s,已排除 %d 个负样本 song_id)",
        len(selected),
        len(negative_selected),
        negative_audio_dir,
        len(negative_song_ids),
    )

    ref_rows = []
    query_rows = []

    for wav in _tqdm(selected, desc="生成正样本变体", total=len(selected)):
        song_id = _song_id(wav)

        ref_rows.append({
            "song_id": song_id,
            "audio_path": str(wav),
            "variant": "original",
        })

        # 第一类:DSP 变换
        for variant_name, af in DSP_VARIANTS:
            dst = variants_dir / f"{song_id}_{variant_name}.wav"
            ok = _ffmpeg_variant(wav, dst, af)
            if not ok:
                logger.warning("DSP 变换失败,跳过: %s %s", wav.name, variant_name)
                continue
            query_rows.append({
                "song_id": song_id,
                "audio_path": str(dst),
                "variant": variant_name,
                "sample_class": "dsp",
                "expected_song_id": song_id,
                "expected": "duplicate",
            })

        # 第三类:困难正样本
        for variant_name, af in HARD_POSITIVE_VARIANTS:
            dst = variants_dir / f"{song_id}_{variant_name}.wav"
            if variant_name == "trim_intro":
                ok = _ffmpeg_trim(wav, dst, start_ratio=0.20, duration_ratio=0.80)
            elif variant_name == "chorus_only":
                ok = _ffmpeg_trim(wav, dst, start_ratio=0.30, duration_ratio=0.40)
            else:
                ok = _ffmpeg_variant(wav, dst, af)
            if not ok:
                logger.warning("困难正样本生成失败,跳过: %s %s", wav.name, variant_name)
                continue
            query_rows.append({
                "song_id": song_id,
                "audio_path": str(dst),
                "variant": variant_name,
                "sample_class": "hard_positive",
                "expected_song_id": song_id,
                "expected": "duplicate",
            })

    # Boolean 接口负样本:查询音频不能在 reference.csv 入库集合中。
    # expected_song_id 留空,表示没有目标重复曲;评测只看最终 duplicate true/false。
    for wav in _tqdm(negative_selected, desc="生成负样本变体", total=len(negative_selected)):
        song_id = _song_id(wav)
        query_rows.append({
            "song_id": song_id,
            "audio_path": str(wav),
            "variant": "negative_original",
            "sample_class": "negative",
            "expected_song_id": "",
            "expected": "not_duplicate",
        })

        if not args.negative_variants:
            continue

        for variant_name, af in NEGATIVE_VARIANTS:
            dst = variants_dir / f"{song_id}_{variant_name}.wav"
            ok = _ffmpeg_variant(wav, dst, af)
            if not ok:
                logger.warning("负样本变换失败,跳过: %s %s", wav.name, variant_name)
                continue
            query_rows.append({
                "song_id": song_id,
                "audio_path": str(dst),
                "variant": variant_name,
                "sample_class": "negative",
                "expected_song_id": "",
                "expected": "not_duplicate",
            })

    ref_path = out_dir / "reference.csv"
    query_path = out_dir / "queries.csv"

    fieldnames = ["song_id", "audio_path", "variant", "sample_class", "expected_song_id", "expected"]
    with ref_path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=["song_id", "audio_path", "variant"])
        writer.writeheader()
        writer.writerows(ref_rows)

    with query_path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(query_rows)

    pos = sum(1 for r in query_rows if r["expected"] == "duplicate")
    neg = sum(1 for r in query_rows if r["expected"] == "not_duplicate")
    logger.info("参照集: %s (%d 条)", ref_path, len(ref_rows))
    logger.info("查询集: %s (%d 条,正样本 %d,负样本 %d)", query_path, len(query_rows), pos, neg)

    # 按 sample_class 统计
    from collections import Counter
    by_class = Counter(r["sample_class"] for r in query_rows)
    for cls, cnt in sorted(by_class.items()):
        logger.info("  %-20s %d 条", cls, cnt)


if __name__ == "__main__":
    main()