generate_composition_testset.py
12.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
"""生成曲去重评估测试集。
从音频目录随机抽取若干首参照歌入库,对每首用 ffmpeg 生成多个变换版本,
覆盖曲去重测试样本类型.md 中第一类(数字信号变换)和第三类(困难正样本)的可合成部分。
负样本从未入库的 holdout 歌曲生成,以匹配最终接口 duplicate true/false 语义。
用法:
python scripts/generate_composition_testset.py \
--audio-dir /Volumes/移动硬盘/lyric_audio_type11 \
--negative-audio-dir /Volumes/移动硬盘/composition_test \
--out-dir composition_dedup/composition_testset \
--num-songs 80 \
--num-negative-songs 40 \
--negative-variants \
--seed 123
输出:
reference.csv — 参照曲(原始文件),需提前入库
queries.csv — 查询曲,带 variant 和 expected 标注
"""
import argparse
import csv
import logging
import random
import subprocess
import sys
from pathlib import Path
try:
from tqdm import tqdm
except ImportError:
tqdm = None
def _tqdm(iterable, **kwargs):
if tqdm is not None:
return tqdm(iterable, **kwargs)
total = kwargs.get("total", None) or (len(iterable) if hasattr(iterable, "__len__") else None)
desc = kwargs.get("desc", "")
class _Simple:
def __init__(self):
self._i = 0
def __iter__(self):
for item in iterable:
self._i += 1
if total:
print(f"\r{desc}: {self._i}/{total}", end="", flush=True)
yield item
if total:
print()
return _Simple()
logger = logging.getLogger(__name__)
# --------------------------------------------------------------------------
# 第一类:数字信号变换
# --------------------------------------------------------------------------
DSP_VARIANTS: list[tuple[str, str]] = [
# Pitch Shift(±1、±2 半音)
("pitch_up1", "asetrate=22050*1.0595,aresample=22050"), # +1 半音
("pitch_up2", "asetrate=22050*1.1225,aresample=22050"), # +2 半音
("pitch_down1", "asetrate=22050*0.9439,aresample=22050"), # -1 半音
("pitch_down2", "asetrate=22050*0.8909,aresample=22050"), # -2 半音
# Tempo Shift
("tempo_slow", "atempo=0.90"), # 0.9x
("tempo_fast", "atempo=1.10"), # 1.1x
("tempo_faster","atempo=1.20"), # 1.2x
# EQ 变换
("lowpass", "lowpass=f=4000"), # 低通
("highpass", "highpass=f=800"), # 高通
("eq_mid", "equalizer=f=1000:width_type=o:width=2:g=-6"), # 中频衰减
# 压缩编码往返(编码为 mp3 再解回 wav,模拟有损压缩引入的失真)
("codec_320k", "acodec=libmp3lame,b:a=320k"),
("codec_128k", "acodec=libmp3lame,b:a=128k"),
]
# --------------------------------------------------------------------------
# 第三类:困难正样本(可合成部分)
# --------------------------------------------------------------------------
HARD_POSITIVE_VARIANTS: list[tuple[str, str]] = [
# 前奏删减:从 20% 处开始截取(模拟删前奏版本)
("trim_intro", None), # 特殊处理,用 -ss 参数
# 只保留副歌:截取中间 40%(模拟短视频截段)
("chorus_only", None), # 特殊处理,用 -ss + -t 参数
# Pitch + Tempo 叠加(模拟 Live 版同时有调整)
("pitch_up1_tempo_slow", "asetrate=22050*1.0595,aresample=22050,atempo=0.92"),
]
# 负样本变体只使用相对温和的处理,避免把负样本评估变成极端音质测试。
NEGATIVE_VARIANTS: list[tuple[str, str | None]] = [
("negative_lowpass", "lowpass=f=4000"),
("negative_codec_128k", "acodec=libmp3lame,b:a=128k"),
]
def _ffmpeg_variant(src: Path, dst: Path, af: str) -> bool:
"""普通 audio filter 变换。"""
# 压缩编码往返需要两步:先编码为 mp3,再解回 wav
if "acodec" in af:
import tempfile
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp:
tmp_mp3 = Path(tmp.name)
ok1 = _run_ffmpeg([
"ffmpeg", "-y", "-i", str(src),
"-ar", "22050", "-ac", "1",
"-codec:a", "libmp3lame", "-b:a", af.split("b:a=")[1],
str(tmp_mp3),
])
if not ok1:
return False
ok2 = _run_ffmpeg([
"ffmpeg", "-y", "-i", str(tmp_mp3),
"-ar", "22050", "-ac", "1",
str(dst),
])
tmp_mp3.unlink(missing_ok=True)
return ok2
cmd = [
"ffmpeg", "-y", "-i", str(src),
"-af", af,
"-ar", "22050", "-ac", "1",
str(dst),
]
return _run_ffmpeg(cmd)
def _ffmpeg_trim(src: Path, dst: Path, start_ratio: float, duration_ratio: float) -> bool:
"""按相对位置截取片段。需要先探测时长。"""
duration = _probe_duration(src)
if duration is None:
return False
ss = duration * start_ratio
t = duration * duration_ratio
return _run_ffmpeg([
"ffmpeg", "-y", "-i", str(src),
"-ss", f"{ss:.3f}", "-t", f"{t:.3f}",
"-ar", "22050", "-ac", "1",
str(dst),
])
def _probe_duration(src: Path) -> float | None:
result = subprocess.run(
["ffprobe", "-v", "error", "-show_entries", "format=duration",
"-of", "default=noprint_wrappers=1:nokey=1", str(src)],
capture_output=True, text=True,
)
try:
return float(result.stdout.strip())
except ValueError:
return None
def _run_ffmpeg(cmd: list[str]) -> bool:
result = subprocess.run(cmd, capture_output=True)
return result.returncode == 0
def _song_id(path: Path) -> str:
return path.stem.split("_")[0]
def _discover_wavs(audio_dir: Path) -> list[Path]:
return [
f for f in sorted(audio_dir.rglob("*.wav"))
if f.is_file() and not f.name.startswith("._")
]
def main() -> None:
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
parser = argparse.ArgumentParser()
parser.add_argument("--audio-dir", required=True)
parser.add_argument(
"--negative-audio-dir",
default="/Volumes/移动硬盘/lyric_audio_type11",
help="负样本来源目录;会排除 --audio-dir 中已存在的 song_id",
)
parser.add_argument("--out-dir", required=True)
parser.add_argument("--num-songs", type=int, default=20, help="抽取歌曲数量")
parser.add_argument("--num-negative-songs", type=int, default=20, help="抽取未入库负样本歌曲数量")
parser.add_argument(
"--negative-variants",
action="store_true",
help="为负样本额外生成 codec/lowpass 变体",
)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
audio_dir = Path(args.audio_dir)
negative_audio_dir = Path(args.negative_audio_dir)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
variants_dir = out_dir / "variants"
variants_dir.mkdir(exist_ok=True)
all_wavs = _discover_wavs(audio_dir)
negative_wavs = _discover_wavs(negative_audio_dir)
if len(negative_wavs) < args.num_negative_songs:
logger.error(
"负样本目录下只有 %d 个 wav,少于 --num-negative-songs = %d",
len(negative_wavs),
args.num_negative_songs,
)
sys.exit(1)
# 从参照目录中排除负样本目录已有的 song_id,避免参照曲与负样本重叠
negative_song_ids = {_song_id(wav) for wav in negative_wavs}
all_wavs = [wav for wav in all_wavs if _song_id(wav) not in negative_song_ids]
if len(all_wavs) < args.num_songs:
logger.error(
"参照目录排除负样本 song_id 后只有 %d 个 wav,少于 --num-songs = %d",
len(all_wavs),
args.num_songs,
)
sys.exit(1)
random.seed(args.seed)
selected = random.sample(all_wavs, args.num_songs)
negative_selected = random.sample(negative_wavs, args.num_negative_songs)
logger.info(
"已抽取 %d 首参照歌,%d 首未入库负样本歌(负样本来源: %s,已排除 %d 个负样本 song_id)",
len(selected),
len(negative_selected),
negative_audio_dir,
len(negative_song_ids),
)
ref_rows = []
query_rows = []
for wav in _tqdm(selected, desc="生成正样本变体", total=len(selected)):
song_id = _song_id(wav)
ref_rows.append({
"song_id": song_id,
"audio_path": str(wav),
"variant": "original",
})
# 第一类:DSP 变换
for variant_name, af in DSP_VARIANTS:
dst = variants_dir / f"{song_id}_{variant_name}.wav"
ok = _ffmpeg_variant(wav, dst, af)
if not ok:
logger.warning("DSP 变换失败,跳过: %s %s", wav.name, variant_name)
continue
query_rows.append({
"song_id": song_id,
"audio_path": str(dst),
"variant": variant_name,
"sample_class": "dsp",
"expected_song_id": song_id,
"expected": "duplicate",
})
# 第三类:困难正样本
for variant_name, af in HARD_POSITIVE_VARIANTS:
dst = variants_dir / f"{song_id}_{variant_name}.wav"
if variant_name == "trim_intro":
ok = _ffmpeg_trim(wav, dst, start_ratio=0.20, duration_ratio=0.80)
elif variant_name == "chorus_only":
ok = _ffmpeg_trim(wav, dst, start_ratio=0.30, duration_ratio=0.40)
else:
ok = _ffmpeg_variant(wav, dst, af)
if not ok:
logger.warning("困难正样本生成失败,跳过: %s %s", wav.name, variant_name)
continue
query_rows.append({
"song_id": song_id,
"audio_path": str(dst),
"variant": variant_name,
"sample_class": "hard_positive",
"expected_song_id": song_id,
"expected": "duplicate",
})
# Boolean 接口负样本:查询音频不能在 reference.csv 入库集合中。
# expected_song_id 留空,表示没有目标重复曲;评测只看最终 duplicate true/false。
for wav in _tqdm(negative_selected, desc="生成负样本变体", total=len(negative_selected)):
song_id = _song_id(wav)
query_rows.append({
"song_id": song_id,
"audio_path": str(wav),
"variant": "negative_original",
"sample_class": "negative",
"expected_song_id": "",
"expected": "not_duplicate",
})
if not args.negative_variants:
continue
for variant_name, af in NEGATIVE_VARIANTS:
dst = variants_dir / f"{song_id}_{variant_name}.wav"
ok = _ffmpeg_variant(wav, dst, af)
if not ok:
logger.warning("负样本变换失败,跳过: %s %s", wav.name, variant_name)
continue
query_rows.append({
"song_id": song_id,
"audio_path": str(dst),
"variant": variant_name,
"sample_class": "negative",
"expected_song_id": "",
"expected": "not_duplicate",
})
ref_path = out_dir / "reference.csv"
query_path = out_dir / "queries.csv"
fieldnames = ["song_id", "audio_path", "variant", "sample_class", "expected_song_id", "expected"]
with ref_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=["song_id", "audio_path", "variant"])
writer.writeheader()
writer.writerows(ref_rows)
with query_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(query_rows)
pos = sum(1 for r in query_rows if r["expected"] == "duplicate")
neg = sum(1 for r in query_rows if r["expected"] == "not_duplicate")
logger.info("参照集: %s (%d 条)", ref_path, len(ref_rows))
logger.info("查询集: %s (%d 条,正样本 %d,负样本 %d)", query_path, len(query_rows), pos, neg)
# 按 sample_class 统计
from collections import Counter
by_class = Counter(r["sample_class"] for r in query_rows)
for cls, cnt in sorted(by_class.items()):
logger.info(" %-20s %d 条", cls, cnt)
if __name__ == "__main__":
main()