Commit 974df4ae 974df4ae895606924f14bf5e679276b5dd51920e by 沈秋雨

优化性能

1 parent 7a11a3d4
...@@ -27,3 +27,5 @@ venv/ ...@@ -27,3 +27,5 @@ venv/
27 test_api 27 test_api
28 28
29 composition_dedup/composition_eval 29 composition_dedup/composition_eval
30
31 composition_testset
...\ No newline at end of file ...\ No newline at end of file
......
...@@ -22,7 +22,6 @@ from pathlib import Path ...@@ -22,7 +22,6 @@ from pathlib import Path
22 import librosa 22 import librosa
23 import numpy as np 23 import numpy as np
24 from scipy.ndimage import ( 24 from scipy.ndimage import (
25 binary_erosion,
26 generate_binary_structure, 25 generate_binary_structure,
27 iterate_structure, 26 iterate_structure,
28 maximum_filter, 27 maximum_filter,
...@@ -60,10 +59,10 @@ MIN_HASH_TIME_DELTA = 0 ...@@ -60,10 +59,10 @@ MIN_HASH_TIME_DELTA = 0
60 MAX_HASH_TIME_DELTA = 200 59 MAX_HASH_TIME_DELTA = 200
61 PEAK_SORT = True 60 PEAK_SORT = True
62 FINGERPRINT_REDUCTION = 20 61 FINGERPRINT_REDUCTION = 20
63 MAX_DURATION_SEC = float(os.environ.get("COMPOSITION_DEJAVU_MAX_DURATION", "120")) # 0=不限制 62 QUERY_MAX_DURATION_SEC = float(os.environ.get("COMPOSITION_DEJAVU_QUERY_MAX_DURATION", "120")) # 0=不限制
64 63
65 64
66 def _normalize_audio(audio_path: str, max_duration: float = MAX_DURATION_SEC) -> tuple[np.ndarray, int]: 65 def _normalize_audio(audio_path: str, max_duration: float = 0) -> tuple[np.ndarray, int]:
67 """将音频标准化为单声道 WAV 并加载为 numpy 数组。 66 """将音频标准化为单声道 WAV 并加载为 numpy 数组。
68 67
69 使用 ffmpeg 先做重采样,再用 librosa 读取。 68 使用 ffmpeg 先做重采样,再用 librosa 读取。
...@@ -133,12 +132,7 @@ def _get_2d_peaks(arr2D: np.ndarray, amp_min: float = DEFAULT_AMP_MIN): ...@@ -133,12 +132,7 @@ def _get_2d_peaks(arr2D: np.ndarray, amp_min: float = DEFAULT_AMP_MIN):
133 neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE) 132 neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
134 133
135 # 找局部极大值 134 # 找局部极大值
136 local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D 135 detected_peaks = maximum_filter(arr2D, footprint=neighborhood) == arr2D
137 background = arr2D == 0
138 eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
139
140 # 布尔掩码
141 detected_peaks = local_max ^ eroded_background
142 136
143 # 提取峰值 137 # 提取峰值
144 amps = arr2D[detected_peaks] 138 amps = arr2D[detected_peaks]
...@@ -181,6 +175,43 @@ def _generate_hashes(peaks: list[tuple[int, int]], fan_value: int = DEFAULT_FAN_ ...@@ -181,6 +175,43 @@ def _generate_hashes(peaks: list[tuple[int, int]], fan_value: int = DEFAULT_FAN_
181 yield (h.hexdigest()[:FINGERPRINT_REDUCTION].encode(), t1) 175 yield (h.hexdigest()[:FINGERPRINT_REDUCTION].encode(), t1)
182 176
183 177
178 def load_audio(audio_path: str, max_duration: float = 0) -> tuple[np.ndarray, int]:
179 """加载并标准化音频为 44100Hz 单声道(供多路径共用,避免重复解码)。
180
181 Args:
182 audio_path: 音频文件路径。
183 max_duration: 最大截取时长(秒),0 表示不限制。
184
185 Returns:
186 (samples, sr) 元组。
187 """
188 return _normalize_audio(audio_path, max_duration)
189
190
191 def fingerprint_from_samples(
192 samples: np.ndarray, sr: int, *, compute_sha1: bool = True
193 ) -> tuple[str, list[tuple[bytes, int]]]:
194 """对已加载的音频样本生成 Dejavu 风格指纹(不做 I/O)。
195
196 Args:
197 samples: 单声道音频样本(应为 DEFAULT_FS=44100Hz)。
198 sr: 采样率。
199 compute_sha1: 是否计算 file_sha1。service 内部调用时传 False 可跳过
200 对 samples.tobytes() 的 21MB 哈希运算(返回值在那些路径中未被使用)。
201
202 Returns:
203 (file_sha1, fingerprints) 元组,
204 其中 fingerprints 是 [(hash_bytes, offset), ...] 列表。
205 compute_sha1=False 时 file_sha1 返回空字符串。
206 """
207 file_sha1 = hashlib.sha1(samples.tobytes()).hexdigest()[:16] if compute_sha1 else ""
208 arr2D = _specgram(samples, sr, DEFAULT_WINDOW_SIZE, DEFAULT_OVERLAP_RATIO)
209 freq_idx, time_idx = _get_2d_peaks(arr2D)
210 peaks = list(zip(freq_idx, time_idx))
211 fingerprints = list(_generate_hashes(peaks))
212 return file_sha1, fingerprints
213
214
184 def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]: 215 def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]:
185 """对音频文件生成 Dejavu 风格指纹。 216 """对音频文件生成 Dejavu 风格指纹。
186 217
...@@ -198,21 +229,7 @@ def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]: ...@@ -198,21 +229,7 @@ def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]:
198 if not os.path.isfile(audio_path): 229 if not os.path.isfile(audio_path):
199 raise FileNotFoundError(f"音频文件不存在: {audio_path}") 230 raise FileNotFoundError(f"音频文件不存在: {audio_path}")
200 231
201 # 1. 标准化并加载音频(可选限制长度)
202 samples, fs = _normalize_audio(audio_path) 232 samples, fs = _normalize_audio(audio_path)
203 233 file_sha1, fingerprints = fingerprint_from_samples(samples, fs)
204 # 2. 计算文件 SHA1(用于标识)
205 file_sha1 = hashlib.sha1(samples.tobytes()).hexdigest()[:16]
206
207 # 3. 计算频谱图
208 arr2D = _specgram(samples, fs, DEFAULT_WINDOW_SIZE, DEFAULT_OVERLAP_RATIO)
209
210 # 4. 检测 2D 峰值
211 freq_idx, time_idx = _get_2d_peaks(arr2D)
212 peaks = list(zip(freq_idx, time_idx))
213
214 # 5. 生成指纹哈希
215 fingerprints = list(_generate_hashes(peaks))
216
217 logger.info("指纹生成完成: audio=%s, 指纹数=%d", audio_path, len(fingerprints)) 234 logger.info("指纹生成完成: audio=%s, 指纹数=%d", audio_path, len(fingerprints))
218 return file_sha1, fingerprints 235 return file_sha1, fingerprints
......
1 """Chromagram 特征提取。 1 """Chromagram 特征提取。
2 2
3 流程: 3 流程:
4 1. 音频标准化:ffmpeg 转 22050Hz / Mono / WAV 4 1. 音频解码:ffmpeg pipe 输出 22050Hz / Mono / f32le,直接读入内存,无临时文件
5 2. librosa 加载音频 5 2. librosa.feature.chroma_cens() 提取 12×T Chromagram(CENS,对速度/音色鲁棒)
6 3. librosa.feature.chroma_cens() 提取 12×T Chromagram(CENS,对速度/音色鲁棒) 6 3. 主音对齐:将能量最大的音级滚至第 0 行,实现转调不变性
7 4. 主音对齐:将能量最大的音级滚至第 0 行,实现转调不变性 7 4. scipy.signal.resample(chroma, 128, axis=1) 时间归一化到 12×128
8 5. scipy.signal.resample(chroma, 128, axis=1) 时间归一化到 12×128 8 5. .flatten() 展开为 1536 维向量
9 6. .flatten() 展开为 1536 维向量
10 """ 9 """
11 10
12 import logging 11 import logging
13 import os 12 import os
14 import subprocess 13 import subprocess
15 import tempfile
16 14
17 import librosa 15 import librosa
18 import numpy as np 16 import numpy as np
...@@ -26,83 +24,99 @@ TARGET_FRAMES = 128 ...@@ -26,83 +24,99 @@ TARGET_FRAMES = 128
26 VECTOR_DIM = 12 * TARGET_FRAMES # 1536 24 VECTOR_DIM = 12 * TARGET_FRAMES # 1536
27 25
28 26
29 def _normalize_audio_ffmpeg(audio_path: str, output_path: str) -> None: 27 def _load_audio_via_pipe(audio_path: str) -> np.ndarray:
30 """使用 ffmpeg 将音频标准化为 22050Hz / Mono / WAV。""" 28 """使用 ffmpeg pipe 将音频解码为 22050Hz mono float32,不落临时文件到磁盘。"""
31 cmd = [ 29 cmd = [
32 "ffmpeg", 30 "ffmpeg", "-y",
33 "-y",
34 "-i", audio_path, 31 "-i", audio_path,
35 "-ar", str(TARGET_SR), 32 "-ar", str(TARGET_SR),
36 "-ac", "1", 33 "-ac", "1",
37 "-f", "wav", 34 "-f", "f32le",
38 output_path, 35 "pipe:1",
39 ] 36 ]
40 result = subprocess.run( 37 result = subprocess.run(cmd, capture_output=True)
41 cmd,
42 capture_output=True,
43 text=True,
44 )
45 if result.returncode != 0: 38 if result.returncode != 0:
46 raise RuntimeError(f"ffmpeg 转换失败: {result.stderr}") 39 raise RuntimeError(f"ffmpeg 解码失败: {result.stderr.decode(errors='replace')}")
40 return np.frombuffer(result.stdout, dtype=np.float32)
47 41
48 42
49 def extract_chroma_feature(audio_path: str) -> np.ndarray: 43 def extract_chroma_feature_from_samples(
50 """从音频文件提取 1536 维 Chromagram 特征向量。 44 samples: np.ndarray,
45 sr: int,
46 hop_length: int = 512,
47 win_len_smooth: int = 41,
48 ) -> np.ndarray:
49 """从已加载的音频样本提取 1536 维 Chromagram 特征向量。
50
51 若 sr 不等于 TARGET_SR,先用 librosa.resample 在内存中降采样,
52 避免重新走 ffmpeg 流程。
51 53
52 Args: 54 Args:
53 audio_path: 音频文件路径。 55 samples: 单声道音频样本(任意采样率)。
56 sr: samples 对应的采样率。
57 hop_length: CQT hop 大小,增大可成比例降低计算量,不影响最终 128 帧精度。
58 win_len_smooth: CENS 平滑窗口帧数,应随 hop_length 等比缩小以保持相同的时间覆盖。
54 59
55 Returns: 60 Returns:
56 shape 为 (1536,) 的 numpy 数组。 61 shape 为 (1536,) 的 numpy 数组。
57
58 Raises:
59 FileNotFoundError: 音频文件不存在。
60 RuntimeError: ffmpeg 转换失败。
61 """ 62 """
62 if not os.path.isfile(audio_path): 63 y = samples if sr == TARGET_SR else librosa.resample(samples, orig_sr=sr, target_sr=TARGET_SR)
63 raise FileNotFoundError(f"音频文件不存在: {audio_path}")
64
65 # 1. 音频标准化:ffmpeg 转 WAV
66 with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
67 tmp_wav = tmp.name
68
69 try:
70 _normalize_audio_ffmpeg(audio_path, tmp_wav)
71
72 # 2. librosa 加载音频
73 y, _sr = librosa.load(tmp_wav, sr=TARGET_SR, mono=True)
74 64
75 # 3. 提取 CENS Chromagram (12×T),对速度变化和音色具有更强鲁棒性 65 # 提取 CENS Chromagram (12×T)
76 chroma = librosa.feature.chroma_cens(y=y, sr=TARGET_SR) 66 chroma = librosa.feature.chroma_cens(y=y, sr=TARGET_SR, hop_length=hop_length, win_len_smooth=win_len_smooth)
77 67
78 # 4. 主音对齐:将全局能量最大的音级循环滚至第 0 行,实现转调不变性 68 # 主音对齐
79 tonic = int(np.argmax(chroma.sum(axis=1))) 69 tonic = int(np.argmax(chroma.sum(axis=1)))
80 if tonic != 0: 70 if tonic != 0:
81 chroma = np.roll(chroma, -tonic, axis=0) 71 chroma = np.roll(chroma, -tonic, axis=0)
82 72
83 # 5. 时间归一化到 12×128 73 # 时间归一化到 12×128
84 if chroma.shape[1] != TARGET_FRAMES: 74 if chroma.shape[1] != TARGET_FRAMES:
85 chroma = resample(chroma, TARGET_FRAMES, axis=1) 75 chroma = resample(chroma, TARGET_FRAMES, axis=1)
86 76
87 # 6. 展开为 1536 维向量
88 feature = chroma.flatten().astype(np.float32) 77 feature = chroma.flatten().astype(np.float32)
89
90 assert feature.shape == (VECTOR_DIM,), ( 78 assert feature.shape == (VECTOR_DIM,), (
91 f"特征维度错误: 期望 {VECTOR_DIM}, 实际 {feature.shape}" 79 f"特征维度错误: 期望 {VECTOR_DIM}, 实际 {feature.shape}"
92 ) 80 )
93
94 return feature 81 return feature
95 finally:
96 # 清理临时文件
97 if os.path.exists(tmp_wav):
98 os.remove(tmp_wav)
99 82
100 83
101 def extract_chroma_matrix(audio_path: str) -> np.ndarray: 84 def extract_chroma_matrix_from_samples(
85 samples: np.ndarray,
86 sr: int,
87 hop_length: int = 512,
88 win_len_smooth: int = 41,
89 ) -> np.ndarray:
90 """从已加载的音频样本提取 12×128 Chromagram 矩阵(供 DTW 精排使用)。"""
91 return extract_chroma_feature_from_samples(samples, sr, hop_length=hop_length, win_len_smooth=win_len_smooth).reshape(12, TARGET_FRAMES)
92
93
94 def extract_chroma_feature(audio_path: str, hop_length: int = 512, win_len_smooth: int = 41) -> np.ndarray:
95 """从音频文件提取 1536 维 Chromagram 特征向量。
96
97 Args:
98 audio_path: 音频文件路径。
99 hop_length: CQT hop 大小。
100 win_len_smooth: CENS 平滑窗口帧数。
101
102 Returns:
103 shape 为 (1536,) 的 numpy 数组。
104
105 Raises:
106 FileNotFoundError: 音频文件不存在。
107 RuntimeError: ffmpeg 解码失败。
108 """
109 if not os.path.isfile(audio_path):
110 raise FileNotFoundError(f"音频文件不存在: {audio_path}")
111
112 y = _load_audio_via_pipe(audio_path)
113 return extract_chroma_feature_from_samples(y, TARGET_SR, hop_length=hop_length, win_len_smooth=win_len_smooth)
114
115
116 def extract_chroma_matrix(audio_path: str, hop_length: int = 512, win_len_smooth: int = 41) -> np.ndarray:
102 """从音频文件提取 12×128 Chromagram 矩阵(未展平,供 DTW 精排使用)。 117 """从音频文件提取 12×128 Chromagram 矩阵(未展平,供 DTW 精排使用)。
103 118
104 Returns: 119 Returns:
105 shape 为 (12, 128) 的 numpy 数组,已做主音对齐。 120 shape 为 (12, 128) 的 numpy 数组,已做主音对齐。
106 """ 121 """
107 feature = extract_chroma_feature(audio_path) 122 return extract_chroma_feature(audio_path, hop_length=hop_length, win_len_smooth=win_len_smooth).reshape(12, TARGET_FRAMES)
108 return feature.reshape(12, TARGET_FRAMES)
......
1 """作曲去重服务(入库 + 查询)。 1 """作曲去重服务(入库 + 查询)。
2 2
3 查询流程: 3 查询流程:
4 1. Dejavu 指纹匹配(毫秒级,子序列匹配,支持 chorus_only / trim_intro 4 1. Chromagram 12路 + DTW 精排(覆盖绝大多数场景,约 1s
5 - 命中(≥ 阈值)→ 直接返回 duplicate(短路) 5 - 命中(≥ 阈值)→ 直接返回结果
6 2. 未命中 → Chromagram 12路 + DTW(百毫秒级 6 2. Chromagram 未命中 → Dejavu 指纹兜底(处理 chorus_only / trim_intro 等片段场景
7 - 返回结果 7 - 命中(≥ 阈值)→ 返回 duplicate(similarity=1.0)
8 """ 8 """
9 9
10 import logging 10 import logging
11 import os 11 import os
12 import time
12 from dataclasses import dataclass, field 13 from dataclasses import dataclass, field
13 14
15 import numba
14 import numpy as np 16 import numpy as np
15 import psycopg 17 import psycopg
16 from scipy.spatial.distance import cdist 18 from scipy.spatial.distance import cdist
17 19
18 from .extractor import TARGET_FRAMES, extract_chroma_feature, extract_chroma_matrix 20 from .extractor import (
19 from .dejavu_fingerprinter import fingerprint_audio 21 TARGET_FRAMES,
22 extract_chroma_feature_from_samples,
23 extract_chroma_matrix,
24 )
25 from .dejavu_fingerprinter import fingerprint_audio, fingerprint_from_samples, load_audio, QUERY_MAX_DURATION_SEC
20 26
21 logger = logging.getLogger(__name__) 27 logger = logging.getLogger(__name__)
22 28
...@@ -56,6 +62,7 @@ class CompositionCandidate: ...@@ -56,6 +62,7 @@ class CompositionCandidate:
56 song_id: int 62 song_id: int
57 similarity: float 63 similarity: float
58 source: str = "chromagram" 64 source: str = "chromagram"
65 dejavu_aligned_count: int | None = None # 仅 source=dejavu 或 dejavu fallback 未命中时记录
59 66
60 67
61 @dataclass 68 @dataclass
...@@ -73,10 +80,18 @@ class CompositionConfig: ...@@ -73,10 +80,18 @@ class CompositionConfig:
73 statement_timeout_ms: int = 30000 80 statement_timeout_ms: int = 30000
74 dtw_rerank_top_k: int = 20 # Cosine 召回后做 DTW 精排的候选数量 81 dtw_rerank_top_k: int = 20 # Cosine 召回后做 DTW 精排的候选数量
75 duplicate_threshold: float = _env_float("COMPOSITION_DUPLICATE_THRESHOLD", 0.85) 82 duplicate_threshold: float = _env_float("COMPOSITION_DUPLICATE_THRESHOLD", 0.85)
83 # Chromagram 提取配置
84 chroma_hop_length: int = _env_int("COMPOSITION_CHROMA_HOP_LENGTH", 512)
85 chroma_win_len_smooth: int = _env_int("COMPOSITION_CHROMA_WIN_LEN_SMOOTH", 0)
76 # Dejavu 指纹匹配配置 86 # Dejavu 指纹匹配配置
77 dejavu_enabled: bool = _env_bool("COMPOSITION_DEJAVU_ENABLED", True) 87 dejavu_enabled: bool = _env_bool("COMPOSITION_DEJAVU_ENABLED", True)
78 dejavu_match_threshold: int = _env_int("COMPOSITION_DEJAVU_MATCH_THRESHOLD", 20) 88 dejavu_match_threshold: int = _env_int("COMPOSITION_DEJAVU_MATCH_THRESHOLD", 20)
79 89
90 def __post_init__(self) -> None:
91 # 0 表示自动:按 hop_length 等比缩小,保持平滑窗覆盖时长约 1 秒
92 if self.chroma_win_len_smooth == 0:
93 self.chroma_win_len_smooth = max(1, round(41 * 512 / self.chroma_hop_length))
94
80 95
81 @dataclass 96 @dataclass
82 class CompositionDedupService: 97 class CompositionDedupService:
...@@ -94,8 +109,12 @@ class CompositionDedupService: ...@@ -94,8 +109,12 @@ class CompositionDedupService:
94 Returns: 109 Returns:
95 提取的特征向量。 110 提取的特征向量。
96 """ 111 """
97 feature = extract_chroma_feature(audio_path) 112 # 共用一次解码(44100Hz),chromagram 路径在内存中重采样,无需二次 ffmpeg
98 self._logger.info("提取 Chromagram 特征完成: song_id=%s, audio=%s", song_id, audio_path) 113 samples, sr = load_audio(audio_path)
114 self._logger.info("音频解码完成: song_id=%s, audio=%s", song_id, audio_path)
115
116 feature = extract_chroma_feature_from_samples(samples, sr, hop_length=self.config.chroma_hop_length, win_len_smooth=self.config.chroma_win_len_smooth)
117 self._logger.info("提取 Chromagram 特征完成: song_id=%s", song_id)
99 118
100 with psycopg.connect(self.config.dsn) as conn: 119 with psycopg.connect(self.config.dsn) as conn:
101 with conn.cursor() as cursor: 120 with conn.cursor() as cursor:
...@@ -111,15 +130,29 @@ class CompositionDedupService: ...@@ -111,15 +130,29 @@ class CompositionDedupService:
111 130
112 self._logger.info("Chromagram 特征入库完成: song_id=%s", song_id) 131 self._logger.info("Chromagram 特征入库完成: song_id=%s", song_id)
113 132
114 # Dejavu 指纹同时入库
115 if self.config.dejavu_enabled: 133 if self.config.dejavu_enabled:
116 self._dejavu_ingest(song_id, audio_path) 134 self._dejavu_ingest(song_id, audio_path, samples=samples, sr=sr)
117 135
118 return feature 136 return feature
119 137
120 def _dejavu_ingest(self, song_id: int, audio_path: str) -> None: 138 def _dejavu_ingest(
121 """提取 Dejavu 指纹并写入 dejavu_fingerprints 表。""" 139 self,
122 file_sha1, fingerprints = fingerprint_audio(audio_path) 140 song_id: int,
141 audio_path: str,
142 *,
143 samples: np.ndarray | None = None,
144 sr: int | None = None,
145 ) -> None:
146 """提取 Dejavu 指纹并写入 dejavu_fingerprints 表。
147
148 若提供了已解码的 samples/sr,直接使用,跳过 ffmpeg;否则从文件重新加载。
149 """
150 if samples is not None and sr is not None:
151 _, fingerprints = fingerprint_from_samples(samples, sr, compute_sha1=False)
152 self._logger.info("Dejavu 指纹提取完成(共用解码): song_id=%s", song_id)
153 else:
154 _, fingerprints = fingerprint_audio(audio_path)
155
123 if not fingerprints: 156 if not fingerprints:
124 self._logger.warning("Dejavu 指纹为空: song_id=%s, audio=%s", song_id, audio_path) 157 self._logger.warning("Dejavu 指纹为空: song_id=%s, audio=%s", song_id, audio_path)
125 return 158 return
...@@ -144,25 +177,54 @@ class CompositionDedupService: ...@@ -144,25 +177,54 @@ class CompositionDedupService:
144 177
145 self._logger.info("Dejavu 指纹入库完成: song_id=%s, 指纹数=%d", song_id, len(fingerprints)) 178 self._logger.info("Dejavu 指纹入库完成: song_id=%s, 指纹数=%d", song_id, len(fingerprints))
146 179
147 def query(self, audio_path: str, top_k: int = 100) -> list[CompositionCandidate]: 180 def query(
181 self,
182 audio_path: str,
183 top_k: int = 100,
184 timings: dict | None = None,
185 ) -> list[CompositionCandidate]:
148 """提取音频特征并查询相似结果。 186 """提取音频特征并查询相似结果。
149 187
150 流程:Dejavu 指纹短路匹配 → 12 路循环对齐 Cosine 召回 → DTW 精排。 188 流程:Chromagram 12路 + DTW 精排 → Dejavu 指纹兜底(片段场景,仅在 chroma 未命中时解码)。
189
190 Args:
191 timings: 若传入非 None 的 dict,方法执行完毕后会在其中写入各阶段耗时(单位 ms):
192 chroma_extract_ms、db_cosine_ms、db_fetch_ms、dtw_ms、
193 dejavu_decode_ms、dejavu_fingerprint_ms、dejavu_db_ms。
194 Dejavu 路径未执行时,对应键不会写入。
151 """ 195 """
152 # 1. 优先尝试 Dejavu 指纹匹配(短路) 196 # 1. Chromagram 12路 + DTW 精排(覆盖绝大多数场景)
153 if self.config.dejavu_enabled: 197 # 使用与入库一致的 22050Hz 解码路径,保证 chroma 向量对齐
154 match = self._dejavu_query(audio_path) 198 candidates = self._query_chroma(audio_path, top_k, timings=timings)
199
200 # 2. Chromagram 未命中时,用 Dejavu 兜底(处理 chorus_only / trim_intro 等片段场景)
201 # 只有未命中才解码 44100Hz,大多数情况下无额外 I/O
202 if self.config.dejavu_enabled and not self.candidates_indicate_duplicate(candidates):
203 _t = time.perf_counter()
204 samples, sr = load_audio(audio_path, max_duration=QUERY_MAX_DURATION_SEC)
205 if timings is not None:
206 timings["dejavu_decode_ms"] = round((time.perf_counter() - _t) * 1000, 1)
207 match = self._dejavu_query(samples, sr, timings=timings)
155 if match is not None: 208 if match is not None:
209 if match.aligned_count >= self.config.dejavu_match_threshold:
156 self._logger.info( 210 self._logger.info(
157 "Dejavu 命中: song_id=%s, aligned_count=%d, total_collisions=%d, decision=duplicate", 211 "Dejavu 命中: song_id=%s, aligned_count=%d, total_collisions=%d, decision=duplicate",
158 match.song_id, 212 match.song_id,
159 match.aligned_count, 213 match.aligned_count,
160 match.total_collisions, 214 match.total_collisions,
161 ) 215 )
162 return [CompositionCandidate(song_id=match.song_id, similarity=1.0, source="dejavu")] 216 return [CompositionCandidate(
163 217 song_id=match.song_id,
164 # 2. Dejavu 未命中或禁用,走现有 Chromagram 12路 + DTW 流程 218 similarity=1.0,
165 return self._query_chroma(audio_path, top_k) 219 source="dejavu",
220 dejavu_aligned_count=match.aligned_count,
221 )]
222 else:
223 # 未达阈值:把 aligned_count 附加到 chromagram top1 上供评估脚本记录
224 if candidates:
225 candidates[0].dejavu_aligned_count = match.aligned_count
226
227 return candidates
166 228
167 def check(self, audio_path: str, top_k: int = 100) -> bool: 229 def check(self, audio_path: str, top_k: int = 100) -> bool:
168 """按最终接口语义返回是否重复。""" 230 """按最终接口语义返回是否重复。"""
...@@ -178,9 +240,17 @@ class CompositionDedupService: ...@@ -178,9 +240,17 @@ class CompositionDedupService:
178 return False 240 return False
179 return candidates[0].similarity >= self.config.duplicate_threshold 241 return candidates[0].similarity >= self.config.duplicate_threshold
180 242
181 def _query_chroma(self, audio_path: str, top_k: int = 100) -> list[CompositionCandidate]: 243 def _query_chroma(
244 self,
245 audio_path: str,
246 top_k: int = 100,
247 timings: dict | None = None,
248 ) -> list[CompositionCandidate]:
182 """Chromagram 12 路循环对齐 + DTW 精排查询。""" 249 """Chromagram 12 路循环对齐 + DTW 精排查询。"""
183 chroma = extract_chroma_matrix(audio_path) 250 _t = time.perf_counter()
251 chroma = extract_chroma_matrix(audio_path, hop_length=self.config.chroma_hop_length, win_len_smooth=self.config.chroma_win_len_smooth)
252 if timings is not None:
253 timings["chroma_extract_ms"] = round((time.perf_counter() - _t) * 1000, 1)
184 self._logger.info("提取 Chromagram 查询特征完成: audio=%s", audio_path) 254 self._logger.info("提取 Chromagram 查询特征完成: audio=%s", audio_path)
185 255
186 # 1. 12 路循环对齐:穷举 12 种半音偏移,单条 SQL 内部展开,按 song_id 取最高 Cosine 相似度 256 # 1. 12 路循环对齐:穷举 12 种半音偏移,单条 SQL 内部展开,按 song_id 取最高 Cosine 相似度
...@@ -213,6 +283,7 @@ class CompositionDedupService: ...@@ -213,6 +283,7 @@ class CompositionDedupService:
213 LIMIT %s 283 LIMIT %s
214 """ 284 """
215 best: dict[int, float] = {} 285 best: dict[int, float] = {}
286 _t = time.perf_counter()
216 with psycopg.connect(self.config.dsn) as conn: 287 with psycopg.connect(self.config.dsn) as conn:
217 with conn.cursor() as cursor: 288 with conn.cursor() as cursor:
218 cursor.execute( 289 cursor.execute(
...@@ -221,33 +292,44 @@ class CompositionDedupService: ...@@ -221,33 +292,44 @@ class CompositionDedupService:
221 cursor.execute(sql, (*shift_vecs, top_k, top_k)) 292 cursor.execute(sql, (*shift_vecs, top_k, top_k))
222 for song_id, sim in cursor.fetchall(): 293 for song_id, sim in cursor.fetchall():
223 best[int(song_id)] = float(sim) 294 best[int(song_id)] = float(sim)
295 if timings is not None:
296 timings["db_cosine_ms"] = round((time.perf_counter() - _t) * 1000, 1)
224 297
225 # 2. 取 Top dtw_rerank_top_k,从库中取原始向量做 DTW 精排 298 # 2. 取 Top dtw_rerank_top_k,从库中取原始向量做 DTW 精排
226 top = sorted(best.items(), key=lambda x: x[1], reverse=True) 299 top = sorted(best.items(), key=lambda x: x[1], reverse=True)
227 rerank_ids = [sid for sid, _ in top[:self.config.dtw_rerank_top_k]] 300 rerank_ids = [sid for sid, _ in top[:self.config.dtw_rerank_top_k]]
228 301
302 _t = time.perf_counter()
229 with conn.cursor() as cursor: 303 with conn.cursor() as cursor:
230 cursor.execute( 304 cursor.execute(
231 "SELECT song_id, feature_vector::float4[] FROM composition_feature WHERE song_id = ANY(%s)", 305 "SELECT song_id, feature_vector::float4[] FROM composition_feature WHERE song_id = ANY(%s)",
232 (rerank_ids,), 306 (rerank_ids,),
233 ) 307 )
234 db_rows = cursor.fetchall() 308 db_rows = cursor.fetchall()
309 if timings is not None:
310 timings["db_fetch_ms"] = round((time.perf_counter() - _t) * 1000, 1)
235 311
312 _t = time.perf_counter()
236 reranked = [] 313 reranked = []
237 for song_id, fv in db_rows: 314 for song_id, fv in db_rows:
238 cand_chroma = np.array(fv, dtype=np.float32).reshape(12, TARGET_FRAMES) 315 cand_chroma = np.array(fv, dtype=np.float32).reshape(12, TARGET_FRAMES)
239 dtw_sim = _best_shifted_dtw_similarity(chroma, cand_chroma) 316 dtw_sim = _best_shifted_dtw_similarity(
317 chroma, cand_chroma, early_exit_threshold=self.config.duplicate_threshold
318 )
240 reranked.append(CompositionCandidate(song_id=int(song_id), similarity=dtw_sim)) 319 reranked.append(CompositionCandidate(song_id=int(song_id), similarity=dtw_sim))
241 reranked.sort(key=lambda c: c.similarity, reverse=True) 320 reranked.sort(key=lambda c: c.similarity, reverse=True)
321 if timings is not None:
322 timings["dtw_ms"] = round((time.perf_counter() - _t) * 1000, 1)
242 323
243 rerank_id_set = {c.song_id for c in reranked} 324 rerank_id_set = {c.song_id for c in reranked}
244 rest = [ 325 # top 已按 cosine 降序排列;直接从中剔除已精排的候选,剩余保留 cosine 分
326 non_reranked = [
245 CompositionCandidate(song_id=sid, similarity=sim) 327 CompositionCandidate(song_id=sid, similarity=sim)
246 for sid, sim in top[self.config.dtw_rerank_top_k:] 328 for sid, sim in top
247 if sid not in rerank_id_set 329 if sid not in rerank_id_set
248 ] 330 ]
249 331
250 result = reranked + rest 332 result = reranked + non_reranked
251 top_summary = ", ".join( 333 top_summary = ", ".join(
252 f"{candidate.song_id}:{candidate.similarity:.4f}" 334 f"{candidate.song_id}:{candidate.similarity:.4f}"
253 for candidate in result[:5] 335 for candidate in result[:5]
...@@ -259,7 +341,12 @@ class CompositionDedupService: ...@@ -259,7 +341,12 @@ class CompositionDedupService:
259 ) 341 )
260 return result 342 return result
261 343
262 def _dejavu_query(self, audio_path: str) -> _DejavuMatch | None: 344 def _dejavu_query(
345 self,
346 samples: np.ndarray,
347 sr: int,
348 timings: dict | None = None,
349 ) -> _DejavuMatch | None:
263 """Dejavu 指纹查询,返回 offset 对齐后碰撞数最多的 song_id。 350 """Dejavu 指纹查询,返回 offset 对齐后碰撞数最多的 song_id。
264 351
265 只统计 hash 总碰撞数会让常见频谱峰值、噪声片段或大库随机碰撞直接短路成 352 只统计 hash 总碰撞数会让常见频谱峰值、噪声片段或大库随机碰撞直接短路成
...@@ -267,15 +354,19 @@ class CompositionDedupService: ...@@ -267,15 +354,19 @@ class CompositionDedupService:
267 db_offset - query_offset 落在同一个时间偏移上。 354 db_offset - query_offset 落在同一个时间偏移上。
268 355
269 Returns: 356 Returns:
270 命中结果,未命中返回 None。 357 最佳匹配结果(不做阈值过滤),无任何碰撞时返回 None。
271 """ 358 """
272 file_sha1, fingerprints = fingerprint_audio(audio_path) 359 _t = time.perf_counter()
360 _, fingerprints = fingerprint_from_samples(samples, sr, compute_sha1=False)
361 if timings is not None:
362 timings["dejavu_fingerprint_ms"] = round((time.perf_counter() - _t) * 1000, 1)
273 if not fingerprints: 363 if not fingerprints:
274 return None 364 return None
275 365
276 hashes = [h for h, _ in fingerprints] 366 hashes = [h for h, _ in fingerprints]
277 offsets = [int(o) for _, o in fingerprints] 367 offsets = [int(o) for _, o in fingerprints]
278 368
369 _t = time.perf_counter()
279 with psycopg.connect(self.config.dsn) as conn: 370 with psycopg.connect(self.config.dsn) as conn:
280 with conn.cursor() as cursor: 371 with conn.cursor() as cursor:
281 # 先按 hash 找碰撞,再按每个 song_id 的 offset delta 聚类。 372 # 先按 hash 找碰撞,再按每个 song_id 的 offset delta 聚类。
...@@ -307,23 +398,21 @@ class CompositionDedupService: ...@@ -307,23 +398,21 @@ class CompositionDedupService:
307 (hashes, offsets), 398 (hashes, offsets),
308 ) 399 )
309 row = cursor.fetchone() 400 row = cursor.fetchone()
401 if timings is not None:
402 timings["dejavu_db_ms"] = round((time.perf_counter() - _t) * 1000, 1)
310 if row is None: 403 if row is None:
311 return None 404 return None
312 sid, aligned_count, total_collisions = row 405 sid, aligned_count, total_collisions = row
313 aligned_count = int(aligned_count)
314 if aligned_count >= self.config.dejavu_match_threshold:
315 return _DejavuMatch( 406 return _DejavuMatch(
316 song_id=int(sid), 407 song_id=int(sid),
317 aligned_count=aligned_count, 408 aligned_count=int(aligned_count),
318 total_collisions=int(total_collisions), 409 total_collisions=int(total_collisions),
319 ) 410 )
320 return None
321 411
322 412
323 def _dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float: 413 @numba.njit(cache=True)
324 """计算两个 12×T Chromagram 矩阵之间的 DTW 相似度(映射到 [0,1])。""" 414 def _dtw_dp(cost: np.ndarray) -> float:
325 # 帧间欧氏距离矩阵 415 """DTW DP 填表(numba JIT 编译,数值结果与纯 Python 实现完全一致)。"""
326 cost = cdist(query.T, candidate.T, metric="euclidean")
327 n, m = cost.shape 416 n, m = cost.shape
328 dp = np.full((n, m), np.inf) 417 dp = np.full((n, m), np.inf)
329 dp[0, 0] = cost[0, 0] 418 dp[0, 0] = cost[0, 0]
...@@ -334,14 +423,37 @@ def _dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float: ...@@ -334,14 +423,37 @@ def _dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float:
334 for i in range(1, n): 423 for i in range(1, n):
335 for j in range(1, m): 424 for j in range(1, m):
336 dp[i, j] = cost[i, j] + min(dp[i - 1, j], dp[i, j - 1], dp[i - 1, j - 1]) 425 dp[i, j] = cost[i, j] + min(dp[i - 1, j], dp[i, j - 1], dp[i - 1, j - 1])
337 dtw_dist = dp[n - 1, m - 1] / (n + m) 426 return dp[n - 1, m - 1]
427
428
429 def _dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float:
430 """计算两个 12×T Chromagram 矩阵之间的 DTW 相似度(映射到 [0,1])。"""
431 # 帧间欧氏距离矩阵
432 cost = cdist(query.T, candidate.T, metric="euclidean")
433 n, m = cost.shape
434 dtw_dist = _dtw_dp(cost) / (n + m)
338 # 转换为相似度:距离越小相似度越高 435 # 转换为相似度:距离越小相似度越高
339 return float(1.0 / (1.0 + dtw_dist)) 436 return float(1.0 / (1.0 + dtw_dist))
340 437
341 438
342 def _best_shifted_dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float: 439 def _best_shifted_dtw_similarity(
343 """计算 12 路音高循环位移下的最佳 DTW 相似度。""" 440 query: np.ndarray,
344 return max( 441 candidate: np.ndarray,
345 _dtw_similarity(np.roll(query, -shift, axis=0), candidate) 442 early_exit_threshold: float = 1.1,
346 for shift in range(12) 443 ) -> float:
347 ) 444 """计算 12 路音高循环位移下的最佳 DTW 相似度。
445
446 Args:
447 early_exit_threshold: 某个 shift 的相似度达到此值时立即返回,跳过剩余 shift。
448 传入 duplicate_threshold 即可:对已确认重复的候选不再浪费算力;
449 返回值可能略低于理论最大值,但不影响 duplicate/non-duplicate 二元判定。
450 默认 1.1(> 1 的不可达值,等价于不启用早退)。
451 """
452 best = 0.0
453 for shift in range(12):
454 sim = _dtw_similarity(np.roll(query, -shift, axis=0), candidate)
455 if sim > best:
456 best = sim
457 if best >= early_exit_threshold:
458 break
459 return best
......
...@@ -12,6 +12,7 @@ tqdm>=4.66 ...@@ -12,6 +12,7 @@ tqdm>=4.66
12 12
13 # Audio composition feature extraction 13 # Audio composition feature extraction
14 librosa>=0.10.0 14 librosa>=0.10.0
15 numba>=0.59.0
15 scipy>=1.11 16 scipy>=1.11
16 numpy>=1.24 17 numpy>=1.24
17 18
...@@ -21,3 +22,6 @@ pgvector>=0.2.0 ...@@ -21,3 +22,6 @@ pgvector>=0.2.0
21 # HTTP API server 22 # HTTP API server
22 fastapi>=0.110.0 23 fastapi>=0.110.0
23 uvicorn[standard]>=0.29.0 24 uvicorn[standard]>=0.29.0
25
26 # Environment variable loading
27 python-dotenv>=1.0
......
...@@ -8,8 +8,8 @@ expected_song_id 的 top-k/top1 命中只作为诊断字段。 ...@@ -8,8 +8,8 @@ expected_song_id 的 top-k/top1 命中只作为诊断字段。
8 用法: 8 用法:
9 python scripts/evaluate_composition.py \ 9 python scripts/evaluate_composition.py \
10 --dsn "postgresql:///lyric_dedup" \ 10 --dsn "postgresql:///lyric_dedup" \
11 --queries composition_dedup/composition_testset4/queries.csv \ 11 --queries composition_testset/test_samples.csv \
12 --out composition_dedup/composition_eval/composition_eval_result_v3.csv 12 --out composition_dedup/composition_eval/nohop_result.csv
13 """ 13 """
14 14
15 import argparse 15 import argparse
...@@ -17,10 +17,14 @@ import csv ...@@ -17,10 +17,14 @@ import csv
17 import json 17 import json
18 import logging 18 import logging
19 import sys 19 import sys
20 import time
20 from pathlib import Path 21 from pathlib import Path
21 22
22 sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) 23 sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
23 24
25 from dotenv import load_dotenv
26 load_dotenv(Path(__file__).resolve().parent.parent / ".env")
27
24 from composition_dedup.service import CompositionConfig, CompositionDedupService 28 from composition_dedup.service import CompositionConfig, CompositionDedupService
25 29
26 logger = logging.getLogger(__name__) 30 logger = logging.getLogger(__name__)
...@@ -92,8 +96,12 @@ def main() -> None: ...@@ -92,8 +96,12 @@ def main() -> None:
92 invalid_negative_pair = (not expected_dup) and audio_song_id == expected_song_id 96 invalid_negative_pair = (not expected_dup) and audio_song_id == expected_song_id
93 97
94 try: 98 try:
95 candidates = service.query(audio_path, top_k=args.top_k) 99 timings: dict = {}
100 _t0 = time.perf_counter()
101 candidates = service.query(audio_path, top_k=args.top_k, timings=timings)
102 query_time_ms = round((time.perf_counter() - _t0) * 1000, 1)
96 except Exception as e: 103 except Exception as e:
104 query_time_ms = round((time.perf_counter() - _t0) * 1000, 1)
97 logger.error("[%d/%d] 查询失败: %s, %s", i, len(rows), audio_path, e) 105 logger.error("[%d/%d] 查询失败: %s, %s", i, len(rows), audio_path, e)
98 result_rows.append({ 106 result_rows.append({
99 "query_song_id": query_song_id, 107 "query_song_id": query_song_id,
...@@ -106,6 +114,7 @@ def main() -> None: ...@@ -106,6 +114,7 @@ def main() -> None:
106 "top1_song_id": "", 114 "top1_song_id": "",
107 "top1_similarity": "", 115 "top1_similarity": "",
108 "top1_source": "", 116 "top1_source": "",
117 "dejavu_aligned_count": "",
109 "top1_hit": False, 118 "top1_hit": False,
110 "topk_hit": False, 119 "topk_hit": False,
111 "expected_rank": "", 120 "expected_rank": "",
...@@ -115,6 +124,14 @@ def main() -> None: ...@@ -115,6 +124,14 @@ def main() -> None:
115 "expected_duplicate": expected_dup, 124 "expected_duplicate": expected_dup,
116 "predicted_duplicate": False, 125 "predicted_duplicate": False,
117 "correct": not expected_dup, # 查询失败视为 not_duplicate 126 "correct": not expected_dup, # 查询失败视为 not_duplicate
127 "query_time_ms": query_time_ms,
128 "chroma_extract_ms": timings.get("chroma_extract_ms", ""),
129 "db_cosine_ms": timings.get("db_cosine_ms", ""),
130 "db_fetch_ms": timings.get("db_fetch_ms", ""),
131 "dtw_ms": timings.get("dtw_ms", ""),
132 "dejavu_decode_ms": timings.get("dejavu_decode_ms", ""),
133 "dejavu_fingerprint_ms": timings.get("dejavu_fingerprint_ms", ""),
134 "dejavu_db_ms": timings.get("dejavu_db_ms", ""),
118 "error": str(e), 135 "error": str(e),
119 }) 136 })
120 continue 137 continue
...@@ -123,6 +140,7 @@ def main() -> None: ...@@ -123,6 +140,7 @@ def main() -> None:
123 top1_song_id = str(top1.song_id) if top1 else "" 140 top1_song_id = str(top1.song_id) if top1 else ""
124 top1_sim = round(top1.similarity, 4) if top1 else "" 141 top1_sim = round(top1.similarity, 4) if top1 else ""
125 top1_source = top1.source if top1 else "" 142 top1_source = top1.source if top1 else ""
143 dejavu_aligned_count = top1.dejavu_aligned_count if top1 else ""
126 144
127 # 诊断召回:expected_song_id 是否进入 top1/top-k。 145 # 诊断召回:expected_song_id 是否进入 top1/top-k。
128 top1_hit = bool(expected_song_id) and top1_song_id == expected_song_id 146 top1_hit = bool(expected_song_id) and top1_song_id == expected_song_id
...@@ -157,6 +175,7 @@ def main() -> None: ...@@ -157,6 +175,7 @@ def main() -> None:
157 "top1_song_id": top1_song_id, 175 "top1_song_id": top1_song_id,
158 "top1_similarity": top1_sim, 176 "top1_similarity": top1_sim,
159 "top1_source": top1_source, 177 "top1_source": top1_source,
178 "dejavu_aligned_count": dejavu_aligned_count if dejavu_aligned_count is not None else "",
160 "top1_hit": top1_hit, 179 "top1_hit": top1_hit,
161 "topk_hit": topk_hit, 180 "topk_hit": topk_hit,
162 "expected_rank": expected_rank, 181 "expected_rank": expected_rank,
...@@ -166,11 +185,19 @@ def main() -> None: ...@@ -166,11 +185,19 @@ def main() -> None:
166 "expected_duplicate": expected_dup, 185 "expected_duplicate": expected_dup,
167 "predicted_duplicate": predicted_dup, 186 "predicted_duplicate": predicted_dup,
168 "correct": correct, 187 "correct": correct,
188 "query_time_ms": query_time_ms,
189 "chroma_extract_ms": timings.get("chroma_extract_ms", ""),
190 "db_cosine_ms": timings.get("db_cosine_ms", ""),
191 "db_fetch_ms": timings.get("db_fetch_ms", ""),
192 "dtw_ms": timings.get("dtw_ms", ""),
193 "dejavu_decode_ms": timings.get("dejavu_decode_ms", ""),
194 "dejavu_fingerprint_ms": timings.get("dejavu_fingerprint_ms", ""),
195 "dejavu_db_ms": timings.get("dejavu_db_ms", ""),
169 "error": "", 196 "error": "",
170 }) 197 })
171 198
172 logger.info( 199 logger.info(
173 "[%d/%d] variant=%s source=%s expected=%s predicted_duplicate=%s threshold=%.4f expected_song_id=%s top1=%s sim=%s top1_hit=%s topk_hit=%s expected_rank=%s expected_sim=%s correct=%s", 200 "[%d/%d] variant=%s source=%s expected=%s predicted_duplicate=%s threshold=%.4f expected_song_id=%s top1=%s sim=%s top1_hit=%s topk_hit=%s expected_rank=%s expected_sim=%s correct=%s time_ms=%s",
174 i, 201 i,
175 len(rows), 202 len(rows),
176 row.get("variant", ""), 203 row.get("variant", ""),
...@@ -186,6 +213,7 @@ def main() -> None: ...@@ -186,6 +213,7 @@ def main() -> None:
186 expected_rank if expected_rank != "" else "-", 213 expected_rank if expected_rank != "" else "-",
187 expected_similarity if expected_similarity != "" else "-", 214 expected_similarity if expected_similarity != "" else "-",
188 correct, 215 correct,
216 query_time_ms,
189 ) 217 )
190 218
191 if i % 10 == 0 or i == len(rows): 219 if i % 10 == 0 or i == len(rows):
...@@ -194,9 +222,14 @@ def main() -> None: ...@@ -194,9 +222,14 @@ def main() -> None:
194 # 写逐条结果 222 # 写逐条结果
195 fieldnames = ["query_song_id", "audio_song_id", "audio_path", "variant", "sample_class", 223 fieldnames = ["query_song_id", "audio_song_id", "audio_path", "variant", "sample_class",
196 "expected_song_id", "expected", "top1_song_id", "top1_similarity", "top1_source", 224 "expected_song_id", "expected", "top1_song_id", "top1_similarity", "top1_source",
225 "dejavu_aligned_count",
197 "top1_hit", "topk_hit", "expected_rank", "expected_similarity", 226 "top1_hit", "topk_hit", "expected_rank", "expected_similarity",
198 "invalid_negative_pair", "invalid_boolean_sample", 227 "invalid_negative_pair", "invalid_boolean_sample",
199 "expected_duplicate", "predicted_duplicate", "correct", "error"] 228 "expected_duplicate", "predicted_duplicate", "correct",
229 "query_time_ms",
230 "chroma_extract_ms", "db_cosine_ms", "db_fetch_ms", "dtw_ms",
231 "dejavu_decode_ms", "dejavu_fingerprint_ms", "dejavu_db_ms",
232 "error"]
200 with out_path.open("w", newline="", encoding="utf-8") as f: 233 with out_path.open("w", newline="", encoding="utf-8") as f:
201 writer = csv.DictWriter(f, fieldnames=fieldnames) 234 writer = csv.DictWriter(f, fieldnames=fieldnames)
202 writer.writeheader() 235 writer.writeheader()
......
...@@ -6,11 +6,11 @@ ...@@ -6,11 +6,11 @@
6 6
7 用法: 7 用法:
8 python scripts/generate_composition_testset.py \ 8 python scripts/generate_composition_testset.py \
9 --audio-dir /Volumes/移动硬盘/lyric_audio_type11 \ 9 --audio-dir /Volumes/移动硬盘/composition_test \
10 --negative-audio-dir /Volumes/移动硬盘/composition_test \ 10 --negative-audio-dir /Volumes/移动硬盘/composition_drop \
11 --out-dir composition_dedup/composition_testset \ 11 --out-dir composition_testset \
12 --num-songs 80 \ 12 --num-songs 100 \
13 --num-negative-songs 40 \ 13 --num-negative-songs 100 \
14 --negative-variants \ 14 --negative-variants \
15 --seed 123 15 --seed 123
16 16
......
...@@ -16,6 +16,9 @@ from pathlib import Path ...@@ -16,6 +16,9 @@ from pathlib import Path
16 16
17 sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) 17 sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
18 18
19 from dotenv import load_dotenv
20 load_dotenv(Path(__file__).resolve().parent.parent / ".env")
21
19 from tqdm import tqdm 22 from tqdm import tqdm
20 23
21 from composition_dedup.service import CompositionConfig, CompositionDedupService 24 from composition_dedup.service import CompositionConfig, CompositionDedupService
......