添加曲结构去重
Showing
10 changed files
with
1668 additions
and
0 deletions
composition_dedup/__init__.py
0 → 100644
composition_dedup/dejavu_fingerprinter.py
0 → 100644
| 1 | """Dejavu 风格的音频指纹生成。 | ||
| 2 | |||
| 3 | 基于 worldveil/dejavu 的指纹算法提取实现,不依赖 Dejavu 的数据库层。 | ||
| 4 | 使用 scipy.signal.spectrogram 替代已废弃的 matplotlib.mlab.specgram。 | ||
| 5 | |||
| 6 | 流程: | ||
| 7 | 1. 音频标准化:ffmpeg 转 44100Hz / Mono / WAV | ||
| 8 | 2. librosa 加载音频 | ||
| 9 | 3. 短时傅里叶变换(STFT)→ 对数频谱图 | ||
| 10 | 4. 2D 峰值检测:在频谱图中找局部极大值 | ||
| 11 | 5. 指纹哈希:对峰值对 (freq1, freq2, time_delta) 做 SHA1,取前 20 位 | ||
| 12 | """ | ||
| 13 | |||
| 14 | import hashlib | ||
| 15 | import logging | ||
| 16 | import os | ||
| 17 | import subprocess | ||
| 18 | import tempfile | ||
| 19 | from operator import itemgetter | ||
| 20 | from pathlib import Path | ||
| 21 | |||
| 22 | import librosa | ||
| 23 | import numpy as np | ||
| 24 | from scipy.ndimage import ( | ||
| 25 | binary_erosion, | ||
| 26 | generate_binary_structure, | ||
| 27 | iterate_structure, | ||
| 28 | maximum_filter, | ||
| 29 | ) | ||
| 30 | from scipy.signal import spectrogram | ||
| 31 | |||
| 32 | logger = logging.getLogger(__name__) | ||
| 33 | |||
| 34 | |||
| 35 | def _load_env_file() -> None: | ||
| 36 | """加载项目根目录 .env,不覆盖已存在的真实环境变量。""" | ||
| 37 | env_path = Path(__file__).resolve().parent.parent / ".env" | ||
| 38 | if not env_path.exists(): | ||
| 39 | return | ||
| 40 | with env_path.open(encoding="utf-8") as file: | ||
| 41 | for raw_line in file: | ||
| 42 | line = raw_line.strip() | ||
| 43 | if not line or line.startswith("#") or "=" not in line: | ||
| 44 | continue | ||
| 45 | key, value = line.split("=", 1) | ||
| 46 | os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'")) | ||
| 47 | |||
| 48 | |||
| 49 | _load_env_file() | ||
| 50 | |||
| 51 | # ===== 常量(可通过环境变量覆盖)===== | ||
| 52 | |||
| 53 | DEFAULT_FS = 44100 | ||
| 54 | DEFAULT_WINDOW_SIZE = 4096 | ||
| 55 | DEFAULT_OVERLAP_RATIO = float(os.environ.get("COMPOSITION_DEJAVU_OVERLAP_RATIO", "0.3")) | ||
| 56 | DEFAULT_FAN_VALUE = int(os.environ.get("COMPOSITION_DEJAVU_FAN_VALUE", "10")) | ||
| 57 | DEFAULT_AMP_MIN = float(os.environ.get("COMPOSITION_DEJAVU_AMP_MIN", "20")) | ||
| 58 | PEAK_NEIGHBORHOOD_SIZE = 20 | ||
| 59 | MIN_HASH_TIME_DELTA = 0 | ||
| 60 | MAX_HASH_TIME_DELTA = 200 | ||
| 61 | PEAK_SORT = True | ||
| 62 | FINGERPRINT_REDUCTION = 20 | ||
| 63 | MAX_DURATION_SEC = float(os.environ.get("COMPOSITION_DEJAVU_MAX_DURATION", "120")) # 0=不限制 | ||
| 64 | |||
| 65 | |||
| 66 | def _normalize_audio(audio_path: str, max_duration: float = MAX_DURATION_SEC) -> tuple[np.ndarray, int]: | ||
| 67 | """将音频标准化为单声道 WAV 并加载为 numpy 数组。 | ||
| 68 | |||
| 69 | 使用 ffmpeg 先做重采样,再用 librosa 读取。 | ||
| 70 | 可选限制音频长度,超长音频只取前 N 秒。 | ||
| 71 | """ | ||
| 72 | with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | ||
| 73 | tmp_wav = tmp.name | ||
| 74 | |||
| 75 | try: | ||
| 76 | cmd = [ | ||
| 77 | "ffmpeg", | ||
| 78 | "-y", | ||
| 79 | "-i", audio_path, | ||
| 80 | "-ar", str(DEFAULT_FS), | ||
| 81 | "-ac", "1", | ||
| 82 | "-f", "wav", | ||
| 83 | ] | ||
| 84 | if max_duration > 0: | ||
| 85 | cmd += ["-t", str(max_duration)] | ||
| 86 | cmd.append(tmp_wav) | ||
| 87 | |||
| 88 | result = subprocess.run(cmd, capture_output=True, text=True) | ||
| 89 | if result.returncode != 0: | ||
| 90 | raise RuntimeError(f"ffmpeg 转换失败: {result.stderr}") | ||
| 91 | |||
| 92 | y, sr = librosa.load(tmp_wav, sr=DEFAULT_FS, mono=True) | ||
| 93 | return y, sr | ||
| 94 | finally: | ||
| 95 | if os.path.exists(tmp_wav): | ||
| 96 | os.remove(tmp_wav) | ||
| 97 | |||
| 98 | |||
| 99 | def _specgram(samples: np.ndarray, fs: int, window_size: int, overlap_ratio: float): | ||
| 100 | """计算对数频谱图,替代 matplotlib.mlab.specgram。 | ||
| 101 | |||
| 102 | Returns: | ||
| 103 | arr2D: shape (n_freq, n_time) 的对数频谱矩阵(dBFS 刻度) | ||
| 104 | """ | ||
| 105 | noverlap = int(window_size * overlap_ratio) | ||
| 106 | window = np.hanning(window_size) | ||
| 107 | |||
| 108 | freqs, times, Sxx = spectrogram( | ||
| 109 | samples, | ||
| 110 | fs=fs, | ||
| 111 | window=window, | ||
| 112 | nperseg=window_size, | ||
| 113 | noverlap=noverlap, | ||
| 114 | ) | ||
| 115 | |||
| 116 | # 转为对数尺度(dBFS,0 dB 为峰值参考) | ||
| 117 | # scipy.signal.spectrogram 返回 PSD,mlab.specgram 返回功率,两者量纲不同 | ||
| 118 | # 统一转为相对于峰值的 dBFS 刻度,使强信号峰值落在 20~80 dB 范围 | ||
| 119 | arr2D = 10 * np.log10(Sxx + 1e-10) | ||
| 120 | arr2D = arr2D - arr2D.max() # 归一化到峰值为 0 dBFS | ||
| 121 | arr2D = arr2D + 80 # 偏移使典型峰值落在 20~80 dB(与 mlab.specgram 一致) | ||
| 122 | arr2D[arr2D < -100] = -100 # 限幅 | ||
| 123 | return arr2D | ||
| 124 | |||
| 125 | |||
| 126 | def _get_2d_peaks(arr2D: np.ndarray, amp_min: float = DEFAULT_AMP_MIN): | ||
| 127 | """在频谱图中检测 2D 局部极大值。 | ||
| 128 | |||
| 129 | Returns: | ||
| 130 | (frequency_idx, time_idx): 峰值的频率和时间索引列表 | ||
| 131 | """ | ||
| 132 | struct = generate_binary_structure(2, 1) | ||
| 133 | neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE) | ||
| 134 | |||
| 135 | # 找局部极大值 | ||
| 136 | local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D | ||
| 137 | background = arr2D == 0 | ||
| 138 | eroded_background = binary_erosion(background, structure=neighborhood, border_value=1) | ||
| 139 | |||
| 140 | # 布尔掩码 | ||
| 141 | detected_peaks = local_max ^ eroded_background | ||
| 142 | |||
| 143 | # 提取峰值 | ||
| 144 | amps = arr2D[detected_peaks] | ||
| 145 | j, i = np.where(detected_peaks) | ||
| 146 | |||
| 147 | # 过滤低于阈值的峰值 | ||
| 148 | peaks = list(zip(i, j, amps)) | ||
| 149 | peaks_filtered = [x for x in peaks if x[2] > amp_min] | ||
| 150 | |||
| 151 | frequency_idx = [x[1] for x in peaks_filtered] | ||
| 152 | time_idx = [x[0] for x in peaks_filtered] | ||
| 153 | |||
| 154 | return frequency_idx, time_idx | ||
| 155 | |||
| 156 | |||
| 157 | def _generate_hashes(peaks: list[tuple[int, int]], fan_value: int = DEFAULT_FAN_VALUE): | ||
| 158 | """根据峰值对生成 SHA1 指纹哈希。 | ||
| 159 | |||
| 160 | Args: | ||
| 161 | peaks: [(freq_idx, time_idx), ...] 列表 | ||
| 162 | fan_value: 每个峰值与后续多少个峰值配对 | ||
| 163 | |||
| 164 | Yields: | ||
| 165 | (hash_bytes, time_offset) 元组 | ||
| 166 | """ | ||
| 167 | if PEAK_SORT: | ||
| 168 | peaks.sort(key=itemgetter(1)) | ||
| 169 | |||
| 170 | for i in range(len(peaks)): | ||
| 171 | for j in range(1, fan_value): | ||
| 172 | if i + j < len(peaks): | ||
| 173 | freq1 = peaks[i][0] | ||
| 174 | freq2 = peaks[i + j][0] | ||
| 175 | t1 = peaks[i][1] | ||
| 176 | t2 = peaks[i + j][1] | ||
| 177 | t_delta = t2 - t1 | ||
| 178 | |||
| 179 | if MIN_HASH_TIME_DELTA <= t_delta <= MAX_HASH_TIME_DELTA: | ||
| 180 | h = hashlib.sha1(f"{freq1}|{freq2}|{t_delta}".encode()) | ||
| 181 | yield (h.hexdigest()[:FINGERPRINT_REDUCTION].encode(), t1) | ||
| 182 | |||
| 183 | |||
| 184 | def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]: | ||
| 185 | """对音频文件生成 Dejavu 风格指纹。 | ||
| 186 | |||
| 187 | Args: | ||
| 188 | audio_path: 音频文件路径。 | ||
| 189 | |||
| 190 | Returns: | ||
| 191 | (file_sha1, fingerprints) 元组, | ||
| 192 | 其中 fingerprints 是 [(hash_bytes, offset), ...] 列表。 | ||
| 193 | |||
| 194 | Raises: | ||
| 195 | FileNotFoundError: 音频文件不存在。 | ||
| 196 | RuntimeError: ffmpeg 转换失败。 | ||
| 197 | """ | ||
| 198 | if not os.path.isfile(audio_path): | ||
| 199 | raise FileNotFoundError(f"音频文件不存在: {audio_path}") | ||
| 200 | |||
| 201 | # 1. 标准化并加载音频(可选限制长度) | ||
| 202 | samples, fs = _normalize_audio(audio_path) | ||
| 203 | |||
| 204 | # 2. 计算文件 SHA1(用于标识) | ||
| 205 | file_sha1 = hashlib.sha1(samples.tobytes()).hexdigest()[:16] | ||
| 206 | |||
| 207 | # 3. 计算频谱图 | ||
| 208 | arr2D = _specgram(samples, fs, DEFAULT_WINDOW_SIZE, DEFAULT_OVERLAP_RATIO) | ||
| 209 | |||
| 210 | # 4. 检测 2D 峰值 | ||
| 211 | freq_idx, time_idx = _get_2d_peaks(arr2D) | ||
| 212 | peaks = list(zip(freq_idx, time_idx)) | ||
| 213 | |||
| 214 | # 5. 生成指纹哈希 | ||
| 215 | fingerprints = list(_generate_hashes(peaks)) | ||
| 216 | |||
| 217 | logger.info("指纹生成完成: audio=%s, 指纹数=%d", audio_path, len(fingerprints)) | ||
| 218 | return file_sha1, fingerprints |
composition_dedup/extractor.py
0 → 100644
| 1 | """Chromagram 特征提取。 | ||
| 2 | |||
| 3 | 流程: | ||
| 4 | 1. 音频标准化:ffmpeg 转 22050Hz / Mono / WAV | ||
| 5 | 2. librosa 加载音频 | ||
| 6 | 3. librosa.feature.chroma_cens() 提取 12×T Chromagram(CENS,对速度/音色鲁棒) | ||
| 7 | 4. 主音对齐:将能量最大的音级滚至第 0 行,实现转调不变性 | ||
| 8 | 5. scipy.signal.resample(chroma, 128, axis=1) 时间归一化到 12×128 | ||
| 9 | 6. .flatten() 展开为 1536 维向量 | ||
| 10 | """ | ||
| 11 | |||
| 12 | import logging | ||
| 13 | import os | ||
| 14 | import subprocess | ||
| 15 | import tempfile | ||
| 16 | |||
| 17 | import librosa | ||
| 18 | import numpy as np | ||
| 19 | from scipy.signal import resample | ||
| 20 | |||
| 21 | logger = logging.getLogger(__name__) | ||
| 22 | |||
| 23 | # 目标采样率和时间帧数 | ||
| 24 | TARGET_SR = 22050 | ||
| 25 | TARGET_FRAMES = 128 | ||
| 26 | VECTOR_DIM = 12 * TARGET_FRAMES # 1536 | ||
| 27 | |||
| 28 | |||
| 29 | def _normalize_audio_ffmpeg(audio_path: str, output_path: str) -> None: | ||
| 30 | """使用 ffmpeg 将音频标准化为 22050Hz / Mono / WAV。""" | ||
| 31 | cmd = [ | ||
| 32 | "ffmpeg", | ||
| 33 | "-y", | ||
| 34 | "-i", audio_path, | ||
| 35 | "-ar", str(TARGET_SR), | ||
| 36 | "-ac", "1", | ||
| 37 | "-f", "wav", | ||
| 38 | output_path, | ||
| 39 | ] | ||
| 40 | result = subprocess.run( | ||
| 41 | cmd, | ||
| 42 | capture_output=True, | ||
| 43 | text=True, | ||
| 44 | ) | ||
| 45 | if result.returncode != 0: | ||
| 46 | raise RuntimeError(f"ffmpeg 转换失败: {result.stderr}") | ||
| 47 | |||
| 48 | |||
| 49 | def extract_chroma_feature(audio_path: str) -> np.ndarray: | ||
| 50 | """从音频文件提取 1536 维 Chromagram 特征向量。 | ||
| 51 | |||
| 52 | Args: | ||
| 53 | audio_path: 音频文件路径。 | ||
| 54 | |||
| 55 | Returns: | ||
| 56 | shape 为 (1536,) 的 numpy 数组。 | ||
| 57 | |||
| 58 | Raises: | ||
| 59 | FileNotFoundError: 音频文件不存在。 | ||
| 60 | RuntimeError: ffmpeg 转换失败。 | ||
| 61 | """ | ||
| 62 | if not os.path.isfile(audio_path): | ||
| 63 | raise FileNotFoundError(f"音频文件不存在: {audio_path}") | ||
| 64 | |||
| 65 | # 1. 音频标准化:ffmpeg 转 WAV | ||
| 66 | with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | ||
| 67 | tmp_wav = tmp.name | ||
| 68 | |||
| 69 | try: | ||
| 70 | _normalize_audio_ffmpeg(audio_path, tmp_wav) | ||
| 71 | |||
| 72 | # 2. librosa 加载音频 | ||
| 73 | y, _sr = librosa.load(tmp_wav, sr=TARGET_SR, mono=True) | ||
| 74 | |||
| 75 | # 3. 提取 CENS Chromagram (12×T),对速度变化和音色具有更强鲁棒性 | ||
| 76 | chroma = librosa.feature.chroma_cens(y=y, sr=TARGET_SR) | ||
| 77 | |||
| 78 | # 4. 主音对齐:将全局能量最大的音级循环滚至第 0 行,实现转调不变性 | ||
| 79 | tonic = int(np.argmax(chroma.sum(axis=1))) | ||
| 80 | if tonic != 0: | ||
| 81 | chroma = np.roll(chroma, -tonic, axis=0) | ||
| 82 | |||
| 83 | # 5. 时间归一化到 12×128 | ||
| 84 | if chroma.shape[1] != TARGET_FRAMES: | ||
| 85 | chroma = resample(chroma, TARGET_FRAMES, axis=1) | ||
| 86 | |||
| 87 | # 6. 展开为 1536 维向量 | ||
| 88 | feature = chroma.flatten().astype(np.float32) | ||
| 89 | |||
| 90 | assert feature.shape == (VECTOR_DIM,), ( | ||
| 91 | f"特征维度错误: 期望 {VECTOR_DIM}, 实际 {feature.shape}" | ||
| 92 | ) | ||
| 93 | |||
| 94 | return feature | ||
| 95 | finally: | ||
| 96 | # 清理临时文件 | ||
| 97 | if os.path.exists(tmp_wav): | ||
| 98 | os.remove(tmp_wav) | ||
| 99 | |||
| 100 | |||
| 101 | def extract_chroma_matrix(audio_path: str) -> np.ndarray: | ||
| 102 | """从音频文件提取 12×128 Chromagram 矩阵(未展平,供 DTW 精排使用)。 | ||
| 103 | |||
| 104 | Returns: | ||
| 105 | shape 为 (12, 128) 的 numpy 数组,已做主音对齐。 | ||
| 106 | """ | ||
| 107 | feature = extract_chroma_feature(audio_path) | ||
| 108 | return feature.reshape(12, TARGET_FRAMES) |
composition_dedup/service.py
0 → 100644
| 1 | """作曲去重服务(入库 + 查询)。 | ||
| 2 | |||
| 3 | 查询流程: | ||
| 4 | 1. Dejavu 指纹匹配(毫秒级,子序列匹配,支持 chorus_only / trim_intro) | ||
| 5 | - 命中(≥ 阈值)→ 直接返回 duplicate(短路) | ||
| 6 | 2. 未命中 → Chromagram 12路 + DTW(百毫秒级) | ||
| 7 | - 返回结果 | ||
| 8 | """ | ||
| 9 | |||
| 10 | import logging | ||
| 11 | import os | ||
| 12 | from dataclasses import dataclass, field | ||
| 13 | |||
| 14 | import numpy as np | ||
| 15 | import psycopg | ||
| 16 | from scipy.spatial.distance import cdist | ||
| 17 | |||
| 18 | from .extractor import TARGET_FRAMES, extract_chroma_feature, extract_chroma_matrix | ||
| 19 | from .dejavu_fingerprinter import fingerprint_audio | ||
| 20 | |||
| 21 | logger = logging.getLogger(__name__) | ||
| 22 | |||
| 23 | |||
| 24 | def _env_bool(name: str, default: bool) -> bool: | ||
| 25 | value = os.getenv(name) | ||
| 26 | if value is None: | ||
| 27 | return default | ||
| 28 | return value.strip().lower() in {"1", "true", "yes", "y", "on"} | ||
| 29 | |||
| 30 | |||
| 31 | def _env_int(name: str, default: int) -> int: | ||
| 32 | value = os.getenv(name) | ||
| 33 | if value is None: | ||
| 34 | return default | ||
| 35 | try: | ||
| 36 | return int(value) | ||
| 37 | except ValueError: | ||
| 38 | logger.warning("环境变量 %s=%r 不是整数,使用默认值 %d", name, value, default) | ||
| 39 | return default | ||
| 40 | |||
| 41 | |||
| 42 | def _env_float(name: str, default: float) -> float: | ||
| 43 | value = os.getenv(name) | ||
| 44 | if value is None: | ||
| 45 | return default | ||
| 46 | try: | ||
| 47 | return float(value) | ||
| 48 | except ValueError: | ||
| 49 | logger.warning("环境变量 %s=%r 不是数字,使用默认值 %.4f", name, value, default) | ||
| 50 | return default | ||
| 51 | |||
| 52 | |||
| 53 | @dataclass | ||
| 54 | class CompositionCandidate: | ||
| 55 | """去重候选结果。""" | ||
| 56 | song_id: int | ||
| 57 | similarity: float | ||
| 58 | source: str = "chromagram" | ||
| 59 | |||
| 60 | |||
| 61 | @dataclass | ||
| 62 | class _DejavuMatch: | ||
| 63 | """Dejavu offset 对齐后的命中结果。""" | ||
| 64 | song_id: int | ||
| 65 | aligned_count: int | ||
| 66 | total_collisions: int | ||
| 67 | |||
| 68 | |||
| 69 | @dataclass | ||
| 70 | class CompositionConfig: | ||
| 71 | """作曲去重服务配置。""" | ||
| 72 | dsn: str = "postgresql:///lyric_dedup" | ||
| 73 | statement_timeout_ms: int = 30000 | ||
| 74 | dtw_rerank_top_k: int = 20 # Cosine 召回后做 DTW 精排的候选数量 | ||
| 75 | duplicate_threshold: float = _env_float("COMPOSITION_DUPLICATE_THRESHOLD", 0.85) | ||
| 76 | # Dejavu 指纹匹配配置 | ||
| 77 | dejavu_enabled: bool = _env_bool("COMPOSITION_DEJAVU_ENABLED", True) | ||
| 78 | dejavu_match_threshold: int = _env_int("COMPOSITION_DEJAVU_MATCH_THRESHOLD", 20) | ||
| 79 | |||
| 80 | |||
| 81 | @dataclass | ||
| 82 | class CompositionDedupService: | ||
| 83 | """作曲去重服务:特征入库 + 相似度查询。""" | ||
| 84 | config: CompositionConfig | ||
| 85 | _logger: logging.Logger = field(default_factory=lambda: logger, repr=False) | ||
| 86 | |||
| 87 | def ingest(self, song_id: int, audio_path: str) -> np.ndarray: | ||
| 88 | """提取音频特征并写入数据库。 | ||
| 89 | |||
| 90 | Args: | ||
| 91 | song_id: 歌曲 ID。 | ||
| 92 | audio_path: 音频文件路径。 | ||
| 93 | |||
| 94 | Returns: | ||
| 95 | 提取的特征向量。 | ||
| 96 | """ | ||
| 97 | feature = extract_chroma_feature(audio_path) | ||
| 98 | self._logger.info("提取 Chromagram 特征完成: song_id=%s, audio=%s", song_id, audio_path) | ||
| 99 | |||
| 100 | with psycopg.connect(self.config.dsn) as conn: | ||
| 101 | with conn.cursor() as cursor: | ||
| 102 | cursor.execute( | ||
| 103 | """ | ||
| 104 | INSERT INTO composition_feature (song_id, feature_vector) | ||
| 105 | VALUES (%s, %s) | ||
| 106 | ON CONFLICT DO NOTHING | ||
| 107 | """, | ||
| 108 | (song_id, feature.tolist()), | ||
| 109 | ) | ||
| 110 | conn.commit() | ||
| 111 | |||
| 112 | self._logger.info("Chromagram 特征入库完成: song_id=%s", song_id) | ||
| 113 | |||
| 114 | # Dejavu 指纹同时入库 | ||
| 115 | if self.config.dejavu_enabled: | ||
| 116 | self._dejavu_ingest(song_id, audio_path) | ||
| 117 | |||
| 118 | return feature | ||
| 119 | |||
| 120 | def _dejavu_ingest(self, song_id: int, audio_path: str) -> None: | ||
| 121 | """提取 Dejavu 指纹并写入 dejavu_fingerprints 表。""" | ||
| 122 | file_sha1, fingerprints = fingerprint_audio(audio_path) | ||
| 123 | if not fingerprints: | ||
| 124 | self._logger.warning("Dejavu 指纹为空: song_id=%s, audio=%s", song_id, audio_path) | ||
| 125 | return | ||
| 126 | |||
| 127 | with psycopg.connect(self.config.dsn) as conn: | ||
| 128 | with conn.cursor() as cursor: | ||
| 129 | # 先清理可能残留的旧指纹(幂等写入) | ||
| 130 | cursor.execute( | ||
| 131 | "DELETE FROM dejavu_fingerprints WHERE song_id = %s", | ||
| 132 | (song_id,), | ||
| 133 | ) | ||
| 134 | # 批量写入 | ||
| 135 | records = [(song_id, h, o) for h, o in fingerprints] | ||
| 136 | cursor.executemany( | ||
| 137 | """ | ||
| 138 | INSERT INTO dejavu_fingerprints (song_id, hash, "offset") | ||
| 139 | VALUES (%s, %s, %s) | ||
| 140 | """, | ||
| 141 | records, | ||
| 142 | ) | ||
| 143 | conn.commit() | ||
| 144 | |||
| 145 | self._logger.info("Dejavu 指纹入库完成: song_id=%s, 指纹数=%d", song_id, len(fingerprints)) | ||
| 146 | |||
| 147 | def query(self, audio_path: str, top_k: int = 100) -> list[CompositionCandidate]: | ||
| 148 | """提取音频特征并查询相似结果。 | ||
| 149 | |||
| 150 | 流程:Dejavu 指纹短路匹配 → 12 路循环对齐 Cosine 召回 → DTW 精排。 | ||
| 151 | """ | ||
| 152 | # 1. 优先尝试 Dejavu 指纹匹配(短路) | ||
| 153 | if self.config.dejavu_enabled: | ||
| 154 | match = self._dejavu_query(audio_path) | ||
| 155 | if match is not None: | ||
| 156 | self._logger.info( | ||
| 157 | "Dejavu 命中: song_id=%s, aligned_count=%d, total_collisions=%d, decision=duplicate", | ||
| 158 | match.song_id, | ||
| 159 | match.aligned_count, | ||
| 160 | match.total_collisions, | ||
| 161 | ) | ||
| 162 | return [CompositionCandidate(song_id=match.song_id, similarity=1.0, source="dejavu")] | ||
| 163 | |||
| 164 | # 2. Dejavu 未命中或禁用,走现有 Chromagram 12路 + DTW 流程 | ||
| 165 | return self._query_chroma(audio_path, top_k) | ||
| 166 | |||
| 167 | def check(self, audio_path: str, top_k: int = 100) -> bool: | ||
| 168 | """按最终接口语义返回是否重复。""" | ||
| 169 | return self.candidates_indicate_duplicate(self.query(audio_path, top_k=top_k)) | ||
| 170 | |||
| 171 | def candidates_indicate_duplicate(self, candidates: list[CompositionCandidate]) -> bool: | ||
| 172 | """将候选结果转换为最终 duplicate bool。 | ||
| 173 | |||
| 174 | 最终接口只返回 true/false,因此判定只看当前查询的最佳候选是否超过阈值, | ||
| 175 | 不依赖评测集里的 expected_song_id 是否出现在 top-k。 | ||
| 176 | """ | ||
| 177 | if not candidates: | ||
| 178 | return False | ||
| 179 | return candidates[0].similarity >= self.config.duplicate_threshold | ||
| 180 | |||
| 181 | def _query_chroma(self, audio_path: str, top_k: int = 100) -> list[CompositionCandidate]: | ||
| 182 | """Chromagram 12 路循环对齐 + DTW 精排查询。""" | ||
| 183 | chroma = extract_chroma_matrix(audio_path) | ||
| 184 | self._logger.info("提取 Chromagram 查询特征完成: audio=%s", audio_path) | ||
| 185 | |||
| 186 | # 1. 12 路循环对齐:穷举 12 种半音偏移,单条 SQL 内部展开,按 song_id 取最高 Cosine 相似度 | ||
| 187 | shift_vecs = [ | ||
| 188 | np.roll(chroma, -shift, axis=0).flatten().astype(np.float32).tolist() | ||
| 189 | for shift in range(12) | ||
| 190 | ] | ||
| 191 | # 用 VALUES 展开 12 个偏移向量,LATERAL 子查询对每个偏移各触发一次 HNSW 扫描 | ||
| 192 | values_clause = ", ".join(f"({i}, %s::vector)" for i in range(12)) | ||
| 193 | sql = f""" | ||
| 194 | WITH shifts(shift_id, vec) AS ( | ||
| 195 | VALUES {values_clause} | ||
| 196 | ), | ||
| 197 | candidates AS ( | ||
| 198 | SELECT | ||
| 199 | cf.song_id, | ||
| 200 | 1 - (cf.feature_vector <=> s.vec) AS sim | ||
| 201 | FROM shifts s | ||
| 202 | CROSS JOIN LATERAL ( | ||
| 203 | SELECT song_id, feature_vector | ||
| 204 | FROM composition_feature | ||
| 205 | ORDER BY feature_vector <=> s.vec | ||
| 206 | LIMIT %s | ||
| 207 | ) cf | ||
| 208 | ) | ||
| 209 | SELECT song_id, MAX(sim) AS similarity | ||
| 210 | FROM candidates | ||
| 211 | GROUP BY song_id | ||
| 212 | ORDER BY similarity DESC | ||
| 213 | LIMIT %s | ||
| 214 | """ | ||
| 215 | best: dict[int, float] = {} | ||
| 216 | with psycopg.connect(self.config.dsn) as conn: | ||
| 217 | with conn.cursor() as cursor: | ||
| 218 | cursor.execute( | ||
| 219 | f"SET statement_timeout = {int(self.config.statement_timeout_ms)}" | ||
| 220 | ) | ||
| 221 | cursor.execute(sql, (*shift_vecs, top_k, top_k)) | ||
| 222 | for song_id, sim in cursor.fetchall(): | ||
| 223 | best[int(song_id)] = float(sim) | ||
| 224 | |||
| 225 | # 2. 取 Top dtw_rerank_top_k,从库中取原始向量做 DTW 精排 | ||
| 226 | top = sorted(best.items(), key=lambda x: x[1], reverse=True) | ||
| 227 | rerank_ids = [sid for sid, _ in top[:self.config.dtw_rerank_top_k]] | ||
| 228 | |||
| 229 | with conn.cursor() as cursor: | ||
| 230 | cursor.execute( | ||
| 231 | "SELECT song_id, feature_vector::float4[] FROM composition_feature WHERE song_id = ANY(%s)", | ||
| 232 | (rerank_ids,), | ||
| 233 | ) | ||
| 234 | db_rows = cursor.fetchall() | ||
| 235 | |||
| 236 | reranked = [] | ||
| 237 | for song_id, fv in db_rows: | ||
| 238 | cand_chroma = np.array(fv, dtype=np.float32).reshape(12, TARGET_FRAMES) | ||
| 239 | dtw_sim = _best_shifted_dtw_similarity(chroma, cand_chroma) | ||
| 240 | reranked.append(CompositionCandidate(song_id=int(song_id), similarity=dtw_sim)) | ||
| 241 | reranked.sort(key=lambda c: c.similarity, reverse=True) | ||
| 242 | |||
| 243 | rerank_id_set = {c.song_id for c in reranked} | ||
| 244 | rest = [ | ||
| 245 | CompositionCandidate(song_id=sid, similarity=sim) | ||
| 246 | for sid, sim in top[self.config.dtw_rerank_top_k:] | ||
| 247 | if sid not in rerank_id_set | ||
| 248 | ] | ||
| 249 | |||
| 250 | result = reranked + rest | ||
| 251 | top_summary = ", ".join( | ||
| 252 | f"{candidate.song_id}:{candidate.similarity:.4f}" | ||
| 253 | for candidate in result[:5] | ||
| 254 | ) | ||
| 255 | self._logger.info( | ||
| 256 | "Chromagram 查询完成: 返回 %d 个候选, top=%s", | ||
| 257 | len(result), | ||
| 258 | top_summary or "[]", | ||
| 259 | ) | ||
| 260 | return result | ||
| 261 | |||
| 262 | def _dejavu_query(self, audio_path: str) -> _DejavuMatch | None: | ||
| 263 | """Dejavu 指纹查询,返回 offset 对齐后碰撞数最多的 song_id。 | ||
| 264 | |||
| 265 | 只统计 hash 总碰撞数会让常见频谱峰值、噪声片段或大库随机碰撞直接短路成 | ||
| 266 | similarity=1.0。Dejavu 的关键判据是同一首候选歌里,多个 hash 碰撞的 | ||
| 267 | db_offset - query_offset 落在同一个时间偏移上。 | ||
| 268 | |||
| 269 | Returns: | ||
| 270 | 命中结果,未命中返回 None。 | ||
| 271 | """ | ||
| 272 | file_sha1, fingerprints = fingerprint_audio(audio_path) | ||
| 273 | if not fingerprints: | ||
| 274 | return None | ||
| 275 | |||
| 276 | hashes = [h for h, _ in fingerprints] | ||
| 277 | offsets = [int(o) for _, o in fingerprints] | ||
| 278 | |||
| 279 | with psycopg.connect(self.config.dsn) as conn: | ||
| 280 | with conn.cursor() as cursor: | ||
| 281 | # 先按 hash 找碰撞,再按每个 song_id 的 offset delta 聚类。 | ||
| 282 | cursor.execute( | ||
| 283 | """ | ||
| 284 | WITH query_fp(hash, query_offset) AS ( | ||
| 285 | SELECT * | ||
| 286 | FROM unnest(%s::bytea[], %s::int[]) | ||
| 287 | ), | ||
| 288 | aligned AS ( | ||
| 289 | SELECT | ||
| 290 | db.song_id, | ||
| 291 | db."offset" - query_fp.query_offset AS offset_delta, | ||
| 292 | COUNT(*) AS aligned_count | ||
| 293 | FROM query_fp | ||
| 294 | JOIN dejavu_fingerprints db | ||
| 295 | ON db.hash = query_fp.hash | ||
| 296 | GROUP BY db.song_id, offset_delta | ||
| 297 | ) | ||
| 298 | SELECT | ||
| 299 | song_id, | ||
| 300 | MAX(aligned_count) AS best_aligned_count, | ||
| 301 | SUM(aligned_count) AS total_collisions | ||
| 302 | FROM aligned | ||
| 303 | GROUP BY song_id | ||
| 304 | ORDER BY best_aligned_count DESC, total_collisions DESC | ||
| 305 | LIMIT 1 | ||
| 306 | """, | ||
| 307 | (hashes, offsets), | ||
| 308 | ) | ||
| 309 | row = cursor.fetchone() | ||
| 310 | if row is None: | ||
| 311 | return None | ||
| 312 | sid, aligned_count, total_collisions = row | ||
| 313 | aligned_count = int(aligned_count) | ||
| 314 | if aligned_count >= self.config.dejavu_match_threshold: | ||
| 315 | return _DejavuMatch( | ||
| 316 | song_id=int(sid), | ||
| 317 | aligned_count=aligned_count, | ||
| 318 | total_collisions=int(total_collisions), | ||
| 319 | ) | ||
| 320 | return None | ||
| 321 | |||
| 322 | |||
| 323 | def _dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float: | ||
| 324 | """计算两个 12×T Chromagram 矩阵之间的 DTW 相似度(映射到 [0,1])。""" | ||
| 325 | # 帧间欧氏距离矩阵 | ||
| 326 | cost = cdist(query.T, candidate.T, metric="euclidean") | ||
| 327 | n, m = cost.shape | ||
| 328 | dp = np.full((n, m), np.inf) | ||
| 329 | dp[0, 0] = cost[0, 0] | ||
| 330 | for i in range(1, n): | ||
| 331 | dp[i, 0] = dp[i - 1, 0] + cost[i, 0] | ||
| 332 | for j in range(1, m): | ||
| 333 | dp[0, j] = dp[0, j - 1] + cost[0, j] | ||
| 334 | for i in range(1, n): | ||
| 335 | for j in range(1, m): | ||
| 336 | dp[i, j] = cost[i, j] + min(dp[i - 1, j], dp[i, j - 1], dp[i - 1, j - 1]) | ||
| 337 | dtw_dist = dp[n - 1, m - 1] / (n + m) | ||
| 338 | # 转换为相似度:距离越小相似度越高 | ||
| 339 | return float(1.0 / (1.0 + dtw_dist)) | ||
| 340 | |||
| 341 | |||
| 342 | def _best_shifted_dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float: | ||
| 343 | """计算 12 路音高循环位移下的最佳 DTW 相似度。""" | ||
| 344 | return max( | ||
| 345 | _dtw_similarity(np.roll(query, -shift, axis=0), candidate) | ||
| 346 | for shift in range(12) | ||
| 347 | ) |
composition_dedup/similarity.py
0 → 100644
| 1 | """Cosine 相似度计算与去重判定。""" | ||
| 2 | |||
| 3 | from enum import Enum | ||
| 4 | |||
| 5 | import numpy as np | ||
| 6 | |||
| 7 | DUPLICATE_THRESHOLD = 0.95 | ||
| 8 | SUSPECTED_THRESHOLD = 0.85 | ||
| 9 | |||
| 10 | |||
| 11 | class SimilarityDecision(Enum): | ||
| 12 | DUPLICATE = "duplicate" | ||
| 13 | SUSPECTED = "suspected" | ||
| 14 | NEW = "new" | ||
| 15 | |||
| 16 | |||
| 17 | class CompositionSimilarity: | ||
| 18 | @staticmethod | ||
| 19 | def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: | ||
| 20 | norm_a = np.linalg.norm(a) | ||
| 21 | norm_b = np.linalg.norm(b) | ||
| 22 | if norm_a == 0.0 or norm_b == 0.0: | ||
| 23 | return 0.0 | ||
| 24 | return float(np.dot(a, b) / (norm_a * norm_b)) | ||
| 25 | |||
| 26 | @staticmethod | ||
| 27 | def classify_similarity(similarity: float) -> SimilarityDecision: | ||
| 28 | if similarity >= DUPLICATE_THRESHOLD: | ||
| 29 | return SimilarityDecision.DUPLICATE | ||
| 30 | if similarity >= SUSPECTED_THRESHOLD: | ||
| 31 | return SimilarityDecision.SUSPECTED | ||
| 32 | return SimilarityDecision.NEW | ||
| 33 | |||
| 34 | @staticmethod | ||
| 35 | def compare(a: np.ndarray, b: np.ndarray) -> tuple[float, SimilarityDecision]: | ||
| 36 | sim = CompositionSimilarity.cosine_similarity(a, b) | ||
| 37 | return sim, CompositionSimilarity.classify_similarity(sim) |
scripts/evaluate_composition.py
0 → 100644
| 1 | """曲去重评估脚本。 | ||
| 2 | |||
| 3 | 对 queries.csv 中每条查询音频调用 CompositionDedupService.query(), | ||
| 4 | 按最终接口语义用 top1 分数阈值输出 predicted_duplicate true/false。 | ||
| 5 | expected_song_id 的 top-k/top1 命中只作为诊断字段。 | ||
| 6 | 输出 precision/recall/F1。 | ||
| 7 | |||
| 8 | 用法: | ||
| 9 | python scripts/evaluate_composition.py \ | ||
| 10 | --dsn "postgresql:///lyric_dedup" \ | ||
| 11 | --queries composition_dedup/composition_testset4/queries.csv \ | ||
| 12 | --out composition_dedup/composition_eval/composition_eval_result_v3.csv | ||
| 13 | """ | ||
| 14 | |||
| 15 | import argparse | ||
| 16 | import csv | ||
| 17 | import json | ||
| 18 | import logging | ||
| 19 | import sys | ||
| 20 | from pathlib import Path | ||
| 21 | |||
| 22 | sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) | ||
| 23 | |||
| 24 | from composition_dedup.service import CompositionConfig, CompositionDedupService | ||
| 25 | |||
| 26 | logger = logging.getLogger(__name__) | ||
| 27 | |||
| 28 | |||
| 29 | def _parse_csv_filter(value: str | None) -> set[str] | None: | ||
| 30 | if value is None: | ||
| 31 | return None | ||
| 32 | items = {item.strip() for item in value.split(",") if item.strip()} | ||
| 33 | return items or None | ||
| 34 | |||
| 35 | |||
| 36 | def _song_id_from_audio_path(audio_path: str) -> str: | ||
| 37 | """从音频文件名开头提取 song_id。""" | ||
| 38 | return Path(audio_path).stem.split("_", 1)[0] | ||
| 39 | |||
| 40 | |||
| 41 | def main() -> None: | ||
| 42 | logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | ||
| 43 | |||
| 44 | parser = argparse.ArgumentParser() | ||
| 45 | parser.add_argument("--dsn", required=True) | ||
| 46 | parser.add_argument("--queries", required=True, help="queries.csv 路径") | ||
| 47 | parser.add_argument("--out", required=True, help="逐条结果输出 CSV") | ||
| 48 | parser.add_argument("--top-k", type=int, default=10) | ||
| 49 | parser.add_argument("--duplicate-threshold", type=float, help="覆盖 COMPOSITION_DUPLICATE_THRESHOLD") | ||
| 50 | parser.add_argument("--variants", help="只评测指定 variant,逗号分隔,如 pitch_up1,pitch_down1") | ||
| 51 | parser.add_argument("--sample-classes", help="只评测指定 sample_class,逗号分隔,如 dsp,negative") | ||
| 52 | parser.add_argument("--expected", choices=["duplicate", "not_duplicate"], help="只评测指定 expected 类型") | ||
| 53 | args = parser.parse_args() | ||
| 54 | |||
| 55 | config = CompositionConfig(dsn=args.dsn) | ||
| 56 | if args.duplicate_threshold is not None: | ||
| 57 | config.duplicate_threshold = args.duplicate_threshold | ||
| 58 | service = CompositionDedupService(config=config) | ||
| 59 | |||
| 60 | with open(args.queries, newline="", encoding="utf-8") as f: | ||
| 61 | rows = list(csv.DictReader(f)) | ||
| 62 | |||
| 63 | variant_filter = _parse_csv_filter(args.variants) | ||
| 64 | sample_class_filter = _parse_csv_filter(args.sample_classes) | ||
| 65 | original_count = len(rows) | ||
| 66 | if variant_filter is not None: | ||
| 67 | rows = [r for r in rows if (r.get("variant") or "") in variant_filter] | ||
| 68 | if sample_class_filter is not None: | ||
| 69 | rows = [r for r in rows if (r.get("sample_class") or "") in sample_class_filter] | ||
| 70 | if args.expected is not None: | ||
| 71 | rows = [r for r in rows if r["expected"].strip().lower() == args.expected] | ||
| 72 | |||
| 73 | logger.info( | ||
| 74 | "评测样本过滤: 原始 %d 条,保留 %d 条 (variants=%s, sample_classes=%s, expected=%s)", | ||
| 75 | original_count, | ||
| 76 | len(rows), | ||
| 77 | ",".join(sorted(variant_filter)) if variant_filter else "ALL", | ||
| 78 | ",".join(sorted(sample_class_filter)) if sample_class_filter else "ALL", | ||
| 79 | args.expected or "ALL", | ||
| 80 | ) | ||
| 81 | |||
| 82 | out_path = Path(args.out) | ||
| 83 | out_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 84 | |||
| 85 | result_rows = [] | ||
| 86 | for i, row in enumerate(rows, 1): | ||
| 87 | audio_path = row["audio_path"] | ||
| 88 | query_song_id = row.get("song_id") or _song_id_from_audio_path(audio_path) | ||
| 89 | audio_song_id = _song_id_from_audio_path(audio_path) | ||
| 90 | expected_song_id = str(row["expected_song_id"]) | ||
| 91 | expected_dup = row["expected"].strip().lower() == "duplicate" | ||
| 92 | invalid_negative_pair = (not expected_dup) and audio_song_id == expected_song_id | ||
| 93 | |||
| 94 | try: | ||
| 95 | candidates = service.query(audio_path, top_k=args.top_k) | ||
| 96 | except Exception as e: | ||
| 97 | logger.error("[%d/%d] 查询失败: %s, %s", i, len(rows), audio_path, e) | ||
| 98 | result_rows.append({ | ||
| 99 | "query_song_id": query_song_id, | ||
| 100 | "audio_song_id": audio_song_id, | ||
| 101 | "audio_path": audio_path, | ||
| 102 | "variant": row.get("variant", ""), | ||
| 103 | "sample_class": row.get("sample_class", ""), | ||
| 104 | "expected_song_id": expected_song_id, | ||
| 105 | "expected": row["expected"], | ||
| 106 | "top1_song_id": "", | ||
| 107 | "top1_similarity": "", | ||
| 108 | "top1_source": "", | ||
| 109 | "top1_hit": False, | ||
| 110 | "topk_hit": False, | ||
| 111 | "expected_rank": "", | ||
| 112 | "expected_similarity": "", | ||
| 113 | "invalid_negative_pair": invalid_negative_pair, | ||
| 114 | "invalid_boolean_sample": False, | ||
| 115 | "expected_duplicate": expected_dup, | ||
| 116 | "predicted_duplicate": False, | ||
| 117 | "correct": not expected_dup, # 查询失败视为 not_duplicate | ||
| 118 | "error": str(e), | ||
| 119 | }) | ||
| 120 | continue | ||
| 121 | |||
| 122 | top1 = candidates[0] if candidates else None | ||
| 123 | top1_song_id = str(top1.song_id) if top1 else "" | ||
| 124 | top1_sim = round(top1.similarity, 4) if top1 else "" | ||
| 125 | top1_source = top1.source if top1 else "" | ||
| 126 | |||
| 127 | # 诊断召回:expected_song_id 是否进入 top1/top-k。 | ||
| 128 | top1_hit = bool(expected_song_id) and top1_song_id == expected_song_id | ||
| 129 | topk_hit = bool(expected_song_id) and any(str(c.song_id) == expected_song_id for c in candidates) | ||
| 130 | expected_rank = "" | ||
| 131 | expected_similarity = "" | ||
| 132 | if expected_song_id: | ||
| 133 | for rank, candidate in enumerate(candidates, 1): | ||
| 134 | if str(candidate.song_id) == expected_song_id: | ||
| 135 | expected_rank = rank | ||
| 136 | expected_similarity = round(candidate.similarity, 4) | ||
| 137 | break | ||
| 138 | |||
| 139 | # 最终接口语义:只返回 duplicate true/false。 | ||
| 140 | predicted_dup = service.candidates_indicate_duplicate(candidates) | ||
| 141 | correct = expected_dup == predicted_dup | ||
| 142 | invalid_boolean_sample = ( | ||
| 143 | (not expected_dup) | ||
| 144 | and bool(top1) | ||
| 145 | and top1_song_id == audio_song_id | ||
| 146 | and predicted_dup | ||
| 147 | ) | ||
| 148 | |||
| 149 | result_rows.append({ | ||
| 150 | "query_song_id": query_song_id, | ||
| 151 | "audio_song_id": audio_song_id, | ||
| 152 | "audio_path": audio_path, | ||
| 153 | "variant": row.get("variant", ""), | ||
| 154 | "sample_class": row.get("sample_class", ""), | ||
| 155 | "expected_song_id": expected_song_id, | ||
| 156 | "expected": row["expected"], | ||
| 157 | "top1_song_id": top1_song_id, | ||
| 158 | "top1_similarity": top1_sim, | ||
| 159 | "top1_source": top1_source, | ||
| 160 | "top1_hit": top1_hit, | ||
| 161 | "topk_hit": topk_hit, | ||
| 162 | "expected_rank": expected_rank, | ||
| 163 | "expected_similarity": expected_similarity, | ||
| 164 | "invalid_negative_pair": invalid_negative_pair, | ||
| 165 | "invalid_boolean_sample": invalid_boolean_sample, | ||
| 166 | "expected_duplicate": expected_dup, | ||
| 167 | "predicted_duplicate": predicted_dup, | ||
| 168 | "correct": correct, | ||
| 169 | "error": "", | ||
| 170 | }) | ||
| 171 | |||
| 172 | logger.info( | ||
| 173 | "[%d/%d] variant=%s source=%s expected=%s predicted_duplicate=%s threshold=%.4f expected_song_id=%s top1=%s sim=%s top1_hit=%s topk_hit=%s expected_rank=%s expected_sim=%s correct=%s", | ||
| 174 | i, | ||
| 175 | len(rows), | ||
| 176 | row.get("variant", ""), | ||
| 177 | top1_source or "-", | ||
| 178 | row["expected"], | ||
| 179 | predicted_dup, | ||
| 180 | service.config.duplicate_threshold, | ||
| 181 | expected_song_id, | ||
| 182 | top1_song_id or "-", | ||
| 183 | top1_sim if top1_sim != "" else "-", | ||
| 184 | top1_hit, | ||
| 185 | topk_hit, | ||
| 186 | expected_rank if expected_rank != "" else "-", | ||
| 187 | expected_similarity if expected_similarity != "" else "-", | ||
| 188 | correct, | ||
| 189 | ) | ||
| 190 | |||
| 191 | if i % 10 == 0 or i == len(rows): | ||
| 192 | logger.info("[%d/%d]", i, len(rows)) | ||
| 193 | |||
| 194 | # 写逐条结果 | ||
| 195 | fieldnames = ["query_song_id", "audio_song_id", "audio_path", "variant", "sample_class", | ||
| 196 | "expected_song_id", "expected", "top1_song_id", "top1_similarity", "top1_source", | ||
| 197 | "top1_hit", "topk_hit", "expected_rank", "expected_similarity", | ||
| 198 | "invalid_negative_pair", "invalid_boolean_sample", | ||
| 199 | "expected_duplicate", "predicted_duplicate", "correct", "error"] | ||
| 200 | with out_path.open("w", newline="", encoding="utf-8") as f: | ||
| 201 | writer = csv.DictWriter(f, fieldnames=fieldnames) | ||
| 202 | writer.writeheader() | ||
| 203 | writer.writerows(result_rows) | ||
| 204 | |||
| 205 | # 汇总指标 | ||
| 206 | def _metrics(rows: list[dict]) -> dict: | ||
| 207 | tp = sum(1 for r in rows if r["expected_duplicate"] and r["predicted_duplicate"]) | ||
| 208 | fp = sum(1 for r in rows if not r["expected_duplicate"] and r["predicted_duplicate"]) | ||
| 209 | tn = sum(1 for r in rows if not r["expected_duplicate"] and not r["predicted_duplicate"]) | ||
| 210 | fn = sum(1 for r in rows if r["expected_duplicate"] and not r["predicted_duplicate"]) | ||
| 211 | precision = tp / (tp + fp) if tp + fp else 0.0 | ||
| 212 | recall = tp / (tp + fn) if tp + fn else 0.0 | ||
| 213 | f1 = 2 * precision * recall / (precision + recall) if precision + recall else 0.0 | ||
| 214 | accuracy = (tp + tn) / len(rows) if rows else 0.0 | ||
| 215 | return { | ||
| 216 | "total": len(rows), | ||
| 217 | "accuracy": round(accuracy, 4), | ||
| 218 | "precision": round(precision, 4), | ||
| 219 | "recall": round(recall, 4), | ||
| 220 | "f1": round(f1, 4), | ||
| 221 | "tp": tp, | ||
| 222 | "fp": fp, | ||
| 223 | "tn": tn, | ||
| 224 | "fn": fn, | ||
| 225 | } | ||
| 226 | |||
| 227 | metrics = _metrics(result_rows) | ||
| 228 | valid_rows = [ | ||
| 229 | r for r in result_rows | ||
| 230 | if not r["invalid_negative_pair"] and not r["invalid_boolean_sample"] | ||
| 231 | ] | ||
| 232 | valid_metrics = _metrics(valid_rows) | ||
| 233 | |||
| 234 | summary = { | ||
| 235 | "total": len(result_rows), | ||
| 236 | "filters": { | ||
| 237 | "variants": sorted(variant_filter) if variant_filter else None, | ||
| 238 | "sample_classes": sorted(sample_class_filter) if sample_class_filter else None, | ||
| 239 | "expected": args.expected, | ||
| 240 | "original_total": original_count, | ||
| 241 | }, | ||
| 242 | "duplicate_threshold": service.config.duplicate_threshold, | ||
| 243 | "invalid_negative_pairs": sum(1 for r in result_rows if r["invalid_negative_pair"]), | ||
| 244 | "invalid_boolean_samples": sum(1 for r in result_rows if r["invalid_boolean_sample"]), | ||
| 245 | "accuracy": metrics["accuracy"], | ||
| 246 | "precision": metrics["precision"], | ||
| 247 | "recall": metrics["recall"], | ||
| 248 | "f1": metrics["f1"], | ||
| 249 | "tp": metrics["tp"], "fp": metrics["fp"], "tn": metrics["tn"], "fn": metrics["fn"], | ||
| 250 | "valid_only": valid_metrics, | ||
| 251 | "out": str(out_path), | ||
| 252 | } | ||
| 253 | |||
| 254 | # 按 variant 分组,方便看各种变换的通过率 | ||
| 255 | from collections import defaultdict | ||
| 256 | by_variant: dict[str, dict] = defaultdict(lambda: {"correct": 0, "total": 0}) | ||
| 257 | for r in result_rows: | ||
| 258 | v = r["variant"] or "unknown" | ||
| 259 | by_variant[v]["total"] += 1 | ||
| 260 | if r["correct"]: | ||
| 261 | by_variant[v]["correct"] += 1 | ||
| 262 | summary["by_variant"] = { | ||
| 263 | v: {"accuracy": round(d["correct"] / d["total"], 4), "total": d["total"]} | ||
| 264 | for v, d in sorted(by_variant.items()) | ||
| 265 | } | ||
| 266 | |||
| 267 | # 按 sample_class 分组 | ||
| 268 | by_class: dict[str, dict] = defaultdict(lambda: {"correct": 0, "total": 0}) | ||
| 269 | for r in result_rows: | ||
| 270 | sc = r.get("sample_class") or "unknown" | ||
| 271 | by_class[sc]["total"] += 1 | ||
| 272 | if r["correct"]: | ||
| 273 | by_class[sc]["correct"] += 1 | ||
| 274 | summary["by_sample_class"] = { | ||
| 275 | sc: {"accuracy": round(d["correct"] / d["total"], 4), "total": d["total"]} | ||
| 276 | for sc, d in sorted(by_class.items()) | ||
| 277 | } | ||
| 278 | |||
| 279 | summary_path = out_path.with_suffix(".summary.json") | ||
| 280 | summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") | ||
| 281 | print(json.dumps(summary, ensure_ascii=False, indent=2)) | ||
| 282 | |||
| 283 | |||
| 284 | if __name__ == "__main__": | ||
| 285 | main() |
scripts/generate_composition_testset.py
0 → 100644
| 1 | """生成曲去重评估测试集。 | ||
| 2 | |||
| 3 | 从音频目录随机抽取若干首参照歌入库,对每首用 ffmpeg 生成多个变换版本, | ||
| 4 | 覆盖曲去重测试样本类型.md 中第一类(数字信号变换)和第三类(困难正样本)的可合成部分。 | ||
| 5 | 负样本从未入库的 holdout 歌曲生成,以匹配最终接口 duplicate true/false 语义。 | ||
| 6 | |||
| 7 | 用法: | ||
| 8 | python scripts/generate_composition_testset.py \ | ||
| 9 | --audio-dir /Volumes/移动硬盘/lyric_audio_type11 \ | ||
| 10 | --negative-audio-dir /Volumes/移动硬盘/composition_test \ | ||
| 11 | --out-dir composition_dedup/composition_testset \ | ||
| 12 | --num-songs 80 \ | ||
| 13 | --num-negative-songs 40 \ | ||
| 14 | --negative-variants \ | ||
| 15 | --seed 123 | ||
| 16 | |||
| 17 | 输出: | ||
| 18 | reference.csv — 参照曲(原始文件),需提前入库 | ||
| 19 | queries.csv — 查询曲,带 variant 和 expected 标注 | ||
| 20 | """ | ||
| 21 | |||
| 22 | import argparse | ||
| 23 | import csv | ||
| 24 | import logging | ||
| 25 | import random | ||
| 26 | import subprocess | ||
| 27 | import sys | ||
| 28 | from pathlib import Path | ||
| 29 | |||
| 30 | try: | ||
| 31 | from tqdm import tqdm | ||
| 32 | except ImportError: | ||
| 33 | tqdm = None | ||
| 34 | |||
| 35 | |||
| 36 | def _tqdm(iterable, **kwargs): | ||
| 37 | if tqdm is not None: | ||
| 38 | return tqdm(iterable, **kwargs) | ||
| 39 | total = kwargs.get("total", None) or (len(iterable) if hasattr(iterable, "__len__") else None) | ||
| 40 | desc = kwargs.get("desc", "") | ||
| 41 | class _Simple: | ||
| 42 | def __init__(self): | ||
| 43 | self._i = 0 | ||
| 44 | def __iter__(self): | ||
| 45 | for item in iterable: | ||
| 46 | self._i += 1 | ||
| 47 | if total: | ||
| 48 | print(f"\r{desc}: {self._i}/{total}", end="", flush=True) | ||
| 49 | yield item | ||
| 50 | if total: | ||
| 51 | print() | ||
| 52 | return _Simple() | ||
| 53 | |||
| 54 | logger = logging.getLogger(__name__) | ||
| 55 | |||
| 56 | # -------------------------------------------------------------------------- | ||
| 57 | # 第一类:数字信号变换 | ||
| 58 | # -------------------------------------------------------------------------- | ||
| 59 | DSP_VARIANTS: list[tuple[str, str]] = [ | ||
| 60 | # Pitch Shift(±1、±2 半音) | ||
| 61 | ("pitch_up1", "asetrate=22050*1.0595,aresample=22050"), # +1 半音 | ||
| 62 | ("pitch_up2", "asetrate=22050*1.1225,aresample=22050"), # +2 半音 | ||
| 63 | ("pitch_down1", "asetrate=22050*0.9439,aresample=22050"), # -1 半音 | ||
| 64 | ("pitch_down2", "asetrate=22050*0.8909,aresample=22050"), # -2 半音 | ||
| 65 | # Tempo Shift | ||
| 66 | ("tempo_slow", "atempo=0.90"), # 0.9x | ||
| 67 | ("tempo_fast", "atempo=1.10"), # 1.1x | ||
| 68 | ("tempo_faster","atempo=1.20"), # 1.2x | ||
| 69 | # EQ 变换 | ||
| 70 | ("lowpass", "lowpass=f=4000"), # 低通 | ||
| 71 | ("highpass", "highpass=f=800"), # 高通 | ||
| 72 | ("eq_mid", "equalizer=f=1000:width_type=o:width=2:g=-6"), # 中频衰减 | ||
| 73 | # 压缩编码往返(编码为 mp3 再解回 wav,模拟有损压缩引入的失真) | ||
| 74 | ("codec_320k", "acodec=libmp3lame,b:a=320k"), | ||
| 75 | ("codec_128k", "acodec=libmp3lame,b:a=128k"), | ||
| 76 | ] | ||
| 77 | |||
| 78 | # -------------------------------------------------------------------------- | ||
| 79 | # 第三类:困难正样本(可合成部分) | ||
| 80 | # -------------------------------------------------------------------------- | ||
| 81 | HARD_POSITIVE_VARIANTS: list[tuple[str, str]] = [ | ||
| 82 | # 前奏删减:从 20% 处开始截取(模拟删前奏版本) | ||
| 83 | ("trim_intro", None), # 特殊处理,用 -ss 参数 | ||
| 84 | # 只保留副歌:截取中间 40%(模拟短视频截段) | ||
| 85 | ("chorus_only", None), # 特殊处理,用 -ss + -t 参数 | ||
| 86 | # Pitch + Tempo 叠加(模拟 Live 版同时有调整) | ||
| 87 | ("pitch_up1_tempo_slow", "asetrate=22050*1.0595,aresample=22050,atempo=0.92"), | ||
| 88 | ] | ||
| 89 | |||
| 90 | # 负样本变体只使用相对温和的处理,避免把负样本评估变成极端音质测试。 | ||
| 91 | NEGATIVE_VARIANTS: list[tuple[str, str | None]] = [ | ||
| 92 | ("negative_lowpass", "lowpass=f=4000"), | ||
| 93 | ("negative_codec_128k", "acodec=libmp3lame,b:a=128k"), | ||
| 94 | ] | ||
| 95 | |||
| 96 | |||
| 97 | def _ffmpeg_variant(src: Path, dst: Path, af: str) -> bool: | ||
| 98 | """普通 audio filter 变换。""" | ||
| 99 | # 压缩编码往返需要两步:先编码为 mp3,再解回 wav | ||
| 100 | if "acodec" in af: | ||
| 101 | import tempfile | ||
| 102 | with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp: | ||
| 103 | tmp_mp3 = Path(tmp.name) | ||
| 104 | ok1 = _run_ffmpeg([ | ||
| 105 | "ffmpeg", "-y", "-i", str(src), | ||
| 106 | "-ar", "22050", "-ac", "1", | ||
| 107 | "-codec:a", "libmp3lame", "-b:a", af.split("b:a=")[1], | ||
| 108 | str(tmp_mp3), | ||
| 109 | ]) | ||
| 110 | if not ok1: | ||
| 111 | return False | ||
| 112 | ok2 = _run_ffmpeg([ | ||
| 113 | "ffmpeg", "-y", "-i", str(tmp_mp3), | ||
| 114 | "-ar", "22050", "-ac", "1", | ||
| 115 | str(dst), | ||
| 116 | ]) | ||
| 117 | tmp_mp3.unlink(missing_ok=True) | ||
| 118 | return ok2 | ||
| 119 | |||
| 120 | cmd = [ | ||
| 121 | "ffmpeg", "-y", "-i", str(src), | ||
| 122 | "-af", af, | ||
| 123 | "-ar", "22050", "-ac", "1", | ||
| 124 | str(dst), | ||
| 125 | ] | ||
| 126 | return _run_ffmpeg(cmd) | ||
| 127 | |||
| 128 | |||
| 129 | def _ffmpeg_trim(src: Path, dst: Path, start_ratio: float, duration_ratio: float) -> bool: | ||
| 130 | """按相对位置截取片段。需要先探测时长。""" | ||
| 131 | duration = _probe_duration(src) | ||
| 132 | if duration is None: | ||
| 133 | return False | ||
| 134 | ss = duration * start_ratio | ||
| 135 | t = duration * duration_ratio | ||
| 136 | return _run_ffmpeg([ | ||
| 137 | "ffmpeg", "-y", "-i", str(src), | ||
| 138 | "-ss", f"{ss:.3f}", "-t", f"{t:.3f}", | ||
| 139 | "-ar", "22050", "-ac", "1", | ||
| 140 | str(dst), | ||
| 141 | ]) | ||
| 142 | |||
| 143 | |||
| 144 | def _probe_duration(src: Path) -> float | None: | ||
| 145 | result = subprocess.run( | ||
| 146 | ["ffprobe", "-v", "error", "-show_entries", "format=duration", | ||
| 147 | "-of", "default=noprint_wrappers=1:nokey=1", str(src)], | ||
| 148 | capture_output=True, text=True, | ||
| 149 | ) | ||
| 150 | try: | ||
| 151 | return float(result.stdout.strip()) | ||
| 152 | except ValueError: | ||
| 153 | return None | ||
| 154 | |||
| 155 | |||
| 156 | def _run_ffmpeg(cmd: list[str]) -> bool: | ||
| 157 | result = subprocess.run(cmd, capture_output=True) | ||
| 158 | return result.returncode == 0 | ||
| 159 | |||
| 160 | |||
| 161 | def _song_id(path: Path) -> str: | ||
| 162 | return path.stem.split("_")[0] | ||
| 163 | |||
| 164 | |||
| 165 | def _discover_wavs(audio_dir: Path) -> list[Path]: | ||
| 166 | return [ | ||
| 167 | f for f in sorted(audio_dir.rglob("*.wav")) | ||
| 168 | if f.is_file() and not f.name.startswith("._") | ||
| 169 | ] | ||
| 170 | |||
| 171 | |||
| 172 | def main() -> None: | ||
| 173 | logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | ||
| 174 | |||
| 175 | parser = argparse.ArgumentParser() | ||
| 176 | parser.add_argument("--audio-dir", required=True) | ||
| 177 | parser.add_argument( | ||
| 178 | "--negative-audio-dir", | ||
| 179 | default="/Volumes/移动硬盘/lyric_audio_type11", | ||
| 180 | help="负样本来源目录;会排除 --audio-dir 中已存在的 song_id", | ||
| 181 | ) | ||
| 182 | parser.add_argument("--out-dir", required=True) | ||
| 183 | parser.add_argument("--num-songs", type=int, default=20, help="抽取歌曲数量") | ||
| 184 | parser.add_argument("--num-negative-songs", type=int, default=20, help="抽取未入库负样本歌曲数量") | ||
| 185 | parser.add_argument( | ||
| 186 | "--negative-variants", | ||
| 187 | action="store_true", | ||
| 188 | help="为负样本额外生成 codec/lowpass 变体", | ||
| 189 | ) | ||
| 190 | parser.add_argument("--seed", type=int, default=42) | ||
| 191 | args = parser.parse_args() | ||
| 192 | |||
| 193 | audio_dir = Path(args.audio_dir) | ||
| 194 | negative_audio_dir = Path(args.negative_audio_dir) | ||
| 195 | out_dir = Path(args.out_dir) | ||
| 196 | out_dir.mkdir(parents=True, exist_ok=True) | ||
| 197 | variants_dir = out_dir / "variants" | ||
| 198 | variants_dir.mkdir(exist_ok=True) | ||
| 199 | |||
| 200 | all_wavs = _discover_wavs(audio_dir) | ||
| 201 | negative_wavs = _discover_wavs(negative_audio_dir) | ||
| 202 | |||
| 203 | if len(negative_wavs) < args.num_negative_songs: | ||
| 204 | logger.error( | ||
| 205 | "负样本目录下只有 %d 个 wav,少于 --num-negative-songs = %d", | ||
| 206 | len(negative_wavs), | ||
| 207 | args.num_negative_songs, | ||
| 208 | ) | ||
| 209 | sys.exit(1) | ||
| 210 | |||
| 211 | # 从参照目录中排除负样本目录已有的 song_id,避免参照曲与负样本重叠 | ||
| 212 | negative_song_ids = {_song_id(wav) for wav in negative_wavs} | ||
| 213 | all_wavs = [wav for wav in all_wavs if _song_id(wav) not in negative_song_ids] | ||
| 214 | |||
| 215 | if len(all_wavs) < args.num_songs: | ||
| 216 | logger.error( | ||
| 217 | "参照目录排除负样本 song_id 后只有 %d 个 wav,少于 --num-songs = %d", | ||
| 218 | len(all_wavs), | ||
| 219 | args.num_songs, | ||
| 220 | ) | ||
| 221 | sys.exit(1) | ||
| 222 | |||
| 223 | random.seed(args.seed) | ||
| 224 | selected = random.sample(all_wavs, args.num_songs) | ||
| 225 | negative_selected = random.sample(negative_wavs, args.num_negative_songs) | ||
| 226 | logger.info( | ||
| 227 | "已抽取 %d 首参照歌,%d 首未入库负样本歌(负样本来源: %s,已排除 %d 个负样本 song_id)", | ||
| 228 | len(selected), | ||
| 229 | len(negative_selected), | ||
| 230 | negative_audio_dir, | ||
| 231 | len(negative_song_ids), | ||
| 232 | ) | ||
| 233 | |||
| 234 | ref_rows = [] | ||
| 235 | query_rows = [] | ||
| 236 | |||
| 237 | for wav in _tqdm(selected, desc="生成正样本变体", total=len(selected)): | ||
| 238 | song_id = _song_id(wav) | ||
| 239 | |||
| 240 | ref_rows.append({ | ||
| 241 | "song_id": song_id, | ||
| 242 | "audio_path": str(wav), | ||
| 243 | "variant": "original", | ||
| 244 | }) | ||
| 245 | |||
| 246 | # 第一类:DSP 变换 | ||
| 247 | for variant_name, af in DSP_VARIANTS: | ||
| 248 | dst = variants_dir / f"{song_id}_{variant_name}.wav" | ||
| 249 | ok = _ffmpeg_variant(wav, dst, af) | ||
| 250 | if not ok: | ||
| 251 | logger.warning("DSP 变换失败,跳过: %s %s", wav.name, variant_name) | ||
| 252 | continue | ||
| 253 | query_rows.append({ | ||
| 254 | "song_id": song_id, | ||
| 255 | "audio_path": str(dst), | ||
| 256 | "variant": variant_name, | ||
| 257 | "sample_class": "dsp", | ||
| 258 | "expected_song_id": song_id, | ||
| 259 | "expected": "duplicate", | ||
| 260 | }) | ||
| 261 | |||
| 262 | # 第三类:困难正样本 | ||
| 263 | for variant_name, af in HARD_POSITIVE_VARIANTS: | ||
| 264 | dst = variants_dir / f"{song_id}_{variant_name}.wav" | ||
| 265 | if variant_name == "trim_intro": | ||
| 266 | ok = _ffmpeg_trim(wav, dst, start_ratio=0.20, duration_ratio=0.80) | ||
| 267 | elif variant_name == "chorus_only": | ||
| 268 | ok = _ffmpeg_trim(wav, dst, start_ratio=0.30, duration_ratio=0.40) | ||
| 269 | else: | ||
| 270 | ok = _ffmpeg_variant(wav, dst, af) | ||
| 271 | if not ok: | ||
| 272 | logger.warning("困难正样本生成失败,跳过: %s %s", wav.name, variant_name) | ||
| 273 | continue | ||
| 274 | query_rows.append({ | ||
| 275 | "song_id": song_id, | ||
| 276 | "audio_path": str(dst), | ||
| 277 | "variant": variant_name, | ||
| 278 | "sample_class": "hard_positive", | ||
| 279 | "expected_song_id": song_id, | ||
| 280 | "expected": "duplicate", | ||
| 281 | }) | ||
| 282 | |||
| 283 | # Boolean 接口负样本:查询音频不能在 reference.csv 入库集合中。 | ||
| 284 | # expected_song_id 留空,表示没有目标重复曲;评测只看最终 duplicate true/false。 | ||
| 285 | for wav in _tqdm(negative_selected, desc="生成负样本变体", total=len(negative_selected)): | ||
| 286 | song_id = _song_id(wav) | ||
| 287 | query_rows.append({ | ||
| 288 | "song_id": song_id, | ||
| 289 | "audio_path": str(wav), | ||
| 290 | "variant": "negative_original", | ||
| 291 | "sample_class": "negative", | ||
| 292 | "expected_song_id": "", | ||
| 293 | "expected": "not_duplicate", | ||
| 294 | }) | ||
| 295 | |||
| 296 | if not args.negative_variants: | ||
| 297 | continue | ||
| 298 | |||
| 299 | for variant_name, af in NEGATIVE_VARIANTS: | ||
| 300 | dst = variants_dir / f"{song_id}_{variant_name}.wav" | ||
| 301 | ok = _ffmpeg_variant(wav, dst, af) | ||
| 302 | if not ok: | ||
| 303 | logger.warning("负样本变换失败,跳过: %s %s", wav.name, variant_name) | ||
| 304 | continue | ||
| 305 | query_rows.append({ | ||
| 306 | "song_id": song_id, | ||
| 307 | "audio_path": str(dst), | ||
| 308 | "variant": variant_name, | ||
| 309 | "sample_class": "negative", | ||
| 310 | "expected_song_id": "", | ||
| 311 | "expected": "not_duplicate", | ||
| 312 | }) | ||
| 313 | |||
| 314 | ref_path = out_dir / "reference.csv" | ||
| 315 | query_path = out_dir / "queries.csv" | ||
| 316 | |||
| 317 | fieldnames = ["song_id", "audio_path", "variant", "sample_class", "expected_song_id", "expected"] | ||
| 318 | with ref_path.open("w", newline="", encoding="utf-8") as f: | ||
| 319 | writer = csv.DictWriter(f, fieldnames=["song_id", "audio_path", "variant"]) | ||
| 320 | writer.writeheader() | ||
| 321 | writer.writerows(ref_rows) | ||
| 322 | |||
| 323 | with query_path.open("w", newline="", encoding="utf-8") as f: | ||
| 324 | writer = csv.DictWriter(f, fieldnames=fieldnames) | ||
| 325 | writer.writeheader() | ||
| 326 | writer.writerows(query_rows) | ||
| 327 | |||
| 328 | pos = sum(1 for r in query_rows if r["expected"] == "duplicate") | ||
| 329 | neg = sum(1 for r in query_rows if r["expected"] == "not_duplicate") | ||
| 330 | logger.info("参照集: %s (%d 条)", ref_path, len(ref_rows)) | ||
| 331 | logger.info("查询集: %s (%d 条,正样本 %d,负样本 %d)", query_path, len(query_rows), pos, neg) | ||
| 332 | |||
| 333 | # 按 sample_class 统计 | ||
| 334 | from collections import Counter | ||
| 335 | by_class = Counter(r["sample_class"] for r in query_rows) | ||
| 336 | for cls, cnt in sorted(by_class.items()): | ||
| 337 | logger.info(" %-20s %d 条", cls, cnt) | ||
| 338 | |||
| 339 | |||
| 340 | if __name__ == "__main__": | ||
| 341 | main() |
scripts/import_audio_composition.py
0 → 100644
| 1 | """批量导入音频文件到 composition_feature 表。 | ||
| 2 | |||
| 3 | 用法: | ||
| 4 | python scripts/import_audio_composition.py \ | ||
| 5 | --dsn "postgresql:///lyric_dedup" \ | ||
| 6 | --audio-dir /Volumes/移动硬盘/composition_test \ | ||
| 7 | --ext .wav | ||
| 8 | |||
| 9 | 支持通过 --file-list 指定一个包含音频路径的文本文件(每行一个路径)。 | ||
| 10 | """ | ||
| 11 | |||
| 12 | import argparse | ||
| 13 | import logging | ||
| 14 | import sys | ||
| 15 | from pathlib import Path | ||
| 16 | |||
| 17 | sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) | ||
| 18 | |||
| 19 | from tqdm import tqdm | ||
| 20 | |||
| 21 | from composition_dedup.service import CompositionConfig, CompositionDedupService | ||
| 22 | |||
| 23 | logger = logging.getLogger(__name__) | ||
| 24 | |||
| 25 | SUPPORTED_EXTENSIONS = {".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac", ".wma"} | ||
| 26 | |||
| 27 | |||
| 28 | def discover_audio_files(audio_dir: str | None, file_list: str | None, ext: str) -> list[tuple[str, str]]: | ||
| 29 | """发现音频文件,返回 [(song_id, audio_path), ...] 列表。 | ||
| 30 | |||
| 31 | 优先使用 --file-list,否则扫描 --audio-dir 目录。 | ||
| 32 | song_id 使用文件名的数字部分或路径的哈希值。 | ||
| 33 | """ | ||
| 34 | results = [] | ||
| 35 | |||
| 36 | if file_list: | ||
| 37 | with open(file_list, "r", encoding="utf-8") as f: | ||
| 38 | for line in f: | ||
| 39 | path = line.strip() | ||
| 40 | if not path: | ||
| 41 | continue | ||
| 42 | song_id = _extract_song_id(path) | ||
| 43 | results.append((song_id, path)) | ||
| 44 | elif audio_dir: | ||
| 45 | audio_dir_path = Path(audio_dir) | ||
| 46 | for audio_file in sorted(audio_dir_path.rglob(f"*{ext}")): | ||
| 47 | if audio_file.is_file() and not audio_file.name.startswith("._"): | ||
| 48 | song_id = _extract_song_id(str(audio_file)) | ||
| 49 | results.append((song_id, str(audio_file))) | ||
| 50 | else: | ||
| 51 | print("错误: 请指定 --audio-dir 或 --file-list") | ||
| 52 | sys.exit(1) | ||
| 53 | |||
| 54 | return results | ||
| 55 | |||
| 56 | |||
| 57 | def _extract_song_id(path: str) -> str: | ||
| 58 | """从路径中提取 song_id。 | ||
| 59 | 优先取文件名第一段(下划线前),若为纯数字则使用,否则用路径哈希。 | ||
| 60 | """ | ||
| 61 | name = Path(path).stem | ||
| 62 | prefix = name.split("_")[0] | ||
| 63 | if prefix.isdigit(): | ||
| 64 | return prefix | ||
| 65 | import hashlib | ||
| 66 | return str(int(hashlib.md5(path.encode()).hexdigest()[:8], 16)) | ||
| 67 | |||
| 68 | |||
| 69 | def main() -> None: | ||
| 70 | logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | ||
| 71 | |||
| 72 | parser = argparse.ArgumentParser(description="批量导入音频文件到 composition_feature 表") | ||
| 73 | parser.add_argument("--dsn", required=True, help="PostgreSQL DSN 连接串") | ||
| 74 | parser.add_argument("--audio-dir", help="音频文件目录") | ||
| 75 | parser.add_argument("--file-list", help="音频文件路径列表文件") | ||
| 76 | parser.add_argument("--ext", default=".wav", help="音频文件扩展名(默认 .wav)") | ||
| 77 | parser.add_argument("--batch-size", type=int, default=10, help="批次大小(默认 10)") | ||
| 78 | parser.add_argument("--clear", action="store_true", help="导入前清空 composition_feature 和 dejavu_fingerprints 表数据(保留表结构)") | ||
| 79 | args = parser.parse_args() | ||
| 80 | |||
| 81 | config = CompositionConfig(dsn=args.dsn) | ||
| 82 | service = CompositionDedupService(config=config) | ||
| 83 | |||
| 84 | if args.clear: | ||
| 85 | import psycopg | ||
| 86 | with psycopg.connect(args.dsn) as conn: | ||
| 87 | with conn.cursor() as cur: | ||
| 88 | cur.execute("TRUNCATE TABLE composition_feature, dejavu_fingerprints") | ||
| 89 | conn.commit() | ||
| 90 | logger.info("已清空 composition_feature 和 dejavu_fingerprints 表") | ||
| 91 | |||
| 92 | audio_files = discover_audio_files(args.audio_dir, args.file_list, args.ext) | ||
| 93 | logger.info("发现 %d 个音频文件", len(audio_files)) | ||
| 94 | |||
| 95 | success_count = 0 | ||
| 96 | fail_count = 0 | ||
| 97 | |||
| 98 | for start in tqdm(range(0, len(audio_files), args.batch_size), desc="导入进度"): | ||
| 99 | batch = audio_files[start:start + args.batch_size] | ||
| 100 | for song_id, audio_path in batch: | ||
| 101 | try: | ||
| 102 | service.ingest(song_id=int(song_id), audio_path=audio_path) | ||
| 103 | success_count += 1 | ||
| 104 | except Exception as e: | ||
| 105 | logger.error("导入失败: song_id=%s, path=%s, error=%s", song_id, audio_path, e) | ||
| 106 | fail_count += 1 | ||
| 107 | |||
| 108 | logger.info("导入完成: 成功 %d, 失败 %d", success_count, fail_count) | ||
| 109 | |||
| 110 | |||
| 111 | if __name__ == "__main__": | ||
| 112 | main() |
| ... | @@ -40,3 +40,33 @@ on lyric_lines (line_hash); | ... | @@ -40,3 +40,33 @@ on lyric_lines (line_hash); |
| 40 | 40 | ||
| 41 | create index if not exists lyric_lines_lyric_id_idx | 41 | create index if not exists lyric_lines_lyric_id_idx |
| 42 | on lyric_lines (lyric_id); | 42 | on lyric_lines (lyric_id); |
| 43 | |||
| 44 | create extension if not exists vector; | ||
| 45 | |||
| 46 | create table if not exists composition_feature ( | ||
| 47 | id bigserial primary key, | ||
| 48 | song_id bigint not null unique, | ||
| 49 | feature_vector vector(1536) not null, | ||
| 50 | created_at timestamptz not null default now() | ||
| 51 | ); | ||
| 52 | |||
| 53 | create index if not exists composition_feature_hnsw_idx | ||
| 54 | on composition_feature | ||
| 55 | using hnsw (feature_vector vector_cosine_ops) | ||
| 56 | with (m = 16, ef_construction = 64); | ||
| 57 | |||
| 58 | create table if not exists dejavu_fingerprints ( | ||
| 59 | id bigserial primary key, | ||
| 60 | song_id bigint not null references composition_feature(song_id) on delete cascade, | ||
| 61 | hash bytea not null, | ||
| 62 | "offset" int not null | ||
| 63 | ); | ||
| 64 | |||
| 65 | create index if not exists idx_fingerprints_hash | ||
| 66 | on dejavu_fingerprints (hash); | ||
| 67 | |||
| 68 | create index if not exists idx_fingerprints_hash_song_offset | ||
| 69 | on dejavu_fingerprints (hash, song_id, "offset"); | ||
| 70 | |||
| 71 | create index if not exists idx_fingerprints_song_id | ||
| 72 | on dejavu_fingerprints (song_id); | ... | ... |
tests/test_composition_dedup.py
0 → 100644
| 1 | """作曲去重模块测试。 | ||
| 2 | |||
| 3 | 测试覆盖: | ||
| 4 | - Chromagram 提取 | ||
| 5 | - 时间归一化输出维度 | ||
| 6 | - Cosine 相似度计算 | ||
| 7 | - 向量展开维度为 1536 | ||
| 8 | """ | ||
| 9 | |||
| 10 | import os | ||
| 11 | import tempfile | ||
| 12 | import wave | ||
| 13 | |||
| 14 | import numpy as np | ||
| 15 | import pytest | ||
| 16 | from scipy.signal import resample | ||
| 17 | |||
| 18 | from composition_dedup.extractor import extract_chroma_feature, _normalize_audio_ffmpeg | ||
| 19 | from composition_dedup.similarity import ( | ||
| 20 | CompositionSimilarity, | ||
| 21 | SimilarityDecision, | ||
| 22 | DUPLICATE_THRESHOLD, | ||
| 23 | SUSPECTED_THRESHOLD, | ||
| 24 | ) | ||
| 25 | |||
| 26 | |||
| 27 | def _generate_test_wav(duration_sec: float = 1.0, sample_rate: int = 22050, frequency: float = 440.0) -> str: | ||
| 28 | """生成测试用的 WAV 文件(正弦波)。 | ||
| 29 | |||
| 30 | Args: | ||
| 31 | duration_sec: 持续时间(秒)。 | ||
| 32 | sample_rate: 采样率。 | ||
| 33 | frequency: 频率(Hz)。 | ||
| 34 | |||
| 35 | Returns: | ||
| 36 | 临时 WAV 文件路径。 | ||
| 37 | """ | ||
| 38 | t = np.linspace(0, duration_sec, int(sample_rate * duration_sec), endpoint=False) | ||
| 39 | audio_data = (0.5 * np.sin(2 * np.pi * frequency * t)).astype(np.float32) | ||
| 40 | |||
| 41 | tmp_path = tempfile.mktemp(suffix=".wav") | ||
| 42 | with wave.open(tmp_path, "wb") as wf: | ||
| 43 | wf.setnchannels(1) | ||
| 44 | wf.setsampwidth(2) # 16-bit | ||
| 45 | wf.setframerate(sample_rate) | ||
| 46 | wf.writeframes((audio_data * 32767).astype(np.int16).tobytes()) | ||
| 47 | |||
| 48 | return tmp_path | ||
| 49 | |||
| 50 | |||
| 51 | class TestChromaExtraction: | ||
| 52 | """Chromagram 提取测试。""" | ||
| 53 | |||
| 54 | def test_extract_chroma_returns_1536_dim(self): | ||
| 55 | """测试 Chromagram 提取返回 1536 维向量。""" | ||
| 56 | wav_path = _generate_test_wav(duration_sec=2.0, frequency=440.0) | ||
| 57 | try: | ||
| 58 | feature = extract_chroma_feature(wav_path) | ||
| 59 | assert isinstance(feature, np.ndarray) | ||
| 60 | assert feature.shape == (1536,), f"期望 (1536,), 实际 {feature.shape}" | ||
| 61 | assert feature.dtype == np.float32 | ||
| 62 | finally: | ||
| 63 | if os.path.exists(wav_path): | ||
| 64 | os.remove(wav_path) | ||
| 65 | |||
| 66 | def test_extract_chroma_file_not_found(self): | ||
| 67 | """测试不存在的音频文件抛出 FileNotFoundError。""" | ||
| 68 | with pytest.raises(FileNotFoundError): | ||
| 69 | extract_chroma_feature("/nonexistent/path/audio.mp3") | ||
| 70 | |||
| 71 | def test_extract_chroma_different_frequencies(self): | ||
| 72 | """测试不同频率的音频产生不同特征。""" | ||
| 73 | wav_a = _generate_test_wav(duration_sec=2.0, frequency=440.0) | ||
| 74 | wav_b = _generate_test_wav(duration_sec=2.0, frequency=880.0) | ||
| 75 | try: | ||
| 76 | feature_a = extract_chroma_feature(wav_a) | ||
| 77 | feature_b = extract_chroma_feature(wav_b) | ||
| 78 | # 不同频率的音频特征不应完全相同 | ||
| 79 | assert not np.allclose(feature_a, feature_b) | ||
| 80 | finally: | ||
| 81 | for path in [wav_a, wav_b]: | ||
| 82 | if os.path.exists(path): | ||
| 83 | os.remove(path) | ||
| 84 | |||
| 85 | def test_extract_chroma_same_audio_consistent(self): | ||
| 86 | """测试同一音频多次提取结果一致。""" | ||
| 87 | wav_path = _generate_test_wav(duration_sec=1.0, frequency=440.0) | ||
| 88 | try: | ||
| 89 | feature_1 = extract_chroma_feature(wav_path) | ||
| 90 | feature_2 = extract_chroma_feature(wav_path) | ||
| 91 | np.testing.assert_array_almost_equal(feature_1, feature_2, decimal=5) | ||
| 92 | finally: | ||
| 93 | if os.path.exists(wav_path): | ||
| 94 | os.remove(wav_path) | ||
| 95 | |||
| 96 | |||
| 97 | class TestTimeNormalization: | ||
| 98 | """时间归一化测试。""" | ||
| 99 | |||
| 100 | def test_resample_chroma_to_128_frames(self): | ||
| 101 | """测试 Chromagram 时间归一化到 128 帧。""" | ||
| 102 | # 模拟不同长度的 Chromagram | ||
| 103 | for num_frames in [100, 256, 512, 1000, 2000]: | ||
| 104 | chroma = np.random.rand(12, num_frames).astype(np.float32) | ||
| 105 | if chroma.shape[1] != 128: | ||
| 106 | chroma = resample(chroma, 128, axis=1) | ||
| 107 | assert chroma.shape == (12, 128), f"帧数归一化失败: {chroma.shape}" | ||
| 108 | |||
| 109 | def test_flatten_to_1536(self): | ||
| 110 | """测试展平后维度为 1536。""" | ||
| 111 | chroma = np.random.rand(12, 128).astype(np.float32) | ||
| 112 | feature = chroma.flatten() | ||
| 113 | assert feature.shape[0] == 12 * 128 == 1536 | ||
| 114 | |||
| 115 | |||
| 116 | class TestCosineSimilarity: | ||
| 117 | """Cosine 相似度计算测试。""" | ||
| 118 | |||
| 119 | def test_identical_vectors(self): | ||
| 120 | """测试相同向量相似度为 1。""" | ||
| 121 | vec = np.random.rand(1536).astype(np.float32) | ||
| 122 | sim = CompositionSimilarity.cosine_similarity(vec, vec) | ||
| 123 | assert abs(sim - 1.0) < 1e-6 | ||
| 124 | |||
| 125 | def test_orthogonal_vectors(self): | ||
| 126 | """测试正交向量相似度接近 0。""" | ||
| 127 | vec_a = np.zeros(1536) | ||
| 128 | vec_a[0] = 1.0 | ||
| 129 | vec_b = np.zeros(1536) | ||
| 130 | vec_b[1] = 1.0 | ||
| 131 | sim = CompositionSimilarity.cosine_similarity(vec_a, vec_b) | ||
| 132 | assert abs(sim) < 1e-6 | ||
| 133 | |||
| 134 | def test_zero_vector(self): | ||
| 135 | """测试零向量返回 0 相似度。""" | ||
| 136 | vec_a = np.random.rand(1536).astype(np.float32) | ||
| 137 | vec_b = np.zeros(1536) | ||
| 138 | sim = CompositionSimilarity.cosine_similarity(vec_a, vec_b) | ||
| 139 | assert sim == 0.0 | ||
| 140 | |||
| 141 | def test_similarity_range(self): | ||
| 142 | """测试相似度值在 [0, 1] 范围内。""" | ||
| 143 | vec_a = np.random.rand(1536).astype(np.float32) | ||
| 144 | vec_b = np.random.rand(1536).astype(np.float32) | ||
| 145 | sim = CompositionSimilarity.cosine_similarity(vec_a, vec_b) | ||
| 146 | assert 0.0 <= sim <= 1.0 | ||
| 147 | |||
| 148 | def test_classify_duplicate(self): | ||
| 149 | """测试重复判定。""" | ||
| 150 | assert CompositionSimilarity.classify_similarity(0.96) == SimilarityDecision.DUPLICATE | ||
| 151 | assert CompositionSimilarity.classify_similarity(0.95) == SimilarityDecision.DUPLICATE | ||
| 152 | |||
| 153 | def test_classify_suspected(self): | ||
| 154 | """测试疑似判定。""" | ||
| 155 | assert CompositionSimilarity.classify_similarity(0.94) == SimilarityDecision.SUSPECTED | ||
| 156 | assert CompositionSimilarity.classify_similarity(0.85) == SimilarityDecision.SUSPECTED | ||
| 157 | |||
| 158 | def test_classify_new(self): | ||
| 159 | """测试非重复判定。""" | ||
| 160 | assert CompositionSimilarity.classify_similarity(0.84) == SimilarityDecision.NEW | ||
| 161 | assert CompositionSimilarity.classify_similarity(0.5) == SimilarityDecision.NEW | ||
| 162 | |||
| 163 | def test_compare_method(self): | ||
| 164 | """测试 compare 方法同时返回相似度和判定。""" | ||
| 165 | vec = np.random.rand(1536).astype(np.float32) | ||
| 166 | sim, decision = CompositionSimilarity.compare(vec, vec) | ||
| 167 | assert abs(sim - 1.0) < 1e-6 | ||
| 168 | assert decision == SimilarityDecision.DUPLICATE | ||
| 169 | |||
| 170 | |||
| 171 | class TestThresholds: | ||
| 172 | """阈值常量测试。""" | ||
| 173 | |||
| 174 | def test_threshold_order(self): | ||
| 175 | """测试阈值顺序正确。""" | ||
| 176 | assert DUPLICATE_THRESHOLD > SUSPECTED_THRESHOLD | ||
| 177 | |||
| 178 | def test_threshold_values(self): | ||
| 179 | """测试阈值符合设计值。""" | ||
| 180 | assert DUPLICATE_THRESHOLD == 0.95 | ||
| 181 | assert SUSPECTED_THRESHOLD == 0.85 |
-
Please register or sign in to post a comment