优化性能
Showing
8 changed files
with
336 additions
and
151 deletions
| ... | @@ -26,4 +26,6 @@ venv/ | ... | @@ -26,4 +26,6 @@ venv/ |
| 26 | 26 | ||
| 27 | test_api | 27 | test_api |
| 28 | 28 | ||
| 29 | composition_dedup/composition_eval | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
| 29 | composition_dedup/composition_eval | ||
| 30 | |||
| 31 | composition_testset | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
| ... | @@ -22,7 +22,6 @@ from pathlib import Path | ... | @@ -22,7 +22,6 @@ from pathlib import Path |
| 22 | import librosa | 22 | import librosa |
| 23 | import numpy as np | 23 | import numpy as np |
| 24 | from scipy.ndimage import ( | 24 | from scipy.ndimage import ( |
| 25 | binary_erosion, | ||
| 26 | generate_binary_structure, | 25 | generate_binary_structure, |
| 27 | iterate_structure, | 26 | iterate_structure, |
| 28 | maximum_filter, | 27 | maximum_filter, |
| ... | @@ -60,10 +59,10 @@ MIN_HASH_TIME_DELTA = 0 | ... | @@ -60,10 +59,10 @@ MIN_HASH_TIME_DELTA = 0 |
| 60 | MAX_HASH_TIME_DELTA = 200 | 59 | MAX_HASH_TIME_DELTA = 200 |
| 61 | PEAK_SORT = True | 60 | PEAK_SORT = True |
| 62 | FINGERPRINT_REDUCTION = 20 | 61 | FINGERPRINT_REDUCTION = 20 |
| 63 | MAX_DURATION_SEC = float(os.environ.get("COMPOSITION_DEJAVU_MAX_DURATION", "120")) # 0=不限制 | 62 | QUERY_MAX_DURATION_SEC = float(os.environ.get("COMPOSITION_DEJAVU_QUERY_MAX_DURATION", "120")) # 0=不限制 |
| 64 | 63 | ||
| 65 | 64 | ||
| 66 | def _normalize_audio(audio_path: str, max_duration: float = MAX_DURATION_SEC) -> tuple[np.ndarray, int]: | 65 | def _normalize_audio(audio_path: str, max_duration: float = 0) -> tuple[np.ndarray, int]: |
| 67 | """将音频标准化为单声道 WAV 并加载为 numpy 数组。 | 66 | """将音频标准化为单声道 WAV 并加载为 numpy 数组。 |
| 68 | 67 | ||
| 69 | 使用 ffmpeg 先做重采样,再用 librosa 读取。 | 68 | 使用 ffmpeg 先做重采样,再用 librosa 读取。 |
| ... | @@ -133,12 +132,7 @@ def _get_2d_peaks(arr2D: np.ndarray, amp_min: float = DEFAULT_AMP_MIN): | ... | @@ -133,12 +132,7 @@ def _get_2d_peaks(arr2D: np.ndarray, amp_min: float = DEFAULT_AMP_MIN): |
| 133 | neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE) | 132 | neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE) |
| 134 | 133 | ||
| 135 | # 找局部极大值 | 134 | # 找局部极大值 |
| 136 | local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D | 135 | detected_peaks = maximum_filter(arr2D, footprint=neighborhood) == arr2D |
| 137 | background = arr2D == 0 | ||
| 138 | eroded_background = binary_erosion(background, structure=neighborhood, border_value=1) | ||
| 139 | |||
| 140 | # 布尔掩码 | ||
| 141 | detected_peaks = local_max ^ eroded_background | ||
| 142 | 136 | ||
| 143 | # 提取峰值 | 137 | # 提取峰值 |
| 144 | amps = arr2D[detected_peaks] | 138 | amps = arr2D[detected_peaks] |
| ... | @@ -181,6 +175,43 @@ def _generate_hashes(peaks: list[tuple[int, int]], fan_value: int = DEFAULT_FAN_ | ... | @@ -181,6 +175,43 @@ def _generate_hashes(peaks: list[tuple[int, int]], fan_value: int = DEFAULT_FAN_ |
| 181 | yield (h.hexdigest()[:FINGERPRINT_REDUCTION].encode(), t1) | 175 | yield (h.hexdigest()[:FINGERPRINT_REDUCTION].encode(), t1) |
| 182 | 176 | ||
| 183 | 177 | ||
| 178 | def load_audio(audio_path: str, max_duration: float = 0) -> tuple[np.ndarray, int]: | ||
| 179 | """加载并标准化音频为 44100Hz 单声道(供多路径共用,避免重复解码)。 | ||
| 180 | |||
| 181 | Args: | ||
| 182 | audio_path: 音频文件路径。 | ||
| 183 | max_duration: 最大截取时长(秒),0 表示不限制。 | ||
| 184 | |||
| 185 | Returns: | ||
| 186 | (samples, sr) 元组。 | ||
| 187 | """ | ||
| 188 | return _normalize_audio(audio_path, max_duration) | ||
| 189 | |||
| 190 | |||
| 191 | def fingerprint_from_samples( | ||
| 192 | samples: np.ndarray, sr: int, *, compute_sha1: bool = True | ||
| 193 | ) -> tuple[str, list[tuple[bytes, int]]]: | ||
| 194 | """对已加载的音频样本生成 Dejavu 风格指纹(不做 I/O)。 | ||
| 195 | |||
| 196 | Args: | ||
| 197 | samples: 单声道音频样本(应为 DEFAULT_FS=44100Hz)。 | ||
| 198 | sr: 采样率。 | ||
| 199 | compute_sha1: 是否计算 file_sha1。service 内部调用时传 False 可跳过 | ||
| 200 | 对 samples.tobytes() 的 21MB 哈希运算(返回值在那些路径中未被使用)。 | ||
| 201 | |||
| 202 | Returns: | ||
| 203 | (file_sha1, fingerprints) 元组, | ||
| 204 | 其中 fingerprints 是 [(hash_bytes, offset), ...] 列表。 | ||
| 205 | compute_sha1=False 时 file_sha1 返回空字符串。 | ||
| 206 | """ | ||
| 207 | file_sha1 = hashlib.sha1(samples.tobytes()).hexdigest()[:16] if compute_sha1 else "" | ||
| 208 | arr2D = _specgram(samples, sr, DEFAULT_WINDOW_SIZE, DEFAULT_OVERLAP_RATIO) | ||
| 209 | freq_idx, time_idx = _get_2d_peaks(arr2D) | ||
| 210 | peaks = list(zip(freq_idx, time_idx)) | ||
| 211 | fingerprints = list(_generate_hashes(peaks)) | ||
| 212 | return file_sha1, fingerprints | ||
| 213 | |||
| 214 | |||
| 184 | def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]: | 215 | def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]: |
| 185 | """对音频文件生成 Dejavu 风格指纹。 | 216 | """对音频文件生成 Dejavu 风格指纹。 |
| 186 | 217 | ||
| ... | @@ -198,21 +229,7 @@ def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]: | ... | @@ -198,21 +229,7 @@ def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]: |
| 198 | if not os.path.isfile(audio_path): | 229 | if not os.path.isfile(audio_path): |
| 199 | raise FileNotFoundError(f"音频文件不存在: {audio_path}") | 230 | raise FileNotFoundError(f"音频文件不存在: {audio_path}") |
| 200 | 231 | ||
| 201 | # 1. 标准化并加载音频(可选限制长度) | ||
| 202 | samples, fs = _normalize_audio(audio_path) | 232 | samples, fs = _normalize_audio(audio_path) |
| 203 | 233 | file_sha1, fingerprints = fingerprint_from_samples(samples, fs) | |
| 204 | # 2. 计算文件 SHA1(用于标识) | ||
| 205 | file_sha1 = hashlib.sha1(samples.tobytes()).hexdigest()[:16] | ||
| 206 | |||
| 207 | # 3. 计算频谱图 | ||
| 208 | arr2D = _specgram(samples, fs, DEFAULT_WINDOW_SIZE, DEFAULT_OVERLAP_RATIO) | ||
| 209 | |||
| 210 | # 4. 检测 2D 峰值 | ||
| 211 | freq_idx, time_idx = _get_2d_peaks(arr2D) | ||
| 212 | peaks = list(zip(freq_idx, time_idx)) | ||
| 213 | |||
| 214 | # 5. 生成指纹哈希 | ||
| 215 | fingerprints = list(_generate_hashes(peaks)) | ||
| 216 | |||
| 217 | logger.info("指纹生成完成: audio=%s, 指纹数=%d", audio_path, len(fingerprints)) | 234 | logger.info("指纹生成完成: audio=%s, 指纹数=%d", audio_path, len(fingerprints)) |
| 218 | return file_sha1, fingerprints | 235 | return file_sha1, fingerprints | ... | ... |
| 1 | """Chromagram 特征提取。 | 1 | """Chromagram 特征提取。 |
| 2 | 2 | ||
| 3 | 流程: | 3 | 流程: |
| 4 | 1. 音频标准化:ffmpeg 转 22050Hz / Mono / WAV | 4 | 1. 音频解码:ffmpeg pipe 输出 22050Hz / Mono / f32le,直接读入内存,无临时文件 |
| 5 | 2. librosa 加载音频 | 5 | 2. librosa.feature.chroma_cens() 提取 12×T Chromagram(CENS,对速度/音色鲁棒) |
| 6 | 3. librosa.feature.chroma_cens() 提取 12×T Chromagram(CENS,对速度/音色鲁棒) | 6 | 3. 主音对齐:将能量最大的音级滚至第 0 行,实现转调不变性 |
| 7 | 4. 主音对齐:将能量最大的音级滚至第 0 行,实现转调不变性 | 7 | 4. scipy.signal.resample(chroma, 128, axis=1) 时间归一化到 12×128 |
| 8 | 5. scipy.signal.resample(chroma, 128, axis=1) 时间归一化到 12×128 | 8 | 5. .flatten() 展开为 1536 维向量 |
| 9 | 6. .flatten() 展开为 1536 维向量 | ||
| 10 | """ | 9 | """ |
| 11 | 10 | ||
| 12 | import logging | 11 | import logging |
| 13 | import os | 12 | import os |
| 14 | import subprocess | 13 | import subprocess |
| 15 | import tempfile | ||
| 16 | 14 | ||
| 17 | import librosa | 15 | import librosa |
| 18 | import numpy as np | 16 | import numpy as np |
| ... | @@ -26,83 +24,99 @@ TARGET_FRAMES = 128 | ... | @@ -26,83 +24,99 @@ TARGET_FRAMES = 128 |
| 26 | VECTOR_DIM = 12 * TARGET_FRAMES # 1536 | 24 | VECTOR_DIM = 12 * TARGET_FRAMES # 1536 |
| 27 | 25 | ||
| 28 | 26 | ||
| 29 | def _normalize_audio_ffmpeg(audio_path: str, output_path: str) -> None: | 27 | def _load_audio_via_pipe(audio_path: str) -> np.ndarray: |
| 30 | """使用 ffmpeg 将音频标准化为 22050Hz / Mono / WAV。""" | 28 | """使用 ffmpeg pipe 将音频解码为 22050Hz mono float32,不落临时文件到磁盘。""" |
| 31 | cmd = [ | 29 | cmd = [ |
| 32 | "ffmpeg", | 30 | "ffmpeg", "-y", |
| 33 | "-y", | ||
| 34 | "-i", audio_path, | 31 | "-i", audio_path, |
| 35 | "-ar", str(TARGET_SR), | 32 | "-ar", str(TARGET_SR), |
| 36 | "-ac", "1", | 33 | "-ac", "1", |
| 37 | "-f", "wav", | 34 | "-f", "f32le", |
| 38 | output_path, | 35 | "pipe:1", |
| 39 | ] | 36 | ] |
| 40 | result = subprocess.run( | 37 | result = subprocess.run(cmd, capture_output=True) |
| 41 | cmd, | ||
| 42 | capture_output=True, | ||
| 43 | text=True, | ||
| 44 | ) | ||
| 45 | if result.returncode != 0: | 38 | if result.returncode != 0: |
| 46 | raise RuntimeError(f"ffmpeg 转换失败: {result.stderr}") | 39 | raise RuntimeError(f"ffmpeg 解码失败: {result.stderr.decode(errors='replace')}") |
| 40 | return np.frombuffer(result.stdout, dtype=np.float32) | ||
| 47 | 41 | ||
| 48 | 42 | ||
| 49 | def extract_chroma_feature(audio_path: str) -> np.ndarray: | 43 | def extract_chroma_feature_from_samples( |
| 50 | """从音频文件提取 1536 维 Chromagram 特征向量。 | 44 | samples: np.ndarray, |
| 45 | sr: int, | ||
| 46 | hop_length: int = 512, | ||
| 47 | win_len_smooth: int = 41, | ||
| 48 | ) -> np.ndarray: | ||
| 49 | """从已加载的音频样本提取 1536 维 Chromagram 特征向量。 | ||
| 50 | |||
| 51 | 若 sr 不等于 TARGET_SR,先用 librosa.resample 在内存中降采样, | ||
| 52 | 避免重新走 ffmpeg 流程。 | ||
| 51 | 53 | ||
| 52 | Args: | 54 | Args: |
| 53 | audio_path: 音频文件路径。 | 55 | samples: 单声道音频样本(任意采样率)。 |
| 56 | sr: samples 对应的采样率。 | ||
| 57 | hop_length: CQT hop 大小,增大可成比例降低计算量,不影响最终 128 帧精度。 | ||
| 58 | win_len_smooth: CENS 平滑窗口帧数,应随 hop_length 等比缩小以保持相同的时间覆盖。 | ||
| 54 | 59 | ||
| 55 | Returns: | 60 | Returns: |
| 56 | shape 为 (1536,) 的 numpy 数组。 | 61 | shape 为 (1536,) 的 numpy 数组。 |
| 57 | |||
| 58 | Raises: | ||
| 59 | FileNotFoundError: 音频文件不存在。 | ||
| 60 | RuntimeError: ffmpeg 转换失败。 | ||
| 61 | """ | 62 | """ |
| 62 | if not os.path.isfile(audio_path): | 63 | y = samples if sr == TARGET_SR else librosa.resample(samples, orig_sr=sr, target_sr=TARGET_SR) |
| 63 | raise FileNotFoundError(f"音频文件不存在: {audio_path}") | ||
| 64 | 64 | ||
| 65 | # 1. 音频标准化:ffmpeg 转 WAV | 65 | # 提取 CENS Chromagram (12×T) |
| 66 | with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | 66 | chroma = librosa.feature.chroma_cens(y=y, sr=TARGET_SR, hop_length=hop_length, win_len_smooth=win_len_smooth) |
| 67 | tmp_wav = tmp.name | ||
| 68 | 67 | ||
| 69 | try: | 68 | # 主音对齐 |
| 70 | _normalize_audio_ffmpeg(audio_path, tmp_wav) | 69 | tonic = int(np.argmax(chroma.sum(axis=1))) |
| 70 | if tonic != 0: | ||
| 71 | chroma = np.roll(chroma, -tonic, axis=0) | ||
| 71 | 72 | ||
| 72 | # 2. librosa 加载音频 | 73 | # 时间归一化到 12×128 |
| 73 | y, _sr = librosa.load(tmp_wav, sr=TARGET_SR, mono=True) | 74 | if chroma.shape[1] != TARGET_FRAMES: |
| 75 | chroma = resample(chroma, TARGET_FRAMES, axis=1) | ||
| 76 | |||
| 77 | feature = chroma.flatten().astype(np.float32) | ||
| 78 | assert feature.shape == (VECTOR_DIM,), ( | ||
| 79 | f"特征维度错误: 期望 {VECTOR_DIM}, 实际 {feature.shape}" | ||
| 80 | ) | ||
| 81 | return feature | ||
| 74 | 82 | ||
| 75 | # 3. 提取 CENS Chromagram (12×T),对速度变化和音色具有更强鲁棒性 | ||
| 76 | chroma = librosa.feature.chroma_cens(y=y, sr=TARGET_SR) | ||
| 77 | 83 | ||
| 78 | # 4. 主音对齐:将全局能量最大的音级循环滚至第 0 行,实现转调不变性 | 84 | def extract_chroma_matrix_from_samples( |
| 79 | tonic = int(np.argmax(chroma.sum(axis=1))) | 85 | samples: np.ndarray, |
| 80 | if tonic != 0: | 86 | sr: int, |
| 81 | chroma = np.roll(chroma, -tonic, axis=0) | 87 | hop_length: int = 512, |
| 88 | win_len_smooth: int = 41, | ||
| 89 | ) -> np.ndarray: | ||
| 90 | """从已加载的音频样本提取 12×128 Chromagram 矩阵(供 DTW 精排使用)。""" | ||
| 91 | return extract_chroma_feature_from_samples(samples, sr, hop_length=hop_length, win_len_smooth=win_len_smooth).reshape(12, TARGET_FRAMES) | ||
| 82 | 92 | ||
| 83 | # 5. 时间归一化到 12×128 | ||
| 84 | if chroma.shape[1] != TARGET_FRAMES: | ||
| 85 | chroma = resample(chroma, TARGET_FRAMES, axis=1) | ||
| 86 | 93 | ||
| 87 | # 6. 展开为 1536 维向量 | 94 | def extract_chroma_feature(audio_path: str, hop_length: int = 512, win_len_smooth: int = 41) -> np.ndarray: |
| 88 | feature = chroma.flatten().astype(np.float32) | 95 | """从音频文件提取 1536 维 Chromagram 特征向量。 |
| 89 | 96 | ||
| 90 | assert feature.shape == (VECTOR_DIM,), ( | 97 | Args: |
| 91 | f"特征维度错误: 期望 {VECTOR_DIM}, 实际 {feature.shape}" | 98 | audio_path: 音频文件路径。 |
| 92 | ) | 99 | hop_length: CQT hop 大小。 |
| 100 | win_len_smooth: CENS 平滑窗口帧数。 | ||
| 101 | |||
| 102 | Returns: | ||
| 103 | shape 为 (1536,) 的 numpy 数组。 | ||
| 104 | |||
| 105 | Raises: | ||
| 106 | FileNotFoundError: 音频文件不存在。 | ||
| 107 | RuntimeError: ffmpeg 解码失败。 | ||
| 108 | """ | ||
| 109 | if not os.path.isfile(audio_path): | ||
| 110 | raise FileNotFoundError(f"音频文件不存在: {audio_path}") | ||
| 93 | 111 | ||
| 94 | return feature | 112 | y = _load_audio_via_pipe(audio_path) |
| 95 | finally: | 113 | return extract_chroma_feature_from_samples(y, TARGET_SR, hop_length=hop_length, win_len_smooth=win_len_smooth) |
| 96 | # 清理临时文件 | ||
| 97 | if os.path.exists(tmp_wav): | ||
| 98 | os.remove(tmp_wav) | ||
| 99 | 114 | ||
| 100 | 115 | ||
| 101 | def extract_chroma_matrix(audio_path: str) -> np.ndarray: | 116 | def extract_chroma_matrix(audio_path: str, hop_length: int = 512, win_len_smooth: int = 41) -> np.ndarray: |
| 102 | """从音频文件提取 12×128 Chromagram 矩阵(未展平,供 DTW 精排使用)。 | 117 | """从音频文件提取 12×128 Chromagram 矩阵(未展平,供 DTW 精排使用)。 |
| 103 | 118 | ||
| 104 | Returns: | 119 | Returns: |
| 105 | shape 为 (12, 128) 的 numpy 数组,已做主音对齐。 | 120 | shape 为 (12, 128) 的 numpy 数组,已做主音对齐。 |
| 106 | """ | 121 | """ |
| 107 | feature = extract_chroma_feature(audio_path) | 122 | return extract_chroma_feature(audio_path, hop_length=hop_length, win_len_smooth=win_len_smooth).reshape(12, TARGET_FRAMES) |
| 108 | return feature.reshape(12, TARGET_FRAMES) | ... | ... |
| 1 | """作曲去重服务(入库 + 查询)。 | 1 | """作曲去重服务(入库 + 查询)。 |
| 2 | 2 | ||
| 3 | 查询流程: | 3 | 查询流程: |
| 4 | 1. Dejavu 指纹匹配(毫秒级,子序列匹配,支持 chorus_only / trim_intro) | 4 | 1. Chromagram 12路 + DTW 精排(覆盖绝大多数场景,约 1s) |
| 5 | - 命中(≥ 阈值)→ 直接返回 duplicate(短路) | 5 | - 命中(≥ 阈值)→ 直接返回结果 |
| 6 | 2. 未命中 → Chromagram 12路 + DTW(百毫秒级) | 6 | 2. Chromagram 未命中 → Dejavu 指纹兜底(处理 chorus_only / trim_intro 等片段场景) |
| 7 | - 返回结果 | 7 | - 命中(≥ 阈值)→ 返回 duplicate(similarity=1.0) |
| 8 | """ | 8 | """ |
| 9 | 9 | ||
| 10 | import logging | 10 | import logging |
| 11 | import os | 11 | import os |
| 12 | import time | ||
| 12 | from dataclasses import dataclass, field | 13 | from dataclasses import dataclass, field |
| 13 | 14 | ||
| 15 | import numba | ||
| 14 | import numpy as np | 16 | import numpy as np |
| 15 | import psycopg | 17 | import psycopg |
| 16 | from scipy.spatial.distance import cdist | 18 | from scipy.spatial.distance import cdist |
| 17 | 19 | ||
| 18 | from .extractor import TARGET_FRAMES, extract_chroma_feature, extract_chroma_matrix | 20 | from .extractor import ( |
| 19 | from .dejavu_fingerprinter import fingerprint_audio | 21 | TARGET_FRAMES, |
| 22 | extract_chroma_feature_from_samples, | ||
| 23 | extract_chroma_matrix, | ||
| 24 | ) | ||
| 25 | from .dejavu_fingerprinter import fingerprint_audio, fingerprint_from_samples, load_audio, QUERY_MAX_DURATION_SEC | ||
| 20 | 26 | ||
| 21 | logger = logging.getLogger(__name__) | 27 | logger = logging.getLogger(__name__) |
| 22 | 28 | ||
| ... | @@ -56,6 +62,7 @@ class CompositionCandidate: | ... | @@ -56,6 +62,7 @@ class CompositionCandidate: |
| 56 | song_id: int | 62 | song_id: int |
| 57 | similarity: float | 63 | similarity: float |
| 58 | source: str = "chromagram" | 64 | source: str = "chromagram" |
| 65 | dejavu_aligned_count: int | None = None # 仅 source=dejavu 或 dejavu fallback 未命中时记录 | ||
| 59 | 66 | ||
| 60 | 67 | ||
| 61 | @dataclass | 68 | @dataclass |
| ... | @@ -73,10 +80,18 @@ class CompositionConfig: | ... | @@ -73,10 +80,18 @@ class CompositionConfig: |
| 73 | statement_timeout_ms: int = 30000 | 80 | statement_timeout_ms: int = 30000 |
| 74 | dtw_rerank_top_k: int = 20 # Cosine 召回后做 DTW 精排的候选数量 | 81 | dtw_rerank_top_k: int = 20 # Cosine 召回后做 DTW 精排的候选数量 |
| 75 | duplicate_threshold: float = _env_float("COMPOSITION_DUPLICATE_THRESHOLD", 0.85) | 82 | duplicate_threshold: float = _env_float("COMPOSITION_DUPLICATE_THRESHOLD", 0.85) |
| 83 | # Chromagram 提取配置 | ||
| 84 | chroma_hop_length: int = _env_int("COMPOSITION_CHROMA_HOP_LENGTH", 512) | ||
| 85 | chroma_win_len_smooth: int = _env_int("COMPOSITION_CHROMA_WIN_LEN_SMOOTH", 0) | ||
| 76 | # Dejavu 指纹匹配配置 | 86 | # Dejavu 指纹匹配配置 |
| 77 | dejavu_enabled: bool = _env_bool("COMPOSITION_DEJAVU_ENABLED", True) | 87 | dejavu_enabled: bool = _env_bool("COMPOSITION_DEJAVU_ENABLED", True) |
| 78 | dejavu_match_threshold: int = _env_int("COMPOSITION_DEJAVU_MATCH_THRESHOLD", 20) | 88 | dejavu_match_threshold: int = _env_int("COMPOSITION_DEJAVU_MATCH_THRESHOLD", 20) |
| 79 | 89 | ||
| 90 | def __post_init__(self) -> None: | ||
| 91 | # 0 表示自动:按 hop_length 等比缩小,保持平滑窗覆盖时长约 1 秒 | ||
| 92 | if self.chroma_win_len_smooth == 0: | ||
| 93 | self.chroma_win_len_smooth = max(1, round(41 * 512 / self.chroma_hop_length)) | ||
| 94 | |||
| 80 | 95 | ||
| 81 | @dataclass | 96 | @dataclass |
| 82 | class CompositionDedupService: | 97 | class CompositionDedupService: |
| ... | @@ -94,8 +109,12 @@ class CompositionDedupService: | ... | @@ -94,8 +109,12 @@ class CompositionDedupService: |
| 94 | Returns: | 109 | Returns: |
| 95 | 提取的特征向量。 | 110 | 提取的特征向量。 |
| 96 | """ | 111 | """ |
| 97 | feature = extract_chroma_feature(audio_path) | 112 | # 共用一次解码(44100Hz),chromagram 路径在内存中重采样,无需二次 ffmpeg |
| 98 | self._logger.info("提取 Chromagram 特征完成: song_id=%s, audio=%s", song_id, audio_path) | 113 | samples, sr = load_audio(audio_path) |
| 114 | self._logger.info("音频解码完成: song_id=%s, audio=%s", song_id, audio_path) | ||
| 115 | |||
| 116 | feature = extract_chroma_feature_from_samples(samples, sr, hop_length=self.config.chroma_hop_length, win_len_smooth=self.config.chroma_win_len_smooth) | ||
| 117 | self._logger.info("提取 Chromagram 特征完成: song_id=%s", song_id) | ||
| 99 | 118 | ||
| 100 | with psycopg.connect(self.config.dsn) as conn: | 119 | with psycopg.connect(self.config.dsn) as conn: |
| 101 | with conn.cursor() as cursor: | 120 | with conn.cursor() as cursor: |
| ... | @@ -111,15 +130,29 @@ class CompositionDedupService: | ... | @@ -111,15 +130,29 @@ class CompositionDedupService: |
| 111 | 130 | ||
| 112 | self._logger.info("Chromagram 特征入库完成: song_id=%s", song_id) | 131 | self._logger.info("Chromagram 特征入库完成: song_id=%s", song_id) |
| 113 | 132 | ||
| 114 | # Dejavu 指纹同时入库 | ||
| 115 | if self.config.dejavu_enabled: | 133 | if self.config.dejavu_enabled: |
| 116 | self._dejavu_ingest(song_id, audio_path) | 134 | self._dejavu_ingest(song_id, audio_path, samples=samples, sr=sr) |
| 117 | 135 | ||
| 118 | return feature | 136 | return feature |
| 119 | 137 | ||
| 120 | def _dejavu_ingest(self, song_id: int, audio_path: str) -> None: | 138 | def _dejavu_ingest( |
| 121 | """提取 Dejavu 指纹并写入 dejavu_fingerprints 表。""" | 139 | self, |
| 122 | file_sha1, fingerprints = fingerprint_audio(audio_path) | 140 | song_id: int, |
| 141 | audio_path: str, | ||
| 142 | *, | ||
| 143 | samples: np.ndarray | None = None, | ||
| 144 | sr: int | None = None, | ||
| 145 | ) -> None: | ||
| 146 | """提取 Dejavu 指纹并写入 dejavu_fingerprints 表。 | ||
| 147 | |||
| 148 | 若提供了已解码的 samples/sr,直接使用,跳过 ffmpeg;否则从文件重新加载。 | ||
| 149 | """ | ||
| 150 | if samples is not None and sr is not None: | ||
| 151 | _, fingerprints = fingerprint_from_samples(samples, sr, compute_sha1=False) | ||
| 152 | self._logger.info("Dejavu 指纹提取完成(共用解码): song_id=%s", song_id) | ||
| 153 | else: | ||
| 154 | _, fingerprints = fingerprint_audio(audio_path) | ||
| 155 | |||
| 123 | if not fingerprints: | 156 | if not fingerprints: |
| 124 | self._logger.warning("Dejavu 指纹为空: song_id=%s, audio=%s", song_id, audio_path) | 157 | self._logger.warning("Dejavu 指纹为空: song_id=%s, audio=%s", song_id, audio_path) |
| 125 | return | 158 | return |
| ... | @@ -144,25 +177,54 @@ class CompositionDedupService: | ... | @@ -144,25 +177,54 @@ class CompositionDedupService: |
| 144 | 177 | ||
| 145 | self._logger.info("Dejavu 指纹入库完成: song_id=%s, 指纹数=%d", song_id, len(fingerprints)) | 178 | self._logger.info("Dejavu 指纹入库完成: song_id=%s, 指纹数=%d", song_id, len(fingerprints)) |
| 146 | 179 | ||
| 147 | def query(self, audio_path: str, top_k: int = 100) -> list[CompositionCandidate]: | 180 | def query( |
| 181 | self, | ||
| 182 | audio_path: str, | ||
| 183 | top_k: int = 100, | ||
| 184 | timings: dict | None = None, | ||
| 185 | ) -> list[CompositionCandidate]: | ||
| 148 | """提取音频特征并查询相似结果。 | 186 | """提取音频特征并查询相似结果。 |
| 149 | 187 | ||
| 150 | 流程:Dejavu 指纹短路匹配 → 12 路循环对齐 Cosine 召回 → DTW 精排。 | 188 | 流程:Chromagram 12路 + DTW 精排 → Dejavu 指纹兜底(片段场景,仅在 chroma 未命中时解码)。 |
| 189 | |||
| 190 | Args: | ||
| 191 | timings: 若传入非 None 的 dict,方法执行完毕后会在其中写入各阶段耗时(单位 ms): | ||
| 192 | chroma_extract_ms、db_cosine_ms、db_fetch_ms、dtw_ms、 | ||
| 193 | dejavu_decode_ms、dejavu_fingerprint_ms、dejavu_db_ms。 | ||
| 194 | Dejavu 路径未执行时,对应键不会写入。 | ||
| 151 | """ | 195 | """ |
| 152 | # 1. 优先尝试 Dejavu 指纹匹配(短路) | 196 | # 1. Chromagram 12路 + DTW 精排(覆盖绝大多数场景) |
| 153 | if self.config.dejavu_enabled: | 197 | # 使用与入库一致的 22050Hz 解码路径,保证 chroma 向量对齐 |
| 154 | match = self._dejavu_query(audio_path) | 198 | candidates = self._query_chroma(audio_path, top_k, timings=timings) |
| 199 | |||
| 200 | # 2. Chromagram 未命中时,用 Dejavu 兜底(处理 chorus_only / trim_intro 等片段场景) | ||
| 201 | # 只有未命中才解码 44100Hz,大多数情况下无额外 I/O | ||
| 202 | if self.config.dejavu_enabled and not self.candidates_indicate_duplicate(candidates): | ||
| 203 | _t = time.perf_counter() | ||
| 204 | samples, sr = load_audio(audio_path, max_duration=QUERY_MAX_DURATION_SEC) | ||
| 205 | if timings is not None: | ||
| 206 | timings["dejavu_decode_ms"] = round((time.perf_counter() - _t) * 1000, 1) | ||
| 207 | match = self._dejavu_query(samples, sr, timings=timings) | ||
| 155 | if match is not None: | 208 | if match is not None: |
| 156 | self._logger.info( | 209 | if match.aligned_count >= self.config.dejavu_match_threshold: |
| 157 | "Dejavu 命中: song_id=%s, aligned_count=%d, total_collisions=%d, decision=duplicate", | 210 | self._logger.info( |
| 158 | match.song_id, | 211 | "Dejavu 命中: song_id=%s, aligned_count=%d, total_collisions=%d, decision=duplicate", |
| 159 | match.aligned_count, | 212 | match.song_id, |
| 160 | match.total_collisions, | 213 | match.aligned_count, |
| 161 | ) | 214 | match.total_collisions, |
| 162 | return [CompositionCandidate(song_id=match.song_id, similarity=1.0, source="dejavu")] | 215 | ) |
| 163 | 216 | return [CompositionCandidate( | |
| 164 | # 2. Dejavu 未命中或禁用,走现有 Chromagram 12路 + DTW 流程 | 217 | song_id=match.song_id, |
| 165 | return self._query_chroma(audio_path, top_k) | 218 | similarity=1.0, |
| 219 | source="dejavu", | ||
| 220 | dejavu_aligned_count=match.aligned_count, | ||
| 221 | )] | ||
| 222 | else: | ||
| 223 | # 未达阈值:把 aligned_count 附加到 chromagram top1 上供评估脚本记录 | ||
| 224 | if candidates: | ||
| 225 | candidates[0].dejavu_aligned_count = match.aligned_count | ||
| 226 | |||
| 227 | return candidates | ||
| 166 | 228 | ||
| 167 | def check(self, audio_path: str, top_k: int = 100) -> bool: | 229 | def check(self, audio_path: str, top_k: int = 100) -> bool: |
| 168 | """按最终接口语义返回是否重复。""" | 230 | """按最终接口语义返回是否重复。""" |
| ... | @@ -178,9 +240,17 @@ class CompositionDedupService: | ... | @@ -178,9 +240,17 @@ class CompositionDedupService: |
| 178 | return False | 240 | return False |
| 179 | return candidates[0].similarity >= self.config.duplicate_threshold | 241 | return candidates[0].similarity >= self.config.duplicate_threshold |
| 180 | 242 | ||
| 181 | def _query_chroma(self, audio_path: str, top_k: int = 100) -> list[CompositionCandidate]: | 243 | def _query_chroma( |
| 244 | self, | ||
| 245 | audio_path: str, | ||
| 246 | top_k: int = 100, | ||
| 247 | timings: dict | None = None, | ||
| 248 | ) -> list[CompositionCandidate]: | ||
| 182 | """Chromagram 12 路循环对齐 + DTW 精排查询。""" | 249 | """Chromagram 12 路循环对齐 + DTW 精排查询。""" |
| 183 | chroma = extract_chroma_matrix(audio_path) | 250 | _t = time.perf_counter() |
| 251 | chroma = extract_chroma_matrix(audio_path, hop_length=self.config.chroma_hop_length, win_len_smooth=self.config.chroma_win_len_smooth) | ||
| 252 | if timings is not None: | ||
| 253 | timings["chroma_extract_ms"] = round((time.perf_counter() - _t) * 1000, 1) | ||
| 184 | self._logger.info("提取 Chromagram 查询特征完成: audio=%s", audio_path) | 254 | self._logger.info("提取 Chromagram 查询特征完成: audio=%s", audio_path) |
| 185 | 255 | ||
| 186 | # 1. 12 路循环对齐:穷举 12 种半音偏移,单条 SQL 内部展开,按 song_id 取最高 Cosine 相似度 | 256 | # 1. 12 路循环对齐:穷举 12 种半音偏移,单条 SQL 内部展开,按 song_id 取最高 Cosine 相似度 |
| ... | @@ -213,6 +283,7 @@ class CompositionDedupService: | ... | @@ -213,6 +283,7 @@ class CompositionDedupService: |
| 213 | LIMIT %s | 283 | LIMIT %s |
| 214 | """ | 284 | """ |
| 215 | best: dict[int, float] = {} | 285 | best: dict[int, float] = {} |
| 286 | _t = time.perf_counter() | ||
| 216 | with psycopg.connect(self.config.dsn) as conn: | 287 | with psycopg.connect(self.config.dsn) as conn: |
| 217 | with conn.cursor() as cursor: | 288 | with conn.cursor() as cursor: |
| 218 | cursor.execute( | 289 | cursor.execute( |
| ... | @@ -221,33 +292,44 @@ class CompositionDedupService: | ... | @@ -221,33 +292,44 @@ class CompositionDedupService: |
| 221 | cursor.execute(sql, (*shift_vecs, top_k, top_k)) | 292 | cursor.execute(sql, (*shift_vecs, top_k, top_k)) |
| 222 | for song_id, sim in cursor.fetchall(): | 293 | for song_id, sim in cursor.fetchall(): |
| 223 | best[int(song_id)] = float(sim) | 294 | best[int(song_id)] = float(sim) |
| 295 | if timings is not None: | ||
| 296 | timings["db_cosine_ms"] = round((time.perf_counter() - _t) * 1000, 1) | ||
| 224 | 297 | ||
| 225 | # 2. 取 Top dtw_rerank_top_k,从库中取原始向量做 DTW 精排 | 298 | # 2. 取 Top dtw_rerank_top_k,从库中取原始向量做 DTW 精排 |
| 226 | top = sorted(best.items(), key=lambda x: x[1], reverse=True) | 299 | top = sorted(best.items(), key=lambda x: x[1], reverse=True) |
| 227 | rerank_ids = [sid for sid, _ in top[:self.config.dtw_rerank_top_k]] | 300 | rerank_ids = [sid for sid, _ in top[:self.config.dtw_rerank_top_k]] |
| 228 | 301 | ||
| 302 | _t = time.perf_counter() | ||
| 229 | with conn.cursor() as cursor: | 303 | with conn.cursor() as cursor: |
| 230 | cursor.execute( | 304 | cursor.execute( |
| 231 | "SELECT song_id, feature_vector::float4[] FROM composition_feature WHERE song_id = ANY(%s)", | 305 | "SELECT song_id, feature_vector::float4[] FROM composition_feature WHERE song_id = ANY(%s)", |
| 232 | (rerank_ids,), | 306 | (rerank_ids,), |
| 233 | ) | 307 | ) |
| 234 | db_rows = cursor.fetchall() | 308 | db_rows = cursor.fetchall() |
| 309 | if timings is not None: | ||
| 310 | timings["db_fetch_ms"] = round((time.perf_counter() - _t) * 1000, 1) | ||
| 235 | 311 | ||
| 312 | _t = time.perf_counter() | ||
| 236 | reranked = [] | 313 | reranked = [] |
| 237 | for song_id, fv in db_rows: | 314 | for song_id, fv in db_rows: |
| 238 | cand_chroma = np.array(fv, dtype=np.float32).reshape(12, TARGET_FRAMES) | 315 | cand_chroma = np.array(fv, dtype=np.float32).reshape(12, TARGET_FRAMES) |
| 239 | dtw_sim = _best_shifted_dtw_similarity(chroma, cand_chroma) | 316 | dtw_sim = _best_shifted_dtw_similarity( |
| 317 | chroma, cand_chroma, early_exit_threshold=self.config.duplicate_threshold | ||
| 318 | ) | ||
| 240 | reranked.append(CompositionCandidate(song_id=int(song_id), similarity=dtw_sim)) | 319 | reranked.append(CompositionCandidate(song_id=int(song_id), similarity=dtw_sim)) |
| 241 | reranked.sort(key=lambda c: c.similarity, reverse=True) | 320 | reranked.sort(key=lambda c: c.similarity, reverse=True) |
| 321 | if timings is not None: | ||
| 322 | timings["dtw_ms"] = round((time.perf_counter() - _t) * 1000, 1) | ||
| 242 | 323 | ||
| 243 | rerank_id_set = {c.song_id for c in reranked} | 324 | rerank_id_set = {c.song_id for c in reranked} |
| 244 | rest = [ | 325 | # top 已按 cosine 降序排列;直接从中剔除已精排的候选,剩余保留 cosine 分 |
| 326 | non_reranked = [ | ||
| 245 | CompositionCandidate(song_id=sid, similarity=sim) | 327 | CompositionCandidate(song_id=sid, similarity=sim) |
| 246 | for sid, sim in top[self.config.dtw_rerank_top_k:] | 328 | for sid, sim in top |
| 247 | if sid not in rerank_id_set | 329 | if sid not in rerank_id_set |
| 248 | ] | 330 | ] |
| 249 | 331 | ||
| 250 | result = reranked + rest | 332 | result = reranked + non_reranked |
| 251 | top_summary = ", ".join( | 333 | top_summary = ", ".join( |
| 252 | f"{candidate.song_id}:{candidate.similarity:.4f}" | 334 | f"{candidate.song_id}:{candidate.similarity:.4f}" |
| 253 | for candidate in result[:5] | 335 | for candidate in result[:5] |
| ... | @@ -259,7 +341,12 @@ class CompositionDedupService: | ... | @@ -259,7 +341,12 @@ class CompositionDedupService: |
| 259 | ) | 341 | ) |
| 260 | return result | 342 | return result |
| 261 | 343 | ||
| 262 | def _dejavu_query(self, audio_path: str) -> _DejavuMatch | None: | 344 | def _dejavu_query( |
| 345 | self, | ||
| 346 | samples: np.ndarray, | ||
| 347 | sr: int, | ||
| 348 | timings: dict | None = None, | ||
| 349 | ) -> _DejavuMatch | None: | ||
| 263 | """Dejavu 指纹查询,返回 offset 对齐后碰撞数最多的 song_id。 | 350 | """Dejavu 指纹查询,返回 offset 对齐后碰撞数最多的 song_id。 |
| 264 | 351 | ||
| 265 | 只统计 hash 总碰撞数会让常见频谱峰值、噪声片段或大库随机碰撞直接短路成 | 352 | 只统计 hash 总碰撞数会让常见频谱峰值、噪声片段或大库随机碰撞直接短路成 |
| ... | @@ -267,15 +354,19 @@ class CompositionDedupService: | ... | @@ -267,15 +354,19 @@ class CompositionDedupService: |
| 267 | db_offset - query_offset 落在同一个时间偏移上。 | 354 | db_offset - query_offset 落在同一个时间偏移上。 |
| 268 | 355 | ||
| 269 | Returns: | 356 | Returns: |
| 270 | 命中结果,未命中返回 None。 | 357 | 最佳匹配结果(不做阈值过滤),无任何碰撞时返回 None。 |
| 271 | """ | 358 | """ |
| 272 | file_sha1, fingerprints = fingerprint_audio(audio_path) | 359 | _t = time.perf_counter() |
| 360 | _, fingerprints = fingerprint_from_samples(samples, sr, compute_sha1=False) | ||
| 361 | if timings is not None: | ||
| 362 | timings["dejavu_fingerprint_ms"] = round((time.perf_counter() - _t) * 1000, 1) | ||
| 273 | if not fingerprints: | 363 | if not fingerprints: |
| 274 | return None | 364 | return None |
| 275 | 365 | ||
| 276 | hashes = [h for h, _ in fingerprints] | 366 | hashes = [h for h, _ in fingerprints] |
| 277 | offsets = [int(o) for _, o in fingerprints] | 367 | offsets = [int(o) for _, o in fingerprints] |
| 278 | 368 | ||
| 369 | _t = time.perf_counter() | ||
| 279 | with psycopg.connect(self.config.dsn) as conn: | 370 | with psycopg.connect(self.config.dsn) as conn: |
| 280 | with conn.cursor() as cursor: | 371 | with conn.cursor() as cursor: |
| 281 | # 先按 hash 找碰撞,再按每个 song_id 的 offset delta 聚类。 | 372 | # 先按 hash 找碰撞,再按每个 song_id 的 offset delta 聚类。 |
| ... | @@ -307,23 +398,21 @@ class CompositionDedupService: | ... | @@ -307,23 +398,21 @@ class CompositionDedupService: |
| 307 | (hashes, offsets), | 398 | (hashes, offsets), |
| 308 | ) | 399 | ) |
| 309 | row = cursor.fetchone() | 400 | row = cursor.fetchone() |
| 310 | if row is None: | 401 | if timings is not None: |
| 311 | return None | 402 | timings["dejavu_db_ms"] = round((time.perf_counter() - _t) * 1000, 1) |
| 312 | sid, aligned_count, total_collisions = row | 403 | if row is None: |
| 313 | aligned_count = int(aligned_count) | 404 | return None |
| 314 | if aligned_count >= self.config.dejavu_match_threshold: | 405 | sid, aligned_count, total_collisions = row |
| 315 | return _DejavuMatch( | 406 | return _DejavuMatch( |
| 316 | song_id=int(sid), | 407 | song_id=int(sid), |
| 317 | aligned_count=aligned_count, | 408 | aligned_count=int(aligned_count), |
| 318 | total_collisions=int(total_collisions), | 409 | total_collisions=int(total_collisions), |
| 319 | ) | 410 | ) |
| 320 | return None | ||
| 321 | 411 | ||
| 322 | 412 | ||
| 323 | def _dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float: | 413 | @numba.njit(cache=True) |
| 324 | """计算两个 12×T Chromagram 矩阵之间的 DTW 相似度(映射到 [0,1])。""" | 414 | def _dtw_dp(cost: np.ndarray) -> float: |
| 325 | # 帧间欧氏距离矩阵 | 415 | """DTW DP 填表(numba JIT 编译,数值结果与纯 Python 实现完全一致)。""" |
| 326 | cost = cdist(query.T, candidate.T, metric="euclidean") | ||
| 327 | n, m = cost.shape | 416 | n, m = cost.shape |
| 328 | dp = np.full((n, m), np.inf) | 417 | dp = np.full((n, m), np.inf) |
| 329 | dp[0, 0] = cost[0, 0] | 418 | dp[0, 0] = cost[0, 0] |
| ... | @@ -334,14 +423,37 @@ def _dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float: | ... | @@ -334,14 +423,37 @@ def _dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float: |
| 334 | for i in range(1, n): | 423 | for i in range(1, n): |
| 335 | for j in range(1, m): | 424 | for j in range(1, m): |
| 336 | dp[i, j] = cost[i, j] + min(dp[i - 1, j], dp[i, j - 1], dp[i - 1, j - 1]) | 425 | dp[i, j] = cost[i, j] + min(dp[i - 1, j], dp[i, j - 1], dp[i - 1, j - 1]) |
| 337 | dtw_dist = dp[n - 1, m - 1] / (n + m) | 426 | return dp[n - 1, m - 1] |
| 427 | |||
| 428 | |||
| 429 | def _dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float: | ||
| 430 | """计算两个 12×T Chromagram 矩阵之间的 DTW 相似度(映射到 [0,1])。""" | ||
| 431 | # 帧间欧氏距离矩阵 | ||
| 432 | cost = cdist(query.T, candidate.T, metric="euclidean") | ||
| 433 | n, m = cost.shape | ||
| 434 | dtw_dist = _dtw_dp(cost) / (n + m) | ||
| 338 | # 转换为相似度:距离越小相似度越高 | 435 | # 转换为相似度:距离越小相似度越高 |
| 339 | return float(1.0 / (1.0 + dtw_dist)) | 436 | return float(1.0 / (1.0 + dtw_dist)) |
| 340 | 437 | ||
| 341 | 438 | ||
| 342 | def _best_shifted_dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float: | 439 | def _best_shifted_dtw_similarity( |
| 343 | """计算 12 路音高循环位移下的最佳 DTW 相似度。""" | 440 | query: np.ndarray, |
| 344 | return max( | 441 | candidate: np.ndarray, |
| 345 | _dtw_similarity(np.roll(query, -shift, axis=0), candidate) | 442 | early_exit_threshold: float = 1.1, |
| 346 | for shift in range(12) | 443 | ) -> float: |
| 347 | ) | 444 | """计算 12 路音高循环位移下的最佳 DTW 相似度。 |
| 445 | |||
| 446 | Args: | ||
| 447 | early_exit_threshold: 某个 shift 的相似度达到此值时立即返回,跳过剩余 shift。 | ||
| 448 | 传入 duplicate_threshold 即可:对已确认重复的候选不再浪费算力; | ||
| 449 | 返回值可能略低于理论最大值,但不影响 duplicate/non-duplicate 二元判定。 | ||
| 450 | 默认 1.1(> 1 的不可达值,等价于不启用早退)。 | ||
| 451 | """ | ||
| 452 | best = 0.0 | ||
| 453 | for shift in range(12): | ||
| 454 | sim = _dtw_similarity(np.roll(query, -shift, axis=0), candidate) | ||
| 455 | if sim > best: | ||
| 456 | best = sim | ||
| 457 | if best >= early_exit_threshold: | ||
| 458 | break | ||
| 459 | return best | ... | ... |
| ... | @@ -12,6 +12,7 @@ tqdm>=4.66 | ... | @@ -12,6 +12,7 @@ tqdm>=4.66 |
| 12 | 12 | ||
| 13 | # Audio composition feature extraction | 13 | # Audio composition feature extraction |
| 14 | librosa>=0.10.0 | 14 | librosa>=0.10.0 |
| 15 | numba>=0.59.0 | ||
| 15 | scipy>=1.11 | 16 | scipy>=1.11 |
| 16 | numpy>=1.24 | 17 | numpy>=1.24 |
| 17 | 18 | ||
| ... | @@ -21,3 +22,6 @@ pgvector>=0.2.0 | ... | @@ -21,3 +22,6 @@ pgvector>=0.2.0 |
| 21 | # HTTP API server | 22 | # HTTP API server |
| 22 | fastapi>=0.110.0 | 23 | fastapi>=0.110.0 |
| 23 | uvicorn[standard]>=0.29.0 | 24 | uvicorn[standard]>=0.29.0 |
| 25 | |||
| 26 | # Environment variable loading | ||
| 27 | python-dotenv>=1.0 | ... | ... |
| ... | @@ -8,8 +8,8 @@ expected_song_id 的 top-k/top1 命中只作为诊断字段。 | ... | @@ -8,8 +8,8 @@ expected_song_id 的 top-k/top1 命中只作为诊断字段。 |
| 8 | 用法: | 8 | 用法: |
| 9 | python scripts/evaluate_composition.py \ | 9 | python scripts/evaluate_composition.py \ |
| 10 | --dsn "postgresql:///lyric_dedup" \ | 10 | --dsn "postgresql:///lyric_dedup" \ |
| 11 | --queries composition_dedup/composition_testset4/queries.csv \ | 11 | --queries composition_testset/test_samples.csv \ |
| 12 | --out composition_dedup/composition_eval/composition_eval_result_v3.csv | 12 | --out composition_dedup/composition_eval/nohop_result.csv |
| 13 | """ | 13 | """ |
| 14 | 14 | ||
| 15 | import argparse | 15 | import argparse |
| ... | @@ -17,10 +17,14 @@ import csv | ... | @@ -17,10 +17,14 @@ import csv |
| 17 | import json | 17 | import json |
| 18 | import logging | 18 | import logging |
| 19 | import sys | 19 | import sys |
| 20 | import time | ||
| 20 | from pathlib import Path | 21 | from pathlib import Path |
| 21 | 22 | ||
| 22 | sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) | 23 | sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| 23 | 24 | ||
| 25 | from dotenv import load_dotenv | ||
| 26 | load_dotenv(Path(__file__).resolve().parent.parent / ".env") | ||
| 27 | |||
| 24 | from composition_dedup.service import CompositionConfig, CompositionDedupService | 28 | from composition_dedup.service import CompositionConfig, CompositionDedupService |
| 25 | 29 | ||
| 26 | logger = logging.getLogger(__name__) | 30 | logger = logging.getLogger(__name__) |
| ... | @@ -92,8 +96,12 @@ def main() -> None: | ... | @@ -92,8 +96,12 @@ def main() -> None: |
| 92 | invalid_negative_pair = (not expected_dup) and audio_song_id == expected_song_id | 96 | invalid_negative_pair = (not expected_dup) and audio_song_id == expected_song_id |
| 93 | 97 | ||
| 94 | try: | 98 | try: |
| 95 | candidates = service.query(audio_path, top_k=args.top_k) | 99 | timings: dict = {} |
| 100 | _t0 = time.perf_counter() | ||
| 101 | candidates = service.query(audio_path, top_k=args.top_k, timings=timings) | ||
| 102 | query_time_ms = round((time.perf_counter() - _t0) * 1000, 1) | ||
| 96 | except Exception as e: | 103 | except Exception as e: |
| 104 | query_time_ms = round((time.perf_counter() - _t0) * 1000, 1) | ||
| 97 | logger.error("[%d/%d] 查询失败: %s, %s", i, len(rows), audio_path, e) | 105 | logger.error("[%d/%d] 查询失败: %s, %s", i, len(rows), audio_path, e) |
| 98 | result_rows.append({ | 106 | result_rows.append({ |
| 99 | "query_song_id": query_song_id, | 107 | "query_song_id": query_song_id, |
| ... | @@ -106,6 +114,7 @@ def main() -> None: | ... | @@ -106,6 +114,7 @@ def main() -> None: |
| 106 | "top1_song_id": "", | 114 | "top1_song_id": "", |
| 107 | "top1_similarity": "", | 115 | "top1_similarity": "", |
| 108 | "top1_source": "", | 116 | "top1_source": "", |
| 117 | "dejavu_aligned_count": "", | ||
| 109 | "top1_hit": False, | 118 | "top1_hit": False, |
| 110 | "topk_hit": False, | 119 | "topk_hit": False, |
| 111 | "expected_rank": "", | 120 | "expected_rank": "", |
| ... | @@ -115,6 +124,14 @@ def main() -> None: | ... | @@ -115,6 +124,14 @@ def main() -> None: |
| 115 | "expected_duplicate": expected_dup, | 124 | "expected_duplicate": expected_dup, |
| 116 | "predicted_duplicate": False, | 125 | "predicted_duplicate": False, |
| 117 | "correct": not expected_dup, # 查询失败视为 not_duplicate | 126 | "correct": not expected_dup, # 查询失败视为 not_duplicate |
| 127 | "query_time_ms": query_time_ms, | ||
| 128 | "chroma_extract_ms": timings.get("chroma_extract_ms", ""), | ||
| 129 | "db_cosine_ms": timings.get("db_cosine_ms", ""), | ||
| 130 | "db_fetch_ms": timings.get("db_fetch_ms", ""), | ||
| 131 | "dtw_ms": timings.get("dtw_ms", ""), | ||
| 132 | "dejavu_decode_ms": timings.get("dejavu_decode_ms", ""), | ||
| 133 | "dejavu_fingerprint_ms": timings.get("dejavu_fingerprint_ms", ""), | ||
| 134 | "dejavu_db_ms": timings.get("dejavu_db_ms", ""), | ||
| 118 | "error": str(e), | 135 | "error": str(e), |
| 119 | }) | 136 | }) |
| 120 | continue | 137 | continue |
| ... | @@ -123,6 +140,7 @@ def main() -> None: | ... | @@ -123,6 +140,7 @@ def main() -> None: |
| 123 | top1_song_id = str(top1.song_id) if top1 else "" | 140 | top1_song_id = str(top1.song_id) if top1 else "" |
| 124 | top1_sim = round(top1.similarity, 4) if top1 else "" | 141 | top1_sim = round(top1.similarity, 4) if top1 else "" |
| 125 | top1_source = top1.source if top1 else "" | 142 | top1_source = top1.source if top1 else "" |
| 143 | dejavu_aligned_count = top1.dejavu_aligned_count if top1 else "" | ||
| 126 | 144 | ||
| 127 | # 诊断召回:expected_song_id 是否进入 top1/top-k。 | 145 | # 诊断召回:expected_song_id 是否进入 top1/top-k。 |
| 128 | top1_hit = bool(expected_song_id) and top1_song_id == expected_song_id | 146 | top1_hit = bool(expected_song_id) and top1_song_id == expected_song_id |
| ... | @@ -157,6 +175,7 @@ def main() -> None: | ... | @@ -157,6 +175,7 @@ def main() -> None: |
| 157 | "top1_song_id": top1_song_id, | 175 | "top1_song_id": top1_song_id, |
| 158 | "top1_similarity": top1_sim, | 176 | "top1_similarity": top1_sim, |
| 159 | "top1_source": top1_source, | 177 | "top1_source": top1_source, |
| 178 | "dejavu_aligned_count": dejavu_aligned_count if dejavu_aligned_count is not None else "", | ||
| 160 | "top1_hit": top1_hit, | 179 | "top1_hit": top1_hit, |
| 161 | "topk_hit": topk_hit, | 180 | "topk_hit": topk_hit, |
| 162 | "expected_rank": expected_rank, | 181 | "expected_rank": expected_rank, |
| ... | @@ -166,11 +185,19 @@ def main() -> None: | ... | @@ -166,11 +185,19 @@ def main() -> None: |
| 166 | "expected_duplicate": expected_dup, | 185 | "expected_duplicate": expected_dup, |
| 167 | "predicted_duplicate": predicted_dup, | 186 | "predicted_duplicate": predicted_dup, |
| 168 | "correct": correct, | 187 | "correct": correct, |
| 188 | "query_time_ms": query_time_ms, | ||
| 189 | "chroma_extract_ms": timings.get("chroma_extract_ms", ""), | ||
| 190 | "db_cosine_ms": timings.get("db_cosine_ms", ""), | ||
| 191 | "db_fetch_ms": timings.get("db_fetch_ms", ""), | ||
| 192 | "dtw_ms": timings.get("dtw_ms", ""), | ||
| 193 | "dejavu_decode_ms": timings.get("dejavu_decode_ms", ""), | ||
| 194 | "dejavu_fingerprint_ms": timings.get("dejavu_fingerprint_ms", ""), | ||
| 195 | "dejavu_db_ms": timings.get("dejavu_db_ms", ""), | ||
| 169 | "error": "", | 196 | "error": "", |
| 170 | }) | 197 | }) |
| 171 | 198 | ||
| 172 | logger.info( | 199 | logger.info( |
| 173 | "[%d/%d] variant=%s source=%s expected=%s predicted_duplicate=%s threshold=%.4f expected_song_id=%s top1=%s sim=%s top1_hit=%s topk_hit=%s expected_rank=%s expected_sim=%s correct=%s", | 200 | "[%d/%d] variant=%s source=%s expected=%s predicted_duplicate=%s threshold=%.4f expected_song_id=%s top1=%s sim=%s top1_hit=%s topk_hit=%s expected_rank=%s expected_sim=%s correct=%s time_ms=%s", |
| 174 | i, | 201 | i, |
| 175 | len(rows), | 202 | len(rows), |
| 176 | row.get("variant", ""), | 203 | row.get("variant", ""), |
| ... | @@ -186,6 +213,7 @@ def main() -> None: | ... | @@ -186,6 +213,7 @@ def main() -> None: |
| 186 | expected_rank if expected_rank != "" else "-", | 213 | expected_rank if expected_rank != "" else "-", |
| 187 | expected_similarity if expected_similarity != "" else "-", | 214 | expected_similarity if expected_similarity != "" else "-", |
| 188 | correct, | 215 | correct, |
| 216 | query_time_ms, | ||
| 189 | ) | 217 | ) |
| 190 | 218 | ||
| 191 | if i % 10 == 0 or i == len(rows): | 219 | if i % 10 == 0 or i == len(rows): |
| ... | @@ -194,9 +222,14 @@ def main() -> None: | ... | @@ -194,9 +222,14 @@ def main() -> None: |
| 194 | # 写逐条结果 | 222 | # 写逐条结果 |
| 195 | fieldnames = ["query_song_id", "audio_song_id", "audio_path", "variant", "sample_class", | 223 | fieldnames = ["query_song_id", "audio_song_id", "audio_path", "variant", "sample_class", |
| 196 | "expected_song_id", "expected", "top1_song_id", "top1_similarity", "top1_source", | 224 | "expected_song_id", "expected", "top1_song_id", "top1_similarity", "top1_source", |
| 225 | "dejavu_aligned_count", | ||
| 197 | "top1_hit", "topk_hit", "expected_rank", "expected_similarity", | 226 | "top1_hit", "topk_hit", "expected_rank", "expected_similarity", |
| 198 | "invalid_negative_pair", "invalid_boolean_sample", | 227 | "invalid_negative_pair", "invalid_boolean_sample", |
| 199 | "expected_duplicate", "predicted_duplicate", "correct", "error"] | 228 | "expected_duplicate", "predicted_duplicate", "correct", |
| 229 | "query_time_ms", | ||
| 230 | "chroma_extract_ms", "db_cosine_ms", "db_fetch_ms", "dtw_ms", | ||
| 231 | "dejavu_decode_ms", "dejavu_fingerprint_ms", "dejavu_db_ms", | ||
| 232 | "error"] | ||
| 200 | with out_path.open("w", newline="", encoding="utf-8") as f: | 233 | with out_path.open("w", newline="", encoding="utf-8") as f: |
| 201 | writer = csv.DictWriter(f, fieldnames=fieldnames) | 234 | writer = csv.DictWriter(f, fieldnames=fieldnames) |
| 202 | writer.writeheader() | 235 | writer.writeheader() | ... | ... |
| ... | @@ -6,11 +6,11 @@ | ... | @@ -6,11 +6,11 @@ |
| 6 | 6 | ||
| 7 | 用法: | 7 | 用法: |
| 8 | python scripts/generate_composition_testset.py \ | 8 | python scripts/generate_composition_testset.py \ |
| 9 | --audio-dir /Volumes/移动硬盘/lyric_audio_type11 \ | 9 | --audio-dir /Volumes/移动硬盘/composition_test \ |
| 10 | --negative-audio-dir /Volumes/移动硬盘/composition_test \ | 10 | --negative-audio-dir /Volumes/移动硬盘/composition_drop \ |
| 11 | --out-dir composition_dedup/composition_testset \ | 11 | --out-dir composition_testset \ |
| 12 | --num-songs 80 \ | 12 | --num-songs 100 \ |
| 13 | --num-negative-songs 40 \ | 13 | --num-negative-songs 100 \ |
| 14 | --negative-variants \ | 14 | --negative-variants \ |
| 15 | --seed 123 | 15 | --seed 123 |
| 16 | 16 | ... | ... |
| ... | @@ -16,6 +16,9 @@ from pathlib import Path | ... | @@ -16,6 +16,9 @@ from pathlib import Path |
| 16 | 16 | ||
| 17 | sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) | 17 | sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| 18 | 18 | ||
| 19 | from dotenv import load_dotenv | ||
| 20 | load_dotenv(Path(__file__).resolve().parent.parent / ".env") | ||
| 21 | |||
| 19 | from tqdm import tqdm | 22 | from tqdm import tqdm |
| 20 | 23 | ||
| 21 | from composition_dedup.service import CompositionConfig, CompositionDedupService | 24 | from composition_dedup.service import CompositionConfig, CompositionDedupService | ... | ... |
-
Please register or sign in to post a comment