Commit 8413944a 8413944ad675bb85c114f5f012b4257a140fef8e by 沈秋雨

添加曲结构去重

1 parent cdfa3a58
from .service import CompositionCandidate, CompositionConfig, CompositionDedupService
from .dejavu_fingerprinter import fingerprint_audio
__all__ = [
"CompositionCandidate",
"CompositionConfig",
"CompositionDedupService",
"fingerprint_audio",
]
"""Dejavu 风格的音频指纹生成。
基于 worldveil/dejavu 的指纹算法提取实现,不依赖 Dejavu 的数据库层。
使用 scipy.signal.spectrogram 替代已废弃的 matplotlib.mlab.specgram。
流程:
1. 音频标准化:ffmpeg 转 44100Hz / Mono / WAV
2. librosa 加载音频
3. 短时傅里叶变换(STFT)→ 对数频谱图
4. 2D 峰值检测:在频谱图中找局部极大值
5. 指纹哈希:对峰值对 (freq1, freq2, time_delta) 做 SHA1,取前 20 位
"""
import hashlib
import logging
import os
import subprocess
import tempfile
from operator import itemgetter
from pathlib import Path
import librosa
import numpy as np
from scipy.ndimage import (
binary_erosion,
generate_binary_structure,
iterate_structure,
maximum_filter,
)
from scipy.signal import spectrogram
logger = logging.getLogger(__name__)
def _load_env_file() -> None:
"""加载项目根目录 .env,不覆盖已存在的真实环境变量。"""
env_path = Path(__file__).resolve().parent.parent / ".env"
if not env_path.exists():
return
with env_path.open(encoding="utf-8") as file:
for raw_line in file:
line = raw_line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, value = line.split("=", 1)
os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'"))
_load_env_file()
# ===== 常量(可通过环境变量覆盖)=====
DEFAULT_FS = 44100
DEFAULT_WINDOW_SIZE = 4096
DEFAULT_OVERLAP_RATIO = float(os.environ.get("COMPOSITION_DEJAVU_OVERLAP_RATIO", "0.3"))
DEFAULT_FAN_VALUE = int(os.environ.get("COMPOSITION_DEJAVU_FAN_VALUE", "10"))
DEFAULT_AMP_MIN = float(os.environ.get("COMPOSITION_DEJAVU_AMP_MIN", "20"))
PEAK_NEIGHBORHOOD_SIZE = 20
MIN_HASH_TIME_DELTA = 0
MAX_HASH_TIME_DELTA = 200
PEAK_SORT = True
FINGERPRINT_REDUCTION = 20
MAX_DURATION_SEC = float(os.environ.get("COMPOSITION_DEJAVU_MAX_DURATION", "120")) # 0=不限制
def _normalize_audio(audio_path: str, max_duration: float = MAX_DURATION_SEC) -> tuple[np.ndarray, int]:
"""将音频标准化为单声道 WAV 并加载为 numpy 数组。
使用 ffmpeg 先做重采样,再用 librosa 读取。
可选限制音频长度,超长音频只取前 N 秒。
"""
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp_wav = tmp.name
try:
cmd = [
"ffmpeg",
"-y",
"-i", audio_path,
"-ar", str(DEFAULT_FS),
"-ac", "1",
"-f", "wav",
]
if max_duration > 0:
cmd += ["-t", str(max_duration)]
cmd.append(tmp_wav)
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"ffmpeg 转换失败: {result.stderr}")
y, sr = librosa.load(tmp_wav, sr=DEFAULT_FS, mono=True)
return y, sr
finally:
if os.path.exists(tmp_wav):
os.remove(tmp_wav)
def _specgram(samples: np.ndarray, fs: int, window_size: int, overlap_ratio: float):
"""计算对数频谱图,替代 matplotlib.mlab.specgram。
Returns:
arr2D: shape (n_freq, n_time) 的对数频谱矩阵(dBFS 刻度)
"""
noverlap = int(window_size * overlap_ratio)
window = np.hanning(window_size)
freqs, times, Sxx = spectrogram(
samples,
fs=fs,
window=window,
nperseg=window_size,
noverlap=noverlap,
)
# 转为对数尺度(dBFS,0 dB 为峰值参考)
# scipy.signal.spectrogram 返回 PSD,mlab.specgram 返回功率,两者量纲不同
# 统一转为相对于峰值的 dBFS 刻度,使强信号峰值落在 20~80 dB 范围
arr2D = 10 * np.log10(Sxx + 1e-10)
arr2D = arr2D - arr2D.max() # 归一化到峰值为 0 dBFS
arr2D = arr2D + 80 # 偏移使典型峰值落在 20~80 dB(与 mlab.specgram 一致)
arr2D[arr2D < -100] = -100 # 限幅
return arr2D
def _get_2d_peaks(arr2D: np.ndarray, amp_min: float = DEFAULT_AMP_MIN):
"""在频谱图中检测 2D 局部极大值。
Returns:
(frequency_idx, time_idx): 峰值的频率和时间索引列表
"""
struct = generate_binary_structure(2, 1)
neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
# 找局部极大值
local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
background = arr2D == 0
eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
# 布尔掩码
detected_peaks = local_max ^ eroded_background
# 提取峰值
amps = arr2D[detected_peaks]
j, i = np.where(detected_peaks)
# 过滤低于阈值的峰值
peaks = list(zip(i, j, amps))
peaks_filtered = [x for x in peaks if x[2] > amp_min]
frequency_idx = [x[1] for x in peaks_filtered]
time_idx = [x[0] for x in peaks_filtered]
return frequency_idx, time_idx
def _generate_hashes(peaks: list[tuple[int, int]], fan_value: int = DEFAULT_FAN_VALUE):
"""根据峰值对生成 SHA1 指纹哈希。
Args:
peaks: [(freq_idx, time_idx), ...] 列表
fan_value: 每个峰值与后续多少个峰值配对
Yields:
(hash_bytes, time_offset) 元组
"""
if PEAK_SORT:
peaks.sort(key=itemgetter(1))
for i in range(len(peaks)):
for j in range(1, fan_value):
if i + j < len(peaks):
freq1 = peaks[i][0]
freq2 = peaks[i + j][0]
t1 = peaks[i][1]
t2 = peaks[i + j][1]
t_delta = t2 - t1
if MIN_HASH_TIME_DELTA <= t_delta <= MAX_HASH_TIME_DELTA:
h = hashlib.sha1(f"{freq1}|{freq2}|{t_delta}".encode())
yield (h.hexdigest()[:FINGERPRINT_REDUCTION].encode(), t1)
def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]:
"""对音频文件生成 Dejavu 风格指纹。
Args:
audio_path: 音频文件路径。
Returns:
(file_sha1, fingerprints) 元组,
其中 fingerprints 是 [(hash_bytes, offset), ...] 列表。
Raises:
FileNotFoundError: 音频文件不存在。
RuntimeError: ffmpeg 转换失败。
"""
if not os.path.isfile(audio_path):
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
# 1. 标准化并加载音频(可选限制长度)
samples, fs = _normalize_audio(audio_path)
# 2. 计算文件 SHA1(用于标识)
file_sha1 = hashlib.sha1(samples.tobytes()).hexdigest()[:16]
# 3. 计算频谱图
arr2D = _specgram(samples, fs, DEFAULT_WINDOW_SIZE, DEFAULT_OVERLAP_RATIO)
# 4. 检测 2D 峰值
freq_idx, time_idx = _get_2d_peaks(arr2D)
peaks = list(zip(freq_idx, time_idx))
# 5. 生成指纹哈希
fingerprints = list(_generate_hashes(peaks))
logger.info("指纹生成完成: audio=%s, 指纹数=%d", audio_path, len(fingerprints))
return file_sha1, fingerprints
"""Chromagram 特征提取。
流程:
1. 音频标准化:ffmpeg 转 22050Hz / Mono / WAV
2. librosa 加载音频
3. librosa.feature.chroma_cens() 提取 12×T Chromagram(CENS,对速度/音色鲁棒)
4. 主音对齐:将能量最大的音级滚至第 0 行,实现转调不变性
5. scipy.signal.resample(chroma, 128, axis=1) 时间归一化到 12×128
6. .flatten() 展开为 1536 维向量
"""
import logging
import os
import subprocess
import tempfile
import librosa
import numpy as np
from scipy.signal import resample
logger = logging.getLogger(__name__)
# 目标采样率和时间帧数
TARGET_SR = 22050
TARGET_FRAMES = 128
VECTOR_DIM = 12 * TARGET_FRAMES # 1536
def _normalize_audio_ffmpeg(audio_path: str, output_path: str) -> None:
"""使用 ffmpeg 将音频标准化为 22050Hz / Mono / WAV。"""
cmd = [
"ffmpeg",
"-y",
"-i", audio_path,
"-ar", str(TARGET_SR),
"-ac", "1",
"-f", "wav",
output_path,
]
result = subprocess.run(
cmd,
capture_output=True,
text=True,
)
if result.returncode != 0:
raise RuntimeError(f"ffmpeg 转换失败: {result.stderr}")
def extract_chroma_feature(audio_path: str) -> np.ndarray:
"""从音频文件提取 1536 维 Chromagram 特征向量。
Args:
audio_path: 音频文件路径。
Returns:
shape 为 (1536,) 的 numpy 数组。
Raises:
FileNotFoundError: 音频文件不存在。
RuntimeError: ffmpeg 转换失败。
"""
if not os.path.isfile(audio_path):
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
# 1. 音频标准化:ffmpeg 转 WAV
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp_wav = tmp.name
try:
_normalize_audio_ffmpeg(audio_path, tmp_wav)
# 2. librosa 加载音频
y, _sr = librosa.load(tmp_wav, sr=TARGET_SR, mono=True)
# 3. 提取 CENS Chromagram (12×T),对速度变化和音色具有更强鲁棒性
chroma = librosa.feature.chroma_cens(y=y, sr=TARGET_SR)
# 4. 主音对齐:将全局能量最大的音级循环滚至第 0 行,实现转调不变性
tonic = int(np.argmax(chroma.sum(axis=1)))
if tonic != 0:
chroma = np.roll(chroma, -tonic, axis=0)
# 5. 时间归一化到 12×128
if chroma.shape[1] != TARGET_FRAMES:
chroma = resample(chroma, TARGET_FRAMES, axis=1)
# 6. 展开为 1536 维向量
feature = chroma.flatten().astype(np.float32)
assert feature.shape == (VECTOR_DIM,), (
f"特征维度错误: 期望 {VECTOR_DIM}, 实际 {feature.shape}"
)
return feature
finally:
# 清理临时文件
if os.path.exists(tmp_wav):
os.remove(tmp_wav)
def extract_chroma_matrix(audio_path: str) -> np.ndarray:
"""从音频文件提取 12×128 Chromagram 矩阵(未展平,供 DTW 精排使用)。
Returns:
shape 为 (12, 128) 的 numpy 数组,已做主音对齐。
"""
feature = extract_chroma_feature(audio_path)
return feature.reshape(12, TARGET_FRAMES)
"""作曲去重服务(入库 + 查询)。
查询流程:
1. Dejavu 指纹匹配(毫秒级,子序列匹配,支持 chorus_only / trim_intro)
- 命中(≥ 阈值)→ 直接返回 duplicate(短路)
2. 未命中 → Chromagram 12路 + DTW(百毫秒级)
- 返回结果
"""
import logging
import os
from dataclasses import dataclass, field
import numpy as np
import psycopg
from scipy.spatial.distance import cdist
from .extractor import TARGET_FRAMES, extract_chroma_feature, extract_chroma_matrix
from .dejavu_fingerprinter import fingerprint_audio
logger = logging.getLogger(__name__)
def _env_bool(name: str, default: bool) -> bool:
value = os.getenv(name)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "y", "on"}
def _env_int(name: str, default: int) -> int:
value = os.getenv(name)
if value is None:
return default
try:
return int(value)
except ValueError:
logger.warning("环境变量 %s=%r 不是整数,使用默认值 %d", name, value, default)
return default
def _env_float(name: str, default: float) -> float:
value = os.getenv(name)
if value is None:
return default
try:
return float(value)
except ValueError:
logger.warning("环境变量 %s=%r 不是数字,使用默认值 %.4f", name, value, default)
return default
@dataclass
class CompositionCandidate:
"""去重候选结果。"""
song_id: int
similarity: float
source: str = "chromagram"
@dataclass
class _DejavuMatch:
"""Dejavu offset 对齐后的命中结果。"""
song_id: int
aligned_count: int
total_collisions: int
@dataclass
class CompositionConfig:
"""作曲去重服务配置。"""
dsn: str = "postgresql:///lyric_dedup"
statement_timeout_ms: int = 30000
dtw_rerank_top_k: int = 20 # Cosine 召回后做 DTW 精排的候选数量
duplicate_threshold: float = _env_float("COMPOSITION_DUPLICATE_THRESHOLD", 0.85)
# Dejavu 指纹匹配配置
dejavu_enabled: bool = _env_bool("COMPOSITION_DEJAVU_ENABLED", True)
dejavu_match_threshold: int = _env_int("COMPOSITION_DEJAVU_MATCH_THRESHOLD", 20)
@dataclass
class CompositionDedupService:
"""作曲去重服务:特征入库 + 相似度查询。"""
config: CompositionConfig
_logger: logging.Logger = field(default_factory=lambda: logger, repr=False)
def ingest(self, song_id: int, audio_path: str) -> np.ndarray:
"""提取音频特征并写入数据库。
Args:
song_id: 歌曲 ID。
audio_path: 音频文件路径。
Returns:
提取的特征向量。
"""
feature = extract_chroma_feature(audio_path)
self._logger.info("提取 Chromagram 特征完成: song_id=%s, audio=%s", song_id, audio_path)
with psycopg.connect(self.config.dsn) as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
INSERT INTO composition_feature (song_id, feature_vector)
VALUES (%s, %s)
ON CONFLICT DO NOTHING
""",
(song_id, feature.tolist()),
)
conn.commit()
self._logger.info("Chromagram 特征入库完成: song_id=%s", song_id)
# Dejavu 指纹同时入库
if self.config.dejavu_enabled:
self._dejavu_ingest(song_id, audio_path)
return feature
def _dejavu_ingest(self, song_id: int, audio_path: str) -> None:
"""提取 Dejavu 指纹并写入 dejavu_fingerprints 表。"""
file_sha1, fingerprints = fingerprint_audio(audio_path)
if not fingerprints:
self._logger.warning("Dejavu 指纹为空: song_id=%s, audio=%s", song_id, audio_path)
return
with psycopg.connect(self.config.dsn) as conn:
with conn.cursor() as cursor:
# 先清理可能残留的旧指纹(幂等写入)
cursor.execute(
"DELETE FROM dejavu_fingerprints WHERE song_id = %s",
(song_id,),
)
# 批量写入
records = [(song_id, h, o) for h, o in fingerprints]
cursor.executemany(
"""
INSERT INTO dejavu_fingerprints (song_id, hash, "offset")
VALUES (%s, %s, %s)
""",
records,
)
conn.commit()
self._logger.info("Dejavu 指纹入库完成: song_id=%s, 指纹数=%d", song_id, len(fingerprints))
def query(self, audio_path: str, top_k: int = 100) -> list[CompositionCandidate]:
"""提取音频特征并查询相似结果。
流程:Dejavu 指纹短路匹配 → 12 路循环对齐 Cosine 召回 → DTW 精排。
"""
# 1. 优先尝试 Dejavu 指纹匹配(短路)
if self.config.dejavu_enabled:
match = self._dejavu_query(audio_path)
if match is not None:
self._logger.info(
"Dejavu 命中: song_id=%s, aligned_count=%d, total_collisions=%d, decision=duplicate",
match.song_id,
match.aligned_count,
match.total_collisions,
)
return [CompositionCandidate(song_id=match.song_id, similarity=1.0, source="dejavu")]
# 2. Dejavu 未命中或禁用,走现有 Chromagram 12路 + DTW 流程
return self._query_chroma(audio_path, top_k)
def check(self, audio_path: str, top_k: int = 100) -> bool:
"""按最终接口语义返回是否重复。"""
return self.candidates_indicate_duplicate(self.query(audio_path, top_k=top_k))
def candidates_indicate_duplicate(self, candidates: list[CompositionCandidate]) -> bool:
"""将候选结果转换为最终 duplicate bool。
最终接口只返回 true/false,因此判定只看当前查询的最佳候选是否超过阈值,
不依赖评测集里的 expected_song_id 是否出现在 top-k。
"""
if not candidates:
return False
return candidates[0].similarity >= self.config.duplicate_threshold
def _query_chroma(self, audio_path: str, top_k: int = 100) -> list[CompositionCandidate]:
"""Chromagram 12 路循环对齐 + DTW 精排查询。"""
chroma = extract_chroma_matrix(audio_path)
self._logger.info("提取 Chromagram 查询特征完成: audio=%s", audio_path)
# 1. 12 路循环对齐:穷举 12 种半音偏移,单条 SQL 内部展开,按 song_id 取最高 Cosine 相似度
shift_vecs = [
np.roll(chroma, -shift, axis=0).flatten().astype(np.float32).tolist()
for shift in range(12)
]
# 用 VALUES 展开 12 个偏移向量,LATERAL 子查询对每个偏移各触发一次 HNSW 扫描
values_clause = ", ".join(f"({i}, %s::vector)" for i in range(12))
sql = f"""
WITH shifts(shift_id, vec) AS (
VALUES {values_clause}
),
candidates AS (
SELECT
cf.song_id,
1 - (cf.feature_vector <=> s.vec) AS sim
FROM shifts s
CROSS JOIN LATERAL (
SELECT song_id, feature_vector
FROM composition_feature
ORDER BY feature_vector <=> s.vec
LIMIT %s
) cf
)
SELECT song_id, MAX(sim) AS similarity
FROM candidates
GROUP BY song_id
ORDER BY similarity DESC
LIMIT %s
"""
best: dict[int, float] = {}
with psycopg.connect(self.config.dsn) as conn:
with conn.cursor() as cursor:
cursor.execute(
f"SET statement_timeout = {int(self.config.statement_timeout_ms)}"
)
cursor.execute(sql, (*shift_vecs, top_k, top_k))
for song_id, sim in cursor.fetchall():
best[int(song_id)] = float(sim)
# 2. 取 Top dtw_rerank_top_k,从库中取原始向量做 DTW 精排
top = sorted(best.items(), key=lambda x: x[1], reverse=True)
rerank_ids = [sid for sid, _ in top[:self.config.dtw_rerank_top_k]]
with conn.cursor() as cursor:
cursor.execute(
"SELECT song_id, feature_vector::float4[] FROM composition_feature WHERE song_id = ANY(%s)",
(rerank_ids,),
)
db_rows = cursor.fetchall()
reranked = []
for song_id, fv in db_rows:
cand_chroma = np.array(fv, dtype=np.float32).reshape(12, TARGET_FRAMES)
dtw_sim = _best_shifted_dtw_similarity(chroma, cand_chroma)
reranked.append(CompositionCandidate(song_id=int(song_id), similarity=dtw_sim))
reranked.sort(key=lambda c: c.similarity, reverse=True)
rerank_id_set = {c.song_id for c in reranked}
rest = [
CompositionCandidate(song_id=sid, similarity=sim)
for sid, sim in top[self.config.dtw_rerank_top_k:]
if sid not in rerank_id_set
]
result = reranked + rest
top_summary = ", ".join(
f"{candidate.song_id}:{candidate.similarity:.4f}"
for candidate in result[:5]
)
self._logger.info(
"Chromagram 查询完成: 返回 %d 个候选, top=%s",
len(result),
top_summary or "[]",
)
return result
def _dejavu_query(self, audio_path: str) -> _DejavuMatch | None:
"""Dejavu 指纹查询,返回 offset 对齐后碰撞数最多的 song_id。
只统计 hash 总碰撞数会让常见频谱峰值、噪声片段或大库随机碰撞直接短路成
similarity=1.0。Dejavu 的关键判据是同一首候选歌里,多个 hash 碰撞的
db_offset - query_offset 落在同一个时间偏移上。
Returns:
命中结果,未命中返回 None。
"""
file_sha1, fingerprints = fingerprint_audio(audio_path)
if not fingerprints:
return None
hashes = [h for h, _ in fingerprints]
offsets = [int(o) for _, o in fingerprints]
with psycopg.connect(self.config.dsn) as conn:
with conn.cursor() as cursor:
# 先按 hash 找碰撞,再按每个 song_id 的 offset delta 聚类。
cursor.execute(
"""
WITH query_fp(hash, query_offset) AS (
SELECT *
FROM unnest(%s::bytea[], %s::int[])
),
aligned AS (
SELECT
db.song_id,
db."offset" - query_fp.query_offset AS offset_delta,
COUNT(*) AS aligned_count
FROM query_fp
JOIN dejavu_fingerprints db
ON db.hash = query_fp.hash
GROUP BY db.song_id, offset_delta
)
SELECT
song_id,
MAX(aligned_count) AS best_aligned_count,
SUM(aligned_count) AS total_collisions
FROM aligned
GROUP BY song_id
ORDER BY best_aligned_count DESC, total_collisions DESC
LIMIT 1
""",
(hashes, offsets),
)
row = cursor.fetchone()
if row is None:
return None
sid, aligned_count, total_collisions = row
aligned_count = int(aligned_count)
if aligned_count >= self.config.dejavu_match_threshold:
return _DejavuMatch(
song_id=int(sid),
aligned_count=aligned_count,
total_collisions=int(total_collisions),
)
return None
def _dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float:
"""计算两个 12×T Chromagram 矩阵之间的 DTW 相似度(映射到 [0,1])。"""
# 帧间欧氏距离矩阵
cost = cdist(query.T, candidate.T, metric="euclidean")
n, m = cost.shape
dp = np.full((n, m), np.inf)
dp[0, 0] = cost[0, 0]
for i in range(1, n):
dp[i, 0] = dp[i - 1, 0] + cost[i, 0]
for j in range(1, m):
dp[0, j] = dp[0, j - 1] + cost[0, j]
for i in range(1, n):
for j in range(1, m):
dp[i, j] = cost[i, j] + min(dp[i - 1, j], dp[i, j - 1], dp[i - 1, j - 1])
dtw_dist = dp[n - 1, m - 1] / (n + m)
# 转换为相似度:距离越小相似度越高
return float(1.0 / (1.0 + dtw_dist))
def _best_shifted_dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float:
"""计算 12 路音高循环位移下的最佳 DTW 相似度。"""
return max(
_dtw_similarity(np.roll(query, -shift, axis=0), candidate)
for shift in range(12)
)
"""Cosine 相似度计算与去重判定。"""
from enum import Enum
import numpy as np
DUPLICATE_THRESHOLD = 0.95
SUSPECTED_THRESHOLD = 0.85
class SimilarityDecision(Enum):
DUPLICATE = "duplicate"
SUSPECTED = "suspected"
NEW = "new"
class CompositionSimilarity:
@staticmethod
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
if norm_a == 0.0 or norm_b == 0.0:
return 0.0
return float(np.dot(a, b) / (norm_a * norm_b))
@staticmethod
def classify_similarity(similarity: float) -> SimilarityDecision:
if similarity >= DUPLICATE_THRESHOLD:
return SimilarityDecision.DUPLICATE
if similarity >= SUSPECTED_THRESHOLD:
return SimilarityDecision.SUSPECTED
return SimilarityDecision.NEW
@staticmethod
def compare(a: np.ndarray, b: np.ndarray) -> tuple[float, SimilarityDecision]:
sim = CompositionSimilarity.cosine_similarity(a, b)
return sim, CompositionSimilarity.classify_similarity(sim)
"""曲去重评估脚本。
对 queries.csv 中每条查询音频调用 CompositionDedupService.query(),
按最终接口语义用 top1 分数阈值输出 predicted_duplicate true/false。
expected_song_id 的 top-k/top1 命中只作为诊断字段。
输出 precision/recall/F1。
用法:
python scripts/evaluate_composition.py \
--dsn "postgresql:///lyric_dedup" \
--queries composition_dedup/composition_testset4/queries.csv \
--out composition_dedup/composition_eval/composition_eval_result_v3.csv
"""
import argparse
import csv
import json
import logging
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from composition_dedup.service import CompositionConfig, CompositionDedupService
logger = logging.getLogger(__name__)
def _parse_csv_filter(value: str | None) -> set[str] | None:
if value is None:
return None
items = {item.strip() for item in value.split(",") if item.strip()}
return items or None
def _song_id_from_audio_path(audio_path: str) -> str:
"""从音频文件名开头提取 song_id。"""
return Path(audio_path).stem.split("_", 1)[0]
def main() -> None:
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
parser = argparse.ArgumentParser()
parser.add_argument("--dsn", required=True)
parser.add_argument("--queries", required=True, help="queries.csv 路径")
parser.add_argument("--out", required=True, help="逐条结果输出 CSV")
parser.add_argument("--top-k", type=int, default=10)
parser.add_argument("--duplicate-threshold", type=float, help="覆盖 COMPOSITION_DUPLICATE_THRESHOLD")
parser.add_argument("--variants", help="只评测指定 variant,逗号分隔,如 pitch_up1,pitch_down1")
parser.add_argument("--sample-classes", help="只评测指定 sample_class,逗号分隔,如 dsp,negative")
parser.add_argument("--expected", choices=["duplicate", "not_duplicate"], help="只评测指定 expected 类型")
args = parser.parse_args()
config = CompositionConfig(dsn=args.dsn)
if args.duplicate_threshold is not None:
config.duplicate_threshold = args.duplicate_threshold
service = CompositionDedupService(config=config)
with open(args.queries, newline="", encoding="utf-8") as f:
rows = list(csv.DictReader(f))
variant_filter = _parse_csv_filter(args.variants)
sample_class_filter = _parse_csv_filter(args.sample_classes)
original_count = len(rows)
if variant_filter is not None:
rows = [r for r in rows if (r.get("variant") or "") in variant_filter]
if sample_class_filter is not None:
rows = [r for r in rows if (r.get("sample_class") or "") in sample_class_filter]
if args.expected is not None:
rows = [r for r in rows if r["expected"].strip().lower() == args.expected]
logger.info(
"评测样本过滤: 原始 %d 条,保留 %d 条 (variants=%s, sample_classes=%s, expected=%s)",
original_count,
len(rows),
",".join(sorted(variant_filter)) if variant_filter else "ALL",
",".join(sorted(sample_class_filter)) if sample_class_filter else "ALL",
args.expected or "ALL",
)
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
result_rows = []
for i, row in enumerate(rows, 1):
audio_path = row["audio_path"]
query_song_id = row.get("song_id") or _song_id_from_audio_path(audio_path)
audio_song_id = _song_id_from_audio_path(audio_path)
expected_song_id = str(row["expected_song_id"])
expected_dup = row["expected"].strip().lower() == "duplicate"
invalid_negative_pair = (not expected_dup) and audio_song_id == expected_song_id
try:
candidates = service.query(audio_path, top_k=args.top_k)
except Exception as e:
logger.error("[%d/%d] 查询失败: %s, %s", i, len(rows), audio_path, e)
result_rows.append({
"query_song_id": query_song_id,
"audio_song_id": audio_song_id,
"audio_path": audio_path,
"variant": row.get("variant", ""),
"sample_class": row.get("sample_class", ""),
"expected_song_id": expected_song_id,
"expected": row["expected"],
"top1_song_id": "",
"top1_similarity": "",
"top1_source": "",
"top1_hit": False,
"topk_hit": False,
"expected_rank": "",
"expected_similarity": "",
"invalid_negative_pair": invalid_negative_pair,
"invalid_boolean_sample": False,
"expected_duplicate": expected_dup,
"predicted_duplicate": False,
"correct": not expected_dup, # 查询失败视为 not_duplicate
"error": str(e),
})
continue
top1 = candidates[0] if candidates else None
top1_song_id = str(top1.song_id) if top1 else ""
top1_sim = round(top1.similarity, 4) if top1 else ""
top1_source = top1.source if top1 else ""
# 诊断召回:expected_song_id 是否进入 top1/top-k。
top1_hit = bool(expected_song_id) and top1_song_id == expected_song_id
topk_hit = bool(expected_song_id) and any(str(c.song_id) == expected_song_id for c in candidates)
expected_rank = ""
expected_similarity = ""
if expected_song_id:
for rank, candidate in enumerate(candidates, 1):
if str(candidate.song_id) == expected_song_id:
expected_rank = rank
expected_similarity = round(candidate.similarity, 4)
break
# 最终接口语义:只返回 duplicate true/false。
predicted_dup = service.candidates_indicate_duplicate(candidates)
correct = expected_dup == predicted_dup
invalid_boolean_sample = (
(not expected_dup)
and bool(top1)
and top1_song_id == audio_song_id
and predicted_dup
)
result_rows.append({
"query_song_id": query_song_id,
"audio_song_id": audio_song_id,
"audio_path": audio_path,
"variant": row.get("variant", ""),
"sample_class": row.get("sample_class", ""),
"expected_song_id": expected_song_id,
"expected": row["expected"],
"top1_song_id": top1_song_id,
"top1_similarity": top1_sim,
"top1_source": top1_source,
"top1_hit": top1_hit,
"topk_hit": topk_hit,
"expected_rank": expected_rank,
"expected_similarity": expected_similarity,
"invalid_negative_pair": invalid_negative_pair,
"invalid_boolean_sample": invalid_boolean_sample,
"expected_duplicate": expected_dup,
"predicted_duplicate": predicted_dup,
"correct": correct,
"error": "",
})
logger.info(
"[%d/%d] variant=%s source=%s expected=%s predicted_duplicate=%s threshold=%.4f expected_song_id=%s top1=%s sim=%s top1_hit=%s topk_hit=%s expected_rank=%s expected_sim=%s correct=%s",
i,
len(rows),
row.get("variant", ""),
top1_source or "-",
row["expected"],
predicted_dup,
service.config.duplicate_threshold,
expected_song_id,
top1_song_id or "-",
top1_sim if top1_sim != "" else "-",
top1_hit,
topk_hit,
expected_rank if expected_rank != "" else "-",
expected_similarity if expected_similarity != "" else "-",
correct,
)
if i % 10 == 0 or i == len(rows):
logger.info("[%d/%d]", i, len(rows))
# 写逐条结果
fieldnames = ["query_song_id", "audio_song_id", "audio_path", "variant", "sample_class",
"expected_song_id", "expected", "top1_song_id", "top1_similarity", "top1_source",
"top1_hit", "topk_hit", "expected_rank", "expected_similarity",
"invalid_negative_pair", "invalid_boolean_sample",
"expected_duplicate", "predicted_duplicate", "correct", "error"]
with out_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(result_rows)
# 汇总指标
def _metrics(rows: list[dict]) -> dict:
tp = sum(1 for r in rows if r["expected_duplicate"] and r["predicted_duplicate"])
fp = sum(1 for r in rows if not r["expected_duplicate"] and r["predicted_duplicate"])
tn = sum(1 for r in rows if not r["expected_duplicate"] and not r["predicted_duplicate"])
fn = sum(1 for r in rows if r["expected_duplicate"] and not r["predicted_duplicate"])
precision = tp / (tp + fp) if tp + fp else 0.0
recall = tp / (tp + fn) if tp + fn else 0.0
f1 = 2 * precision * recall / (precision + recall) if precision + recall else 0.0
accuracy = (tp + tn) / len(rows) if rows else 0.0
return {
"total": len(rows),
"accuracy": round(accuracy, 4),
"precision": round(precision, 4),
"recall": round(recall, 4),
"f1": round(f1, 4),
"tp": tp,
"fp": fp,
"tn": tn,
"fn": fn,
}
metrics = _metrics(result_rows)
valid_rows = [
r for r in result_rows
if not r["invalid_negative_pair"] and not r["invalid_boolean_sample"]
]
valid_metrics = _metrics(valid_rows)
summary = {
"total": len(result_rows),
"filters": {
"variants": sorted(variant_filter) if variant_filter else None,
"sample_classes": sorted(sample_class_filter) if sample_class_filter else None,
"expected": args.expected,
"original_total": original_count,
},
"duplicate_threshold": service.config.duplicate_threshold,
"invalid_negative_pairs": sum(1 for r in result_rows if r["invalid_negative_pair"]),
"invalid_boolean_samples": sum(1 for r in result_rows if r["invalid_boolean_sample"]),
"accuracy": metrics["accuracy"],
"precision": metrics["precision"],
"recall": metrics["recall"],
"f1": metrics["f1"],
"tp": metrics["tp"], "fp": metrics["fp"], "tn": metrics["tn"], "fn": metrics["fn"],
"valid_only": valid_metrics,
"out": str(out_path),
}
# 按 variant 分组,方便看各种变换的通过率
from collections import defaultdict
by_variant: dict[str, dict] = defaultdict(lambda: {"correct": 0, "total": 0})
for r in result_rows:
v = r["variant"] or "unknown"
by_variant[v]["total"] += 1
if r["correct"]:
by_variant[v]["correct"] += 1
summary["by_variant"] = {
v: {"accuracy": round(d["correct"] / d["total"], 4), "total": d["total"]}
for v, d in sorted(by_variant.items())
}
# 按 sample_class 分组
by_class: dict[str, dict] = defaultdict(lambda: {"correct": 0, "total": 0})
for r in result_rows:
sc = r.get("sample_class") or "unknown"
by_class[sc]["total"] += 1
if r["correct"]:
by_class[sc]["correct"] += 1
summary["by_sample_class"] = {
sc: {"accuracy": round(d["correct"] / d["total"], 4), "total": d["total"]}
for sc, d in sorted(by_class.items())
}
summary_path = out_path.with_suffix(".summary.json")
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps(summary, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()
"""生成曲去重评估测试集。
从音频目录随机抽取若干首参照歌入库,对每首用 ffmpeg 生成多个变换版本,
覆盖曲去重测试样本类型.md 中第一类(数字信号变换)和第三类(困难正样本)的可合成部分。
负样本从未入库的 holdout 歌曲生成,以匹配最终接口 duplicate true/false 语义。
用法:
python scripts/generate_composition_testset.py \
--audio-dir /Volumes/移动硬盘/lyric_audio_type11 \
--negative-audio-dir /Volumes/移动硬盘/composition_test \
--out-dir composition_dedup/composition_testset \
--num-songs 80 \
--num-negative-songs 40 \
--negative-variants \
--seed 123
输出:
reference.csv — 参照曲(原始文件),需提前入库
queries.csv — 查询曲,带 variant 和 expected 标注
"""
import argparse
import csv
import logging
import random
import subprocess
import sys
from pathlib import Path
try:
from tqdm import tqdm
except ImportError:
tqdm = None
def _tqdm(iterable, **kwargs):
if tqdm is not None:
return tqdm(iterable, **kwargs)
total = kwargs.get("total", None) or (len(iterable) if hasattr(iterable, "__len__") else None)
desc = kwargs.get("desc", "")
class _Simple:
def __init__(self):
self._i = 0
def __iter__(self):
for item in iterable:
self._i += 1
if total:
print(f"\r{desc}: {self._i}/{total}", end="", flush=True)
yield item
if total:
print()
return _Simple()
logger = logging.getLogger(__name__)
# --------------------------------------------------------------------------
# 第一类:数字信号变换
# --------------------------------------------------------------------------
DSP_VARIANTS: list[tuple[str, str]] = [
# Pitch Shift(±1、±2 半音)
("pitch_up1", "asetrate=22050*1.0595,aresample=22050"), # +1 半音
("pitch_up2", "asetrate=22050*1.1225,aresample=22050"), # +2 半音
("pitch_down1", "asetrate=22050*0.9439,aresample=22050"), # -1 半音
("pitch_down2", "asetrate=22050*0.8909,aresample=22050"), # -2 半音
# Tempo Shift
("tempo_slow", "atempo=0.90"), # 0.9x
("tempo_fast", "atempo=1.10"), # 1.1x
("tempo_faster","atempo=1.20"), # 1.2x
# EQ 变换
("lowpass", "lowpass=f=4000"), # 低通
("highpass", "highpass=f=800"), # 高通
("eq_mid", "equalizer=f=1000:width_type=o:width=2:g=-6"), # 中频衰减
# 压缩编码往返(编码为 mp3 再解回 wav,模拟有损压缩引入的失真)
("codec_320k", "acodec=libmp3lame,b:a=320k"),
("codec_128k", "acodec=libmp3lame,b:a=128k"),
]
# --------------------------------------------------------------------------
# 第三类:困难正样本(可合成部分)
# --------------------------------------------------------------------------
HARD_POSITIVE_VARIANTS: list[tuple[str, str]] = [
# 前奏删减:从 20% 处开始截取(模拟删前奏版本)
("trim_intro", None), # 特殊处理,用 -ss 参数
# 只保留副歌:截取中间 40%(模拟短视频截段)
("chorus_only", None), # 特殊处理,用 -ss + -t 参数
# Pitch + Tempo 叠加(模拟 Live 版同时有调整)
("pitch_up1_tempo_slow", "asetrate=22050*1.0595,aresample=22050,atempo=0.92"),
]
# 负样本变体只使用相对温和的处理,避免把负样本评估变成极端音质测试。
NEGATIVE_VARIANTS: list[tuple[str, str | None]] = [
("negative_lowpass", "lowpass=f=4000"),
("negative_codec_128k", "acodec=libmp3lame,b:a=128k"),
]
def _ffmpeg_variant(src: Path, dst: Path, af: str) -> bool:
"""普通 audio filter 变换。"""
# 压缩编码往返需要两步:先编码为 mp3,再解回 wav
if "acodec" in af:
import tempfile
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp:
tmp_mp3 = Path(tmp.name)
ok1 = _run_ffmpeg([
"ffmpeg", "-y", "-i", str(src),
"-ar", "22050", "-ac", "1",
"-codec:a", "libmp3lame", "-b:a", af.split("b:a=")[1],
str(tmp_mp3),
])
if not ok1:
return False
ok2 = _run_ffmpeg([
"ffmpeg", "-y", "-i", str(tmp_mp3),
"-ar", "22050", "-ac", "1",
str(dst),
])
tmp_mp3.unlink(missing_ok=True)
return ok2
cmd = [
"ffmpeg", "-y", "-i", str(src),
"-af", af,
"-ar", "22050", "-ac", "1",
str(dst),
]
return _run_ffmpeg(cmd)
def _ffmpeg_trim(src: Path, dst: Path, start_ratio: float, duration_ratio: float) -> bool:
"""按相对位置截取片段。需要先探测时长。"""
duration = _probe_duration(src)
if duration is None:
return False
ss = duration * start_ratio
t = duration * duration_ratio
return _run_ffmpeg([
"ffmpeg", "-y", "-i", str(src),
"-ss", f"{ss:.3f}", "-t", f"{t:.3f}",
"-ar", "22050", "-ac", "1",
str(dst),
])
def _probe_duration(src: Path) -> float | None:
result = subprocess.run(
["ffprobe", "-v", "error", "-show_entries", "format=duration",
"-of", "default=noprint_wrappers=1:nokey=1", str(src)],
capture_output=True, text=True,
)
try:
return float(result.stdout.strip())
except ValueError:
return None
def _run_ffmpeg(cmd: list[str]) -> bool:
result = subprocess.run(cmd, capture_output=True)
return result.returncode == 0
def _song_id(path: Path) -> str:
return path.stem.split("_")[0]
def _discover_wavs(audio_dir: Path) -> list[Path]:
return [
f for f in sorted(audio_dir.rglob("*.wav"))
if f.is_file() and not f.name.startswith("._")
]
def main() -> None:
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
parser = argparse.ArgumentParser()
parser.add_argument("--audio-dir", required=True)
parser.add_argument(
"--negative-audio-dir",
default="/Volumes/移动硬盘/lyric_audio_type11",
help="负样本来源目录;会排除 --audio-dir 中已存在的 song_id",
)
parser.add_argument("--out-dir", required=True)
parser.add_argument("--num-songs", type=int, default=20, help="抽取歌曲数量")
parser.add_argument("--num-negative-songs", type=int, default=20, help="抽取未入库负样本歌曲数量")
parser.add_argument(
"--negative-variants",
action="store_true",
help="为负样本额外生成 codec/lowpass 变体",
)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
audio_dir = Path(args.audio_dir)
negative_audio_dir = Path(args.negative_audio_dir)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
variants_dir = out_dir / "variants"
variants_dir.mkdir(exist_ok=True)
all_wavs = _discover_wavs(audio_dir)
negative_wavs = _discover_wavs(negative_audio_dir)
if len(negative_wavs) < args.num_negative_songs:
logger.error(
"负样本目录下只有 %d 个 wav,少于 --num-negative-songs = %d",
len(negative_wavs),
args.num_negative_songs,
)
sys.exit(1)
# 从参照目录中排除负样本目录已有的 song_id,避免参照曲与负样本重叠
negative_song_ids = {_song_id(wav) for wav in negative_wavs}
all_wavs = [wav for wav in all_wavs if _song_id(wav) not in negative_song_ids]
if len(all_wavs) < args.num_songs:
logger.error(
"参照目录排除负样本 song_id 后只有 %d 个 wav,少于 --num-songs = %d",
len(all_wavs),
args.num_songs,
)
sys.exit(1)
random.seed(args.seed)
selected = random.sample(all_wavs, args.num_songs)
negative_selected = random.sample(negative_wavs, args.num_negative_songs)
logger.info(
"已抽取 %d 首参照歌,%d 首未入库负样本歌(负样本来源: %s,已排除 %d 个负样本 song_id)",
len(selected),
len(negative_selected),
negative_audio_dir,
len(negative_song_ids),
)
ref_rows = []
query_rows = []
for wav in _tqdm(selected, desc="生成正样本变体", total=len(selected)):
song_id = _song_id(wav)
ref_rows.append({
"song_id": song_id,
"audio_path": str(wav),
"variant": "original",
})
# 第一类:DSP 变换
for variant_name, af in DSP_VARIANTS:
dst = variants_dir / f"{song_id}_{variant_name}.wav"
ok = _ffmpeg_variant(wav, dst, af)
if not ok:
logger.warning("DSP 变换失败,跳过: %s %s", wav.name, variant_name)
continue
query_rows.append({
"song_id": song_id,
"audio_path": str(dst),
"variant": variant_name,
"sample_class": "dsp",
"expected_song_id": song_id,
"expected": "duplicate",
})
# 第三类:困难正样本
for variant_name, af in HARD_POSITIVE_VARIANTS:
dst = variants_dir / f"{song_id}_{variant_name}.wav"
if variant_name == "trim_intro":
ok = _ffmpeg_trim(wav, dst, start_ratio=0.20, duration_ratio=0.80)
elif variant_name == "chorus_only":
ok = _ffmpeg_trim(wav, dst, start_ratio=0.30, duration_ratio=0.40)
else:
ok = _ffmpeg_variant(wav, dst, af)
if not ok:
logger.warning("困难正样本生成失败,跳过: %s %s", wav.name, variant_name)
continue
query_rows.append({
"song_id": song_id,
"audio_path": str(dst),
"variant": variant_name,
"sample_class": "hard_positive",
"expected_song_id": song_id,
"expected": "duplicate",
})
# Boolean 接口负样本:查询音频不能在 reference.csv 入库集合中。
# expected_song_id 留空,表示没有目标重复曲;评测只看最终 duplicate true/false。
for wav in _tqdm(negative_selected, desc="生成负样本变体", total=len(negative_selected)):
song_id = _song_id(wav)
query_rows.append({
"song_id": song_id,
"audio_path": str(wav),
"variant": "negative_original",
"sample_class": "negative",
"expected_song_id": "",
"expected": "not_duplicate",
})
if not args.negative_variants:
continue
for variant_name, af in NEGATIVE_VARIANTS:
dst = variants_dir / f"{song_id}_{variant_name}.wav"
ok = _ffmpeg_variant(wav, dst, af)
if not ok:
logger.warning("负样本变换失败,跳过: %s %s", wav.name, variant_name)
continue
query_rows.append({
"song_id": song_id,
"audio_path": str(dst),
"variant": variant_name,
"sample_class": "negative",
"expected_song_id": "",
"expected": "not_duplicate",
})
ref_path = out_dir / "reference.csv"
query_path = out_dir / "queries.csv"
fieldnames = ["song_id", "audio_path", "variant", "sample_class", "expected_song_id", "expected"]
with ref_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=["song_id", "audio_path", "variant"])
writer.writeheader()
writer.writerows(ref_rows)
with query_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(query_rows)
pos = sum(1 for r in query_rows if r["expected"] == "duplicate")
neg = sum(1 for r in query_rows if r["expected"] == "not_duplicate")
logger.info("参照集: %s (%d 条)", ref_path, len(ref_rows))
logger.info("查询集: %s (%d 条,正样本 %d,负样本 %d)", query_path, len(query_rows), pos, neg)
# 按 sample_class 统计
from collections import Counter
by_class = Counter(r["sample_class"] for r in query_rows)
for cls, cnt in sorted(by_class.items()):
logger.info(" %-20s %d 条", cls, cnt)
if __name__ == "__main__":
main()
"""批量导入音频文件到 composition_feature 表。
用法:
python scripts/import_audio_composition.py \
--dsn "postgresql:///lyric_dedup" \
--audio-dir /Volumes/移动硬盘/composition_test \
--ext .wav
支持通过 --file-list 指定一个包含音频路径的文本文件(每行一个路径)。
"""
import argparse
import logging
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from tqdm import tqdm
from composition_dedup.service import CompositionConfig, CompositionDedupService
logger = logging.getLogger(__name__)
SUPPORTED_EXTENSIONS = {".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac", ".wma"}
def discover_audio_files(audio_dir: str | None, file_list: str | None, ext: str) -> list[tuple[str, str]]:
"""发现音频文件,返回 [(song_id, audio_path), ...] 列表。
优先使用 --file-list,否则扫描 --audio-dir 目录。
song_id 使用文件名的数字部分或路径的哈希值。
"""
results = []
if file_list:
with open(file_list, "r", encoding="utf-8") as f:
for line in f:
path = line.strip()
if not path:
continue
song_id = _extract_song_id(path)
results.append((song_id, path))
elif audio_dir:
audio_dir_path = Path(audio_dir)
for audio_file in sorted(audio_dir_path.rglob(f"*{ext}")):
if audio_file.is_file() and not audio_file.name.startswith("._"):
song_id = _extract_song_id(str(audio_file))
results.append((song_id, str(audio_file)))
else:
print("错误: 请指定 --audio-dir 或 --file-list")
sys.exit(1)
return results
def _extract_song_id(path: str) -> str:
"""从路径中提取 song_id。
优先取文件名第一段(下划线前),若为纯数字则使用,否则用路径哈希。
"""
name = Path(path).stem
prefix = name.split("_")[0]
if prefix.isdigit():
return prefix
import hashlib
return str(int(hashlib.md5(path.encode()).hexdigest()[:8], 16))
def main() -> None:
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
parser = argparse.ArgumentParser(description="批量导入音频文件到 composition_feature 表")
parser.add_argument("--dsn", required=True, help="PostgreSQL DSN 连接串")
parser.add_argument("--audio-dir", help="音频文件目录")
parser.add_argument("--file-list", help="音频文件路径列表文件")
parser.add_argument("--ext", default=".wav", help="音频文件扩展名(默认 .wav)")
parser.add_argument("--batch-size", type=int, default=10, help="批次大小(默认 10)")
parser.add_argument("--clear", action="store_true", help="导入前清空 composition_feature 和 dejavu_fingerprints 表数据(保留表结构)")
args = parser.parse_args()
config = CompositionConfig(dsn=args.dsn)
service = CompositionDedupService(config=config)
if args.clear:
import psycopg
with psycopg.connect(args.dsn) as conn:
with conn.cursor() as cur:
cur.execute("TRUNCATE TABLE composition_feature, dejavu_fingerprints")
conn.commit()
logger.info("已清空 composition_feature 和 dejavu_fingerprints 表")
audio_files = discover_audio_files(args.audio_dir, args.file_list, args.ext)
logger.info("发现 %d 个音频文件", len(audio_files))
success_count = 0
fail_count = 0
for start in tqdm(range(0, len(audio_files), args.batch_size), desc="导入进度"):
batch = audio_files[start:start + args.batch_size]
for song_id, audio_path in batch:
try:
service.ingest(song_id=int(song_id), audio_path=audio_path)
success_count += 1
except Exception as e:
logger.error("导入失败: song_id=%s, path=%s, error=%s", song_id, audio_path, e)
fail_count += 1
logger.info("导入完成: 成功 %d, 失败 %d", success_count, fail_count)
if __name__ == "__main__":
main()
......@@ -40,3 +40,33 @@ on lyric_lines (line_hash);
create index if not exists lyric_lines_lyric_id_idx
on lyric_lines (lyric_id);
create extension if not exists vector;
create table if not exists composition_feature (
id bigserial primary key,
song_id bigint not null unique,
feature_vector vector(1536) not null,
created_at timestamptz not null default now()
);
create index if not exists composition_feature_hnsw_idx
on composition_feature
using hnsw (feature_vector vector_cosine_ops)
with (m = 16, ef_construction = 64);
create table if not exists dejavu_fingerprints (
id bigserial primary key,
song_id bigint not null references composition_feature(song_id) on delete cascade,
hash bytea not null,
"offset" int not null
);
create index if not exists idx_fingerprints_hash
on dejavu_fingerprints (hash);
create index if not exists idx_fingerprints_hash_song_offset
on dejavu_fingerprints (hash, song_id, "offset");
create index if not exists idx_fingerprints_song_id
on dejavu_fingerprints (song_id);
......
"""作曲去重模块测试。
测试覆盖:
- Chromagram 提取
- 时间归一化输出维度
- Cosine 相似度计算
- 向量展开维度为 1536
"""
import os
import tempfile
import wave
import numpy as np
import pytest
from scipy.signal import resample
from composition_dedup.extractor import extract_chroma_feature, _normalize_audio_ffmpeg
from composition_dedup.similarity import (
CompositionSimilarity,
SimilarityDecision,
DUPLICATE_THRESHOLD,
SUSPECTED_THRESHOLD,
)
def _generate_test_wav(duration_sec: float = 1.0, sample_rate: int = 22050, frequency: float = 440.0) -> str:
"""生成测试用的 WAV 文件(正弦波)。
Args:
duration_sec: 持续时间(秒)。
sample_rate: 采样率。
frequency: 频率(Hz)。
Returns:
临时 WAV 文件路径。
"""
t = np.linspace(0, duration_sec, int(sample_rate * duration_sec), endpoint=False)
audio_data = (0.5 * np.sin(2 * np.pi * frequency * t)).astype(np.float32)
tmp_path = tempfile.mktemp(suffix=".wav")
with wave.open(tmp_path, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2) # 16-bit
wf.setframerate(sample_rate)
wf.writeframes((audio_data * 32767).astype(np.int16).tobytes())
return tmp_path
class TestChromaExtraction:
"""Chromagram 提取测试。"""
def test_extract_chroma_returns_1536_dim(self):
"""测试 Chromagram 提取返回 1536 维向量。"""
wav_path = _generate_test_wav(duration_sec=2.0, frequency=440.0)
try:
feature = extract_chroma_feature(wav_path)
assert isinstance(feature, np.ndarray)
assert feature.shape == (1536,), f"期望 (1536,), 实际 {feature.shape}"
assert feature.dtype == np.float32
finally:
if os.path.exists(wav_path):
os.remove(wav_path)
def test_extract_chroma_file_not_found(self):
"""测试不存在的音频文件抛出 FileNotFoundError。"""
with pytest.raises(FileNotFoundError):
extract_chroma_feature("/nonexistent/path/audio.mp3")
def test_extract_chroma_different_frequencies(self):
"""测试不同频率的音频产生不同特征。"""
wav_a = _generate_test_wav(duration_sec=2.0, frequency=440.0)
wav_b = _generate_test_wav(duration_sec=2.0, frequency=880.0)
try:
feature_a = extract_chroma_feature(wav_a)
feature_b = extract_chroma_feature(wav_b)
# 不同频率的音频特征不应完全相同
assert not np.allclose(feature_a, feature_b)
finally:
for path in [wav_a, wav_b]:
if os.path.exists(path):
os.remove(path)
def test_extract_chroma_same_audio_consistent(self):
"""测试同一音频多次提取结果一致。"""
wav_path = _generate_test_wav(duration_sec=1.0, frequency=440.0)
try:
feature_1 = extract_chroma_feature(wav_path)
feature_2 = extract_chroma_feature(wav_path)
np.testing.assert_array_almost_equal(feature_1, feature_2, decimal=5)
finally:
if os.path.exists(wav_path):
os.remove(wav_path)
class TestTimeNormalization:
"""时间归一化测试。"""
def test_resample_chroma_to_128_frames(self):
"""测试 Chromagram 时间归一化到 128 帧。"""
# 模拟不同长度的 Chromagram
for num_frames in [100, 256, 512, 1000, 2000]:
chroma = np.random.rand(12, num_frames).astype(np.float32)
if chroma.shape[1] != 128:
chroma = resample(chroma, 128, axis=1)
assert chroma.shape == (12, 128), f"帧数归一化失败: {chroma.shape}"
def test_flatten_to_1536(self):
"""测试展平后维度为 1536。"""
chroma = np.random.rand(12, 128).astype(np.float32)
feature = chroma.flatten()
assert feature.shape[0] == 12 * 128 == 1536
class TestCosineSimilarity:
"""Cosine 相似度计算测试。"""
def test_identical_vectors(self):
"""测试相同向量相似度为 1。"""
vec = np.random.rand(1536).astype(np.float32)
sim = CompositionSimilarity.cosine_similarity(vec, vec)
assert abs(sim - 1.0) < 1e-6
def test_orthogonal_vectors(self):
"""测试正交向量相似度接近 0。"""
vec_a = np.zeros(1536)
vec_a[0] = 1.0
vec_b = np.zeros(1536)
vec_b[1] = 1.0
sim = CompositionSimilarity.cosine_similarity(vec_a, vec_b)
assert abs(sim) < 1e-6
def test_zero_vector(self):
"""测试零向量返回 0 相似度。"""
vec_a = np.random.rand(1536).astype(np.float32)
vec_b = np.zeros(1536)
sim = CompositionSimilarity.cosine_similarity(vec_a, vec_b)
assert sim == 0.0
def test_similarity_range(self):
"""测试相似度值在 [0, 1] 范围内。"""
vec_a = np.random.rand(1536).astype(np.float32)
vec_b = np.random.rand(1536).astype(np.float32)
sim = CompositionSimilarity.cosine_similarity(vec_a, vec_b)
assert 0.0 <= sim <= 1.0
def test_classify_duplicate(self):
"""测试重复判定。"""
assert CompositionSimilarity.classify_similarity(0.96) == SimilarityDecision.DUPLICATE
assert CompositionSimilarity.classify_similarity(0.95) == SimilarityDecision.DUPLICATE
def test_classify_suspected(self):
"""测试疑似判定。"""
assert CompositionSimilarity.classify_similarity(0.94) == SimilarityDecision.SUSPECTED
assert CompositionSimilarity.classify_similarity(0.85) == SimilarityDecision.SUSPECTED
def test_classify_new(self):
"""测试非重复判定。"""
assert CompositionSimilarity.classify_similarity(0.84) == SimilarityDecision.NEW
assert CompositionSimilarity.classify_similarity(0.5) == SimilarityDecision.NEW
def test_compare_method(self):
"""测试 compare 方法同时返回相似度和判定。"""
vec = np.random.rand(1536).astype(np.float32)
sim, decision = CompositionSimilarity.compare(vec, vec)
assert abs(sim - 1.0) < 1e-6
assert decision == SimilarityDecision.DUPLICATE
class TestThresholds:
"""阈值常量测试。"""
def test_threshold_order(self):
"""测试阈值顺序正确。"""
assert DUPLICATE_THRESHOLD > SUSPECTED_THRESHOLD
def test_threshold_values(self):
"""测试阈值符合设计值。"""
assert DUPLICATE_THRESHOLD == 0.95
assert SUSPECTED_THRESHOLD == 0.85