Commit 974df4ae 974df4ae895606924f14bf5e679276b5dd51920e by 沈秋雨

优化性能

1 parent 7a11a3d4
......@@ -27,3 +27,5 @@ venv/
test_api
composition_dedup/composition_eval
composition_testset
\ No newline at end of file
......
......@@ -22,7 +22,6 @@ from pathlib import Path
import librosa
import numpy as np
from scipy.ndimage import (
binary_erosion,
generate_binary_structure,
iterate_structure,
maximum_filter,
......@@ -60,10 +59,10 @@ MIN_HASH_TIME_DELTA = 0
MAX_HASH_TIME_DELTA = 200
PEAK_SORT = True
FINGERPRINT_REDUCTION = 20
MAX_DURATION_SEC = float(os.environ.get("COMPOSITION_DEJAVU_MAX_DURATION", "120")) # 0=不限制
QUERY_MAX_DURATION_SEC = float(os.environ.get("COMPOSITION_DEJAVU_QUERY_MAX_DURATION", "120")) # 0=不限制
def _normalize_audio(audio_path: str, max_duration: float = MAX_DURATION_SEC) -> tuple[np.ndarray, int]:
def _normalize_audio(audio_path: str, max_duration: float = 0) -> tuple[np.ndarray, int]:
"""将音频标准化为单声道 WAV 并加载为 numpy 数组。
使用 ffmpeg 先做重采样,再用 librosa 读取。
......@@ -133,12 +132,7 @@ def _get_2d_peaks(arr2D: np.ndarray, amp_min: float = DEFAULT_AMP_MIN):
neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
# 找局部极大值
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
background = arr2D == 0
eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
# 布尔掩码
detected_peaks = local_max ^ eroded_background
detected_peaks = maximum_filter(arr2D, footprint=neighborhood) == arr2D
# 提取峰值
amps = arr2D[detected_peaks]
......@@ -181,6 +175,43 @@ def _generate_hashes(peaks: list[tuple[int, int]], fan_value: int = DEFAULT_FAN_
yield (h.hexdigest()[:FINGERPRINT_REDUCTION].encode(), t1)
def load_audio(audio_path: str, max_duration: float = 0) -> tuple[np.ndarray, int]:
"""加载并标准化音频为 44100Hz 单声道(供多路径共用,避免重复解码)。
Args:
audio_path: 音频文件路径。
max_duration: 最大截取时长(秒),0 表示不限制。
Returns:
(samples, sr) 元组。
"""
return _normalize_audio(audio_path, max_duration)
def fingerprint_from_samples(
samples: np.ndarray, sr: int, *, compute_sha1: bool = True
) -> tuple[str, list[tuple[bytes, int]]]:
"""对已加载的音频样本生成 Dejavu 风格指纹(不做 I/O)。
Args:
samples: 单声道音频样本(应为 DEFAULT_FS=44100Hz)。
sr: 采样率。
compute_sha1: 是否计算 file_sha1。service 内部调用时传 False 可跳过
对 samples.tobytes() 的 21MB 哈希运算(返回值在那些路径中未被使用)。
Returns:
(file_sha1, fingerprints) 元组,
其中 fingerprints 是 [(hash_bytes, offset), ...] 列表。
compute_sha1=False 时 file_sha1 返回空字符串。
"""
file_sha1 = hashlib.sha1(samples.tobytes()).hexdigest()[:16] if compute_sha1 else ""
arr2D = _specgram(samples, sr, DEFAULT_WINDOW_SIZE, DEFAULT_OVERLAP_RATIO)
freq_idx, time_idx = _get_2d_peaks(arr2D)
peaks = list(zip(freq_idx, time_idx))
fingerprints = list(_generate_hashes(peaks))
return file_sha1, fingerprints
def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]:
"""对音频文件生成 Dejavu 风格指纹。
......@@ -198,21 +229,7 @@ def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]:
if not os.path.isfile(audio_path):
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
# 1. 标准化并加载音频(可选限制长度)
samples, fs = _normalize_audio(audio_path)
# 2. 计算文件 SHA1(用于标识)
file_sha1 = hashlib.sha1(samples.tobytes()).hexdigest()[:16]
# 3. 计算频谱图
arr2D = _specgram(samples, fs, DEFAULT_WINDOW_SIZE, DEFAULT_OVERLAP_RATIO)
# 4. 检测 2D 峰值
freq_idx, time_idx = _get_2d_peaks(arr2D)
peaks = list(zip(freq_idx, time_idx))
# 5. 生成指纹哈希
fingerprints = list(_generate_hashes(peaks))
file_sha1, fingerprints = fingerprint_from_samples(samples, fs)
logger.info("指纹生成完成: audio=%s, 指纹数=%d", audio_path, len(fingerprints))
return file_sha1, fingerprints
......
"""Chromagram 特征提取。
流程:
1. 音频标准化:ffmpeg 转 22050Hz / Mono / WAV
2. librosa 加载音频
3. librosa.feature.chroma_cens() 提取 12×T Chromagram(CENS,对速度/音色鲁棒)
4. 主音对齐:将能量最大的音级滚至第 0 行,实现转调不变性
5. scipy.signal.resample(chroma, 128, axis=1) 时间归一化到 12×128
6. .flatten() 展开为 1536 维向量
1. 音频解码:ffmpeg pipe 输出 22050Hz / Mono / f32le,直接读入内存,无临时文件
2. librosa.feature.chroma_cens() 提取 12×T Chromagram(CENS,对速度/音色鲁棒)
3. 主音对齐:将能量最大的音级滚至第 0 行,实现转调不变性
4. scipy.signal.resample(chroma, 128, axis=1) 时间归一化到 12×128
5. .flatten() 展开为 1536 维向量
"""
import logging
import os
import subprocess
import tempfile
import librosa
import numpy as np
......@@ -26,83 +24,99 @@ TARGET_FRAMES = 128
VECTOR_DIM = 12 * TARGET_FRAMES # 1536
def _normalize_audio_ffmpeg(audio_path: str, output_path: str) -> None:
"""使用 ffmpeg 将音频标准化为 22050Hz / Mono / WAV。"""
def _load_audio_via_pipe(audio_path: str) -> np.ndarray:
"""使用 ffmpeg pipe 将音频解码为 22050Hz mono float32,不落临时文件到磁盘。"""
cmd = [
"ffmpeg",
"-y",
"ffmpeg", "-y",
"-i", audio_path,
"-ar", str(TARGET_SR),
"-ac", "1",
"-f", "wav",
output_path,
"-f", "f32le",
"pipe:1",
]
result = subprocess.run(
cmd,
capture_output=True,
text=True,
)
result = subprocess.run(cmd, capture_output=True)
if result.returncode != 0:
raise RuntimeError(f"ffmpeg 转换失败: {result.stderr}")
raise RuntimeError(f"ffmpeg 解码失败: {result.stderr.decode(errors='replace')}")
return np.frombuffer(result.stdout, dtype=np.float32)
def extract_chroma_feature(audio_path: str) -> np.ndarray:
"""从音频文件提取 1536 维 Chromagram 特征向量。
def extract_chroma_feature_from_samples(
samples: np.ndarray,
sr: int,
hop_length: int = 512,
win_len_smooth: int = 41,
) -> np.ndarray:
"""从已加载的音频样本提取 1536 维 Chromagram 特征向量。
若 sr 不等于 TARGET_SR,先用 librosa.resample 在内存中降采样,
避免重新走 ffmpeg 流程。
Args:
audio_path: 音频文件路径。
samples: 单声道音频样本(任意采样率)。
sr: samples 对应的采样率。
hop_length: CQT hop 大小,增大可成比例降低计算量,不影响最终 128 帧精度。
win_len_smooth: CENS 平滑窗口帧数,应随 hop_length 等比缩小以保持相同的时间覆盖。
Returns:
shape 为 (1536,) 的 numpy 数组。
Raises:
FileNotFoundError: 音频文件不存在。
RuntimeError: ffmpeg 转换失败。
"""
if not os.path.isfile(audio_path):
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
# 1. 音频标准化:ffmpeg 转 WAV
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp_wav = tmp.name
try:
_normalize_audio_ffmpeg(audio_path, tmp_wav)
# 2. librosa 加载音频
y, _sr = librosa.load(tmp_wav, sr=TARGET_SR, mono=True)
y = samples if sr == TARGET_SR else librosa.resample(samples, orig_sr=sr, target_sr=TARGET_SR)
# 3. 提取 CENS Chromagram (12×T),对速度变化和音色具有更强鲁棒性
chroma = librosa.feature.chroma_cens(y=y, sr=TARGET_SR)
# 提取 CENS Chromagram (12×T)
chroma = librosa.feature.chroma_cens(y=y, sr=TARGET_SR, hop_length=hop_length, win_len_smooth=win_len_smooth)
# 4. 主音对齐:将全局能量最大的音级循环滚至第 0 行,实现转调不变性
# 主音对齐
tonic = int(np.argmax(chroma.sum(axis=1)))
if tonic != 0:
chroma = np.roll(chroma, -tonic, axis=0)
# 5. 时间归一化到 12×128
# 时间归一化到 12×128
if chroma.shape[1] != TARGET_FRAMES:
chroma = resample(chroma, TARGET_FRAMES, axis=1)
# 6. 展开为 1536 维向量
feature = chroma.flatten().astype(np.float32)
assert feature.shape == (VECTOR_DIM,), (
f"特征维度错误: 期望 {VECTOR_DIM}, 实际 {feature.shape}"
)
return feature
finally:
# 清理临时文件
if os.path.exists(tmp_wav):
os.remove(tmp_wav)
def extract_chroma_matrix(audio_path: str) -> np.ndarray:
def extract_chroma_matrix_from_samples(
samples: np.ndarray,
sr: int,
hop_length: int = 512,
win_len_smooth: int = 41,
) -> np.ndarray:
"""从已加载的音频样本提取 12×128 Chromagram 矩阵(供 DTW 精排使用)。"""
return extract_chroma_feature_from_samples(samples, sr, hop_length=hop_length, win_len_smooth=win_len_smooth).reshape(12, TARGET_FRAMES)
def extract_chroma_feature(audio_path: str, hop_length: int = 512, win_len_smooth: int = 41) -> np.ndarray:
"""从音频文件提取 1536 维 Chromagram 特征向量。
Args:
audio_path: 音频文件路径。
hop_length: CQT hop 大小。
win_len_smooth: CENS 平滑窗口帧数。
Returns:
shape 为 (1536,) 的 numpy 数组。
Raises:
FileNotFoundError: 音频文件不存在。
RuntimeError: ffmpeg 解码失败。
"""
if not os.path.isfile(audio_path):
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
y = _load_audio_via_pipe(audio_path)
return extract_chroma_feature_from_samples(y, TARGET_SR, hop_length=hop_length, win_len_smooth=win_len_smooth)
def extract_chroma_matrix(audio_path: str, hop_length: int = 512, win_len_smooth: int = 41) -> np.ndarray:
"""从音频文件提取 12×128 Chromagram 矩阵(未展平,供 DTW 精排使用)。
Returns:
shape 为 (12, 128) 的 numpy 数组,已做主音对齐。
"""
feature = extract_chroma_feature(audio_path)
return feature.reshape(12, TARGET_FRAMES)
return extract_chroma_feature(audio_path, hop_length=hop_length, win_len_smooth=win_len_smooth).reshape(12, TARGET_FRAMES)
......
"""作曲去重服务(入库 + 查询)。
查询流程:
1. Dejavu 指纹匹配(毫秒级,子序列匹配,支持 chorus_only / trim_intro
- 命中(≥ 阈值)→ 直接返回 duplicate(短路)
2. 未命中 → Chromagram 12路 + DTW(百毫秒级
- 返回结果
1. Chromagram 12路 + DTW 精排(覆盖绝大多数场景,约 1s
- 命中(≥ 阈值)→ 直接返回结果
2. Chromagram 未命中 → Dejavu 指纹兜底(处理 chorus_only / trim_intro 等片段场景
- 命中(≥ 阈值)→ 返回 duplicate(similarity=1.0)
"""
import logging
import os
import time
from dataclasses import dataclass, field
import numba
import numpy as np
import psycopg
from scipy.spatial.distance import cdist
from .extractor import TARGET_FRAMES, extract_chroma_feature, extract_chroma_matrix
from .dejavu_fingerprinter import fingerprint_audio
from .extractor import (
TARGET_FRAMES,
extract_chroma_feature_from_samples,
extract_chroma_matrix,
)
from .dejavu_fingerprinter import fingerprint_audio, fingerprint_from_samples, load_audio, QUERY_MAX_DURATION_SEC
logger = logging.getLogger(__name__)
......@@ -56,6 +62,7 @@ class CompositionCandidate:
song_id: int
similarity: float
source: str = "chromagram"
dejavu_aligned_count: int | None = None # 仅 source=dejavu 或 dejavu fallback 未命中时记录
@dataclass
......@@ -73,10 +80,18 @@ class CompositionConfig:
statement_timeout_ms: int = 30000
dtw_rerank_top_k: int = 20 # Cosine 召回后做 DTW 精排的候选数量
duplicate_threshold: float = _env_float("COMPOSITION_DUPLICATE_THRESHOLD", 0.85)
# Chromagram 提取配置
chroma_hop_length: int = _env_int("COMPOSITION_CHROMA_HOP_LENGTH", 512)
chroma_win_len_smooth: int = _env_int("COMPOSITION_CHROMA_WIN_LEN_SMOOTH", 0)
# Dejavu 指纹匹配配置
dejavu_enabled: bool = _env_bool("COMPOSITION_DEJAVU_ENABLED", True)
dejavu_match_threshold: int = _env_int("COMPOSITION_DEJAVU_MATCH_THRESHOLD", 20)
def __post_init__(self) -> None:
# 0 表示自动:按 hop_length 等比缩小,保持平滑窗覆盖时长约 1 秒
if self.chroma_win_len_smooth == 0:
self.chroma_win_len_smooth = max(1, round(41 * 512 / self.chroma_hop_length))
@dataclass
class CompositionDedupService:
......@@ -94,8 +109,12 @@ class CompositionDedupService:
Returns:
提取的特征向量。
"""
feature = extract_chroma_feature(audio_path)
self._logger.info("提取 Chromagram 特征完成: song_id=%s, audio=%s", song_id, audio_path)
# 共用一次解码(44100Hz),chromagram 路径在内存中重采样,无需二次 ffmpeg
samples, sr = load_audio(audio_path)
self._logger.info("音频解码完成: song_id=%s, audio=%s", song_id, audio_path)
feature = extract_chroma_feature_from_samples(samples, sr, hop_length=self.config.chroma_hop_length, win_len_smooth=self.config.chroma_win_len_smooth)
self._logger.info("提取 Chromagram 特征完成: song_id=%s", song_id)
with psycopg.connect(self.config.dsn) as conn:
with conn.cursor() as cursor:
......@@ -111,15 +130,29 @@ class CompositionDedupService:
self._logger.info("Chromagram 特征入库完成: song_id=%s", song_id)
# Dejavu 指纹同时入库
if self.config.dejavu_enabled:
self._dejavu_ingest(song_id, audio_path)
self._dejavu_ingest(song_id, audio_path, samples=samples, sr=sr)
return feature
def _dejavu_ingest(self, song_id: int, audio_path: str) -> None:
"""提取 Dejavu 指纹并写入 dejavu_fingerprints 表。"""
file_sha1, fingerprints = fingerprint_audio(audio_path)
def _dejavu_ingest(
self,
song_id: int,
audio_path: str,
*,
samples: np.ndarray | None = None,
sr: int | None = None,
) -> None:
"""提取 Dejavu 指纹并写入 dejavu_fingerprints 表。
若提供了已解码的 samples/sr,直接使用,跳过 ffmpeg;否则从文件重新加载。
"""
if samples is not None and sr is not None:
_, fingerprints = fingerprint_from_samples(samples, sr, compute_sha1=False)
self._logger.info("Dejavu 指纹提取完成(共用解码): song_id=%s", song_id)
else:
_, fingerprints = fingerprint_audio(audio_path)
if not fingerprints:
self._logger.warning("Dejavu 指纹为空: song_id=%s, audio=%s", song_id, audio_path)
return
......@@ -144,25 +177,54 @@ class CompositionDedupService:
self._logger.info("Dejavu 指纹入库完成: song_id=%s, 指纹数=%d", song_id, len(fingerprints))
def query(self, audio_path: str, top_k: int = 100) -> list[CompositionCandidate]:
def query(
self,
audio_path: str,
top_k: int = 100,
timings: dict | None = None,
) -> list[CompositionCandidate]:
"""提取音频特征并查询相似结果。
流程:Dejavu 指纹短路匹配 → 12 路循环对齐 Cosine 召回 → DTW 精排。
流程:Chromagram 12路 + DTW 精排 → Dejavu 指纹兜底(片段场景,仅在 chroma 未命中时解码)。
Args:
timings: 若传入非 None 的 dict,方法执行完毕后会在其中写入各阶段耗时(单位 ms):
chroma_extract_ms、db_cosine_ms、db_fetch_ms、dtw_ms、
dejavu_decode_ms、dejavu_fingerprint_ms、dejavu_db_ms。
Dejavu 路径未执行时,对应键不会写入。
"""
# 1. 优先尝试 Dejavu 指纹匹配(短路)
if self.config.dejavu_enabled:
match = self._dejavu_query(audio_path)
# 1. Chromagram 12路 + DTW 精排(覆盖绝大多数场景)
# 使用与入库一致的 22050Hz 解码路径,保证 chroma 向量对齐
candidates = self._query_chroma(audio_path, top_k, timings=timings)
# 2. Chromagram 未命中时,用 Dejavu 兜底(处理 chorus_only / trim_intro 等片段场景)
# 只有未命中才解码 44100Hz,大多数情况下无额外 I/O
if self.config.dejavu_enabled and not self.candidates_indicate_duplicate(candidates):
_t = time.perf_counter()
samples, sr = load_audio(audio_path, max_duration=QUERY_MAX_DURATION_SEC)
if timings is not None:
timings["dejavu_decode_ms"] = round((time.perf_counter() - _t) * 1000, 1)
match = self._dejavu_query(samples, sr, timings=timings)
if match is not None:
if match.aligned_count >= self.config.dejavu_match_threshold:
self._logger.info(
"Dejavu 命中: song_id=%s, aligned_count=%d, total_collisions=%d, decision=duplicate",
match.song_id,
match.aligned_count,
match.total_collisions,
)
return [CompositionCandidate(song_id=match.song_id, similarity=1.0, source="dejavu")]
# 2. Dejavu 未命中或禁用,走现有 Chromagram 12路 + DTW 流程
return self._query_chroma(audio_path, top_k)
return [CompositionCandidate(
song_id=match.song_id,
similarity=1.0,
source="dejavu",
dejavu_aligned_count=match.aligned_count,
)]
else:
# 未达阈值:把 aligned_count 附加到 chromagram top1 上供评估脚本记录
if candidates:
candidates[0].dejavu_aligned_count = match.aligned_count
return candidates
def check(self, audio_path: str, top_k: int = 100) -> bool:
"""按最终接口语义返回是否重复。"""
......@@ -178,9 +240,17 @@ class CompositionDedupService:
return False
return candidates[0].similarity >= self.config.duplicate_threshold
def _query_chroma(self, audio_path: str, top_k: int = 100) -> list[CompositionCandidate]:
def _query_chroma(
self,
audio_path: str,
top_k: int = 100,
timings: dict | None = None,
) -> list[CompositionCandidate]:
"""Chromagram 12 路循环对齐 + DTW 精排查询。"""
chroma = extract_chroma_matrix(audio_path)
_t = time.perf_counter()
chroma = extract_chroma_matrix(audio_path, hop_length=self.config.chroma_hop_length, win_len_smooth=self.config.chroma_win_len_smooth)
if timings is not None:
timings["chroma_extract_ms"] = round((time.perf_counter() - _t) * 1000, 1)
self._logger.info("提取 Chromagram 查询特征完成: audio=%s", audio_path)
# 1. 12 路循环对齐:穷举 12 种半音偏移,单条 SQL 内部展开,按 song_id 取最高 Cosine 相似度
......@@ -213,6 +283,7 @@ class CompositionDedupService:
LIMIT %s
"""
best: dict[int, float] = {}
_t = time.perf_counter()
with psycopg.connect(self.config.dsn) as conn:
with conn.cursor() as cursor:
cursor.execute(
......@@ -221,33 +292,44 @@ class CompositionDedupService:
cursor.execute(sql, (*shift_vecs, top_k, top_k))
for song_id, sim in cursor.fetchall():
best[int(song_id)] = float(sim)
if timings is not None:
timings["db_cosine_ms"] = round((time.perf_counter() - _t) * 1000, 1)
# 2. 取 Top dtw_rerank_top_k,从库中取原始向量做 DTW 精排
top = sorted(best.items(), key=lambda x: x[1], reverse=True)
rerank_ids = [sid for sid, _ in top[:self.config.dtw_rerank_top_k]]
_t = time.perf_counter()
with conn.cursor() as cursor:
cursor.execute(
"SELECT song_id, feature_vector::float4[] FROM composition_feature WHERE song_id = ANY(%s)",
(rerank_ids,),
)
db_rows = cursor.fetchall()
if timings is not None:
timings["db_fetch_ms"] = round((time.perf_counter() - _t) * 1000, 1)
_t = time.perf_counter()
reranked = []
for song_id, fv in db_rows:
cand_chroma = np.array(fv, dtype=np.float32).reshape(12, TARGET_FRAMES)
dtw_sim = _best_shifted_dtw_similarity(chroma, cand_chroma)
dtw_sim = _best_shifted_dtw_similarity(
chroma, cand_chroma, early_exit_threshold=self.config.duplicate_threshold
)
reranked.append(CompositionCandidate(song_id=int(song_id), similarity=dtw_sim))
reranked.sort(key=lambda c: c.similarity, reverse=True)
if timings is not None:
timings["dtw_ms"] = round((time.perf_counter() - _t) * 1000, 1)
rerank_id_set = {c.song_id for c in reranked}
rest = [
# top 已按 cosine 降序排列;直接从中剔除已精排的候选,剩余保留 cosine 分
non_reranked = [
CompositionCandidate(song_id=sid, similarity=sim)
for sid, sim in top[self.config.dtw_rerank_top_k:]
for sid, sim in top
if sid not in rerank_id_set
]
result = reranked + rest
result = reranked + non_reranked
top_summary = ", ".join(
f"{candidate.song_id}:{candidate.similarity:.4f}"
for candidate in result[:5]
......@@ -259,7 +341,12 @@ class CompositionDedupService:
)
return result
def _dejavu_query(self, audio_path: str) -> _DejavuMatch | None:
def _dejavu_query(
self,
samples: np.ndarray,
sr: int,
timings: dict | None = None,
) -> _DejavuMatch | None:
"""Dejavu 指纹查询,返回 offset 对齐后碰撞数最多的 song_id。
只统计 hash 总碰撞数会让常见频谱峰值、噪声片段或大库随机碰撞直接短路成
......@@ -267,15 +354,19 @@ class CompositionDedupService:
db_offset - query_offset 落在同一个时间偏移上。
Returns:
命中结果,未命中返回 None。
最佳匹配结果(不做阈值过滤),无任何碰撞时返回 None。
"""
file_sha1, fingerprints = fingerprint_audio(audio_path)
_t = time.perf_counter()
_, fingerprints = fingerprint_from_samples(samples, sr, compute_sha1=False)
if timings is not None:
timings["dejavu_fingerprint_ms"] = round((time.perf_counter() - _t) * 1000, 1)
if not fingerprints:
return None
hashes = [h for h, _ in fingerprints]
offsets = [int(o) for _, o in fingerprints]
_t = time.perf_counter()
with psycopg.connect(self.config.dsn) as conn:
with conn.cursor() as cursor:
# 先按 hash 找碰撞,再按每个 song_id 的 offset delta 聚类。
......@@ -307,23 +398,21 @@ class CompositionDedupService:
(hashes, offsets),
)
row = cursor.fetchone()
if timings is not None:
timings["dejavu_db_ms"] = round((time.perf_counter() - _t) * 1000, 1)
if row is None:
return None
sid, aligned_count, total_collisions = row
aligned_count = int(aligned_count)
if aligned_count >= self.config.dejavu_match_threshold:
return _DejavuMatch(
song_id=int(sid),
aligned_count=aligned_count,
aligned_count=int(aligned_count),
total_collisions=int(total_collisions),
)
return None
def _dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float:
"""计算两个 12×T Chromagram 矩阵之间的 DTW 相似度(映射到 [0,1])。"""
# 帧间欧氏距离矩阵
cost = cdist(query.T, candidate.T, metric="euclidean")
@numba.njit(cache=True)
def _dtw_dp(cost: np.ndarray) -> float:
"""DTW DP 填表(numba JIT 编译,数值结果与纯 Python 实现完全一致)。"""
n, m = cost.shape
dp = np.full((n, m), np.inf)
dp[0, 0] = cost[0, 0]
......@@ -334,14 +423,37 @@ def _dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float:
for i in range(1, n):
for j in range(1, m):
dp[i, j] = cost[i, j] + min(dp[i - 1, j], dp[i, j - 1], dp[i - 1, j - 1])
dtw_dist = dp[n - 1, m - 1] / (n + m)
return dp[n - 1, m - 1]
def _dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float:
"""计算两个 12×T Chromagram 矩阵之间的 DTW 相似度(映射到 [0,1])。"""
# 帧间欧氏距离矩阵
cost = cdist(query.T, candidate.T, metric="euclidean")
n, m = cost.shape
dtw_dist = _dtw_dp(cost) / (n + m)
# 转换为相似度:距离越小相似度越高
return float(1.0 / (1.0 + dtw_dist))
def _best_shifted_dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float:
"""计算 12 路音高循环位移下的最佳 DTW 相似度。"""
return max(
_dtw_similarity(np.roll(query, -shift, axis=0), candidate)
for shift in range(12)
)
def _best_shifted_dtw_similarity(
query: np.ndarray,
candidate: np.ndarray,
early_exit_threshold: float = 1.1,
) -> float:
"""计算 12 路音高循环位移下的最佳 DTW 相似度。
Args:
early_exit_threshold: 某个 shift 的相似度达到此值时立即返回,跳过剩余 shift。
传入 duplicate_threshold 即可:对已确认重复的候选不再浪费算力;
返回值可能略低于理论最大值,但不影响 duplicate/non-duplicate 二元判定。
默认 1.1(> 1 的不可达值,等价于不启用早退)。
"""
best = 0.0
for shift in range(12):
sim = _dtw_similarity(np.roll(query, -shift, axis=0), candidate)
if sim > best:
best = sim
if best >= early_exit_threshold:
break
return best
......
......@@ -12,6 +12,7 @@ tqdm>=4.66
# Audio composition feature extraction
librosa>=0.10.0
numba>=0.59.0
scipy>=1.11
numpy>=1.24
......@@ -21,3 +22,6 @@ pgvector>=0.2.0
# HTTP API server
fastapi>=0.110.0
uvicorn[standard]>=0.29.0
# Environment variable loading
python-dotenv>=1.0
......
......@@ -8,8 +8,8 @@ expected_song_id 的 top-k/top1 命中只作为诊断字段。
用法:
python scripts/evaluate_composition.py \
--dsn "postgresql:///lyric_dedup" \
--queries composition_dedup/composition_testset4/queries.csv \
--out composition_dedup/composition_eval/composition_eval_result_v3.csv
--queries composition_testset/test_samples.csv \
--out composition_dedup/composition_eval/nohop_result.csv
"""
import argparse
......@@ -17,10 +17,14 @@ import csv
import json
import logging
import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from dotenv import load_dotenv
load_dotenv(Path(__file__).resolve().parent.parent / ".env")
from composition_dedup.service import CompositionConfig, CompositionDedupService
logger = logging.getLogger(__name__)
......@@ -92,8 +96,12 @@ def main() -> None:
invalid_negative_pair = (not expected_dup) and audio_song_id == expected_song_id
try:
candidates = service.query(audio_path, top_k=args.top_k)
timings: dict = {}
_t0 = time.perf_counter()
candidates = service.query(audio_path, top_k=args.top_k, timings=timings)
query_time_ms = round((time.perf_counter() - _t0) * 1000, 1)
except Exception as e:
query_time_ms = round((time.perf_counter() - _t0) * 1000, 1)
logger.error("[%d/%d] 查询失败: %s, %s", i, len(rows), audio_path, e)
result_rows.append({
"query_song_id": query_song_id,
......@@ -106,6 +114,7 @@ def main() -> None:
"top1_song_id": "",
"top1_similarity": "",
"top1_source": "",
"dejavu_aligned_count": "",
"top1_hit": False,
"topk_hit": False,
"expected_rank": "",
......@@ -115,6 +124,14 @@ def main() -> None:
"expected_duplicate": expected_dup,
"predicted_duplicate": False,
"correct": not expected_dup, # 查询失败视为 not_duplicate
"query_time_ms": query_time_ms,
"chroma_extract_ms": timings.get("chroma_extract_ms", ""),
"db_cosine_ms": timings.get("db_cosine_ms", ""),
"db_fetch_ms": timings.get("db_fetch_ms", ""),
"dtw_ms": timings.get("dtw_ms", ""),
"dejavu_decode_ms": timings.get("dejavu_decode_ms", ""),
"dejavu_fingerprint_ms": timings.get("dejavu_fingerprint_ms", ""),
"dejavu_db_ms": timings.get("dejavu_db_ms", ""),
"error": str(e),
})
continue
......@@ -123,6 +140,7 @@ def main() -> None:
top1_song_id = str(top1.song_id) if top1 else ""
top1_sim = round(top1.similarity, 4) if top1 else ""
top1_source = top1.source if top1 else ""
dejavu_aligned_count = top1.dejavu_aligned_count if top1 else ""
# 诊断召回:expected_song_id 是否进入 top1/top-k。
top1_hit = bool(expected_song_id) and top1_song_id == expected_song_id
......@@ -157,6 +175,7 @@ def main() -> None:
"top1_song_id": top1_song_id,
"top1_similarity": top1_sim,
"top1_source": top1_source,
"dejavu_aligned_count": dejavu_aligned_count if dejavu_aligned_count is not None else "",
"top1_hit": top1_hit,
"topk_hit": topk_hit,
"expected_rank": expected_rank,
......@@ -166,11 +185,19 @@ def main() -> None:
"expected_duplicate": expected_dup,
"predicted_duplicate": predicted_dup,
"correct": correct,
"query_time_ms": query_time_ms,
"chroma_extract_ms": timings.get("chroma_extract_ms", ""),
"db_cosine_ms": timings.get("db_cosine_ms", ""),
"db_fetch_ms": timings.get("db_fetch_ms", ""),
"dtw_ms": timings.get("dtw_ms", ""),
"dejavu_decode_ms": timings.get("dejavu_decode_ms", ""),
"dejavu_fingerprint_ms": timings.get("dejavu_fingerprint_ms", ""),
"dejavu_db_ms": timings.get("dejavu_db_ms", ""),
"error": "",
})
logger.info(
"[%d/%d] variant=%s source=%s expected=%s predicted_duplicate=%s threshold=%.4f expected_song_id=%s top1=%s sim=%s top1_hit=%s topk_hit=%s expected_rank=%s expected_sim=%s correct=%s",
"[%d/%d] variant=%s source=%s expected=%s predicted_duplicate=%s threshold=%.4f expected_song_id=%s top1=%s sim=%s top1_hit=%s topk_hit=%s expected_rank=%s expected_sim=%s correct=%s time_ms=%s",
i,
len(rows),
row.get("variant", ""),
......@@ -186,6 +213,7 @@ def main() -> None:
expected_rank if expected_rank != "" else "-",
expected_similarity if expected_similarity != "" else "-",
correct,
query_time_ms,
)
if i % 10 == 0 or i == len(rows):
......@@ -194,9 +222,14 @@ def main() -> None:
# 写逐条结果
fieldnames = ["query_song_id", "audio_song_id", "audio_path", "variant", "sample_class",
"expected_song_id", "expected", "top1_song_id", "top1_similarity", "top1_source",
"dejavu_aligned_count",
"top1_hit", "topk_hit", "expected_rank", "expected_similarity",
"invalid_negative_pair", "invalid_boolean_sample",
"expected_duplicate", "predicted_duplicate", "correct", "error"]
"expected_duplicate", "predicted_duplicate", "correct",
"query_time_ms",
"chroma_extract_ms", "db_cosine_ms", "db_fetch_ms", "dtw_ms",
"dejavu_decode_ms", "dejavu_fingerprint_ms", "dejavu_db_ms",
"error"]
with out_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
......
......@@ -6,11 +6,11 @@
用法:
python scripts/generate_composition_testset.py \
--audio-dir /Volumes/移动硬盘/lyric_audio_type11 \
--negative-audio-dir /Volumes/移动硬盘/composition_test \
--out-dir composition_dedup/composition_testset \
--num-songs 80 \
--num-negative-songs 40 \
--audio-dir /Volumes/移动硬盘/composition_test \
--negative-audio-dir /Volumes/移动硬盘/composition_drop \
--out-dir composition_testset \
--num-songs 100 \
--num-negative-songs 100 \
--negative-variants \
--seed 123
......
......@@ -16,6 +16,9 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from dotenv import load_dotenv
load_dotenv(Path(__file__).resolve().parent.parent / ".env")
from tqdm import tqdm
from composition_dedup.service import CompositionConfig, CompositionDedupService
......