Commit 974df4ae 974df4ae895606924f14bf5e679276b5dd51920e by 沈秋雨

优化性能

1 parent 7a11a3d4
...@@ -26,4 +26,6 @@ venv/ ...@@ -26,4 +26,6 @@ venv/
26 26
27 test_api 27 test_api
28 28
29 composition_dedup/composition_eval
...\ No newline at end of file ...\ No newline at end of file
29 composition_dedup/composition_eval
30
31 composition_testset
...\ No newline at end of file ...\ No newline at end of file
......
...@@ -22,7 +22,6 @@ from pathlib import Path ...@@ -22,7 +22,6 @@ from pathlib import Path
22 import librosa 22 import librosa
23 import numpy as np 23 import numpy as np
24 from scipy.ndimage import ( 24 from scipy.ndimage import (
25 binary_erosion,
26 generate_binary_structure, 25 generate_binary_structure,
27 iterate_structure, 26 iterate_structure,
28 maximum_filter, 27 maximum_filter,
...@@ -60,10 +59,10 @@ MIN_HASH_TIME_DELTA = 0 ...@@ -60,10 +59,10 @@ MIN_HASH_TIME_DELTA = 0
60 MAX_HASH_TIME_DELTA = 200 59 MAX_HASH_TIME_DELTA = 200
61 PEAK_SORT = True 60 PEAK_SORT = True
62 FINGERPRINT_REDUCTION = 20 61 FINGERPRINT_REDUCTION = 20
63 MAX_DURATION_SEC = float(os.environ.get("COMPOSITION_DEJAVU_MAX_DURATION", "120")) # 0=不限制 62 QUERY_MAX_DURATION_SEC = float(os.environ.get("COMPOSITION_DEJAVU_QUERY_MAX_DURATION", "120")) # 0=不限制
64 63
65 64
66 def _normalize_audio(audio_path: str, max_duration: float = MAX_DURATION_SEC) -> tuple[np.ndarray, int]: 65 def _normalize_audio(audio_path: str, max_duration: float = 0) -> tuple[np.ndarray, int]:
67 """将音频标准化为单声道 WAV 并加载为 numpy 数组。 66 """将音频标准化为单声道 WAV 并加载为 numpy 数组。
68 67
69 使用 ffmpeg 先做重采样,再用 librosa 读取。 68 使用 ffmpeg 先做重采样,再用 librosa 读取。
...@@ -133,12 +132,7 @@ def _get_2d_peaks(arr2D: np.ndarray, amp_min: float = DEFAULT_AMP_MIN): ...@@ -133,12 +132,7 @@ def _get_2d_peaks(arr2D: np.ndarray, amp_min: float = DEFAULT_AMP_MIN):
133 neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE) 132 neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
134 133
135 # 找局部极大值 134 # 找局部极大值
136 local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D 135 detected_peaks = maximum_filter(arr2D, footprint=neighborhood) == arr2D
137 background = arr2D == 0
138 eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
139
140 # 布尔掩码
141 detected_peaks = local_max ^ eroded_background
142 136
143 # 提取峰值 137 # 提取峰值
144 amps = arr2D[detected_peaks] 138 amps = arr2D[detected_peaks]
...@@ -181,6 +175,43 @@ def _generate_hashes(peaks: list[tuple[int, int]], fan_value: int = DEFAULT_FAN_ ...@@ -181,6 +175,43 @@ def _generate_hashes(peaks: list[tuple[int, int]], fan_value: int = DEFAULT_FAN_
181 yield (h.hexdigest()[:FINGERPRINT_REDUCTION].encode(), t1) 175 yield (h.hexdigest()[:FINGERPRINT_REDUCTION].encode(), t1)
182 176
183 177
178 def load_audio(audio_path: str, max_duration: float = 0) -> tuple[np.ndarray, int]:
179 """加载并标准化音频为 44100Hz 单声道(供多路径共用,避免重复解码)。
180
181 Args:
182 audio_path: 音频文件路径。
183 max_duration: 最大截取时长(秒),0 表示不限制。
184
185 Returns:
186 (samples, sr) 元组。
187 """
188 return _normalize_audio(audio_path, max_duration)
189
190
191 def fingerprint_from_samples(
192 samples: np.ndarray, sr: int, *, compute_sha1: bool = True
193 ) -> tuple[str, list[tuple[bytes, int]]]:
194 """对已加载的音频样本生成 Dejavu 风格指纹(不做 I/O)。
195
196 Args:
197 samples: 单声道音频样本(应为 DEFAULT_FS=44100Hz)。
198 sr: 采样率。
199 compute_sha1: 是否计算 file_sha1。service 内部调用时传 False 可跳过
200 对 samples.tobytes() 的 21MB 哈希运算(返回值在那些路径中未被使用)。
201
202 Returns:
203 (file_sha1, fingerprints) 元组,
204 其中 fingerprints 是 [(hash_bytes, offset), ...] 列表。
205 compute_sha1=False 时 file_sha1 返回空字符串。
206 """
207 file_sha1 = hashlib.sha1(samples.tobytes()).hexdigest()[:16] if compute_sha1 else ""
208 arr2D = _specgram(samples, sr, DEFAULT_WINDOW_SIZE, DEFAULT_OVERLAP_RATIO)
209 freq_idx, time_idx = _get_2d_peaks(arr2D)
210 peaks = list(zip(freq_idx, time_idx))
211 fingerprints = list(_generate_hashes(peaks))
212 return file_sha1, fingerprints
213
214
184 def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]: 215 def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]:
185 """对音频文件生成 Dejavu 风格指纹。 216 """对音频文件生成 Dejavu 风格指纹。
186 217
...@@ -198,21 +229,7 @@ def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]: ...@@ -198,21 +229,7 @@ def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]:
198 if not os.path.isfile(audio_path): 229 if not os.path.isfile(audio_path):
199 raise FileNotFoundError(f"音频文件不存在: {audio_path}") 230 raise FileNotFoundError(f"音频文件不存在: {audio_path}")
200 231
201 # 1. 标准化并加载音频(可选限制长度)
202 samples, fs = _normalize_audio(audio_path) 232 samples, fs = _normalize_audio(audio_path)
203 233 file_sha1, fingerprints = fingerprint_from_samples(samples, fs)
204 # 2. 计算文件 SHA1(用于标识)
205 file_sha1 = hashlib.sha1(samples.tobytes()).hexdigest()[:16]
206
207 # 3. 计算频谱图
208 arr2D = _specgram(samples, fs, DEFAULT_WINDOW_SIZE, DEFAULT_OVERLAP_RATIO)
209
210 # 4. 检测 2D 峰值
211 freq_idx, time_idx = _get_2d_peaks(arr2D)
212 peaks = list(zip(freq_idx, time_idx))
213
214 # 5. 生成指纹哈希
215 fingerprints = list(_generate_hashes(peaks))
216
217 logger.info("指纹生成完成: audio=%s, 指纹数=%d", audio_path, len(fingerprints)) 234 logger.info("指纹生成完成: audio=%s, 指纹数=%d", audio_path, len(fingerprints))
218 return file_sha1, fingerprints 235 return file_sha1, fingerprints
......
1 """Chromagram 特征提取。 1 """Chromagram 特征提取。
2 2
3 流程: 3 流程:
4 1. 音频标准化:ffmpeg 转 22050Hz / Mono / WAV 4 1. 音频解码:ffmpeg pipe 输出 22050Hz / Mono / f32le,直接读入内存,无临时文件
5 2. librosa 加载音频 5 2. librosa.feature.chroma_cens() 提取 12×T Chromagram(CENS,对速度/音色鲁棒)
6 3. librosa.feature.chroma_cens() 提取 12×T Chromagram(CENS,对速度/音色鲁棒) 6 3. 主音对齐:将能量最大的音级滚至第 0 行,实现转调不变性
7 4. 主音对齐:将能量最大的音级滚至第 0 行,实现转调不变性 7 4. scipy.signal.resample(chroma, 128, axis=1) 时间归一化到 12×128
8 5. scipy.signal.resample(chroma, 128, axis=1) 时间归一化到 12×128 8 5. .flatten() 展开为 1536 维向量
9 6. .flatten() 展开为 1536 维向量
10 """ 9 """
11 10
12 import logging 11 import logging
13 import os 12 import os
14 import subprocess 13 import subprocess
15 import tempfile
16 14
17 import librosa 15 import librosa
18 import numpy as np 16 import numpy as np
...@@ -26,83 +24,99 @@ TARGET_FRAMES = 128 ...@@ -26,83 +24,99 @@ TARGET_FRAMES = 128
26 VECTOR_DIM = 12 * TARGET_FRAMES # 1536 24 VECTOR_DIM = 12 * TARGET_FRAMES # 1536
27 25
28 26
29 def _normalize_audio_ffmpeg(audio_path: str, output_path: str) -> None: 27 def _load_audio_via_pipe(audio_path: str) -> np.ndarray:
30 """使用 ffmpeg 将音频标准化为 22050Hz / Mono / WAV。""" 28 """使用 ffmpeg pipe 将音频解码为 22050Hz mono float32,不落临时文件到磁盘。"""
31 cmd = [ 29 cmd = [
32 "ffmpeg", 30 "ffmpeg", "-y",
33 "-y",
34 "-i", audio_path, 31 "-i", audio_path,
35 "-ar", str(TARGET_SR), 32 "-ar", str(TARGET_SR),
36 "-ac", "1", 33 "-ac", "1",
37 "-f", "wav", 34 "-f", "f32le",
38 output_path, 35 "pipe:1",
39 ] 36 ]
40 result = subprocess.run( 37 result = subprocess.run(cmd, capture_output=True)
41 cmd,
42 capture_output=True,
43 text=True,
44 )
45 if result.returncode != 0: 38 if result.returncode != 0:
46 raise RuntimeError(f"ffmpeg 转换失败: {result.stderr}") 39 raise RuntimeError(f"ffmpeg 解码失败: {result.stderr.decode(errors='replace')}")
40 return np.frombuffer(result.stdout, dtype=np.float32)
47 41
48 42
49 def extract_chroma_feature(audio_path: str) -> np.ndarray: 43 def extract_chroma_feature_from_samples(
50 """从音频文件提取 1536 维 Chromagram 特征向量。 44 samples: np.ndarray,
45 sr: int,
46 hop_length: int = 512,
47 win_len_smooth: int = 41,
48 ) -> np.ndarray:
49 """从已加载的音频样本提取 1536 维 Chromagram 特征向量。
50
51 若 sr 不等于 TARGET_SR,先用 librosa.resample 在内存中降采样,
52 避免重新走 ffmpeg 流程。
51 53
52 Args: 54 Args:
53 audio_path: 音频文件路径。 55 samples: 单声道音频样本(任意采样率)。
56 sr: samples 对应的采样率。
57 hop_length: CQT hop 大小,增大可成比例降低计算量,不影响最终 128 帧精度。
58 win_len_smooth: CENS 平滑窗口帧数,应随 hop_length 等比缩小以保持相同的时间覆盖。
54 59
55 Returns: 60 Returns:
56 shape 为 (1536,) 的 numpy 数组。 61 shape 为 (1536,) 的 numpy 数组。
57
58 Raises:
59 FileNotFoundError: 音频文件不存在。
60 RuntimeError: ffmpeg 转换失败。
61 """ 62 """
62 if not os.path.isfile(audio_path): 63 y = samples if sr == TARGET_SR else librosa.resample(samples, orig_sr=sr, target_sr=TARGET_SR)
63 raise FileNotFoundError(f"音频文件不存在: {audio_path}")
64 64
65 # 1. 音频标准化:ffmpeg 转 WAV 65 # 提取 CENS Chromagram (12×T)
66 with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: 66 chroma = librosa.feature.chroma_cens(y=y, sr=TARGET_SR, hop_length=hop_length, win_len_smooth=win_len_smooth)
67 tmp_wav = tmp.name
68 67
69 try: 68 # 主音对齐
70 _normalize_audio_ffmpeg(audio_path, tmp_wav) 69 tonic = int(np.argmax(chroma.sum(axis=1)))
70 if tonic != 0:
71 chroma = np.roll(chroma, -tonic, axis=0)
71 72
72 # 2. librosa 加载音频 73 # 时间归一化到 12×128
73 y, _sr = librosa.load(tmp_wav, sr=TARGET_SR, mono=True) 74 if chroma.shape[1] != TARGET_FRAMES:
75 chroma = resample(chroma, TARGET_FRAMES, axis=1)
76
77 feature = chroma.flatten().astype(np.float32)
78 assert feature.shape == (VECTOR_DIM,), (
79 f"特征维度错误: 期望 {VECTOR_DIM}, 实际 {feature.shape}"
80 )
81 return feature
74 82
75 # 3. 提取 CENS Chromagram (12×T),对速度变化和音色具有更强鲁棒性
76 chroma = librosa.feature.chroma_cens(y=y, sr=TARGET_SR)
77 83
78 # 4. 主音对齐:将全局能量最大的音级循环滚至第 0 行,实现转调不变性 84 def extract_chroma_matrix_from_samples(
79 tonic = int(np.argmax(chroma.sum(axis=1))) 85 samples: np.ndarray,
80 if tonic != 0: 86 sr: int,
81 chroma = np.roll(chroma, -tonic, axis=0) 87 hop_length: int = 512,
88 win_len_smooth: int = 41,
89 ) -> np.ndarray:
90 """从已加载的音频样本提取 12×128 Chromagram 矩阵(供 DTW 精排使用)。"""
91 return extract_chroma_feature_from_samples(samples, sr, hop_length=hop_length, win_len_smooth=win_len_smooth).reshape(12, TARGET_FRAMES)
82 92
83 # 5. 时间归一化到 12×128
84 if chroma.shape[1] != TARGET_FRAMES:
85 chroma = resample(chroma, TARGET_FRAMES, axis=1)
86 93
87 # 6. 展开为 1536 维向量 94 def extract_chroma_feature(audio_path: str, hop_length: int = 512, win_len_smooth: int = 41) -> np.ndarray:
88 feature = chroma.flatten().astype(np.float32) 95 """从音频文件提取 1536 维 Chromagram 特征向量。
89 96
90 assert feature.shape == (VECTOR_DIM,), ( 97 Args:
91 f"特征维度错误: 期望 {VECTOR_DIM}, 实际 {feature.shape}" 98 audio_path: 音频文件路径。
92 ) 99 hop_length: CQT hop 大小。
100 win_len_smooth: CENS 平滑窗口帧数。
101
102 Returns:
103 shape 为 (1536,) 的 numpy 数组。
104
105 Raises:
106 FileNotFoundError: 音频文件不存在。
107 RuntimeError: ffmpeg 解码失败。
108 """
109 if not os.path.isfile(audio_path):
110 raise FileNotFoundError(f"音频文件不存在: {audio_path}")
93 111
94 return feature 112 y = _load_audio_via_pipe(audio_path)
95 finally: 113 return extract_chroma_feature_from_samples(y, TARGET_SR, hop_length=hop_length, win_len_smooth=win_len_smooth)
96 # 清理临时文件
97 if os.path.exists(tmp_wav):
98 os.remove(tmp_wav)
99 114
100 115
101 def extract_chroma_matrix(audio_path: str) -> np.ndarray: 116 def extract_chroma_matrix(audio_path: str, hop_length: int = 512, win_len_smooth: int = 41) -> np.ndarray:
102 """从音频文件提取 12×128 Chromagram 矩阵(未展平,供 DTW 精排使用)。 117 """从音频文件提取 12×128 Chromagram 矩阵(未展平,供 DTW 精排使用)。
103 118
104 Returns: 119 Returns:
105 shape 为 (12, 128) 的 numpy 数组,已做主音对齐。 120 shape 为 (12, 128) 的 numpy 数组,已做主音对齐。
106 """ 121 """
107 feature = extract_chroma_feature(audio_path) 122 return extract_chroma_feature(audio_path, hop_length=hop_length, win_len_smooth=win_len_smooth).reshape(12, TARGET_FRAMES)
108 return feature.reshape(12, TARGET_FRAMES)
......
...@@ -12,6 +12,7 @@ tqdm>=4.66 ...@@ -12,6 +12,7 @@ tqdm>=4.66
12 12
13 # Audio composition feature extraction 13 # Audio composition feature extraction
14 librosa>=0.10.0 14 librosa>=0.10.0
15 numba>=0.59.0
15 scipy>=1.11 16 scipy>=1.11
16 numpy>=1.24 17 numpy>=1.24
17 18
...@@ -21,3 +22,6 @@ pgvector>=0.2.0 ...@@ -21,3 +22,6 @@ pgvector>=0.2.0
21 # HTTP API server 22 # HTTP API server
22 fastapi>=0.110.0 23 fastapi>=0.110.0
23 uvicorn[standard]>=0.29.0 24 uvicorn[standard]>=0.29.0
25
26 # Environment variable loading
27 python-dotenv>=1.0
......
...@@ -8,8 +8,8 @@ expected_song_id 的 top-k/top1 命中只作为诊断字段。 ...@@ -8,8 +8,8 @@ expected_song_id 的 top-k/top1 命中只作为诊断字段。
8 用法: 8 用法:
9 python scripts/evaluate_composition.py \ 9 python scripts/evaluate_composition.py \
10 --dsn "postgresql:///lyric_dedup" \ 10 --dsn "postgresql:///lyric_dedup" \
11 --queries composition_dedup/composition_testset4/queries.csv \ 11 --queries composition_testset/test_samples.csv \
12 --out composition_dedup/composition_eval/composition_eval_result_v3.csv 12 --out composition_dedup/composition_eval/nohop_result.csv
13 """ 13 """
14 14
15 import argparse 15 import argparse
...@@ -17,10 +17,14 @@ import csv ...@@ -17,10 +17,14 @@ import csv
17 import json 17 import json
18 import logging 18 import logging
19 import sys 19 import sys
20 import time
20 from pathlib import Path 21 from pathlib import Path
21 22
22 sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) 23 sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
23 24
25 from dotenv import load_dotenv
26 load_dotenv(Path(__file__).resolve().parent.parent / ".env")
27
24 from composition_dedup.service import CompositionConfig, CompositionDedupService 28 from composition_dedup.service import CompositionConfig, CompositionDedupService
25 29
26 logger = logging.getLogger(__name__) 30 logger = logging.getLogger(__name__)
...@@ -92,8 +96,12 @@ def main() -> None: ...@@ -92,8 +96,12 @@ def main() -> None:
92 invalid_negative_pair = (not expected_dup) and audio_song_id == expected_song_id 96 invalid_negative_pair = (not expected_dup) and audio_song_id == expected_song_id
93 97
94 try: 98 try:
95 candidates = service.query(audio_path, top_k=args.top_k) 99 timings: dict = {}
100 _t0 = time.perf_counter()
101 candidates = service.query(audio_path, top_k=args.top_k, timings=timings)
102 query_time_ms = round((time.perf_counter() - _t0) * 1000, 1)
96 except Exception as e: 103 except Exception as e:
104 query_time_ms = round((time.perf_counter() - _t0) * 1000, 1)
97 logger.error("[%d/%d] 查询失败: %s, %s", i, len(rows), audio_path, e) 105 logger.error("[%d/%d] 查询失败: %s, %s", i, len(rows), audio_path, e)
98 result_rows.append({ 106 result_rows.append({
99 "query_song_id": query_song_id, 107 "query_song_id": query_song_id,
...@@ -106,6 +114,7 @@ def main() -> None: ...@@ -106,6 +114,7 @@ def main() -> None:
106 "top1_song_id": "", 114 "top1_song_id": "",
107 "top1_similarity": "", 115 "top1_similarity": "",
108 "top1_source": "", 116 "top1_source": "",
117 "dejavu_aligned_count": "",
109 "top1_hit": False, 118 "top1_hit": False,
110 "topk_hit": False, 119 "topk_hit": False,
111 "expected_rank": "", 120 "expected_rank": "",
...@@ -115,6 +124,14 @@ def main() -> None: ...@@ -115,6 +124,14 @@ def main() -> None:
115 "expected_duplicate": expected_dup, 124 "expected_duplicate": expected_dup,
116 "predicted_duplicate": False, 125 "predicted_duplicate": False,
117 "correct": not expected_dup, # 查询失败视为 not_duplicate 126 "correct": not expected_dup, # 查询失败视为 not_duplicate
127 "query_time_ms": query_time_ms,
128 "chroma_extract_ms": timings.get("chroma_extract_ms", ""),
129 "db_cosine_ms": timings.get("db_cosine_ms", ""),
130 "db_fetch_ms": timings.get("db_fetch_ms", ""),
131 "dtw_ms": timings.get("dtw_ms", ""),
132 "dejavu_decode_ms": timings.get("dejavu_decode_ms", ""),
133 "dejavu_fingerprint_ms": timings.get("dejavu_fingerprint_ms", ""),
134 "dejavu_db_ms": timings.get("dejavu_db_ms", ""),
118 "error": str(e), 135 "error": str(e),
119 }) 136 })
120 continue 137 continue
...@@ -123,6 +140,7 @@ def main() -> None: ...@@ -123,6 +140,7 @@ def main() -> None:
123 top1_song_id = str(top1.song_id) if top1 else "" 140 top1_song_id = str(top1.song_id) if top1 else ""
124 top1_sim = round(top1.similarity, 4) if top1 else "" 141 top1_sim = round(top1.similarity, 4) if top1 else ""
125 top1_source = top1.source if top1 else "" 142 top1_source = top1.source if top1 else ""
143 dejavu_aligned_count = top1.dejavu_aligned_count if top1 else ""
126 144
127 # 诊断召回:expected_song_id 是否进入 top1/top-k。 145 # 诊断召回:expected_song_id 是否进入 top1/top-k。
128 top1_hit = bool(expected_song_id) and top1_song_id == expected_song_id 146 top1_hit = bool(expected_song_id) and top1_song_id == expected_song_id
...@@ -157,6 +175,7 @@ def main() -> None: ...@@ -157,6 +175,7 @@ def main() -> None:
157 "top1_song_id": top1_song_id, 175 "top1_song_id": top1_song_id,
158 "top1_similarity": top1_sim, 176 "top1_similarity": top1_sim,
159 "top1_source": top1_source, 177 "top1_source": top1_source,
178 "dejavu_aligned_count": dejavu_aligned_count if dejavu_aligned_count is not None else "",
160 "top1_hit": top1_hit, 179 "top1_hit": top1_hit,
161 "topk_hit": topk_hit, 180 "topk_hit": topk_hit,
162 "expected_rank": expected_rank, 181 "expected_rank": expected_rank,
...@@ -166,11 +185,19 @@ def main() -> None: ...@@ -166,11 +185,19 @@ def main() -> None:
166 "expected_duplicate": expected_dup, 185 "expected_duplicate": expected_dup,
167 "predicted_duplicate": predicted_dup, 186 "predicted_duplicate": predicted_dup,
168 "correct": correct, 187 "correct": correct,
188 "query_time_ms": query_time_ms,
189 "chroma_extract_ms": timings.get("chroma_extract_ms", ""),
190 "db_cosine_ms": timings.get("db_cosine_ms", ""),
191 "db_fetch_ms": timings.get("db_fetch_ms", ""),
192 "dtw_ms": timings.get("dtw_ms", ""),
193 "dejavu_decode_ms": timings.get("dejavu_decode_ms", ""),
194 "dejavu_fingerprint_ms": timings.get("dejavu_fingerprint_ms", ""),
195 "dejavu_db_ms": timings.get("dejavu_db_ms", ""),
169 "error": "", 196 "error": "",
170 }) 197 })
171 198
172 logger.info( 199 logger.info(
173 "[%d/%d] variant=%s source=%s expected=%s predicted_duplicate=%s threshold=%.4f expected_song_id=%s top1=%s sim=%s top1_hit=%s topk_hit=%s expected_rank=%s expected_sim=%s correct=%s", 200 "[%d/%d] variant=%s source=%s expected=%s predicted_duplicate=%s threshold=%.4f expected_song_id=%s top1=%s sim=%s top1_hit=%s topk_hit=%s expected_rank=%s expected_sim=%s correct=%s time_ms=%s",
174 i, 201 i,
175 len(rows), 202 len(rows),
176 row.get("variant", ""), 203 row.get("variant", ""),
...@@ -186,6 +213,7 @@ def main() -> None: ...@@ -186,6 +213,7 @@ def main() -> None:
186 expected_rank if expected_rank != "" else "-", 213 expected_rank if expected_rank != "" else "-",
187 expected_similarity if expected_similarity != "" else "-", 214 expected_similarity if expected_similarity != "" else "-",
188 correct, 215 correct,
216 query_time_ms,
189 ) 217 )
190 218
191 if i % 10 == 0 or i == len(rows): 219 if i % 10 == 0 or i == len(rows):
...@@ -194,9 +222,14 @@ def main() -> None: ...@@ -194,9 +222,14 @@ def main() -> None:
194 # 写逐条结果 222 # 写逐条结果
195 fieldnames = ["query_song_id", "audio_song_id", "audio_path", "variant", "sample_class", 223 fieldnames = ["query_song_id", "audio_song_id", "audio_path", "variant", "sample_class",
196 "expected_song_id", "expected", "top1_song_id", "top1_similarity", "top1_source", 224 "expected_song_id", "expected", "top1_song_id", "top1_similarity", "top1_source",
225 "dejavu_aligned_count",
197 "top1_hit", "topk_hit", "expected_rank", "expected_similarity", 226 "top1_hit", "topk_hit", "expected_rank", "expected_similarity",
198 "invalid_negative_pair", "invalid_boolean_sample", 227 "invalid_negative_pair", "invalid_boolean_sample",
199 "expected_duplicate", "predicted_duplicate", "correct", "error"] 228 "expected_duplicate", "predicted_duplicate", "correct",
229 "query_time_ms",
230 "chroma_extract_ms", "db_cosine_ms", "db_fetch_ms", "dtw_ms",
231 "dejavu_decode_ms", "dejavu_fingerprint_ms", "dejavu_db_ms",
232 "error"]
200 with out_path.open("w", newline="", encoding="utf-8") as f: 233 with out_path.open("w", newline="", encoding="utf-8") as f:
201 writer = csv.DictWriter(f, fieldnames=fieldnames) 234 writer = csv.DictWriter(f, fieldnames=fieldnames)
202 writer.writeheader() 235 writer.writeheader()
......
...@@ -6,11 +6,11 @@ ...@@ -6,11 +6,11 @@
6 6
7 用法: 7 用法:
8 python scripts/generate_composition_testset.py \ 8 python scripts/generate_composition_testset.py \
9 --audio-dir /Volumes/移动硬盘/lyric_audio_type11 \ 9 --audio-dir /Volumes/移动硬盘/composition_test \
10 --negative-audio-dir /Volumes/移动硬盘/composition_test \ 10 --negative-audio-dir /Volumes/移动硬盘/composition_drop \
11 --out-dir composition_dedup/composition_testset \ 11 --out-dir composition_testset \
12 --num-songs 80 \ 12 --num-songs 100 \
13 --num-negative-songs 40 \ 13 --num-negative-songs 100 \
14 --negative-variants \ 14 --negative-variants \
15 --seed 123 15 --seed 123
16 16
......
...@@ -16,6 +16,9 @@ from pathlib import Path ...@@ -16,6 +16,9 @@ from pathlib import Path
16 16
17 sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) 17 sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
18 18
19 from dotenv import load_dotenv
20 load_dotenv(Path(__file__).resolve().parent.parent / ".env")
21
19 from tqdm import tqdm 22 from tqdm import tqdm
20 23
21 from composition_dedup.service import CompositionConfig, CompositionDedupService 24 from composition_dedup.service import CompositionConfig, CompositionDedupService
......