Commit 8413944a 8413944ad675bb85c114f5f012b4257a140fef8e by 沈秋雨

添加曲结构去重

1 parent cdfa3a58
1 from .service import CompositionCandidate, CompositionConfig, CompositionDedupService
2 from .dejavu_fingerprinter import fingerprint_audio
3
4 __all__ = [
5 "CompositionCandidate",
6 "CompositionConfig",
7 "CompositionDedupService",
8 "fingerprint_audio",
9 ]
1 """Dejavu 风格的音频指纹生成。
2
3 基于 worldveil/dejavu 的指纹算法提取实现,不依赖 Dejavu 的数据库层。
4 使用 scipy.signal.spectrogram 替代已废弃的 matplotlib.mlab.specgram。
5
6 流程:
7 1. 音频标准化:ffmpeg 转 44100Hz / Mono / WAV
8 2. librosa 加载音频
9 3. 短时傅里叶变换(STFT)→ 对数频谱图
10 4. 2D 峰值检测:在频谱图中找局部极大值
11 5. 指纹哈希:对峰值对 (freq1, freq2, time_delta) 做 SHA1,取前 20 位
12 """
13
14 import hashlib
15 import logging
16 import os
17 import subprocess
18 import tempfile
19 from operator import itemgetter
20 from pathlib import Path
21
22 import librosa
23 import numpy as np
24 from scipy.ndimage import (
25 binary_erosion,
26 generate_binary_structure,
27 iterate_structure,
28 maximum_filter,
29 )
30 from scipy.signal import spectrogram
31
32 logger = logging.getLogger(__name__)
33
34
35 def _load_env_file() -> None:
36 """加载项目根目录 .env,不覆盖已存在的真实环境变量。"""
37 env_path = Path(__file__).resolve().parent.parent / ".env"
38 if not env_path.exists():
39 return
40 with env_path.open(encoding="utf-8") as file:
41 for raw_line in file:
42 line = raw_line.strip()
43 if not line or line.startswith("#") or "=" not in line:
44 continue
45 key, value = line.split("=", 1)
46 os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'"))
47
48
49 _load_env_file()
50
51 # ===== 常量(可通过环境变量覆盖)=====
52
53 DEFAULT_FS = 44100
54 DEFAULT_WINDOW_SIZE = 4096
55 DEFAULT_OVERLAP_RATIO = float(os.environ.get("COMPOSITION_DEJAVU_OVERLAP_RATIO", "0.3"))
56 DEFAULT_FAN_VALUE = int(os.environ.get("COMPOSITION_DEJAVU_FAN_VALUE", "10"))
57 DEFAULT_AMP_MIN = float(os.environ.get("COMPOSITION_DEJAVU_AMP_MIN", "20"))
58 PEAK_NEIGHBORHOOD_SIZE = 20
59 MIN_HASH_TIME_DELTA = 0
60 MAX_HASH_TIME_DELTA = 200
61 PEAK_SORT = True
62 FINGERPRINT_REDUCTION = 20
63 MAX_DURATION_SEC = float(os.environ.get("COMPOSITION_DEJAVU_MAX_DURATION", "120")) # 0=不限制
64
65
66 def _normalize_audio(audio_path: str, max_duration: float = MAX_DURATION_SEC) -> tuple[np.ndarray, int]:
67 """将音频标准化为单声道 WAV 并加载为 numpy 数组。
68
69 使用 ffmpeg 先做重采样,再用 librosa 读取。
70 可选限制音频长度,超长音频只取前 N 秒。
71 """
72 with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
73 tmp_wav = tmp.name
74
75 try:
76 cmd = [
77 "ffmpeg",
78 "-y",
79 "-i", audio_path,
80 "-ar", str(DEFAULT_FS),
81 "-ac", "1",
82 "-f", "wav",
83 ]
84 if max_duration > 0:
85 cmd += ["-t", str(max_duration)]
86 cmd.append(tmp_wav)
87
88 result = subprocess.run(cmd, capture_output=True, text=True)
89 if result.returncode != 0:
90 raise RuntimeError(f"ffmpeg 转换失败: {result.stderr}")
91
92 y, sr = librosa.load(tmp_wav, sr=DEFAULT_FS, mono=True)
93 return y, sr
94 finally:
95 if os.path.exists(tmp_wav):
96 os.remove(tmp_wav)
97
98
99 def _specgram(samples: np.ndarray, fs: int, window_size: int, overlap_ratio: float):
100 """计算对数频谱图,替代 matplotlib.mlab.specgram。
101
102 Returns:
103 arr2D: shape (n_freq, n_time) 的对数频谱矩阵(dBFS 刻度)
104 """
105 noverlap = int(window_size * overlap_ratio)
106 window = np.hanning(window_size)
107
108 freqs, times, Sxx = spectrogram(
109 samples,
110 fs=fs,
111 window=window,
112 nperseg=window_size,
113 noverlap=noverlap,
114 )
115
116 # 转为对数尺度(dBFS,0 dB 为峰值参考)
117 # scipy.signal.spectrogram 返回 PSD,mlab.specgram 返回功率,两者量纲不同
118 # 统一转为相对于峰值的 dBFS 刻度,使强信号峰值落在 20~80 dB 范围
119 arr2D = 10 * np.log10(Sxx + 1e-10)
120 arr2D = arr2D - arr2D.max() # 归一化到峰值为 0 dBFS
121 arr2D = arr2D + 80 # 偏移使典型峰值落在 20~80 dB(与 mlab.specgram 一致)
122 arr2D[arr2D < -100] = -100 # 限幅
123 return arr2D
124
125
126 def _get_2d_peaks(arr2D: np.ndarray, amp_min: float = DEFAULT_AMP_MIN):
127 """在频谱图中检测 2D 局部极大值。
128
129 Returns:
130 (frequency_idx, time_idx): 峰值的频率和时间索引列表
131 """
132 struct = generate_binary_structure(2, 1)
133 neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
134
135 # 找局部极大值
136 local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
137 background = arr2D == 0
138 eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
139
140 # 布尔掩码
141 detected_peaks = local_max ^ eroded_background
142
143 # 提取峰值
144 amps = arr2D[detected_peaks]
145 j, i = np.where(detected_peaks)
146
147 # 过滤低于阈值的峰值
148 peaks = list(zip(i, j, amps))
149 peaks_filtered = [x for x in peaks if x[2] > amp_min]
150
151 frequency_idx = [x[1] for x in peaks_filtered]
152 time_idx = [x[0] for x in peaks_filtered]
153
154 return frequency_idx, time_idx
155
156
157 def _generate_hashes(peaks: list[tuple[int, int]], fan_value: int = DEFAULT_FAN_VALUE):
158 """根据峰值对生成 SHA1 指纹哈希。
159
160 Args:
161 peaks: [(freq_idx, time_idx), ...] 列表
162 fan_value: 每个峰值与后续多少个峰值配对
163
164 Yields:
165 (hash_bytes, time_offset) 元组
166 """
167 if PEAK_SORT:
168 peaks.sort(key=itemgetter(1))
169
170 for i in range(len(peaks)):
171 for j in range(1, fan_value):
172 if i + j < len(peaks):
173 freq1 = peaks[i][0]
174 freq2 = peaks[i + j][0]
175 t1 = peaks[i][1]
176 t2 = peaks[i + j][1]
177 t_delta = t2 - t1
178
179 if MIN_HASH_TIME_DELTA <= t_delta <= MAX_HASH_TIME_DELTA:
180 h = hashlib.sha1(f"{freq1}|{freq2}|{t_delta}".encode())
181 yield (h.hexdigest()[:FINGERPRINT_REDUCTION].encode(), t1)
182
183
184 def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]:
185 """对音频文件生成 Dejavu 风格指纹。
186
187 Args:
188 audio_path: 音频文件路径。
189
190 Returns:
191 (file_sha1, fingerprints) 元组,
192 其中 fingerprints 是 [(hash_bytes, offset), ...] 列表。
193
194 Raises:
195 FileNotFoundError: 音频文件不存在。
196 RuntimeError: ffmpeg 转换失败。
197 """
198 if not os.path.isfile(audio_path):
199 raise FileNotFoundError(f"音频文件不存在: {audio_path}")
200
201 # 1. 标准化并加载音频(可选限制长度)
202 samples, fs = _normalize_audio(audio_path)
203
204 # 2. 计算文件 SHA1(用于标识)
205 file_sha1 = hashlib.sha1(samples.tobytes()).hexdigest()[:16]
206
207 # 3. 计算频谱图
208 arr2D = _specgram(samples, fs, DEFAULT_WINDOW_SIZE, DEFAULT_OVERLAP_RATIO)
209
210 # 4. 检测 2D 峰值
211 freq_idx, time_idx = _get_2d_peaks(arr2D)
212 peaks = list(zip(freq_idx, time_idx))
213
214 # 5. 生成指纹哈希
215 fingerprints = list(_generate_hashes(peaks))
216
217 logger.info("指纹生成完成: audio=%s, 指纹数=%d", audio_path, len(fingerprints))
218 return file_sha1, fingerprints
1 """Chromagram 特征提取。
2
3 流程:
4 1. 音频标准化:ffmpeg 转 22050Hz / Mono / WAV
5 2. librosa 加载音频
6 3. librosa.feature.chroma_cens() 提取 12×T Chromagram(CENS,对速度/音色鲁棒)
7 4. 主音对齐:将能量最大的音级滚至第 0 行,实现转调不变性
8 5. scipy.signal.resample(chroma, 128, axis=1) 时间归一化到 12×128
9 6. .flatten() 展开为 1536 维向量
10 """
11
12 import logging
13 import os
14 import subprocess
15 import tempfile
16
17 import librosa
18 import numpy as np
19 from scipy.signal import resample
20
21 logger = logging.getLogger(__name__)
22
23 # 目标采样率和时间帧数
24 TARGET_SR = 22050
25 TARGET_FRAMES = 128
26 VECTOR_DIM = 12 * TARGET_FRAMES # 1536
27
28
29 def _normalize_audio_ffmpeg(audio_path: str, output_path: str) -> None:
30 """使用 ffmpeg 将音频标准化为 22050Hz / Mono / WAV。"""
31 cmd = [
32 "ffmpeg",
33 "-y",
34 "-i", audio_path,
35 "-ar", str(TARGET_SR),
36 "-ac", "1",
37 "-f", "wav",
38 output_path,
39 ]
40 result = subprocess.run(
41 cmd,
42 capture_output=True,
43 text=True,
44 )
45 if result.returncode != 0:
46 raise RuntimeError(f"ffmpeg 转换失败: {result.stderr}")
47
48
49 def extract_chroma_feature(audio_path: str) -> np.ndarray:
50 """从音频文件提取 1536 维 Chromagram 特征向量。
51
52 Args:
53 audio_path: 音频文件路径。
54
55 Returns:
56 shape 为 (1536,) 的 numpy 数组。
57
58 Raises:
59 FileNotFoundError: 音频文件不存在。
60 RuntimeError: ffmpeg 转换失败。
61 """
62 if not os.path.isfile(audio_path):
63 raise FileNotFoundError(f"音频文件不存在: {audio_path}")
64
65 # 1. 音频标准化:ffmpeg 转 WAV
66 with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
67 tmp_wav = tmp.name
68
69 try:
70 _normalize_audio_ffmpeg(audio_path, tmp_wav)
71
72 # 2. librosa 加载音频
73 y, _sr = librosa.load(tmp_wav, sr=TARGET_SR, mono=True)
74
75 # 3. 提取 CENS Chromagram (12×T),对速度变化和音色具有更强鲁棒性
76 chroma = librosa.feature.chroma_cens(y=y, sr=TARGET_SR)
77
78 # 4. 主音对齐:将全局能量最大的音级循环滚至第 0 行,实现转调不变性
79 tonic = int(np.argmax(chroma.sum(axis=1)))
80 if tonic != 0:
81 chroma = np.roll(chroma, -tonic, axis=0)
82
83 # 5. 时间归一化到 12×128
84 if chroma.shape[1] != TARGET_FRAMES:
85 chroma = resample(chroma, TARGET_FRAMES, axis=1)
86
87 # 6. 展开为 1536 维向量
88 feature = chroma.flatten().astype(np.float32)
89
90 assert feature.shape == (VECTOR_DIM,), (
91 f"特征维度错误: 期望 {VECTOR_DIM}, 实际 {feature.shape}"
92 )
93
94 return feature
95 finally:
96 # 清理临时文件
97 if os.path.exists(tmp_wav):
98 os.remove(tmp_wav)
99
100
101 def extract_chroma_matrix(audio_path: str) -> np.ndarray:
102 """从音频文件提取 12×128 Chromagram 矩阵(未展平,供 DTW 精排使用)。
103
104 Returns:
105 shape 为 (12, 128) 的 numpy 数组,已做主音对齐。
106 """
107 feature = extract_chroma_feature(audio_path)
108 return feature.reshape(12, TARGET_FRAMES)
1 """作曲去重服务(入库 + 查询)。
2
3 查询流程:
4 1. Dejavu 指纹匹配(毫秒级,子序列匹配,支持 chorus_only / trim_intro)
5 - 命中(≥ 阈值)→ 直接返回 duplicate(短路)
6 2. 未命中 → Chromagram 12路 + DTW(百毫秒级)
7 - 返回结果
8 """
9
10 import logging
11 import os
12 from dataclasses import dataclass, field
13
14 import numpy as np
15 import psycopg
16 from scipy.spatial.distance import cdist
17
18 from .extractor import TARGET_FRAMES, extract_chroma_feature, extract_chroma_matrix
19 from .dejavu_fingerprinter import fingerprint_audio
20
21 logger = logging.getLogger(__name__)
22
23
24 def _env_bool(name: str, default: bool) -> bool:
25 value = os.getenv(name)
26 if value is None:
27 return default
28 return value.strip().lower() in {"1", "true", "yes", "y", "on"}
29
30
31 def _env_int(name: str, default: int) -> int:
32 value = os.getenv(name)
33 if value is None:
34 return default
35 try:
36 return int(value)
37 except ValueError:
38 logger.warning("环境变量 %s=%r 不是整数,使用默认值 %d", name, value, default)
39 return default
40
41
42 def _env_float(name: str, default: float) -> float:
43 value = os.getenv(name)
44 if value is None:
45 return default
46 try:
47 return float(value)
48 except ValueError:
49 logger.warning("环境变量 %s=%r 不是数字,使用默认值 %.4f", name, value, default)
50 return default
51
52
53 @dataclass
54 class CompositionCandidate:
55 """去重候选结果。"""
56 song_id: int
57 similarity: float
58 source: str = "chromagram"
59
60
61 @dataclass
62 class _DejavuMatch:
63 """Dejavu offset 对齐后的命中结果。"""
64 song_id: int
65 aligned_count: int
66 total_collisions: int
67
68
69 @dataclass
70 class CompositionConfig:
71 """作曲去重服务配置。"""
72 dsn: str = "postgresql:///lyric_dedup"
73 statement_timeout_ms: int = 30000
74 dtw_rerank_top_k: int = 20 # Cosine 召回后做 DTW 精排的候选数量
75 duplicate_threshold: float = _env_float("COMPOSITION_DUPLICATE_THRESHOLD", 0.85)
76 # Dejavu 指纹匹配配置
77 dejavu_enabled: bool = _env_bool("COMPOSITION_DEJAVU_ENABLED", True)
78 dejavu_match_threshold: int = _env_int("COMPOSITION_DEJAVU_MATCH_THRESHOLD", 20)
79
80
81 @dataclass
82 class CompositionDedupService:
83 """作曲去重服务:特征入库 + 相似度查询。"""
84 config: CompositionConfig
85 _logger: logging.Logger = field(default_factory=lambda: logger, repr=False)
86
87 def ingest(self, song_id: int, audio_path: str) -> np.ndarray:
88 """提取音频特征并写入数据库。
89
90 Args:
91 song_id: 歌曲 ID。
92 audio_path: 音频文件路径。
93
94 Returns:
95 提取的特征向量。
96 """
97 feature = extract_chroma_feature(audio_path)
98 self._logger.info("提取 Chromagram 特征完成: song_id=%s, audio=%s", song_id, audio_path)
99
100 with psycopg.connect(self.config.dsn) as conn:
101 with conn.cursor() as cursor:
102 cursor.execute(
103 """
104 INSERT INTO composition_feature (song_id, feature_vector)
105 VALUES (%s, %s)
106 ON CONFLICT DO NOTHING
107 """,
108 (song_id, feature.tolist()),
109 )
110 conn.commit()
111
112 self._logger.info("Chromagram 特征入库完成: song_id=%s", song_id)
113
114 # Dejavu 指纹同时入库
115 if self.config.dejavu_enabled:
116 self._dejavu_ingest(song_id, audio_path)
117
118 return feature
119
120 def _dejavu_ingest(self, song_id: int, audio_path: str) -> None:
121 """提取 Dejavu 指纹并写入 dejavu_fingerprints 表。"""
122 file_sha1, fingerprints = fingerprint_audio(audio_path)
123 if not fingerprints:
124 self._logger.warning("Dejavu 指纹为空: song_id=%s, audio=%s", song_id, audio_path)
125 return
126
127 with psycopg.connect(self.config.dsn) as conn:
128 with conn.cursor() as cursor:
129 # 先清理可能残留的旧指纹(幂等写入)
130 cursor.execute(
131 "DELETE FROM dejavu_fingerprints WHERE song_id = %s",
132 (song_id,),
133 )
134 # 批量写入
135 records = [(song_id, h, o) for h, o in fingerprints]
136 cursor.executemany(
137 """
138 INSERT INTO dejavu_fingerprints (song_id, hash, "offset")
139 VALUES (%s, %s, %s)
140 """,
141 records,
142 )
143 conn.commit()
144
145 self._logger.info("Dejavu 指纹入库完成: song_id=%s, 指纹数=%d", song_id, len(fingerprints))
146
147 def query(self, audio_path: str, top_k: int = 100) -> list[CompositionCandidate]:
148 """提取音频特征并查询相似结果。
149
150 流程:Dejavu 指纹短路匹配 → 12 路循环对齐 Cosine 召回 → DTW 精排。
151 """
152 # 1. 优先尝试 Dejavu 指纹匹配(短路)
153 if self.config.dejavu_enabled:
154 match = self._dejavu_query(audio_path)
155 if match is not None:
156 self._logger.info(
157 "Dejavu 命中: song_id=%s, aligned_count=%d, total_collisions=%d, decision=duplicate",
158 match.song_id,
159 match.aligned_count,
160 match.total_collisions,
161 )
162 return [CompositionCandidate(song_id=match.song_id, similarity=1.0, source="dejavu")]
163
164 # 2. Dejavu 未命中或禁用,走现有 Chromagram 12路 + DTW 流程
165 return self._query_chroma(audio_path, top_k)
166
167 def check(self, audio_path: str, top_k: int = 100) -> bool:
168 """按最终接口语义返回是否重复。"""
169 return self.candidates_indicate_duplicate(self.query(audio_path, top_k=top_k))
170
171 def candidates_indicate_duplicate(self, candidates: list[CompositionCandidate]) -> bool:
172 """将候选结果转换为最终 duplicate bool。
173
174 最终接口只返回 true/false,因此判定只看当前查询的最佳候选是否超过阈值,
175 不依赖评测集里的 expected_song_id 是否出现在 top-k。
176 """
177 if not candidates:
178 return False
179 return candidates[0].similarity >= self.config.duplicate_threshold
180
181 def _query_chroma(self, audio_path: str, top_k: int = 100) -> list[CompositionCandidate]:
182 """Chromagram 12 路循环对齐 + DTW 精排查询。"""
183 chroma = extract_chroma_matrix(audio_path)
184 self._logger.info("提取 Chromagram 查询特征完成: audio=%s", audio_path)
185
186 # 1. 12 路循环对齐:穷举 12 种半音偏移,单条 SQL 内部展开,按 song_id 取最高 Cosine 相似度
187 shift_vecs = [
188 np.roll(chroma, -shift, axis=0).flatten().astype(np.float32).tolist()
189 for shift in range(12)
190 ]
191 # 用 VALUES 展开 12 个偏移向量,LATERAL 子查询对每个偏移各触发一次 HNSW 扫描
192 values_clause = ", ".join(f"({i}, %s::vector)" for i in range(12))
193 sql = f"""
194 WITH shifts(shift_id, vec) AS (
195 VALUES {values_clause}
196 ),
197 candidates AS (
198 SELECT
199 cf.song_id,
200 1 - (cf.feature_vector <=> s.vec) AS sim
201 FROM shifts s
202 CROSS JOIN LATERAL (
203 SELECT song_id, feature_vector
204 FROM composition_feature
205 ORDER BY feature_vector <=> s.vec
206 LIMIT %s
207 ) cf
208 )
209 SELECT song_id, MAX(sim) AS similarity
210 FROM candidates
211 GROUP BY song_id
212 ORDER BY similarity DESC
213 LIMIT %s
214 """
215 best: dict[int, float] = {}
216 with psycopg.connect(self.config.dsn) as conn:
217 with conn.cursor() as cursor:
218 cursor.execute(
219 f"SET statement_timeout = {int(self.config.statement_timeout_ms)}"
220 )
221 cursor.execute(sql, (*shift_vecs, top_k, top_k))
222 for song_id, sim in cursor.fetchall():
223 best[int(song_id)] = float(sim)
224
225 # 2. 取 Top dtw_rerank_top_k,从库中取原始向量做 DTW 精排
226 top = sorted(best.items(), key=lambda x: x[1], reverse=True)
227 rerank_ids = [sid for sid, _ in top[:self.config.dtw_rerank_top_k]]
228
229 with conn.cursor() as cursor:
230 cursor.execute(
231 "SELECT song_id, feature_vector::float4[] FROM composition_feature WHERE song_id = ANY(%s)",
232 (rerank_ids,),
233 )
234 db_rows = cursor.fetchall()
235
236 reranked = []
237 for song_id, fv in db_rows:
238 cand_chroma = np.array(fv, dtype=np.float32).reshape(12, TARGET_FRAMES)
239 dtw_sim = _best_shifted_dtw_similarity(chroma, cand_chroma)
240 reranked.append(CompositionCandidate(song_id=int(song_id), similarity=dtw_sim))
241 reranked.sort(key=lambda c: c.similarity, reverse=True)
242
243 rerank_id_set = {c.song_id for c in reranked}
244 rest = [
245 CompositionCandidate(song_id=sid, similarity=sim)
246 for sid, sim in top[self.config.dtw_rerank_top_k:]
247 if sid not in rerank_id_set
248 ]
249
250 result = reranked + rest
251 top_summary = ", ".join(
252 f"{candidate.song_id}:{candidate.similarity:.4f}"
253 for candidate in result[:5]
254 )
255 self._logger.info(
256 "Chromagram 查询完成: 返回 %d 个候选, top=%s",
257 len(result),
258 top_summary or "[]",
259 )
260 return result
261
262 def _dejavu_query(self, audio_path: str) -> _DejavuMatch | None:
263 """Dejavu 指纹查询,返回 offset 对齐后碰撞数最多的 song_id。
264
265 只统计 hash 总碰撞数会让常见频谱峰值、噪声片段或大库随机碰撞直接短路成
266 similarity=1.0。Dejavu 的关键判据是同一首候选歌里,多个 hash 碰撞的
267 db_offset - query_offset 落在同一个时间偏移上。
268
269 Returns:
270 命中结果,未命中返回 None。
271 """
272 file_sha1, fingerprints = fingerprint_audio(audio_path)
273 if not fingerprints:
274 return None
275
276 hashes = [h for h, _ in fingerprints]
277 offsets = [int(o) for _, o in fingerprints]
278
279 with psycopg.connect(self.config.dsn) as conn:
280 with conn.cursor() as cursor:
281 # 先按 hash 找碰撞,再按每个 song_id 的 offset delta 聚类。
282 cursor.execute(
283 """
284 WITH query_fp(hash, query_offset) AS (
285 SELECT *
286 FROM unnest(%s::bytea[], %s::int[])
287 ),
288 aligned AS (
289 SELECT
290 db.song_id,
291 db."offset" - query_fp.query_offset AS offset_delta,
292 COUNT(*) AS aligned_count
293 FROM query_fp
294 JOIN dejavu_fingerprints db
295 ON db.hash = query_fp.hash
296 GROUP BY db.song_id, offset_delta
297 )
298 SELECT
299 song_id,
300 MAX(aligned_count) AS best_aligned_count,
301 SUM(aligned_count) AS total_collisions
302 FROM aligned
303 GROUP BY song_id
304 ORDER BY best_aligned_count DESC, total_collisions DESC
305 LIMIT 1
306 """,
307 (hashes, offsets),
308 )
309 row = cursor.fetchone()
310 if row is None:
311 return None
312 sid, aligned_count, total_collisions = row
313 aligned_count = int(aligned_count)
314 if aligned_count >= self.config.dejavu_match_threshold:
315 return _DejavuMatch(
316 song_id=int(sid),
317 aligned_count=aligned_count,
318 total_collisions=int(total_collisions),
319 )
320 return None
321
322
323 def _dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float:
324 """计算两个 12×T Chromagram 矩阵之间的 DTW 相似度(映射到 [0,1])。"""
325 # 帧间欧氏距离矩阵
326 cost = cdist(query.T, candidate.T, metric="euclidean")
327 n, m = cost.shape
328 dp = np.full((n, m), np.inf)
329 dp[0, 0] = cost[0, 0]
330 for i in range(1, n):
331 dp[i, 0] = dp[i - 1, 0] + cost[i, 0]
332 for j in range(1, m):
333 dp[0, j] = dp[0, j - 1] + cost[0, j]
334 for i in range(1, n):
335 for j in range(1, m):
336 dp[i, j] = cost[i, j] + min(dp[i - 1, j], dp[i, j - 1], dp[i - 1, j - 1])
337 dtw_dist = dp[n - 1, m - 1] / (n + m)
338 # 转换为相似度:距离越小相似度越高
339 return float(1.0 / (1.0 + dtw_dist))
340
341
342 def _best_shifted_dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float:
343 """计算 12 路音高循环位移下的最佳 DTW 相似度。"""
344 return max(
345 _dtw_similarity(np.roll(query, -shift, axis=0), candidate)
346 for shift in range(12)
347 )
1 """Cosine 相似度计算与去重判定。"""
2
3 from enum import Enum
4
5 import numpy as np
6
7 DUPLICATE_THRESHOLD = 0.95
8 SUSPECTED_THRESHOLD = 0.85
9
10
11 class SimilarityDecision(Enum):
12 DUPLICATE = "duplicate"
13 SUSPECTED = "suspected"
14 NEW = "new"
15
16
17 class CompositionSimilarity:
18 @staticmethod
19 def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
20 norm_a = np.linalg.norm(a)
21 norm_b = np.linalg.norm(b)
22 if norm_a == 0.0 or norm_b == 0.0:
23 return 0.0
24 return float(np.dot(a, b) / (norm_a * norm_b))
25
26 @staticmethod
27 def classify_similarity(similarity: float) -> SimilarityDecision:
28 if similarity >= DUPLICATE_THRESHOLD:
29 return SimilarityDecision.DUPLICATE
30 if similarity >= SUSPECTED_THRESHOLD:
31 return SimilarityDecision.SUSPECTED
32 return SimilarityDecision.NEW
33
34 @staticmethod
35 def compare(a: np.ndarray, b: np.ndarray) -> tuple[float, SimilarityDecision]:
36 sim = CompositionSimilarity.cosine_similarity(a, b)
37 return sim, CompositionSimilarity.classify_similarity(sim)
1 """曲去重评估脚本。
2
3 对 queries.csv 中每条查询音频调用 CompositionDedupService.query(),
4 按最终接口语义用 top1 分数阈值输出 predicted_duplicate true/false。
5 expected_song_id 的 top-k/top1 命中只作为诊断字段。
6 输出 precision/recall/F1。
7
8 用法:
9 python scripts/evaluate_composition.py \
10 --dsn "postgresql:///lyric_dedup" \
11 --queries composition_dedup/composition_testset4/queries.csv \
12 --out composition_dedup/composition_eval/composition_eval_result_v3.csv
13 """
14
15 import argparse
16 import csv
17 import json
18 import logging
19 import sys
20 from pathlib import Path
21
22 sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
23
24 from composition_dedup.service import CompositionConfig, CompositionDedupService
25
26 logger = logging.getLogger(__name__)
27
28
29 def _parse_csv_filter(value: str | None) -> set[str] | None:
30 if value is None:
31 return None
32 items = {item.strip() for item in value.split(",") if item.strip()}
33 return items or None
34
35
36 def _song_id_from_audio_path(audio_path: str) -> str:
37 """从音频文件名开头提取 song_id。"""
38 return Path(audio_path).stem.split("_", 1)[0]
39
40
41 def main() -> None:
42 logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
43
44 parser = argparse.ArgumentParser()
45 parser.add_argument("--dsn", required=True)
46 parser.add_argument("--queries", required=True, help="queries.csv 路径")
47 parser.add_argument("--out", required=True, help="逐条结果输出 CSV")
48 parser.add_argument("--top-k", type=int, default=10)
49 parser.add_argument("--duplicate-threshold", type=float, help="覆盖 COMPOSITION_DUPLICATE_THRESHOLD")
50 parser.add_argument("--variants", help="只评测指定 variant,逗号分隔,如 pitch_up1,pitch_down1")
51 parser.add_argument("--sample-classes", help="只评测指定 sample_class,逗号分隔,如 dsp,negative")
52 parser.add_argument("--expected", choices=["duplicate", "not_duplicate"], help="只评测指定 expected 类型")
53 args = parser.parse_args()
54
55 config = CompositionConfig(dsn=args.dsn)
56 if args.duplicate_threshold is not None:
57 config.duplicate_threshold = args.duplicate_threshold
58 service = CompositionDedupService(config=config)
59
60 with open(args.queries, newline="", encoding="utf-8") as f:
61 rows = list(csv.DictReader(f))
62
63 variant_filter = _parse_csv_filter(args.variants)
64 sample_class_filter = _parse_csv_filter(args.sample_classes)
65 original_count = len(rows)
66 if variant_filter is not None:
67 rows = [r for r in rows if (r.get("variant") or "") in variant_filter]
68 if sample_class_filter is not None:
69 rows = [r for r in rows if (r.get("sample_class") or "") in sample_class_filter]
70 if args.expected is not None:
71 rows = [r for r in rows if r["expected"].strip().lower() == args.expected]
72
73 logger.info(
74 "评测样本过滤: 原始 %d 条,保留 %d 条 (variants=%s, sample_classes=%s, expected=%s)",
75 original_count,
76 len(rows),
77 ",".join(sorted(variant_filter)) if variant_filter else "ALL",
78 ",".join(sorted(sample_class_filter)) if sample_class_filter else "ALL",
79 args.expected or "ALL",
80 )
81
82 out_path = Path(args.out)
83 out_path.parent.mkdir(parents=True, exist_ok=True)
84
85 result_rows = []
86 for i, row in enumerate(rows, 1):
87 audio_path = row["audio_path"]
88 query_song_id = row.get("song_id") or _song_id_from_audio_path(audio_path)
89 audio_song_id = _song_id_from_audio_path(audio_path)
90 expected_song_id = str(row["expected_song_id"])
91 expected_dup = row["expected"].strip().lower() == "duplicate"
92 invalid_negative_pair = (not expected_dup) and audio_song_id == expected_song_id
93
94 try:
95 candidates = service.query(audio_path, top_k=args.top_k)
96 except Exception as e:
97 logger.error("[%d/%d] 查询失败: %s, %s", i, len(rows), audio_path, e)
98 result_rows.append({
99 "query_song_id": query_song_id,
100 "audio_song_id": audio_song_id,
101 "audio_path": audio_path,
102 "variant": row.get("variant", ""),
103 "sample_class": row.get("sample_class", ""),
104 "expected_song_id": expected_song_id,
105 "expected": row["expected"],
106 "top1_song_id": "",
107 "top1_similarity": "",
108 "top1_source": "",
109 "top1_hit": False,
110 "topk_hit": False,
111 "expected_rank": "",
112 "expected_similarity": "",
113 "invalid_negative_pair": invalid_negative_pair,
114 "invalid_boolean_sample": False,
115 "expected_duplicate": expected_dup,
116 "predicted_duplicate": False,
117 "correct": not expected_dup, # 查询失败视为 not_duplicate
118 "error": str(e),
119 })
120 continue
121
122 top1 = candidates[0] if candidates else None
123 top1_song_id = str(top1.song_id) if top1 else ""
124 top1_sim = round(top1.similarity, 4) if top1 else ""
125 top1_source = top1.source if top1 else ""
126
127 # 诊断召回:expected_song_id 是否进入 top1/top-k。
128 top1_hit = bool(expected_song_id) and top1_song_id == expected_song_id
129 topk_hit = bool(expected_song_id) and any(str(c.song_id) == expected_song_id for c in candidates)
130 expected_rank = ""
131 expected_similarity = ""
132 if expected_song_id:
133 for rank, candidate in enumerate(candidates, 1):
134 if str(candidate.song_id) == expected_song_id:
135 expected_rank = rank
136 expected_similarity = round(candidate.similarity, 4)
137 break
138
139 # 最终接口语义:只返回 duplicate true/false。
140 predicted_dup = service.candidates_indicate_duplicate(candidates)
141 correct = expected_dup == predicted_dup
142 invalid_boolean_sample = (
143 (not expected_dup)
144 and bool(top1)
145 and top1_song_id == audio_song_id
146 and predicted_dup
147 )
148
149 result_rows.append({
150 "query_song_id": query_song_id,
151 "audio_song_id": audio_song_id,
152 "audio_path": audio_path,
153 "variant": row.get("variant", ""),
154 "sample_class": row.get("sample_class", ""),
155 "expected_song_id": expected_song_id,
156 "expected": row["expected"],
157 "top1_song_id": top1_song_id,
158 "top1_similarity": top1_sim,
159 "top1_source": top1_source,
160 "top1_hit": top1_hit,
161 "topk_hit": topk_hit,
162 "expected_rank": expected_rank,
163 "expected_similarity": expected_similarity,
164 "invalid_negative_pair": invalid_negative_pair,
165 "invalid_boolean_sample": invalid_boolean_sample,
166 "expected_duplicate": expected_dup,
167 "predicted_duplicate": predicted_dup,
168 "correct": correct,
169 "error": "",
170 })
171
172 logger.info(
173 "[%d/%d] variant=%s source=%s expected=%s predicted_duplicate=%s threshold=%.4f expected_song_id=%s top1=%s sim=%s top1_hit=%s topk_hit=%s expected_rank=%s expected_sim=%s correct=%s",
174 i,
175 len(rows),
176 row.get("variant", ""),
177 top1_source or "-",
178 row["expected"],
179 predicted_dup,
180 service.config.duplicate_threshold,
181 expected_song_id,
182 top1_song_id or "-",
183 top1_sim if top1_sim != "" else "-",
184 top1_hit,
185 topk_hit,
186 expected_rank if expected_rank != "" else "-",
187 expected_similarity if expected_similarity != "" else "-",
188 correct,
189 )
190
191 if i % 10 == 0 or i == len(rows):
192 logger.info("[%d/%d]", i, len(rows))
193
194 # 写逐条结果
195 fieldnames = ["query_song_id", "audio_song_id", "audio_path", "variant", "sample_class",
196 "expected_song_id", "expected", "top1_song_id", "top1_similarity", "top1_source",
197 "top1_hit", "topk_hit", "expected_rank", "expected_similarity",
198 "invalid_negative_pair", "invalid_boolean_sample",
199 "expected_duplicate", "predicted_duplicate", "correct", "error"]
200 with out_path.open("w", newline="", encoding="utf-8") as f:
201 writer = csv.DictWriter(f, fieldnames=fieldnames)
202 writer.writeheader()
203 writer.writerows(result_rows)
204
205 # 汇总指标
206 def _metrics(rows: list[dict]) -> dict:
207 tp = sum(1 for r in rows if r["expected_duplicate"] and r["predicted_duplicate"])
208 fp = sum(1 for r in rows if not r["expected_duplicate"] and r["predicted_duplicate"])
209 tn = sum(1 for r in rows if not r["expected_duplicate"] and not r["predicted_duplicate"])
210 fn = sum(1 for r in rows if r["expected_duplicate"] and not r["predicted_duplicate"])
211 precision = tp / (tp + fp) if tp + fp else 0.0
212 recall = tp / (tp + fn) if tp + fn else 0.0
213 f1 = 2 * precision * recall / (precision + recall) if precision + recall else 0.0
214 accuracy = (tp + tn) / len(rows) if rows else 0.0
215 return {
216 "total": len(rows),
217 "accuracy": round(accuracy, 4),
218 "precision": round(precision, 4),
219 "recall": round(recall, 4),
220 "f1": round(f1, 4),
221 "tp": tp,
222 "fp": fp,
223 "tn": tn,
224 "fn": fn,
225 }
226
227 metrics = _metrics(result_rows)
228 valid_rows = [
229 r for r in result_rows
230 if not r["invalid_negative_pair"] and not r["invalid_boolean_sample"]
231 ]
232 valid_metrics = _metrics(valid_rows)
233
234 summary = {
235 "total": len(result_rows),
236 "filters": {
237 "variants": sorted(variant_filter) if variant_filter else None,
238 "sample_classes": sorted(sample_class_filter) if sample_class_filter else None,
239 "expected": args.expected,
240 "original_total": original_count,
241 },
242 "duplicate_threshold": service.config.duplicate_threshold,
243 "invalid_negative_pairs": sum(1 for r in result_rows if r["invalid_negative_pair"]),
244 "invalid_boolean_samples": sum(1 for r in result_rows if r["invalid_boolean_sample"]),
245 "accuracy": metrics["accuracy"],
246 "precision": metrics["precision"],
247 "recall": metrics["recall"],
248 "f1": metrics["f1"],
249 "tp": metrics["tp"], "fp": metrics["fp"], "tn": metrics["tn"], "fn": metrics["fn"],
250 "valid_only": valid_metrics,
251 "out": str(out_path),
252 }
253
254 # 按 variant 分组,方便看各种变换的通过率
255 from collections import defaultdict
256 by_variant: dict[str, dict] = defaultdict(lambda: {"correct": 0, "total": 0})
257 for r in result_rows:
258 v = r["variant"] or "unknown"
259 by_variant[v]["total"] += 1
260 if r["correct"]:
261 by_variant[v]["correct"] += 1
262 summary["by_variant"] = {
263 v: {"accuracy": round(d["correct"] / d["total"], 4), "total": d["total"]}
264 for v, d in sorted(by_variant.items())
265 }
266
267 # 按 sample_class 分组
268 by_class: dict[str, dict] = defaultdict(lambda: {"correct": 0, "total": 0})
269 for r in result_rows:
270 sc = r.get("sample_class") or "unknown"
271 by_class[sc]["total"] += 1
272 if r["correct"]:
273 by_class[sc]["correct"] += 1
274 summary["by_sample_class"] = {
275 sc: {"accuracy": round(d["correct"] / d["total"], 4), "total": d["total"]}
276 for sc, d in sorted(by_class.items())
277 }
278
279 summary_path = out_path.with_suffix(".summary.json")
280 summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
281 print(json.dumps(summary, ensure_ascii=False, indent=2))
282
283
284 if __name__ == "__main__":
285 main()
1 """生成曲去重评估测试集。
2
3 从音频目录随机抽取若干首参照歌入库,对每首用 ffmpeg 生成多个变换版本,
4 覆盖曲去重测试样本类型.md 中第一类(数字信号变换)和第三类(困难正样本)的可合成部分。
5 负样本从未入库的 holdout 歌曲生成,以匹配最终接口 duplicate true/false 语义。
6
7 用法:
8 python scripts/generate_composition_testset.py \
9 --audio-dir /Volumes/移动硬盘/lyric_audio_type11 \
10 --negative-audio-dir /Volumes/移动硬盘/composition_test \
11 --out-dir composition_dedup/composition_testset \
12 --num-songs 80 \
13 --num-negative-songs 40 \
14 --negative-variants \
15 --seed 123
16
17 输出:
18 reference.csv — 参照曲(原始文件),需提前入库
19 queries.csv — 查询曲,带 variant 和 expected 标注
20 """
21
22 import argparse
23 import csv
24 import logging
25 import random
26 import subprocess
27 import sys
28 from pathlib import Path
29
30 try:
31 from tqdm import tqdm
32 except ImportError:
33 tqdm = None
34
35
36 def _tqdm(iterable, **kwargs):
37 if tqdm is not None:
38 return tqdm(iterable, **kwargs)
39 total = kwargs.get("total", None) or (len(iterable) if hasattr(iterable, "__len__") else None)
40 desc = kwargs.get("desc", "")
41 class _Simple:
42 def __init__(self):
43 self._i = 0
44 def __iter__(self):
45 for item in iterable:
46 self._i += 1
47 if total:
48 print(f"\r{desc}: {self._i}/{total}", end="", flush=True)
49 yield item
50 if total:
51 print()
52 return _Simple()
53
54 logger = logging.getLogger(__name__)
55
56 # --------------------------------------------------------------------------
57 # 第一类:数字信号变换
58 # --------------------------------------------------------------------------
59 DSP_VARIANTS: list[tuple[str, str]] = [
60 # Pitch Shift(±1、±2 半音)
61 ("pitch_up1", "asetrate=22050*1.0595,aresample=22050"), # +1 半音
62 ("pitch_up2", "asetrate=22050*1.1225,aresample=22050"), # +2 半音
63 ("pitch_down1", "asetrate=22050*0.9439,aresample=22050"), # -1 半音
64 ("pitch_down2", "asetrate=22050*0.8909,aresample=22050"), # -2 半音
65 # Tempo Shift
66 ("tempo_slow", "atempo=0.90"), # 0.9x
67 ("tempo_fast", "atempo=1.10"), # 1.1x
68 ("tempo_faster","atempo=1.20"), # 1.2x
69 # EQ 变换
70 ("lowpass", "lowpass=f=4000"), # 低通
71 ("highpass", "highpass=f=800"), # 高通
72 ("eq_mid", "equalizer=f=1000:width_type=o:width=2:g=-6"), # 中频衰减
73 # 压缩编码往返(编码为 mp3 再解回 wav,模拟有损压缩引入的失真)
74 ("codec_320k", "acodec=libmp3lame,b:a=320k"),
75 ("codec_128k", "acodec=libmp3lame,b:a=128k"),
76 ]
77
78 # --------------------------------------------------------------------------
79 # 第三类:困难正样本(可合成部分)
80 # --------------------------------------------------------------------------
81 HARD_POSITIVE_VARIANTS: list[tuple[str, str]] = [
82 # 前奏删减:从 20% 处开始截取(模拟删前奏版本)
83 ("trim_intro", None), # 特殊处理,用 -ss 参数
84 # 只保留副歌:截取中间 40%(模拟短视频截段)
85 ("chorus_only", None), # 特殊处理,用 -ss + -t 参数
86 # Pitch + Tempo 叠加(模拟 Live 版同时有调整)
87 ("pitch_up1_tempo_slow", "asetrate=22050*1.0595,aresample=22050,atempo=0.92"),
88 ]
89
90 # 负样本变体只使用相对温和的处理,避免把负样本评估变成极端音质测试。
91 NEGATIVE_VARIANTS: list[tuple[str, str | None]] = [
92 ("negative_lowpass", "lowpass=f=4000"),
93 ("negative_codec_128k", "acodec=libmp3lame,b:a=128k"),
94 ]
95
96
97 def _ffmpeg_variant(src: Path, dst: Path, af: str) -> bool:
98 """普通 audio filter 变换。"""
99 # 压缩编码往返需要两步:先编码为 mp3,再解回 wav
100 if "acodec" in af:
101 import tempfile
102 with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp:
103 tmp_mp3 = Path(tmp.name)
104 ok1 = _run_ffmpeg([
105 "ffmpeg", "-y", "-i", str(src),
106 "-ar", "22050", "-ac", "1",
107 "-codec:a", "libmp3lame", "-b:a", af.split("b:a=")[1],
108 str(tmp_mp3),
109 ])
110 if not ok1:
111 return False
112 ok2 = _run_ffmpeg([
113 "ffmpeg", "-y", "-i", str(tmp_mp3),
114 "-ar", "22050", "-ac", "1",
115 str(dst),
116 ])
117 tmp_mp3.unlink(missing_ok=True)
118 return ok2
119
120 cmd = [
121 "ffmpeg", "-y", "-i", str(src),
122 "-af", af,
123 "-ar", "22050", "-ac", "1",
124 str(dst),
125 ]
126 return _run_ffmpeg(cmd)
127
128
129 def _ffmpeg_trim(src: Path, dst: Path, start_ratio: float, duration_ratio: float) -> bool:
130 """按相对位置截取片段。需要先探测时长。"""
131 duration = _probe_duration(src)
132 if duration is None:
133 return False
134 ss = duration * start_ratio
135 t = duration * duration_ratio
136 return _run_ffmpeg([
137 "ffmpeg", "-y", "-i", str(src),
138 "-ss", f"{ss:.3f}", "-t", f"{t:.3f}",
139 "-ar", "22050", "-ac", "1",
140 str(dst),
141 ])
142
143
144 def _probe_duration(src: Path) -> float | None:
145 result = subprocess.run(
146 ["ffprobe", "-v", "error", "-show_entries", "format=duration",
147 "-of", "default=noprint_wrappers=1:nokey=1", str(src)],
148 capture_output=True, text=True,
149 )
150 try:
151 return float(result.stdout.strip())
152 except ValueError:
153 return None
154
155
156 def _run_ffmpeg(cmd: list[str]) -> bool:
157 result = subprocess.run(cmd, capture_output=True)
158 return result.returncode == 0
159
160
161 def _song_id(path: Path) -> str:
162 return path.stem.split("_")[0]
163
164
165 def _discover_wavs(audio_dir: Path) -> list[Path]:
166 return [
167 f for f in sorted(audio_dir.rglob("*.wav"))
168 if f.is_file() and not f.name.startswith("._")
169 ]
170
171
172 def main() -> None:
173 logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
174
175 parser = argparse.ArgumentParser()
176 parser.add_argument("--audio-dir", required=True)
177 parser.add_argument(
178 "--negative-audio-dir",
179 default="/Volumes/移动硬盘/lyric_audio_type11",
180 help="负样本来源目录;会排除 --audio-dir 中已存在的 song_id",
181 )
182 parser.add_argument("--out-dir", required=True)
183 parser.add_argument("--num-songs", type=int, default=20, help="抽取歌曲数量")
184 parser.add_argument("--num-negative-songs", type=int, default=20, help="抽取未入库负样本歌曲数量")
185 parser.add_argument(
186 "--negative-variants",
187 action="store_true",
188 help="为负样本额外生成 codec/lowpass 变体",
189 )
190 parser.add_argument("--seed", type=int, default=42)
191 args = parser.parse_args()
192
193 audio_dir = Path(args.audio_dir)
194 negative_audio_dir = Path(args.negative_audio_dir)
195 out_dir = Path(args.out_dir)
196 out_dir.mkdir(parents=True, exist_ok=True)
197 variants_dir = out_dir / "variants"
198 variants_dir.mkdir(exist_ok=True)
199
200 all_wavs = _discover_wavs(audio_dir)
201 negative_wavs = _discover_wavs(negative_audio_dir)
202
203 if len(negative_wavs) < args.num_negative_songs:
204 logger.error(
205 "负样本目录下只有 %d 个 wav,少于 --num-negative-songs = %d",
206 len(negative_wavs),
207 args.num_negative_songs,
208 )
209 sys.exit(1)
210
211 # 从参照目录中排除负样本目录已有的 song_id,避免参照曲与负样本重叠
212 negative_song_ids = {_song_id(wav) for wav in negative_wavs}
213 all_wavs = [wav for wav in all_wavs if _song_id(wav) not in negative_song_ids]
214
215 if len(all_wavs) < args.num_songs:
216 logger.error(
217 "参照目录排除负样本 song_id 后只有 %d 个 wav,少于 --num-songs = %d",
218 len(all_wavs),
219 args.num_songs,
220 )
221 sys.exit(1)
222
223 random.seed(args.seed)
224 selected = random.sample(all_wavs, args.num_songs)
225 negative_selected = random.sample(negative_wavs, args.num_negative_songs)
226 logger.info(
227 "已抽取 %d 首参照歌,%d 首未入库负样本歌(负样本来源: %s,已排除 %d 个负样本 song_id)",
228 len(selected),
229 len(negative_selected),
230 negative_audio_dir,
231 len(negative_song_ids),
232 )
233
234 ref_rows = []
235 query_rows = []
236
237 for wav in _tqdm(selected, desc="生成正样本变体", total=len(selected)):
238 song_id = _song_id(wav)
239
240 ref_rows.append({
241 "song_id": song_id,
242 "audio_path": str(wav),
243 "variant": "original",
244 })
245
246 # 第一类:DSP 变换
247 for variant_name, af in DSP_VARIANTS:
248 dst = variants_dir / f"{song_id}_{variant_name}.wav"
249 ok = _ffmpeg_variant(wav, dst, af)
250 if not ok:
251 logger.warning("DSP 变换失败,跳过: %s %s", wav.name, variant_name)
252 continue
253 query_rows.append({
254 "song_id": song_id,
255 "audio_path": str(dst),
256 "variant": variant_name,
257 "sample_class": "dsp",
258 "expected_song_id": song_id,
259 "expected": "duplicate",
260 })
261
262 # 第三类:困难正样本
263 for variant_name, af in HARD_POSITIVE_VARIANTS:
264 dst = variants_dir / f"{song_id}_{variant_name}.wav"
265 if variant_name == "trim_intro":
266 ok = _ffmpeg_trim(wav, dst, start_ratio=0.20, duration_ratio=0.80)
267 elif variant_name == "chorus_only":
268 ok = _ffmpeg_trim(wav, dst, start_ratio=0.30, duration_ratio=0.40)
269 else:
270 ok = _ffmpeg_variant(wav, dst, af)
271 if not ok:
272 logger.warning("困难正样本生成失败,跳过: %s %s", wav.name, variant_name)
273 continue
274 query_rows.append({
275 "song_id": song_id,
276 "audio_path": str(dst),
277 "variant": variant_name,
278 "sample_class": "hard_positive",
279 "expected_song_id": song_id,
280 "expected": "duplicate",
281 })
282
283 # Boolean 接口负样本:查询音频不能在 reference.csv 入库集合中。
284 # expected_song_id 留空,表示没有目标重复曲;评测只看最终 duplicate true/false。
285 for wav in _tqdm(negative_selected, desc="生成负样本变体", total=len(negative_selected)):
286 song_id = _song_id(wav)
287 query_rows.append({
288 "song_id": song_id,
289 "audio_path": str(wav),
290 "variant": "negative_original",
291 "sample_class": "negative",
292 "expected_song_id": "",
293 "expected": "not_duplicate",
294 })
295
296 if not args.negative_variants:
297 continue
298
299 for variant_name, af in NEGATIVE_VARIANTS:
300 dst = variants_dir / f"{song_id}_{variant_name}.wav"
301 ok = _ffmpeg_variant(wav, dst, af)
302 if not ok:
303 logger.warning("负样本变换失败,跳过: %s %s", wav.name, variant_name)
304 continue
305 query_rows.append({
306 "song_id": song_id,
307 "audio_path": str(dst),
308 "variant": variant_name,
309 "sample_class": "negative",
310 "expected_song_id": "",
311 "expected": "not_duplicate",
312 })
313
314 ref_path = out_dir / "reference.csv"
315 query_path = out_dir / "queries.csv"
316
317 fieldnames = ["song_id", "audio_path", "variant", "sample_class", "expected_song_id", "expected"]
318 with ref_path.open("w", newline="", encoding="utf-8") as f:
319 writer = csv.DictWriter(f, fieldnames=["song_id", "audio_path", "variant"])
320 writer.writeheader()
321 writer.writerows(ref_rows)
322
323 with query_path.open("w", newline="", encoding="utf-8") as f:
324 writer = csv.DictWriter(f, fieldnames=fieldnames)
325 writer.writeheader()
326 writer.writerows(query_rows)
327
328 pos = sum(1 for r in query_rows if r["expected"] == "duplicate")
329 neg = sum(1 for r in query_rows if r["expected"] == "not_duplicate")
330 logger.info("参照集: %s (%d 条)", ref_path, len(ref_rows))
331 logger.info("查询集: %s (%d 条,正样本 %d,负样本 %d)", query_path, len(query_rows), pos, neg)
332
333 # 按 sample_class 统计
334 from collections import Counter
335 by_class = Counter(r["sample_class"] for r in query_rows)
336 for cls, cnt in sorted(by_class.items()):
337 logger.info(" %-20s %d 条", cls, cnt)
338
339
340 if __name__ == "__main__":
341 main()
1 """批量导入音频文件到 composition_feature 表。
2
3 用法:
4 python scripts/import_audio_composition.py \
5 --dsn "postgresql:///lyric_dedup" \
6 --audio-dir /Volumes/移动硬盘/composition_test \
7 --ext .wav
8
9 支持通过 --file-list 指定一个包含音频路径的文本文件(每行一个路径)。
10 """
11
12 import argparse
13 import logging
14 import sys
15 from pathlib import Path
16
17 sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
18
19 from tqdm import tqdm
20
21 from composition_dedup.service import CompositionConfig, CompositionDedupService
22
23 logger = logging.getLogger(__name__)
24
25 SUPPORTED_EXTENSIONS = {".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac", ".wma"}
26
27
28 def discover_audio_files(audio_dir: str | None, file_list: str | None, ext: str) -> list[tuple[str, str]]:
29 """发现音频文件,返回 [(song_id, audio_path), ...] 列表。
30
31 优先使用 --file-list,否则扫描 --audio-dir 目录。
32 song_id 使用文件名的数字部分或路径的哈希值。
33 """
34 results = []
35
36 if file_list:
37 with open(file_list, "r", encoding="utf-8") as f:
38 for line in f:
39 path = line.strip()
40 if not path:
41 continue
42 song_id = _extract_song_id(path)
43 results.append((song_id, path))
44 elif audio_dir:
45 audio_dir_path = Path(audio_dir)
46 for audio_file in sorted(audio_dir_path.rglob(f"*{ext}")):
47 if audio_file.is_file() and not audio_file.name.startswith("._"):
48 song_id = _extract_song_id(str(audio_file))
49 results.append((song_id, str(audio_file)))
50 else:
51 print("错误: 请指定 --audio-dir 或 --file-list")
52 sys.exit(1)
53
54 return results
55
56
57 def _extract_song_id(path: str) -> str:
58 """从路径中提取 song_id。
59 优先取文件名第一段(下划线前),若为纯数字则使用,否则用路径哈希。
60 """
61 name = Path(path).stem
62 prefix = name.split("_")[0]
63 if prefix.isdigit():
64 return prefix
65 import hashlib
66 return str(int(hashlib.md5(path.encode()).hexdigest()[:8], 16))
67
68
69 def main() -> None:
70 logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
71
72 parser = argparse.ArgumentParser(description="批量导入音频文件到 composition_feature 表")
73 parser.add_argument("--dsn", required=True, help="PostgreSQL DSN 连接串")
74 parser.add_argument("--audio-dir", help="音频文件目录")
75 parser.add_argument("--file-list", help="音频文件路径列表文件")
76 parser.add_argument("--ext", default=".wav", help="音频文件扩展名(默认 .wav)")
77 parser.add_argument("--batch-size", type=int, default=10, help="批次大小(默认 10)")
78 parser.add_argument("--clear", action="store_true", help="导入前清空 composition_feature 和 dejavu_fingerprints 表数据(保留表结构)")
79 args = parser.parse_args()
80
81 config = CompositionConfig(dsn=args.dsn)
82 service = CompositionDedupService(config=config)
83
84 if args.clear:
85 import psycopg
86 with psycopg.connect(args.dsn) as conn:
87 with conn.cursor() as cur:
88 cur.execute("TRUNCATE TABLE composition_feature, dejavu_fingerprints")
89 conn.commit()
90 logger.info("已清空 composition_feature 和 dejavu_fingerprints 表")
91
92 audio_files = discover_audio_files(args.audio_dir, args.file_list, args.ext)
93 logger.info("发现 %d 个音频文件", len(audio_files))
94
95 success_count = 0
96 fail_count = 0
97
98 for start in tqdm(range(0, len(audio_files), args.batch_size), desc="导入进度"):
99 batch = audio_files[start:start + args.batch_size]
100 for song_id, audio_path in batch:
101 try:
102 service.ingest(song_id=int(song_id), audio_path=audio_path)
103 success_count += 1
104 except Exception as e:
105 logger.error("导入失败: song_id=%s, path=%s, error=%s", song_id, audio_path, e)
106 fail_count += 1
107
108 logger.info("导入完成: 成功 %d, 失败 %d", success_count, fail_count)
109
110
111 if __name__ == "__main__":
112 main()
...@@ -40,3 +40,33 @@ on lyric_lines (line_hash); ...@@ -40,3 +40,33 @@ on lyric_lines (line_hash);
40 40
41 create index if not exists lyric_lines_lyric_id_idx 41 create index if not exists lyric_lines_lyric_id_idx
42 on lyric_lines (lyric_id); 42 on lyric_lines (lyric_id);
43
44 create extension if not exists vector;
45
46 create table if not exists composition_feature (
47 id bigserial primary key,
48 song_id bigint not null unique,
49 feature_vector vector(1536) not null,
50 created_at timestamptz not null default now()
51 );
52
53 create index if not exists composition_feature_hnsw_idx
54 on composition_feature
55 using hnsw (feature_vector vector_cosine_ops)
56 with (m = 16, ef_construction = 64);
57
58 create table if not exists dejavu_fingerprints (
59 id bigserial primary key,
60 song_id bigint not null references composition_feature(song_id) on delete cascade,
61 hash bytea not null,
62 "offset" int not null
63 );
64
65 create index if not exists idx_fingerprints_hash
66 on dejavu_fingerprints (hash);
67
68 create index if not exists idx_fingerprints_hash_song_offset
69 on dejavu_fingerprints (hash, song_id, "offset");
70
71 create index if not exists idx_fingerprints_song_id
72 on dejavu_fingerprints (song_id);
......
1 """作曲去重模块测试。
2
3 测试覆盖:
4 - Chromagram 提取
5 - 时间归一化输出维度
6 - Cosine 相似度计算
7 - 向量展开维度为 1536
8 """
9
10 import os
11 import tempfile
12 import wave
13
14 import numpy as np
15 import pytest
16 from scipy.signal import resample
17
18 from composition_dedup.extractor import extract_chroma_feature, _normalize_audio_ffmpeg
19 from composition_dedup.similarity import (
20 CompositionSimilarity,
21 SimilarityDecision,
22 DUPLICATE_THRESHOLD,
23 SUSPECTED_THRESHOLD,
24 )
25
26
27 def _generate_test_wav(duration_sec: float = 1.0, sample_rate: int = 22050, frequency: float = 440.0) -> str:
28 """生成测试用的 WAV 文件(正弦波)。
29
30 Args:
31 duration_sec: 持续时间(秒)。
32 sample_rate: 采样率。
33 frequency: 频率(Hz)。
34
35 Returns:
36 临时 WAV 文件路径。
37 """
38 t = np.linspace(0, duration_sec, int(sample_rate * duration_sec), endpoint=False)
39 audio_data = (0.5 * np.sin(2 * np.pi * frequency * t)).astype(np.float32)
40
41 tmp_path = tempfile.mktemp(suffix=".wav")
42 with wave.open(tmp_path, "wb") as wf:
43 wf.setnchannels(1)
44 wf.setsampwidth(2) # 16-bit
45 wf.setframerate(sample_rate)
46 wf.writeframes((audio_data * 32767).astype(np.int16).tobytes())
47
48 return tmp_path
49
50
51 class TestChromaExtraction:
52 """Chromagram 提取测试。"""
53
54 def test_extract_chroma_returns_1536_dim(self):
55 """测试 Chromagram 提取返回 1536 维向量。"""
56 wav_path = _generate_test_wav(duration_sec=2.0, frequency=440.0)
57 try:
58 feature = extract_chroma_feature(wav_path)
59 assert isinstance(feature, np.ndarray)
60 assert feature.shape == (1536,), f"期望 (1536,), 实际 {feature.shape}"
61 assert feature.dtype == np.float32
62 finally:
63 if os.path.exists(wav_path):
64 os.remove(wav_path)
65
66 def test_extract_chroma_file_not_found(self):
67 """测试不存在的音频文件抛出 FileNotFoundError。"""
68 with pytest.raises(FileNotFoundError):
69 extract_chroma_feature("/nonexistent/path/audio.mp3")
70
71 def test_extract_chroma_different_frequencies(self):
72 """测试不同频率的音频产生不同特征。"""
73 wav_a = _generate_test_wav(duration_sec=2.0, frequency=440.0)
74 wav_b = _generate_test_wav(duration_sec=2.0, frequency=880.0)
75 try:
76 feature_a = extract_chroma_feature(wav_a)
77 feature_b = extract_chroma_feature(wav_b)
78 # 不同频率的音频特征不应完全相同
79 assert not np.allclose(feature_a, feature_b)
80 finally:
81 for path in [wav_a, wav_b]:
82 if os.path.exists(path):
83 os.remove(path)
84
85 def test_extract_chroma_same_audio_consistent(self):
86 """测试同一音频多次提取结果一致。"""
87 wav_path = _generate_test_wav(duration_sec=1.0, frequency=440.0)
88 try:
89 feature_1 = extract_chroma_feature(wav_path)
90 feature_2 = extract_chroma_feature(wav_path)
91 np.testing.assert_array_almost_equal(feature_1, feature_2, decimal=5)
92 finally:
93 if os.path.exists(wav_path):
94 os.remove(wav_path)
95
96
97 class TestTimeNormalization:
98 """时间归一化测试。"""
99
100 def test_resample_chroma_to_128_frames(self):
101 """测试 Chromagram 时间归一化到 128 帧。"""
102 # 模拟不同长度的 Chromagram
103 for num_frames in [100, 256, 512, 1000, 2000]:
104 chroma = np.random.rand(12, num_frames).astype(np.float32)
105 if chroma.shape[1] != 128:
106 chroma = resample(chroma, 128, axis=1)
107 assert chroma.shape == (12, 128), f"帧数归一化失败: {chroma.shape}"
108
109 def test_flatten_to_1536(self):
110 """测试展平后维度为 1536。"""
111 chroma = np.random.rand(12, 128).astype(np.float32)
112 feature = chroma.flatten()
113 assert feature.shape[0] == 12 * 128 == 1536
114
115
116 class TestCosineSimilarity:
117 """Cosine 相似度计算测试。"""
118
119 def test_identical_vectors(self):
120 """测试相同向量相似度为 1。"""
121 vec = np.random.rand(1536).astype(np.float32)
122 sim = CompositionSimilarity.cosine_similarity(vec, vec)
123 assert abs(sim - 1.0) < 1e-6
124
125 def test_orthogonal_vectors(self):
126 """测试正交向量相似度接近 0。"""
127 vec_a = np.zeros(1536)
128 vec_a[0] = 1.0
129 vec_b = np.zeros(1536)
130 vec_b[1] = 1.0
131 sim = CompositionSimilarity.cosine_similarity(vec_a, vec_b)
132 assert abs(sim) < 1e-6
133
134 def test_zero_vector(self):
135 """测试零向量返回 0 相似度。"""
136 vec_a = np.random.rand(1536).astype(np.float32)
137 vec_b = np.zeros(1536)
138 sim = CompositionSimilarity.cosine_similarity(vec_a, vec_b)
139 assert sim == 0.0
140
141 def test_similarity_range(self):
142 """测试相似度值在 [0, 1] 范围内。"""
143 vec_a = np.random.rand(1536).astype(np.float32)
144 vec_b = np.random.rand(1536).astype(np.float32)
145 sim = CompositionSimilarity.cosine_similarity(vec_a, vec_b)
146 assert 0.0 <= sim <= 1.0
147
148 def test_classify_duplicate(self):
149 """测试重复判定。"""
150 assert CompositionSimilarity.classify_similarity(0.96) == SimilarityDecision.DUPLICATE
151 assert CompositionSimilarity.classify_similarity(0.95) == SimilarityDecision.DUPLICATE
152
153 def test_classify_suspected(self):
154 """测试疑似判定。"""
155 assert CompositionSimilarity.classify_similarity(0.94) == SimilarityDecision.SUSPECTED
156 assert CompositionSimilarity.classify_similarity(0.85) == SimilarityDecision.SUSPECTED
157
158 def test_classify_new(self):
159 """测试非重复判定。"""
160 assert CompositionSimilarity.classify_similarity(0.84) == SimilarityDecision.NEW
161 assert CompositionSimilarity.classify_similarity(0.5) == SimilarityDecision.NEW
162
163 def test_compare_method(self):
164 """测试 compare 方法同时返回相似度和判定。"""
165 vec = np.random.rand(1536).astype(np.float32)
166 sim, decision = CompositionSimilarity.compare(vec, vec)
167 assert abs(sim - 1.0) < 1e-6
168 assert decision == SimilarityDecision.DUPLICATE
169
170
171 class TestThresholds:
172 """阈值常量测试。"""
173
174 def test_threshold_order(self):
175 """测试阈值顺序正确。"""
176 assert DUPLICATE_THRESHOLD > SUSPECTED_THRESHOLD
177
178 def test_threshold_values(self):
179 """测试阈值符合设计值。"""
180 assert DUPLICATE_THRESHOLD == 0.95
181 assert SUSPECTED_THRESHOLD == 0.85