优化性能
Showing
8 changed files
with
159 additions
and
86 deletions
| ... | @@ -22,7 +22,6 @@ from pathlib import Path | ... | @@ -22,7 +22,6 @@ from pathlib import Path |
| 22 | import librosa | 22 | import librosa |
| 23 | import numpy as np | 23 | import numpy as np |
| 24 | from scipy.ndimage import ( | 24 | from scipy.ndimage import ( |
| 25 | binary_erosion, | ||
| 26 | generate_binary_structure, | 25 | generate_binary_structure, |
| 27 | iterate_structure, | 26 | iterate_structure, |
| 28 | maximum_filter, | 27 | maximum_filter, |
| ... | @@ -60,10 +59,10 @@ MIN_HASH_TIME_DELTA = 0 | ... | @@ -60,10 +59,10 @@ MIN_HASH_TIME_DELTA = 0 |
| 60 | MAX_HASH_TIME_DELTA = 200 | 59 | MAX_HASH_TIME_DELTA = 200 |
| 61 | PEAK_SORT = True | 60 | PEAK_SORT = True |
| 62 | FINGERPRINT_REDUCTION = 20 | 61 | FINGERPRINT_REDUCTION = 20 |
| 63 | MAX_DURATION_SEC = float(os.environ.get("COMPOSITION_DEJAVU_MAX_DURATION", "120")) # 0=不限制 | 62 | QUERY_MAX_DURATION_SEC = float(os.environ.get("COMPOSITION_DEJAVU_QUERY_MAX_DURATION", "120")) # 0=不限制 |
| 64 | 63 | ||
| 65 | 64 | ||
| 66 | def _normalize_audio(audio_path: str, max_duration: float = MAX_DURATION_SEC) -> tuple[np.ndarray, int]: | 65 | def _normalize_audio(audio_path: str, max_duration: float = 0) -> tuple[np.ndarray, int]: |
| 67 | """将音频标准化为单声道 WAV 并加载为 numpy 数组。 | 66 | """将音频标准化为单声道 WAV 并加载为 numpy 数组。 |
| 68 | 67 | ||
| 69 | 使用 ffmpeg 先做重采样,再用 librosa 读取。 | 68 | 使用 ffmpeg 先做重采样,再用 librosa 读取。 |
| ... | @@ -133,12 +132,7 @@ def _get_2d_peaks(arr2D: np.ndarray, amp_min: float = DEFAULT_AMP_MIN): | ... | @@ -133,12 +132,7 @@ def _get_2d_peaks(arr2D: np.ndarray, amp_min: float = DEFAULT_AMP_MIN): |
| 133 | neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE) | 132 | neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE) |
| 134 | 133 | ||
| 135 | # 找局部极大值 | 134 | # 找局部极大值 |
| 136 | local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D | 135 | detected_peaks = maximum_filter(arr2D, footprint=neighborhood) == arr2D |
| 137 | background = arr2D == 0 | ||
| 138 | eroded_background = binary_erosion(background, structure=neighborhood, border_value=1) | ||
| 139 | |||
| 140 | # 布尔掩码 | ||
| 141 | detected_peaks = local_max ^ eroded_background | ||
| 142 | 136 | ||
| 143 | # 提取峰值 | 137 | # 提取峰值 |
| 144 | amps = arr2D[detected_peaks] | 138 | amps = arr2D[detected_peaks] |
| ... | @@ -181,6 +175,43 @@ def _generate_hashes(peaks: list[tuple[int, int]], fan_value: int = DEFAULT_FAN_ | ... | @@ -181,6 +175,43 @@ def _generate_hashes(peaks: list[tuple[int, int]], fan_value: int = DEFAULT_FAN_ |
| 181 | yield (h.hexdigest()[:FINGERPRINT_REDUCTION].encode(), t1) | 175 | yield (h.hexdigest()[:FINGERPRINT_REDUCTION].encode(), t1) |
| 182 | 176 | ||
| 183 | 177 | ||
| 178 | def load_audio(audio_path: str, max_duration: float = 0) -> tuple[np.ndarray, int]: | ||
| 179 | """加载并标准化音频为 44100Hz 单声道(供多路径共用,避免重复解码)。 | ||
| 180 | |||
| 181 | Args: | ||
| 182 | audio_path: 音频文件路径。 | ||
| 183 | max_duration: 最大截取时长(秒),0 表示不限制。 | ||
| 184 | |||
| 185 | Returns: | ||
| 186 | (samples, sr) 元组。 | ||
| 187 | """ | ||
| 188 | return _normalize_audio(audio_path, max_duration) | ||
| 189 | |||
| 190 | |||
| 191 | def fingerprint_from_samples( | ||
| 192 | samples: np.ndarray, sr: int, *, compute_sha1: bool = True | ||
| 193 | ) -> tuple[str, list[tuple[bytes, int]]]: | ||
| 194 | """对已加载的音频样本生成 Dejavu 风格指纹(不做 I/O)。 | ||
| 195 | |||
| 196 | Args: | ||
| 197 | samples: 单声道音频样本(应为 DEFAULT_FS=44100Hz)。 | ||
| 198 | sr: 采样率。 | ||
| 199 | compute_sha1: 是否计算 file_sha1。service 内部调用时传 False 可跳过 | ||
| 200 | 对 samples.tobytes() 的 21MB 哈希运算(返回值在那些路径中未被使用)。 | ||
| 201 | |||
| 202 | Returns: | ||
| 203 | (file_sha1, fingerprints) 元组, | ||
| 204 | 其中 fingerprints 是 [(hash_bytes, offset), ...] 列表。 | ||
| 205 | compute_sha1=False 时 file_sha1 返回空字符串。 | ||
| 206 | """ | ||
| 207 | file_sha1 = hashlib.sha1(samples.tobytes()).hexdigest()[:16] if compute_sha1 else "" | ||
| 208 | arr2D = _specgram(samples, sr, DEFAULT_WINDOW_SIZE, DEFAULT_OVERLAP_RATIO) | ||
| 209 | freq_idx, time_idx = _get_2d_peaks(arr2D) | ||
| 210 | peaks = list(zip(freq_idx, time_idx)) | ||
| 211 | fingerprints = list(_generate_hashes(peaks)) | ||
| 212 | return file_sha1, fingerprints | ||
| 213 | |||
| 214 | |||
| 184 | def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]: | 215 | def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]: |
| 185 | """对音频文件生成 Dejavu 风格指纹。 | 216 | """对音频文件生成 Dejavu 风格指纹。 |
| 186 | 217 | ||
| ... | @@ -198,21 +229,7 @@ def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]: | ... | @@ -198,21 +229,7 @@ def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]: |
| 198 | if not os.path.isfile(audio_path): | 229 | if not os.path.isfile(audio_path): |
| 199 | raise FileNotFoundError(f"音频文件不存在: {audio_path}") | 230 | raise FileNotFoundError(f"音频文件不存在: {audio_path}") |
| 200 | 231 | ||
| 201 | # 1. 标准化并加载音频(可选限制长度) | ||
| 202 | samples, fs = _normalize_audio(audio_path) | 232 | samples, fs = _normalize_audio(audio_path) |
| 203 | 233 | file_sha1, fingerprints = fingerprint_from_samples(samples, fs) | |
| 204 | # 2. 计算文件 SHA1(用于标识) | ||
| 205 | file_sha1 = hashlib.sha1(samples.tobytes()).hexdigest()[:16] | ||
| 206 | |||
| 207 | # 3. 计算频谱图 | ||
| 208 | arr2D = _specgram(samples, fs, DEFAULT_WINDOW_SIZE, DEFAULT_OVERLAP_RATIO) | ||
| 209 | |||
| 210 | # 4. 检测 2D 峰值 | ||
| 211 | freq_idx, time_idx = _get_2d_peaks(arr2D) | ||
| 212 | peaks = list(zip(freq_idx, time_idx)) | ||
| 213 | |||
| 214 | # 5. 生成指纹哈希 | ||
| 215 | fingerprints = list(_generate_hashes(peaks)) | ||
| 216 | |||
| 217 | logger.info("指纹生成完成: audio=%s, 指纹数=%d", audio_path, len(fingerprints)) | 234 | logger.info("指纹生成完成: audio=%s, 指纹数=%d", audio_path, len(fingerprints)) |
| 218 | return file_sha1, fingerprints | 235 | return file_sha1, fingerprints | ... | ... |
| 1 | """Chromagram 特征提取。 | 1 | """Chromagram 特征提取。 |
| 2 | 2 | ||
| 3 | 流程: | 3 | 流程: |
| 4 | 1. 音频标准化:ffmpeg 转 22050Hz / Mono / WAV | 4 | 1. 音频解码:ffmpeg pipe 输出 22050Hz / Mono / f32le,直接读入内存,无临时文件 |
| 5 | 2. librosa 加载音频 | 5 | 2. librosa.feature.chroma_cens() 提取 12×T Chromagram(CENS,对速度/音色鲁棒) |
| 6 | 3. librosa.feature.chroma_cens() 提取 12×T Chromagram(CENS,对速度/音色鲁棒) | 6 | 3. 主音对齐:将能量最大的音级滚至第 0 行,实现转调不变性 |
| 7 | 4. 主音对齐:将能量最大的音级滚至第 0 行,实现转调不变性 | 7 | 4. scipy.signal.resample(chroma, 128, axis=1) 时间归一化到 12×128 |
| 8 | 5. scipy.signal.resample(chroma, 128, axis=1) 时间归一化到 12×128 | 8 | 5. .flatten() 展开为 1536 维向量 |
| 9 | 6. .flatten() 展开为 1536 维向量 | ||
| 10 | """ | 9 | """ |
| 11 | 10 | ||
| 12 | import logging | 11 | import logging |
| 13 | import os | 12 | import os |
| 14 | import subprocess | 13 | import subprocess |
| 15 | import tempfile | ||
| 16 | 14 | ||
| 17 | import librosa | 15 | import librosa |
| 18 | import numpy as np | 16 | import numpy as np |
| ... | @@ -26,83 +24,99 @@ TARGET_FRAMES = 128 | ... | @@ -26,83 +24,99 @@ TARGET_FRAMES = 128 |
| 26 | VECTOR_DIM = 12 * TARGET_FRAMES # 1536 | 24 | VECTOR_DIM = 12 * TARGET_FRAMES # 1536 |
| 27 | 25 | ||
| 28 | 26 | ||
| 29 | def _normalize_audio_ffmpeg(audio_path: str, output_path: str) -> None: | 27 | def _load_audio_via_pipe(audio_path: str) -> np.ndarray: |
| 30 | """使用 ffmpeg 将音频标准化为 22050Hz / Mono / WAV。""" | 28 | """使用 ffmpeg pipe 将音频解码为 22050Hz mono float32,不落临时文件到磁盘。""" |
| 31 | cmd = [ | 29 | cmd = [ |
| 32 | "ffmpeg", | 30 | "ffmpeg", "-y", |
| 33 | "-y", | ||
| 34 | "-i", audio_path, | 31 | "-i", audio_path, |
| 35 | "-ar", str(TARGET_SR), | 32 | "-ar", str(TARGET_SR), |
| 36 | "-ac", "1", | 33 | "-ac", "1", |
| 37 | "-f", "wav", | 34 | "-f", "f32le", |
| 38 | output_path, | 35 | "pipe:1", |
| 39 | ] | 36 | ] |
| 40 | result = subprocess.run( | 37 | result = subprocess.run(cmd, capture_output=True) |
| 41 | cmd, | ||
| 42 | capture_output=True, | ||
| 43 | text=True, | ||
| 44 | ) | ||
| 45 | if result.returncode != 0: | 38 | if result.returncode != 0: |
| 46 | raise RuntimeError(f"ffmpeg 转换失败: {result.stderr}") | 39 | raise RuntimeError(f"ffmpeg 解码失败: {result.stderr.decode(errors='replace')}") |
| 40 | return np.frombuffer(result.stdout, dtype=np.float32) | ||
| 47 | 41 | ||
| 48 | 42 | ||
| 49 | def extract_chroma_feature(audio_path: str) -> np.ndarray: | 43 | def extract_chroma_feature_from_samples( |
| 50 | """从音频文件提取 1536 维 Chromagram 特征向量。 | 44 | samples: np.ndarray, |
| 45 | sr: int, | ||
| 46 | hop_length: int = 512, | ||
| 47 | win_len_smooth: int = 41, | ||
| 48 | ) -> np.ndarray: | ||
| 49 | """从已加载的音频样本提取 1536 维 Chromagram 特征向量。 | ||
| 50 | |||
| 51 | 若 sr 不等于 TARGET_SR,先用 librosa.resample 在内存中降采样, | ||
| 52 | 避免重新走 ffmpeg 流程。 | ||
| 51 | 53 | ||
| 52 | Args: | 54 | Args: |
| 53 | audio_path: 音频文件路径。 | 55 | samples: 单声道音频样本(任意采样率)。 |
| 56 | sr: samples 对应的采样率。 | ||
| 57 | hop_length: CQT hop 大小,增大可成比例降低计算量,不影响最终 128 帧精度。 | ||
| 58 | win_len_smooth: CENS 平滑窗口帧数,应随 hop_length 等比缩小以保持相同的时间覆盖。 | ||
| 54 | 59 | ||
| 55 | Returns: | 60 | Returns: |
| 56 | shape 为 (1536,) 的 numpy 数组。 | 61 | shape 为 (1536,) 的 numpy 数组。 |
| 57 | |||
| 58 | Raises: | ||
| 59 | FileNotFoundError: 音频文件不存在。 | ||
| 60 | RuntimeError: ffmpeg 转换失败。 | ||
| 61 | """ | 62 | """ |
| 62 | if not os.path.isfile(audio_path): | 63 | y = samples if sr == TARGET_SR else librosa.resample(samples, orig_sr=sr, target_sr=TARGET_SR) |
| 63 | raise FileNotFoundError(f"音频文件不存在: {audio_path}") | ||
| 64 | |||
| 65 | # 1. 音频标准化:ffmpeg 转 WAV | ||
| 66 | with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | ||
| 67 | tmp_wav = tmp.name | ||
| 68 | |||
| 69 | try: | ||
| 70 | _normalize_audio_ffmpeg(audio_path, tmp_wav) | ||
| 71 | |||
| 72 | # 2. librosa 加载音频 | ||
| 73 | y, _sr = librosa.load(tmp_wav, sr=TARGET_SR, mono=True) | ||
| 74 | 64 | ||
| 75 | # 3. 提取 CENS Chromagram (12×T),对速度变化和音色具有更强鲁棒性 | 65 | # 提取 CENS Chromagram (12×T) |
| 76 | chroma = librosa.feature.chroma_cens(y=y, sr=TARGET_SR) | 66 | chroma = librosa.feature.chroma_cens(y=y, sr=TARGET_SR, hop_length=hop_length, win_len_smooth=win_len_smooth) |
| 77 | 67 | ||
| 78 | # 4. 主音对齐:将全局能量最大的音级循环滚至第 0 行,实现转调不变性 | 68 | # 主音对齐 |
| 79 | tonic = int(np.argmax(chroma.sum(axis=1))) | 69 | tonic = int(np.argmax(chroma.sum(axis=1))) |
| 80 | if tonic != 0: | 70 | if tonic != 0: |
| 81 | chroma = np.roll(chroma, -tonic, axis=0) | 71 | chroma = np.roll(chroma, -tonic, axis=0) |
| 82 | 72 | ||
| 83 | # 5. 时间归一化到 12×128 | 73 | # 时间归一化到 12×128 |
| 84 | if chroma.shape[1] != TARGET_FRAMES: | 74 | if chroma.shape[1] != TARGET_FRAMES: |
| 85 | chroma = resample(chroma, TARGET_FRAMES, axis=1) | 75 | chroma = resample(chroma, TARGET_FRAMES, axis=1) |
| 86 | 76 | ||
| 87 | # 6. 展开为 1536 维向量 | ||
| 88 | feature = chroma.flatten().astype(np.float32) | 77 | feature = chroma.flatten().astype(np.float32) |
| 89 | |||
| 90 | assert feature.shape == (VECTOR_DIM,), ( | 78 | assert feature.shape == (VECTOR_DIM,), ( |
| 91 | f"特征维度错误: 期望 {VECTOR_DIM}, 实际 {feature.shape}" | 79 | f"特征维度错误: 期望 {VECTOR_DIM}, 实际 {feature.shape}" |
| 92 | ) | 80 | ) |
| 93 | |||
| 94 | return feature | 81 | return feature |
| 95 | finally: | ||
| 96 | # 清理临时文件 | ||
| 97 | if os.path.exists(tmp_wav): | ||
| 98 | os.remove(tmp_wav) | ||
| 99 | 82 | ||
| 100 | 83 | ||
| 101 | def extract_chroma_matrix(audio_path: str) -> np.ndarray: | 84 | def extract_chroma_matrix_from_samples( |
| 85 | samples: np.ndarray, | ||
| 86 | sr: int, | ||
| 87 | hop_length: int = 512, | ||
| 88 | win_len_smooth: int = 41, | ||
| 89 | ) -> np.ndarray: | ||
| 90 | """从已加载的音频样本提取 12×128 Chromagram 矩阵(供 DTW 精排使用)。""" | ||
| 91 | return extract_chroma_feature_from_samples(samples, sr, hop_length=hop_length, win_len_smooth=win_len_smooth).reshape(12, TARGET_FRAMES) | ||
| 92 | |||
| 93 | |||
| 94 | def extract_chroma_feature(audio_path: str, hop_length: int = 512, win_len_smooth: int = 41) -> np.ndarray: | ||
| 95 | """从音频文件提取 1536 维 Chromagram 特征向量。 | ||
| 96 | |||
| 97 | Args: | ||
| 98 | audio_path: 音频文件路径。 | ||
| 99 | hop_length: CQT hop 大小。 | ||
| 100 | win_len_smooth: CENS 平滑窗口帧数。 | ||
| 101 | |||
| 102 | Returns: | ||
| 103 | shape 为 (1536,) 的 numpy 数组。 | ||
| 104 | |||
| 105 | Raises: | ||
| 106 | FileNotFoundError: 音频文件不存在。 | ||
| 107 | RuntimeError: ffmpeg 解码失败。 | ||
| 108 | """ | ||
| 109 | if not os.path.isfile(audio_path): | ||
| 110 | raise FileNotFoundError(f"音频文件不存在: {audio_path}") | ||
| 111 | |||
| 112 | y = _load_audio_via_pipe(audio_path) | ||
| 113 | return extract_chroma_feature_from_samples(y, TARGET_SR, hop_length=hop_length, win_len_smooth=win_len_smooth) | ||
| 114 | |||
| 115 | |||
| 116 | def extract_chroma_matrix(audio_path: str, hop_length: int = 512, win_len_smooth: int = 41) -> np.ndarray: | ||
| 102 | """从音频文件提取 12×128 Chromagram 矩阵(未展平,供 DTW 精排使用)。 | 117 | """从音频文件提取 12×128 Chromagram 矩阵(未展平,供 DTW 精排使用)。 |
| 103 | 118 | ||
| 104 | Returns: | 119 | Returns: |
| 105 | shape 为 (12, 128) 的 numpy 数组,已做主音对齐。 | 120 | shape 为 (12, 128) 的 numpy 数组,已做主音对齐。 |
| 106 | """ | 121 | """ |
| 107 | feature = extract_chroma_feature(audio_path) | 122 | return extract_chroma_feature(audio_path, hop_length=hop_length, win_len_smooth=win_len_smooth).reshape(12, TARGET_FRAMES) |
| 108 | return feature.reshape(12, TARGET_FRAMES) | ... | ... |
This diff is collapsed.
Click to expand it.
| ... | @@ -12,6 +12,7 @@ tqdm>=4.66 | ... | @@ -12,6 +12,7 @@ tqdm>=4.66 |
| 12 | 12 | ||
| 13 | # Audio composition feature extraction | 13 | # Audio composition feature extraction |
| 14 | librosa>=0.10.0 | 14 | librosa>=0.10.0 |
| 15 | numba>=0.59.0 | ||
| 15 | scipy>=1.11 | 16 | scipy>=1.11 |
| 16 | numpy>=1.24 | 17 | numpy>=1.24 |
| 17 | 18 | ||
| ... | @@ -21,3 +22,6 @@ pgvector>=0.2.0 | ... | @@ -21,3 +22,6 @@ pgvector>=0.2.0 |
| 21 | # HTTP API server | 22 | # HTTP API server |
| 22 | fastapi>=0.110.0 | 23 | fastapi>=0.110.0 |
| 23 | uvicorn[standard]>=0.29.0 | 24 | uvicorn[standard]>=0.29.0 |
| 25 | |||
| 26 | # Environment variable loading | ||
| 27 | python-dotenv>=1.0 | ... | ... |
| ... | @@ -8,8 +8,8 @@ expected_song_id 的 top-k/top1 命中只作为诊断字段。 | ... | @@ -8,8 +8,8 @@ expected_song_id 的 top-k/top1 命中只作为诊断字段。 |
| 8 | 用法: | 8 | 用法: |
| 9 | python scripts/evaluate_composition.py \ | 9 | python scripts/evaluate_composition.py \ |
| 10 | --dsn "postgresql:///lyric_dedup" \ | 10 | --dsn "postgresql:///lyric_dedup" \ |
| 11 | --queries composition_dedup/composition_testset4/queries.csv \ | 11 | --queries composition_testset/test_samples.csv \ |
| 12 | --out composition_dedup/composition_eval/composition_eval_result_v3.csv | 12 | --out composition_dedup/composition_eval/nohop_result.csv |
| 13 | """ | 13 | """ |
| 14 | 14 | ||
| 15 | import argparse | 15 | import argparse |
| ... | @@ -17,10 +17,14 @@ import csv | ... | @@ -17,10 +17,14 @@ import csv |
| 17 | import json | 17 | import json |
| 18 | import logging | 18 | import logging |
| 19 | import sys | 19 | import sys |
| 20 | import time | ||
| 20 | from pathlib import Path | 21 | from pathlib import Path |
| 21 | 22 | ||
| 22 | sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) | 23 | sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| 23 | 24 | ||
| 25 | from dotenv import load_dotenv | ||
| 26 | load_dotenv(Path(__file__).resolve().parent.parent / ".env") | ||
| 27 | |||
| 24 | from composition_dedup.service import CompositionConfig, CompositionDedupService | 28 | from composition_dedup.service import CompositionConfig, CompositionDedupService |
| 25 | 29 | ||
| 26 | logger = logging.getLogger(__name__) | 30 | logger = logging.getLogger(__name__) |
| ... | @@ -92,8 +96,12 @@ def main() -> None: | ... | @@ -92,8 +96,12 @@ def main() -> None: |
| 92 | invalid_negative_pair = (not expected_dup) and audio_song_id == expected_song_id | 96 | invalid_negative_pair = (not expected_dup) and audio_song_id == expected_song_id |
| 93 | 97 | ||
| 94 | try: | 98 | try: |
| 95 | candidates = service.query(audio_path, top_k=args.top_k) | 99 | timings: dict = {} |
| 100 | _t0 = time.perf_counter() | ||
| 101 | candidates = service.query(audio_path, top_k=args.top_k, timings=timings) | ||
| 102 | query_time_ms = round((time.perf_counter() - _t0) * 1000, 1) | ||
| 96 | except Exception as e: | 103 | except Exception as e: |
| 104 | query_time_ms = round((time.perf_counter() - _t0) * 1000, 1) | ||
| 97 | logger.error("[%d/%d] 查询失败: %s, %s", i, len(rows), audio_path, e) | 105 | logger.error("[%d/%d] 查询失败: %s, %s", i, len(rows), audio_path, e) |
| 98 | result_rows.append({ | 106 | result_rows.append({ |
| 99 | "query_song_id": query_song_id, | 107 | "query_song_id": query_song_id, |
| ... | @@ -106,6 +114,7 @@ def main() -> None: | ... | @@ -106,6 +114,7 @@ def main() -> None: |
| 106 | "top1_song_id": "", | 114 | "top1_song_id": "", |
| 107 | "top1_similarity": "", | 115 | "top1_similarity": "", |
| 108 | "top1_source": "", | 116 | "top1_source": "", |
| 117 | "dejavu_aligned_count": "", | ||
| 109 | "top1_hit": False, | 118 | "top1_hit": False, |
| 110 | "topk_hit": False, | 119 | "topk_hit": False, |
| 111 | "expected_rank": "", | 120 | "expected_rank": "", |
| ... | @@ -115,6 +124,14 @@ def main() -> None: | ... | @@ -115,6 +124,14 @@ def main() -> None: |
| 115 | "expected_duplicate": expected_dup, | 124 | "expected_duplicate": expected_dup, |
| 116 | "predicted_duplicate": False, | 125 | "predicted_duplicate": False, |
| 117 | "correct": not expected_dup, # 查询失败视为 not_duplicate | 126 | "correct": not expected_dup, # 查询失败视为 not_duplicate |
| 127 | "query_time_ms": query_time_ms, | ||
| 128 | "chroma_extract_ms": timings.get("chroma_extract_ms", ""), | ||
| 129 | "db_cosine_ms": timings.get("db_cosine_ms", ""), | ||
| 130 | "db_fetch_ms": timings.get("db_fetch_ms", ""), | ||
| 131 | "dtw_ms": timings.get("dtw_ms", ""), | ||
| 132 | "dejavu_decode_ms": timings.get("dejavu_decode_ms", ""), | ||
| 133 | "dejavu_fingerprint_ms": timings.get("dejavu_fingerprint_ms", ""), | ||
| 134 | "dejavu_db_ms": timings.get("dejavu_db_ms", ""), | ||
| 118 | "error": str(e), | 135 | "error": str(e), |
| 119 | }) | 136 | }) |
| 120 | continue | 137 | continue |
| ... | @@ -123,6 +140,7 @@ def main() -> None: | ... | @@ -123,6 +140,7 @@ def main() -> None: |
| 123 | top1_song_id = str(top1.song_id) if top1 else "" | 140 | top1_song_id = str(top1.song_id) if top1 else "" |
| 124 | top1_sim = round(top1.similarity, 4) if top1 else "" | 141 | top1_sim = round(top1.similarity, 4) if top1 else "" |
| 125 | top1_source = top1.source if top1 else "" | 142 | top1_source = top1.source if top1 else "" |
| 143 | dejavu_aligned_count = top1.dejavu_aligned_count if top1 else "" | ||
| 126 | 144 | ||
| 127 | # 诊断召回:expected_song_id 是否进入 top1/top-k。 | 145 | # 诊断召回:expected_song_id 是否进入 top1/top-k。 |
| 128 | top1_hit = bool(expected_song_id) and top1_song_id == expected_song_id | 146 | top1_hit = bool(expected_song_id) and top1_song_id == expected_song_id |
| ... | @@ -157,6 +175,7 @@ def main() -> None: | ... | @@ -157,6 +175,7 @@ def main() -> None: |
| 157 | "top1_song_id": top1_song_id, | 175 | "top1_song_id": top1_song_id, |
| 158 | "top1_similarity": top1_sim, | 176 | "top1_similarity": top1_sim, |
| 159 | "top1_source": top1_source, | 177 | "top1_source": top1_source, |
| 178 | "dejavu_aligned_count": dejavu_aligned_count if dejavu_aligned_count is not None else "", | ||
| 160 | "top1_hit": top1_hit, | 179 | "top1_hit": top1_hit, |
| 161 | "topk_hit": topk_hit, | 180 | "topk_hit": topk_hit, |
| 162 | "expected_rank": expected_rank, | 181 | "expected_rank": expected_rank, |
| ... | @@ -166,11 +185,19 @@ def main() -> None: | ... | @@ -166,11 +185,19 @@ def main() -> None: |
| 166 | "expected_duplicate": expected_dup, | 185 | "expected_duplicate": expected_dup, |
| 167 | "predicted_duplicate": predicted_dup, | 186 | "predicted_duplicate": predicted_dup, |
| 168 | "correct": correct, | 187 | "correct": correct, |
| 188 | "query_time_ms": query_time_ms, | ||
| 189 | "chroma_extract_ms": timings.get("chroma_extract_ms", ""), | ||
| 190 | "db_cosine_ms": timings.get("db_cosine_ms", ""), | ||
| 191 | "db_fetch_ms": timings.get("db_fetch_ms", ""), | ||
| 192 | "dtw_ms": timings.get("dtw_ms", ""), | ||
| 193 | "dejavu_decode_ms": timings.get("dejavu_decode_ms", ""), | ||
| 194 | "dejavu_fingerprint_ms": timings.get("dejavu_fingerprint_ms", ""), | ||
| 195 | "dejavu_db_ms": timings.get("dejavu_db_ms", ""), | ||
| 169 | "error": "", | 196 | "error": "", |
| 170 | }) | 197 | }) |
| 171 | 198 | ||
| 172 | logger.info( | 199 | logger.info( |
| 173 | "[%d/%d] variant=%s source=%s expected=%s predicted_duplicate=%s threshold=%.4f expected_song_id=%s top1=%s sim=%s top1_hit=%s topk_hit=%s expected_rank=%s expected_sim=%s correct=%s", | 200 | "[%d/%d] variant=%s source=%s expected=%s predicted_duplicate=%s threshold=%.4f expected_song_id=%s top1=%s sim=%s top1_hit=%s topk_hit=%s expected_rank=%s expected_sim=%s correct=%s time_ms=%s", |
| 174 | i, | 201 | i, |
| 175 | len(rows), | 202 | len(rows), |
| 176 | row.get("variant", ""), | 203 | row.get("variant", ""), |
| ... | @@ -186,6 +213,7 @@ def main() -> None: | ... | @@ -186,6 +213,7 @@ def main() -> None: |
| 186 | expected_rank if expected_rank != "" else "-", | 213 | expected_rank if expected_rank != "" else "-", |
| 187 | expected_similarity if expected_similarity != "" else "-", | 214 | expected_similarity if expected_similarity != "" else "-", |
| 188 | correct, | 215 | correct, |
| 216 | query_time_ms, | ||
| 189 | ) | 217 | ) |
| 190 | 218 | ||
| 191 | if i % 10 == 0 or i == len(rows): | 219 | if i % 10 == 0 or i == len(rows): |
| ... | @@ -194,9 +222,14 @@ def main() -> None: | ... | @@ -194,9 +222,14 @@ def main() -> None: |
| 194 | # 写逐条结果 | 222 | # 写逐条结果 |
| 195 | fieldnames = ["query_song_id", "audio_song_id", "audio_path", "variant", "sample_class", | 223 | fieldnames = ["query_song_id", "audio_song_id", "audio_path", "variant", "sample_class", |
| 196 | "expected_song_id", "expected", "top1_song_id", "top1_similarity", "top1_source", | 224 | "expected_song_id", "expected", "top1_song_id", "top1_similarity", "top1_source", |
| 225 | "dejavu_aligned_count", | ||
| 197 | "top1_hit", "topk_hit", "expected_rank", "expected_similarity", | 226 | "top1_hit", "topk_hit", "expected_rank", "expected_similarity", |
| 198 | "invalid_negative_pair", "invalid_boolean_sample", | 227 | "invalid_negative_pair", "invalid_boolean_sample", |
| 199 | "expected_duplicate", "predicted_duplicate", "correct", "error"] | 228 | "expected_duplicate", "predicted_duplicate", "correct", |
| 229 | "query_time_ms", | ||
| 230 | "chroma_extract_ms", "db_cosine_ms", "db_fetch_ms", "dtw_ms", | ||
| 231 | "dejavu_decode_ms", "dejavu_fingerprint_ms", "dejavu_db_ms", | ||
| 232 | "error"] | ||
| 200 | with out_path.open("w", newline="", encoding="utf-8") as f: | 233 | with out_path.open("w", newline="", encoding="utf-8") as f: |
| 201 | writer = csv.DictWriter(f, fieldnames=fieldnames) | 234 | writer = csv.DictWriter(f, fieldnames=fieldnames) |
| 202 | writer.writeheader() | 235 | writer.writeheader() | ... | ... |
| ... | @@ -6,11 +6,11 @@ | ... | @@ -6,11 +6,11 @@ |
| 6 | 6 | ||
| 7 | 用法: | 7 | 用法: |
| 8 | python scripts/generate_composition_testset.py \ | 8 | python scripts/generate_composition_testset.py \ |
| 9 | --audio-dir /Volumes/移动硬盘/lyric_audio_type11 \ | 9 | --audio-dir /Volumes/移动硬盘/composition_test \ |
| 10 | --negative-audio-dir /Volumes/移动硬盘/composition_test \ | 10 | --negative-audio-dir /Volumes/移动硬盘/composition_drop \ |
| 11 | --out-dir composition_dedup/composition_testset \ | 11 | --out-dir composition_testset \ |
| 12 | --num-songs 80 \ | 12 | --num-songs 100 \ |
| 13 | --num-negative-songs 40 \ | 13 | --num-negative-songs 100 \ |
| 14 | --negative-variants \ | 14 | --negative-variants \ |
| 15 | --seed 123 | 15 | --seed 123 |
| 16 | 16 | ... | ... |
| ... | @@ -16,6 +16,9 @@ from pathlib import Path | ... | @@ -16,6 +16,9 @@ from pathlib import Path |
| 16 | 16 | ||
| 17 | sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) | 17 | sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| 18 | 18 | ||
| 19 | from dotenv import load_dotenv | ||
| 20 | load_dotenv(Path(__file__).resolve().parent.parent / ".env") | ||
| 21 | |||
| 19 | from tqdm import tqdm | 22 | from tqdm import tqdm |
| 20 | 23 | ||
| 21 | from composition_dedup.service import CompositionConfig, CompositionDedupService | 24 | from composition_dedup.service import CompositionConfig, CompositionDedupService | ... | ... |
-
Please register or sign in to post a comment