添加曲结构去重
Showing
10 changed files
with
695 additions
and
0 deletions
composition_dedup/__init__.py
0 → 100644
composition_dedup/dejavu_fingerprinter.py
0 → 100644
| 1 | """Dejavu 风格的音频指纹生成。 | ||
| 2 | |||
| 3 | 基于 worldveil/dejavu 的指纹算法提取实现,不依赖 Dejavu 的数据库层。 | ||
| 4 | 使用 scipy.signal.spectrogram 替代已废弃的 matplotlib.mlab.specgram。 | ||
| 5 | |||
| 6 | 流程: | ||
| 7 | 1. 音频标准化:ffmpeg 转 44100Hz / Mono / WAV | ||
| 8 | 2. librosa 加载音频 | ||
| 9 | 3. 短时傅里叶变换(STFT)→ 对数频谱图 | ||
| 10 | 4. 2D 峰值检测:在频谱图中找局部极大值 | ||
| 11 | 5. 指纹哈希:对峰值对 (freq1, freq2, time_delta) 做 SHA1,取前 20 位 | ||
| 12 | """ | ||
| 13 | |||
| 14 | import hashlib | ||
| 15 | import logging | ||
| 16 | import os | ||
| 17 | import subprocess | ||
| 18 | import tempfile | ||
| 19 | from operator import itemgetter | ||
| 20 | from pathlib import Path | ||
| 21 | |||
| 22 | import librosa | ||
| 23 | import numpy as np | ||
| 24 | from scipy.ndimage import ( | ||
| 25 | binary_erosion, | ||
| 26 | generate_binary_structure, | ||
| 27 | iterate_structure, | ||
| 28 | maximum_filter, | ||
| 29 | ) | ||
| 30 | from scipy.signal import spectrogram | ||
| 31 | |||
| 32 | logger = logging.getLogger(__name__) | ||
| 33 | |||
| 34 | |||
| 35 | def _load_env_file() -> None: | ||
| 36 | """加载项目根目录 .env,不覆盖已存在的真实环境变量。""" | ||
| 37 | env_path = Path(__file__).resolve().parent.parent / ".env" | ||
| 38 | if not env_path.exists(): | ||
| 39 | return | ||
| 40 | with env_path.open(encoding="utf-8") as file: | ||
| 41 | for raw_line in file: | ||
| 42 | line = raw_line.strip() | ||
| 43 | if not line or line.startswith("#") or "=" not in line: | ||
| 44 | continue | ||
| 45 | key, value = line.split("=", 1) | ||
| 46 | os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'")) | ||
| 47 | |||
| 48 | |||
| 49 | _load_env_file() | ||
| 50 | |||
| 51 | # ===== 常量(可通过环境变量覆盖)===== | ||
| 52 | |||
| 53 | DEFAULT_FS = 44100 | ||
| 54 | DEFAULT_WINDOW_SIZE = 4096 | ||
| 55 | DEFAULT_OVERLAP_RATIO = float(os.environ.get("COMPOSITION_DEJAVU_OVERLAP_RATIO", "0.3")) | ||
| 56 | DEFAULT_FAN_VALUE = int(os.environ.get("COMPOSITION_DEJAVU_FAN_VALUE", "10")) | ||
| 57 | DEFAULT_AMP_MIN = float(os.environ.get("COMPOSITION_DEJAVU_AMP_MIN", "20")) | ||
| 58 | PEAK_NEIGHBORHOOD_SIZE = 20 | ||
| 59 | MIN_HASH_TIME_DELTA = 0 | ||
| 60 | MAX_HASH_TIME_DELTA = 200 | ||
| 61 | PEAK_SORT = True | ||
| 62 | FINGERPRINT_REDUCTION = 20 | ||
| 63 | MAX_DURATION_SEC = float(os.environ.get("COMPOSITION_DEJAVU_MAX_DURATION", "120")) # 0=不限制 | ||
| 64 | |||
| 65 | |||
| 66 | def _normalize_audio(audio_path: str, max_duration: float = MAX_DURATION_SEC) -> tuple[np.ndarray, int]: | ||
| 67 | """将音频标准化为单声道 WAV 并加载为 numpy 数组。 | ||
| 68 | |||
| 69 | 使用 ffmpeg 先做重采样,再用 librosa 读取。 | ||
| 70 | 可选限制音频长度,超长音频只取前 N 秒。 | ||
| 71 | """ | ||
| 72 | with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | ||
| 73 | tmp_wav = tmp.name | ||
| 74 | |||
| 75 | try: | ||
| 76 | cmd = [ | ||
| 77 | "ffmpeg", | ||
| 78 | "-y", | ||
| 79 | "-i", audio_path, | ||
| 80 | "-ar", str(DEFAULT_FS), | ||
| 81 | "-ac", "1", | ||
| 82 | "-f", "wav", | ||
| 83 | ] | ||
| 84 | if max_duration > 0: | ||
| 85 | cmd += ["-t", str(max_duration)] | ||
| 86 | cmd.append(tmp_wav) | ||
| 87 | |||
| 88 | result = subprocess.run(cmd, capture_output=True, text=True) | ||
| 89 | if result.returncode != 0: | ||
| 90 | raise RuntimeError(f"ffmpeg 转换失败: {result.stderr}") | ||
| 91 | |||
| 92 | y, sr = librosa.load(tmp_wav, sr=DEFAULT_FS, mono=True) | ||
| 93 | return y, sr | ||
| 94 | finally: | ||
| 95 | if os.path.exists(tmp_wav): | ||
| 96 | os.remove(tmp_wav) | ||
| 97 | |||
| 98 | |||
| 99 | def _specgram(samples: np.ndarray, fs: int, window_size: int, overlap_ratio: float): | ||
| 100 | """计算对数频谱图,替代 matplotlib.mlab.specgram。 | ||
| 101 | |||
| 102 | Returns: | ||
| 103 | arr2D: shape (n_freq, n_time) 的对数频谱矩阵(dBFS 刻度) | ||
| 104 | """ | ||
| 105 | noverlap = int(window_size * overlap_ratio) | ||
| 106 | window = np.hanning(window_size) | ||
| 107 | |||
| 108 | freqs, times, Sxx = spectrogram( | ||
| 109 | samples, | ||
| 110 | fs=fs, | ||
| 111 | window=window, | ||
| 112 | nperseg=window_size, | ||
| 113 | noverlap=noverlap, | ||
| 114 | ) | ||
| 115 | |||
| 116 | # 转为对数尺度(dBFS,0 dB 为峰值参考) | ||
| 117 | # scipy.signal.spectrogram 返回 PSD,mlab.specgram 返回功率,两者量纲不同 | ||
| 118 | # 统一转为相对于峰值的 dBFS 刻度,使强信号峰值落在 20~80 dB 范围 | ||
| 119 | arr2D = 10 * np.log10(Sxx + 1e-10) | ||
| 120 | arr2D = arr2D - arr2D.max() # 归一化到峰值为 0 dBFS | ||
| 121 | arr2D = arr2D + 80 # 偏移使典型峰值落在 20~80 dB(与 mlab.specgram 一致) | ||
| 122 | arr2D[arr2D < -100] = -100 # 限幅 | ||
| 123 | return arr2D | ||
| 124 | |||
| 125 | |||
| 126 | def _get_2d_peaks(arr2D: np.ndarray, amp_min: float = DEFAULT_AMP_MIN): | ||
| 127 | """在频谱图中检测 2D 局部极大值。 | ||
| 128 | |||
| 129 | Returns: | ||
| 130 | (frequency_idx, time_idx): 峰值的频率和时间索引列表 | ||
| 131 | """ | ||
| 132 | struct = generate_binary_structure(2, 1) | ||
| 133 | neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE) | ||
| 134 | |||
| 135 | # 找局部极大值 | ||
| 136 | local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D | ||
| 137 | background = arr2D == 0 | ||
| 138 | eroded_background = binary_erosion(background, structure=neighborhood, border_value=1) | ||
| 139 | |||
| 140 | # 布尔掩码 | ||
| 141 | detected_peaks = local_max ^ eroded_background | ||
| 142 | |||
| 143 | # 提取峰值 | ||
| 144 | amps = arr2D[detected_peaks] | ||
| 145 | j, i = np.where(detected_peaks) | ||
| 146 | |||
| 147 | # 过滤低于阈值的峰值 | ||
| 148 | peaks = list(zip(i, j, amps)) | ||
| 149 | peaks_filtered = [x for x in peaks if x[2] > amp_min] | ||
| 150 | |||
| 151 | frequency_idx = [x[1] for x in peaks_filtered] | ||
| 152 | time_idx = [x[0] for x in peaks_filtered] | ||
| 153 | |||
| 154 | return frequency_idx, time_idx | ||
| 155 | |||
| 156 | |||
| 157 | def _generate_hashes(peaks: list[tuple[int, int]], fan_value: int = DEFAULT_FAN_VALUE): | ||
| 158 | """根据峰值对生成 SHA1 指纹哈希。 | ||
| 159 | |||
| 160 | Args: | ||
| 161 | peaks: [(freq_idx, time_idx), ...] 列表 | ||
| 162 | fan_value: 每个峰值与后续多少个峰值配对 | ||
| 163 | |||
| 164 | Yields: | ||
| 165 | (hash_bytes, time_offset) 元组 | ||
| 166 | """ | ||
| 167 | if PEAK_SORT: | ||
| 168 | peaks.sort(key=itemgetter(1)) | ||
| 169 | |||
| 170 | for i in range(len(peaks)): | ||
| 171 | for j in range(1, fan_value): | ||
| 172 | if i + j < len(peaks): | ||
| 173 | freq1 = peaks[i][0] | ||
| 174 | freq2 = peaks[i + j][0] | ||
| 175 | t1 = peaks[i][1] | ||
| 176 | t2 = peaks[i + j][1] | ||
| 177 | t_delta = t2 - t1 | ||
| 178 | |||
| 179 | if MIN_HASH_TIME_DELTA <= t_delta <= MAX_HASH_TIME_DELTA: | ||
| 180 | h = hashlib.sha1(f"{freq1}|{freq2}|{t_delta}".encode()) | ||
| 181 | yield (h.hexdigest()[:FINGERPRINT_REDUCTION].encode(), t1) | ||
| 182 | |||
| 183 | |||
| 184 | def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]: | ||
| 185 | """对音频文件生成 Dejavu 风格指纹。 | ||
| 186 | |||
| 187 | Args: | ||
| 188 | audio_path: 音频文件路径。 | ||
| 189 | |||
| 190 | Returns: | ||
| 191 | (file_sha1, fingerprints) 元组, | ||
| 192 | 其中 fingerprints 是 [(hash_bytes, offset), ...] 列表。 | ||
| 193 | |||
| 194 | Raises: | ||
| 195 | FileNotFoundError: 音频文件不存在。 | ||
| 196 | RuntimeError: ffmpeg 转换失败。 | ||
| 197 | """ | ||
| 198 | if not os.path.isfile(audio_path): | ||
| 199 | raise FileNotFoundError(f"音频文件不存在: {audio_path}") | ||
| 200 | |||
| 201 | # 1. 标准化并加载音频(可选限制长度) | ||
| 202 | samples, fs = _normalize_audio(audio_path) | ||
| 203 | |||
| 204 | # 2. 计算文件 SHA1(用于标识) | ||
| 205 | file_sha1 = hashlib.sha1(samples.tobytes()).hexdigest()[:16] | ||
| 206 | |||
| 207 | # 3. 计算频谱图 | ||
| 208 | arr2D = _specgram(samples, fs, DEFAULT_WINDOW_SIZE, DEFAULT_OVERLAP_RATIO) | ||
| 209 | |||
| 210 | # 4. 检测 2D 峰值 | ||
| 211 | freq_idx, time_idx = _get_2d_peaks(arr2D) | ||
| 212 | peaks = list(zip(freq_idx, time_idx)) | ||
| 213 | |||
| 214 | # 5. 生成指纹哈希 | ||
| 215 | fingerprints = list(_generate_hashes(peaks)) | ||
| 216 | |||
| 217 | logger.info("指纹生成完成: audio=%s, 指纹数=%d", audio_path, len(fingerprints)) | ||
| 218 | return file_sha1, fingerprints |
composition_dedup/extractor.py
0 → 100644
| 1 | """Chromagram 特征提取。 | ||
| 2 | |||
| 3 | 流程: | ||
| 4 | 1. 音频标准化:ffmpeg 转 22050Hz / Mono / WAV | ||
| 5 | 2. librosa 加载音频 | ||
| 6 | 3. librosa.feature.chroma_cens() 提取 12×T Chromagram(CENS,对速度/音色鲁棒) | ||
| 7 | 4. 主音对齐:将能量最大的音级滚至第 0 行,实现转调不变性 | ||
| 8 | 5. scipy.signal.resample(chroma, 128, axis=1) 时间归一化到 12×128 | ||
| 9 | 6. .flatten() 展开为 1536 维向量 | ||
| 10 | """ | ||
| 11 | |||
| 12 | import logging | ||
| 13 | import os | ||
| 14 | import subprocess | ||
| 15 | import tempfile | ||
| 16 | |||
| 17 | import librosa | ||
| 18 | import numpy as np | ||
| 19 | from scipy.signal import resample | ||
| 20 | |||
| 21 | logger = logging.getLogger(__name__) | ||
| 22 | |||
| 23 | # 目标采样率和时间帧数 | ||
| 24 | TARGET_SR = 22050 | ||
| 25 | TARGET_FRAMES = 128 | ||
| 26 | VECTOR_DIM = 12 * TARGET_FRAMES # 1536 | ||
| 27 | |||
| 28 | |||
| 29 | def _normalize_audio_ffmpeg(audio_path: str, output_path: str) -> None: | ||
| 30 | """使用 ffmpeg 将音频标准化为 22050Hz / Mono / WAV。""" | ||
| 31 | cmd = [ | ||
| 32 | "ffmpeg", | ||
| 33 | "-y", | ||
| 34 | "-i", audio_path, | ||
| 35 | "-ar", str(TARGET_SR), | ||
| 36 | "-ac", "1", | ||
| 37 | "-f", "wav", | ||
| 38 | output_path, | ||
| 39 | ] | ||
| 40 | result = subprocess.run( | ||
| 41 | cmd, | ||
| 42 | capture_output=True, | ||
| 43 | text=True, | ||
| 44 | ) | ||
| 45 | if result.returncode != 0: | ||
| 46 | raise RuntimeError(f"ffmpeg 转换失败: {result.stderr}") | ||
| 47 | |||
| 48 | |||
| 49 | def extract_chroma_feature(audio_path: str) -> np.ndarray: | ||
| 50 | """从音频文件提取 1536 维 Chromagram 特征向量。 | ||
| 51 | |||
| 52 | Args: | ||
| 53 | audio_path: 音频文件路径。 | ||
| 54 | |||
| 55 | Returns: | ||
| 56 | shape 为 (1536,) 的 numpy 数组。 | ||
| 57 | |||
| 58 | Raises: | ||
| 59 | FileNotFoundError: 音频文件不存在。 | ||
| 60 | RuntimeError: ffmpeg 转换失败。 | ||
| 61 | """ | ||
| 62 | if not os.path.isfile(audio_path): | ||
| 63 | raise FileNotFoundError(f"音频文件不存在: {audio_path}") | ||
| 64 | |||
| 65 | # 1. 音频标准化:ffmpeg 转 WAV | ||
| 66 | with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | ||
| 67 | tmp_wav = tmp.name | ||
| 68 | |||
| 69 | try: | ||
| 70 | _normalize_audio_ffmpeg(audio_path, tmp_wav) | ||
| 71 | |||
| 72 | # 2. librosa 加载音频 | ||
| 73 | y, _sr = librosa.load(tmp_wav, sr=TARGET_SR, mono=True) | ||
| 74 | |||
| 75 | # 3. 提取 CENS Chromagram (12×T),对速度变化和音色具有更强鲁棒性 | ||
| 76 | chroma = librosa.feature.chroma_cens(y=y, sr=TARGET_SR) | ||
| 77 | |||
| 78 | # 4. 主音对齐:将全局能量最大的音级循环滚至第 0 行,实现转调不变性 | ||
| 79 | tonic = int(np.argmax(chroma.sum(axis=1))) | ||
| 80 | if tonic != 0: | ||
| 81 | chroma = np.roll(chroma, -tonic, axis=0) | ||
| 82 | |||
| 83 | # 5. 时间归一化到 12×128 | ||
| 84 | if chroma.shape[1] != TARGET_FRAMES: | ||
| 85 | chroma = resample(chroma, TARGET_FRAMES, axis=1) | ||
| 86 | |||
| 87 | # 6. 展开为 1536 维向量 | ||
| 88 | feature = chroma.flatten().astype(np.float32) | ||
| 89 | |||
| 90 | assert feature.shape == (VECTOR_DIM,), ( | ||
| 91 | f"特征维度错误: 期望 {VECTOR_DIM}, 实际 {feature.shape}" | ||
| 92 | ) | ||
| 93 | |||
| 94 | return feature | ||
| 95 | finally: | ||
| 96 | # 清理临时文件 | ||
| 97 | if os.path.exists(tmp_wav): | ||
| 98 | os.remove(tmp_wav) | ||
| 99 | |||
| 100 | |||
| 101 | def extract_chroma_matrix(audio_path: str) -> np.ndarray: | ||
| 102 | """从音频文件提取 12×128 Chromagram 矩阵(未展平,供 DTW 精排使用)。 | ||
| 103 | |||
| 104 | Returns: | ||
| 105 | shape 为 (12, 128) 的 numpy 数组,已做主音对齐。 | ||
| 106 | """ | ||
| 107 | feature = extract_chroma_feature(audio_path) | ||
| 108 | return feature.reshape(12, TARGET_FRAMES) |
composition_dedup/service.py
0 → 100644
This diff is collapsed.
Click to expand it.
composition_dedup/similarity.py
0 → 100644
| 1 | """Cosine 相似度计算与去重判定。""" | ||
| 2 | |||
| 3 | from enum import Enum | ||
| 4 | |||
| 5 | import numpy as np | ||
| 6 | |||
| 7 | DUPLICATE_THRESHOLD = 0.95 | ||
| 8 | SUSPECTED_THRESHOLD = 0.85 | ||
| 9 | |||
| 10 | |||
| 11 | class SimilarityDecision(Enum): | ||
| 12 | DUPLICATE = "duplicate" | ||
| 13 | SUSPECTED = "suspected" | ||
| 14 | NEW = "new" | ||
| 15 | |||
| 16 | |||
| 17 | class CompositionSimilarity: | ||
| 18 | @staticmethod | ||
| 19 | def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: | ||
| 20 | norm_a = np.linalg.norm(a) | ||
| 21 | norm_b = np.linalg.norm(b) | ||
| 22 | if norm_a == 0.0 or norm_b == 0.0: | ||
| 23 | return 0.0 | ||
| 24 | return float(np.dot(a, b) / (norm_a * norm_b)) | ||
| 25 | |||
| 26 | @staticmethod | ||
| 27 | def classify_similarity(similarity: float) -> SimilarityDecision: | ||
| 28 | if similarity >= DUPLICATE_THRESHOLD: | ||
| 29 | return SimilarityDecision.DUPLICATE | ||
| 30 | if similarity >= SUSPECTED_THRESHOLD: | ||
| 31 | return SimilarityDecision.SUSPECTED | ||
| 32 | return SimilarityDecision.NEW | ||
| 33 | |||
| 34 | @staticmethod | ||
| 35 | def compare(a: np.ndarray, b: np.ndarray) -> tuple[float, SimilarityDecision]: | ||
| 36 | sim = CompositionSimilarity.cosine_similarity(a, b) | ||
| 37 | return sim, CompositionSimilarity.classify_similarity(sim) |
scripts/evaluate_composition.py
0 → 100644
This diff is collapsed.
Click to expand it.
scripts/generate_composition_testset.py
0 → 100644
This diff is collapsed.
Click to expand it.
scripts/import_audio_composition.py
0 → 100644
| 1 | """批量导入音频文件到 composition_feature 表。 | ||
| 2 | |||
| 3 | 用法: | ||
| 4 | python scripts/import_audio_composition.py \ | ||
| 5 | --dsn "postgresql:///lyric_dedup" \ | ||
| 6 | --audio-dir /Volumes/移动硬盘/composition_test \ | ||
| 7 | --ext .wav | ||
| 8 | |||
| 9 | 支持通过 --file-list 指定一个包含音频路径的文本文件(每行一个路径)。 | ||
| 10 | """ | ||
| 11 | |||
| 12 | import argparse | ||
| 13 | import logging | ||
| 14 | import sys | ||
| 15 | from pathlib import Path | ||
| 16 | |||
| 17 | sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) | ||
| 18 | |||
| 19 | from tqdm import tqdm | ||
| 20 | |||
| 21 | from composition_dedup.service import CompositionConfig, CompositionDedupService | ||
| 22 | |||
| 23 | logger = logging.getLogger(__name__) | ||
| 24 | |||
| 25 | SUPPORTED_EXTENSIONS = {".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac", ".wma"} | ||
| 26 | |||
| 27 | |||
| 28 | def discover_audio_files(audio_dir: str | None, file_list: str | None, ext: str) -> list[tuple[str, str]]: | ||
| 29 | """发现音频文件,返回 [(song_id, audio_path), ...] 列表。 | ||
| 30 | |||
| 31 | 优先使用 --file-list,否则扫描 --audio-dir 目录。 | ||
| 32 | song_id 使用文件名的数字部分或路径的哈希值。 | ||
| 33 | """ | ||
| 34 | results = [] | ||
| 35 | |||
| 36 | if file_list: | ||
| 37 | with open(file_list, "r", encoding="utf-8") as f: | ||
| 38 | for line in f: | ||
| 39 | path = line.strip() | ||
| 40 | if not path: | ||
| 41 | continue | ||
| 42 | song_id = _extract_song_id(path) | ||
| 43 | results.append((song_id, path)) | ||
| 44 | elif audio_dir: | ||
| 45 | audio_dir_path = Path(audio_dir) | ||
| 46 | for audio_file in sorted(audio_dir_path.rglob(f"*{ext}")): | ||
| 47 | if audio_file.is_file() and not audio_file.name.startswith("._"): | ||
| 48 | song_id = _extract_song_id(str(audio_file)) | ||
| 49 | results.append((song_id, str(audio_file))) | ||
| 50 | else: | ||
| 51 | print("错误: 请指定 --audio-dir 或 --file-list") | ||
| 52 | sys.exit(1) | ||
| 53 | |||
| 54 | return results | ||
| 55 | |||
| 56 | |||
| 57 | def _extract_song_id(path: str) -> str: | ||
| 58 | """从路径中提取 song_id。 | ||
| 59 | 优先取文件名第一段(下划线前),若为纯数字则使用,否则用路径哈希。 | ||
| 60 | """ | ||
| 61 | name = Path(path).stem | ||
| 62 | prefix = name.split("_")[0] | ||
| 63 | if prefix.isdigit(): | ||
| 64 | return prefix | ||
| 65 | import hashlib | ||
| 66 | return str(int(hashlib.md5(path.encode()).hexdigest()[:8], 16)) | ||
| 67 | |||
| 68 | |||
| 69 | def main() -> None: | ||
| 70 | logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | ||
| 71 | |||
| 72 | parser = argparse.ArgumentParser(description="批量导入音频文件到 composition_feature 表") | ||
| 73 | parser.add_argument("--dsn", required=True, help="PostgreSQL DSN 连接串") | ||
| 74 | parser.add_argument("--audio-dir", help="音频文件目录") | ||
| 75 | parser.add_argument("--file-list", help="音频文件路径列表文件") | ||
| 76 | parser.add_argument("--ext", default=".wav", help="音频文件扩展名(默认 .wav)") | ||
| 77 | parser.add_argument("--batch-size", type=int, default=10, help="批次大小(默认 10)") | ||
| 78 | parser.add_argument("--clear", action="store_true", help="导入前清空 composition_feature 和 dejavu_fingerprints 表数据(保留表结构)") | ||
| 79 | args = parser.parse_args() | ||
| 80 | |||
| 81 | config = CompositionConfig(dsn=args.dsn) | ||
| 82 | service = CompositionDedupService(config=config) | ||
| 83 | |||
| 84 | if args.clear: | ||
| 85 | import psycopg | ||
| 86 | with psycopg.connect(args.dsn) as conn: | ||
| 87 | with conn.cursor() as cur: | ||
| 88 | cur.execute("TRUNCATE TABLE composition_feature, dejavu_fingerprints") | ||
| 89 | conn.commit() | ||
| 90 | logger.info("已清空 composition_feature 和 dejavu_fingerprints 表") | ||
| 91 | |||
| 92 | audio_files = discover_audio_files(args.audio_dir, args.file_list, args.ext) | ||
| 93 | logger.info("发现 %d 个音频文件", len(audio_files)) | ||
| 94 | |||
| 95 | success_count = 0 | ||
| 96 | fail_count = 0 | ||
| 97 | |||
| 98 | for start in tqdm(range(0, len(audio_files), args.batch_size), desc="导入进度"): | ||
| 99 | batch = audio_files[start:start + args.batch_size] | ||
| 100 | for song_id, audio_path in batch: | ||
| 101 | try: | ||
| 102 | service.ingest(song_id=int(song_id), audio_path=audio_path) | ||
| 103 | success_count += 1 | ||
| 104 | except Exception as e: | ||
| 105 | logger.error("导入失败: song_id=%s, path=%s, error=%s", song_id, audio_path, e) | ||
| 106 | fail_count += 1 | ||
| 107 | |||
| 108 | logger.info("导入完成: 成功 %d, 失败 %d", success_count, fail_count) | ||
| 109 | |||
| 110 | |||
| 111 | if __name__ == "__main__": | ||
| 112 | main() |
| ... | @@ -40,3 +40,33 @@ on lyric_lines (line_hash); | ... | @@ -40,3 +40,33 @@ on lyric_lines (line_hash); |
| 40 | 40 | ||
| 41 | create index if not exists lyric_lines_lyric_id_idx | 41 | create index if not exists lyric_lines_lyric_id_idx |
| 42 | on lyric_lines (lyric_id); | 42 | on lyric_lines (lyric_id); |
| 43 | |||
| 44 | create extension if not exists vector; | ||
| 45 | |||
| 46 | create table if not exists composition_feature ( | ||
| 47 | id bigserial primary key, | ||
| 48 | song_id bigint not null unique, | ||
| 49 | feature_vector vector(1536) not null, | ||
| 50 | created_at timestamptz not null default now() | ||
| 51 | ); | ||
| 52 | |||
| 53 | create index if not exists composition_feature_hnsw_idx | ||
| 54 | on composition_feature | ||
| 55 | using hnsw (feature_vector vector_cosine_ops) | ||
| 56 | with (m = 16, ef_construction = 64); | ||
| 57 | |||
| 58 | create table if not exists dejavu_fingerprints ( | ||
| 59 | id bigserial primary key, | ||
| 60 | song_id bigint not null references composition_feature(song_id) on delete cascade, | ||
| 61 | hash bytea not null, | ||
| 62 | "offset" int not null | ||
| 63 | ); | ||
| 64 | |||
| 65 | create index if not exists idx_fingerprints_hash | ||
| 66 | on dejavu_fingerprints (hash); | ||
| 67 | |||
| 68 | create index if not exists idx_fingerprints_hash_song_offset | ||
| 69 | on dejavu_fingerprints (hash, song_id, "offset"); | ||
| 70 | |||
| 71 | create index if not exists idx_fingerprints_song_id | ||
| 72 | on dejavu_fingerprints (song_id); | ... | ... |
tests/test_composition_dedup.py
0 → 100644
| 1 | """作曲去重模块测试。 | ||
| 2 | |||
| 3 | 测试覆盖: | ||
| 4 | - Chromagram 提取 | ||
| 5 | - 时间归一化输出维度 | ||
| 6 | - Cosine 相似度计算 | ||
| 7 | - 向量展开维度为 1536 | ||
| 8 | """ | ||
| 9 | |||
| 10 | import os | ||
| 11 | import tempfile | ||
| 12 | import wave | ||
| 13 | |||
| 14 | import numpy as np | ||
| 15 | import pytest | ||
| 16 | from scipy.signal import resample | ||
| 17 | |||
| 18 | from composition_dedup.extractor import extract_chroma_feature, _normalize_audio_ffmpeg | ||
| 19 | from composition_dedup.similarity import ( | ||
| 20 | CompositionSimilarity, | ||
| 21 | SimilarityDecision, | ||
| 22 | DUPLICATE_THRESHOLD, | ||
| 23 | SUSPECTED_THRESHOLD, | ||
| 24 | ) | ||
| 25 | |||
| 26 | |||
| 27 | def _generate_test_wav(duration_sec: float = 1.0, sample_rate: int = 22050, frequency: float = 440.0) -> str: | ||
| 28 | """生成测试用的 WAV 文件(正弦波)。 | ||
| 29 | |||
| 30 | Args: | ||
| 31 | duration_sec: 持续时间(秒)。 | ||
| 32 | sample_rate: 采样率。 | ||
| 33 | frequency: 频率(Hz)。 | ||
| 34 | |||
| 35 | Returns: | ||
| 36 | 临时 WAV 文件路径。 | ||
| 37 | """ | ||
| 38 | t = np.linspace(0, duration_sec, int(sample_rate * duration_sec), endpoint=False) | ||
| 39 | audio_data = (0.5 * np.sin(2 * np.pi * frequency * t)).astype(np.float32) | ||
| 40 | |||
| 41 | tmp_path = tempfile.mktemp(suffix=".wav") | ||
| 42 | with wave.open(tmp_path, "wb") as wf: | ||
| 43 | wf.setnchannels(1) | ||
| 44 | wf.setsampwidth(2) # 16-bit | ||
| 45 | wf.setframerate(sample_rate) | ||
| 46 | wf.writeframes((audio_data * 32767).astype(np.int16).tobytes()) | ||
| 47 | |||
| 48 | return tmp_path | ||
| 49 | |||
| 50 | |||
| 51 | class TestChromaExtraction: | ||
| 52 | """Chromagram 提取测试。""" | ||
| 53 | |||
| 54 | def test_extract_chroma_returns_1536_dim(self): | ||
| 55 | """测试 Chromagram 提取返回 1536 维向量。""" | ||
| 56 | wav_path = _generate_test_wav(duration_sec=2.0, frequency=440.0) | ||
| 57 | try: | ||
| 58 | feature = extract_chroma_feature(wav_path) | ||
| 59 | assert isinstance(feature, np.ndarray) | ||
| 60 | assert feature.shape == (1536,), f"期望 (1536,), 实际 {feature.shape}" | ||
| 61 | assert feature.dtype == np.float32 | ||
| 62 | finally: | ||
| 63 | if os.path.exists(wav_path): | ||
| 64 | os.remove(wav_path) | ||
| 65 | |||
| 66 | def test_extract_chroma_file_not_found(self): | ||
| 67 | """测试不存在的音频文件抛出 FileNotFoundError。""" | ||
| 68 | with pytest.raises(FileNotFoundError): | ||
| 69 | extract_chroma_feature("/nonexistent/path/audio.mp3") | ||
| 70 | |||
| 71 | def test_extract_chroma_different_frequencies(self): | ||
| 72 | """测试不同频率的音频产生不同特征。""" | ||
| 73 | wav_a = _generate_test_wav(duration_sec=2.0, frequency=440.0) | ||
| 74 | wav_b = _generate_test_wav(duration_sec=2.0, frequency=880.0) | ||
| 75 | try: | ||
| 76 | feature_a = extract_chroma_feature(wav_a) | ||
| 77 | feature_b = extract_chroma_feature(wav_b) | ||
| 78 | # 不同频率的音频特征不应完全相同 | ||
| 79 | assert not np.allclose(feature_a, feature_b) | ||
| 80 | finally: | ||
| 81 | for path in [wav_a, wav_b]: | ||
| 82 | if os.path.exists(path): | ||
| 83 | os.remove(path) | ||
| 84 | |||
| 85 | def test_extract_chroma_same_audio_consistent(self): | ||
| 86 | """测试同一音频多次提取结果一致。""" | ||
| 87 | wav_path = _generate_test_wav(duration_sec=1.0, frequency=440.0) | ||
| 88 | try: | ||
| 89 | feature_1 = extract_chroma_feature(wav_path) | ||
| 90 | feature_2 = extract_chroma_feature(wav_path) | ||
| 91 | np.testing.assert_array_almost_equal(feature_1, feature_2, decimal=5) | ||
| 92 | finally: | ||
| 93 | if os.path.exists(wav_path): | ||
| 94 | os.remove(wav_path) | ||
| 95 | |||
| 96 | |||
| 97 | class TestTimeNormalization: | ||
| 98 | """时间归一化测试。""" | ||
| 99 | |||
| 100 | def test_resample_chroma_to_128_frames(self): | ||
| 101 | """测试 Chromagram 时间归一化到 128 帧。""" | ||
| 102 | # 模拟不同长度的 Chromagram | ||
| 103 | for num_frames in [100, 256, 512, 1000, 2000]: | ||
| 104 | chroma = np.random.rand(12, num_frames).astype(np.float32) | ||
| 105 | if chroma.shape[1] != 128: | ||
| 106 | chroma = resample(chroma, 128, axis=1) | ||
| 107 | assert chroma.shape == (12, 128), f"帧数归一化失败: {chroma.shape}" | ||
| 108 | |||
| 109 | def test_flatten_to_1536(self): | ||
| 110 | """测试展平后维度为 1536。""" | ||
| 111 | chroma = np.random.rand(12, 128).astype(np.float32) | ||
| 112 | feature = chroma.flatten() | ||
| 113 | assert feature.shape[0] == 12 * 128 == 1536 | ||
| 114 | |||
| 115 | |||
| 116 | class TestCosineSimilarity: | ||
| 117 | """Cosine 相似度计算测试。""" | ||
| 118 | |||
| 119 | def test_identical_vectors(self): | ||
| 120 | """测试相同向量相似度为 1。""" | ||
| 121 | vec = np.random.rand(1536).astype(np.float32) | ||
| 122 | sim = CompositionSimilarity.cosine_similarity(vec, vec) | ||
| 123 | assert abs(sim - 1.0) < 1e-6 | ||
| 124 | |||
| 125 | def test_orthogonal_vectors(self): | ||
| 126 | """测试正交向量相似度接近 0。""" | ||
| 127 | vec_a = np.zeros(1536) | ||
| 128 | vec_a[0] = 1.0 | ||
| 129 | vec_b = np.zeros(1536) | ||
| 130 | vec_b[1] = 1.0 | ||
| 131 | sim = CompositionSimilarity.cosine_similarity(vec_a, vec_b) | ||
| 132 | assert abs(sim) < 1e-6 | ||
| 133 | |||
| 134 | def test_zero_vector(self): | ||
| 135 | """测试零向量返回 0 相似度。""" | ||
| 136 | vec_a = np.random.rand(1536).astype(np.float32) | ||
| 137 | vec_b = np.zeros(1536) | ||
| 138 | sim = CompositionSimilarity.cosine_similarity(vec_a, vec_b) | ||
| 139 | assert sim == 0.0 | ||
| 140 | |||
| 141 | def test_similarity_range(self): | ||
| 142 | """测试相似度值在 [0, 1] 范围内。""" | ||
| 143 | vec_a = np.random.rand(1536).astype(np.float32) | ||
| 144 | vec_b = np.random.rand(1536).astype(np.float32) | ||
| 145 | sim = CompositionSimilarity.cosine_similarity(vec_a, vec_b) | ||
| 146 | assert 0.0 <= sim <= 1.0 | ||
| 147 | |||
| 148 | def test_classify_duplicate(self): | ||
| 149 | """测试重复判定。""" | ||
| 150 | assert CompositionSimilarity.classify_similarity(0.96) == SimilarityDecision.DUPLICATE | ||
| 151 | assert CompositionSimilarity.classify_similarity(0.95) == SimilarityDecision.DUPLICATE | ||
| 152 | |||
| 153 | def test_classify_suspected(self): | ||
| 154 | """测试疑似判定。""" | ||
| 155 | assert CompositionSimilarity.classify_similarity(0.94) == SimilarityDecision.SUSPECTED | ||
| 156 | assert CompositionSimilarity.classify_similarity(0.85) == SimilarityDecision.SUSPECTED | ||
| 157 | |||
| 158 | def test_classify_new(self): | ||
| 159 | """测试非重复判定。""" | ||
| 160 | assert CompositionSimilarity.classify_similarity(0.84) == SimilarityDecision.NEW | ||
| 161 | assert CompositionSimilarity.classify_similarity(0.5) == SimilarityDecision.NEW | ||
| 162 | |||
| 163 | def test_compare_method(self): | ||
| 164 | """测试 compare 方法同时返回相似度和判定。""" | ||
| 165 | vec = np.random.rand(1536).astype(np.float32) | ||
| 166 | sim, decision = CompositionSimilarity.compare(vec, vec) | ||
| 167 | assert abs(sim - 1.0) < 1e-6 | ||
| 168 | assert decision == SimilarityDecision.DUPLICATE | ||
| 169 | |||
| 170 | |||
| 171 | class TestThresholds: | ||
| 172 | """阈值常量测试。""" | ||
| 173 | |||
| 174 | def test_threshold_order(self): | ||
| 175 | """测试阈值顺序正确。""" | ||
| 176 | assert DUPLICATE_THRESHOLD > SUSPECTED_THRESHOLD | ||
| 177 | |||
| 178 | def test_threshold_values(self): | ||
| 179 | """测试阈值符合设计值。""" | ||
| 180 | assert DUPLICATE_THRESHOLD == 0.95 | ||
| 181 | assert SUSPECTED_THRESHOLD == 0.85 |
-
Please register or sign in to post a comment