Commit 8413944a 8413944ad675bb85c114f5f012b4257a140fef8e by 沈秋雨

添加曲结构去重

1 parent cdfa3a58
1 from .service import CompositionCandidate, CompositionConfig, CompositionDedupService
2 from .dejavu_fingerprinter import fingerprint_audio
3
4 __all__ = [
5 "CompositionCandidate",
6 "CompositionConfig",
7 "CompositionDedupService",
8 "fingerprint_audio",
9 ]
1 """Dejavu 风格的音频指纹生成。
2
3 基于 worldveil/dejavu 的指纹算法提取实现,不依赖 Dejavu 的数据库层。
4 使用 scipy.signal.spectrogram 替代已废弃的 matplotlib.mlab.specgram。
5
6 流程:
7 1. 音频标准化:ffmpeg 转 44100Hz / Mono / WAV
8 2. librosa 加载音频
9 3. 短时傅里叶变换(STFT)→ 对数频谱图
10 4. 2D 峰值检测:在频谱图中找局部极大值
11 5. 指纹哈希:对峰值对 (freq1, freq2, time_delta) 做 SHA1,取前 20 位
12 """
13
14 import hashlib
15 import logging
16 import os
17 import subprocess
18 import tempfile
19 from operator import itemgetter
20 from pathlib import Path
21
22 import librosa
23 import numpy as np
24 from scipy.ndimage import (
25 binary_erosion,
26 generate_binary_structure,
27 iterate_structure,
28 maximum_filter,
29 )
30 from scipy.signal import spectrogram
31
32 logger = logging.getLogger(__name__)
33
34
35 def _load_env_file() -> None:
36 """加载项目根目录 .env,不覆盖已存在的真实环境变量。"""
37 env_path = Path(__file__).resolve().parent.parent / ".env"
38 if not env_path.exists():
39 return
40 with env_path.open(encoding="utf-8") as file:
41 for raw_line in file:
42 line = raw_line.strip()
43 if not line or line.startswith("#") or "=" not in line:
44 continue
45 key, value = line.split("=", 1)
46 os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'"))
47
48
49 _load_env_file()
50
51 # ===== 常量(可通过环境变量覆盖)=====
52
53 DEFAULT_FS = 44100
54 DEFAULT_WINDOW_SIZE = 4096
55 DEFAULT_OVERLAP_RATIO = float(os.environ.get("COMPOSITION_DEJAVU_OVERLAP_RATIO", "0.3"))
56 DEFAULT_FAN_VALUE = int(os.environ.get("COMPOSITION_DEJAVU_FAN_VALUE", "10"))
57 DEFAULT_AMP_MIN = float(os.environ.get("COMPOSITION_DEJAVU_AMP_MIN", "20"))
58 PEAK_NEIGHBORHOOD_SIZE = 20
59 MIN_HASH_TIME_DELTA = 0
60 MAX_HASH_TIME_DELTA = 200
61 PEAK_SORT = True
62 FINGERPRINT_REDUCTION = 20
63 MAX_DURATION_SEC = float(os.environ.get("COMPOSITION_DEJAVU_MAX_DURATION", "120")) # 0=不限制
64
65
66 def _normalize_audio(audio_path: str, max_duration: float = MAX_DURATION_SEC) -> tuple[np.ndarray, int]:
67 """将音频标准化为单声道 WAV 并加载为 numpy 数组。
68
69 使用 ffmpeg 先做重采样,再用 librosa 读取。
70 可选限制音频长度,超长音频只取前 N 秒。
71 """
72 with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
73 tmp_wav = tmp.name
74
75 try:
76 cmd = [
77 "ffmpeg",
78 "-y",
79 "-i", audio_path,
80 "-ar", str(DEFAULT_FS),
81 "-ac", "1",
82 "-f", "wav",
83 ]
84 if max_duration > 0:
85 cmd += ["-t", str(max_duration)]
86 cmd.append(tmp_wav)
87
88 result = subprocess.run(cmd, capture_output=True, text=True)
89 if result.returncode != 0:
90 raise RuntimeError(f"ffmpeg 转换失败: {result.stderr}")
91
92 y, sr = librosa.load(tmp_wav, sr=DEFAULT_FS, mono=True)
93 return y, sr
94 finally:
95 if os.path.exists(tmp_wav):
96 os.remove(tmp_wav)
97
98
99 def _specgram(samples: np.ndarray, fs: int, window_size: int, overlap_ratio: float):
100 """计算对数频谱图,替代 matplotlib.mlab.specgram。
101
102 Returns:
103 arr2D: shape (n_freq, n_time) 的对数频谱矩阵(dBFS 刻度)
104 """
105 noverlap = int(window_size * overlap_ratio)
106 window = np.hanning(window_size)
107
108 freqs, times, Sxx = spectrogram(
109 samples,
110 fs=fs,
111 window=window,
112 nperseg=window_size,
113 noverlap=noverlap,
114 )
115
116 # 转为对数尺度(dBFS,0 dB 为峰值参考)
117 # scipy.signal.spectrogram 返回 PSD,mlab.specgram 返回功率,两者量纲不同
118 # 统一转为相对于峰值的 dBFS 刻度,使强信号峰值落在 20~80 dB 范围
119 arr2D = 10 * np.log10(Sxx + 1e-10)
120 arr2D = arr2D - arr2D.max() # 归一化到峰值为 0 dBFS
121 arr2D = arr2D + 80 # 偏移使典型峰值落在 20~80 dB(与 mlab.specgram 一致)
122 arr2D[arr2D < -100] = -100 # 限幅
123 return arr2D
124
125
126 def _get_2d_peaks(arr2D: np.ndarray, amp_min: float = DEFAULT_AMP_MIN):
127 """在频谱图中检测 2D 局部极大值。
128
129 Returns:
130 (frequency_idx, time_idx): 峰值的频率和时间索引列表
131 """
132 struct = generate_binary_structure(2, 1)
133 neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)
134
135 # 找局部极大值
136 local_max = maximum_filter(arr2D, footprint=neighborhood) == arr2D
137 background = arr2D == 0
138 eroded_background = binary_erosion(background, structure=neighborhood, border_value=1)
139
140 # 布尔掩码
141 detected_peaks = local_max ^ eroded_background
142
143 # 提取峰值
144 amps = arr2D[detected_peaks]
145 j, i = np.where(detected_peaks)
146
147 # 过滤低于阈值的峰值
148 peaks = list(zip(i, j, amps))
149 peaks_filtered = [x for x in peaks if x[2] > amp_min]
150
151 frequency_idx = [x[1] for x in peaks_filtered]
152 time_idx = [x[0] for x in peaks_filtered]
153
154 return frequency_idx, time_idx
155
156
157 def _generate_hashes(peaks: list[tuple[int, int]], fan_value: int = DEFAULT_FAN_VALUE):
158 """根据峰值对生成 SHA1 指纹哈希。
159
160 Args:
161 peaks: [(freq_idx, time_idx), ...] 列表
162 fan_value: 每个峰值与后续多少个峰值配对
163
164 Yields:
165 (hash_bytes, time_offset) 元组
166 """
167 if PEAK_SORT:
168 peaks.sort(key=itemgetter(1))
169
170 for i in range(len(peaks)):
171 for j in range(1, fan_value):
172 if i + j < len(peaks):
173 freq1 = peaks[i][0]
174 freq2 = peaks[i + j][0]
175 t1 = peaks[i][1]
176 t2 = peaks[i + j][1]
177 t_delta = t2 - t1
178
179 if MIN_HASH_TIME_DELTA <= t_delta <= MAX_HASH_TIME_DELTA:
180 h = hashlib.sha1(f"{freq1}|{freq2}|{t_delta}".encode())
181 yield (h.hexdigest()[:FINGERPRINT_REDUCTION].encode(), t1)
182
183
184 def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]:
185 """对音频文件生成 Dejavu 风格指纹。
186
187 Args:
188 audio_path: 音频文件路径。
189
190 Returns:
191 (file_sha1, fingerprints) 元组,
192 其中 fingerprints 是 [(hash_bytes, offset), ...] 列表。
193
194 Raises:
195 FileNotFoundError: 音频文件不存在。
196 RuntimeError: ffmpeg 转换失败。
197 """
198 if not os.path.isfile(audio_path):
199 raise FileNotFoundError(f"音频文件不存在: {audio_path}")
200
201 # 1. 标准化并加载音频(可选限制长度)
202 samples, fs = _normalize_audio(audio_path)
203
204 # 2. 计算文件 SHA1(用于标识)
205 file_sha1 = hashlib.sha1(samples.tobytes()).hexdigest()[:16]
206
207 # 3. 计算频谱图
208 arr2D = _specgram(samples, fs, DEFAULT_WINDOW_SIZE, DEFAULT_OVERLAP_RATIO)
209
210 # 4. 检测 2D 峰值
211 freq_idx, time_idx = _get_2d_peaks(arr2D)
212 peaks = list(zip(freq_idx, time_idx))
213
214 # 5. 生成指纹哈希
215 fingerprints = list(_generate_hashes(peaks))
216
217 logger.info("指纹生成完成: audio=%s, 指纹数=%d", audio_path, len(fingerprints))
218 return file_sha1, fingerprints
1 """Chromagram 特征提取。
2
3 流程:
4 1. 音频标准化:ffmpeg 转 22050Hz / Mono / WAV
5 2. librosa 加载音频
6 3. librosa.feature.chroma_cens() 提取 12×T Chromagram(CENS,对速度/音色鲁棒)
7 4. 主音对齐:将能量最大的音级滚至第 0 行,实现转调不变性
8 5. scipy.signal.resample(chroma, 128, axis=1) 时间归一化到 12×128
9 6. .flatten() 展开为 1536 维向量
10 """
11
12 import logging
13 import os
14 import subprocess
15 import tempfile
16
17 import librosa
18 import numpy as np
19 from scipy.signal import resample
20
21 logger = logging.getLogger(__name__)
22
23 # 目标采样率和时间帧数
24 TARGET_SR = 22050
25 TARGET_FRAMES = 128
26 VECTOR_DIM = 12 * TARGET_FRAMES # 1536
27
28
29 def _normalize_audio_ffmpeg(audio_path: str, output_path: str) -> None:
30 """使用 ffmpeg 将音频标准化为 22050Hz / Mono / WAV。"""
31 cmd = [
32 "ffmpeg",
33 "-y",
34 "-i", audio_path,
35 "-ar", str(TARGET_SR),
36 "-ac", "1",
37 "-f", "wav",
38 output_path,
39 ]
40 result = subprocess.run(
41 cmd,
42 capture_output=True,
43 text=True,
44 )
45 if result.returncode != 0:
46 raise RuntimeError(f"ffmpeg 转换失败: {result.stderr}")
47
48
49 def extract_chroma_feature(audio_path: str) -> np.ndarray:
50 """从音频文件提取 1536 维 Chromagram 特征向量。
51
52 Args:
53 audio_path: 音频文件路径。
54
55 Returns:
56 shape 为 (1536,) 的 numpy 数组。
57
58 Raises:
59 FileNotFoundError: 音频文件不存在。
60 RuntimeError: ffmpeg 转换失败。
61 """
62 if not os.path.isfile(audio_path):
63 raise FileNotFoundError(f"音频文件不存在: {audio_path}")
64
65 # 1. 音频标准化:ffmpeg 转 WAV
66 with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
67 tmp_wav = tmp.name
68
69 try:
70 _normalize_audio_ffmpeg(audio_path, tmp_wav)
71
72 # 2. librosa 加载音频
73 y, _sr = librosa.load(tmp_wav, sr=TARGET_SR, mono=True)
74
75 # 3. 提取 CENS Chromagram (12×T),对速度变化和音色具有更强鲁棒性
76 chroma = librosa.feature.chroma_cens(y=y, sr=TARGET_SR)
77
78 # 4. 主音对齐:将全局能量最大的音级循环滚至第 0 行,实现转调不变性
79 tonic = int(np.argmax(chroma.sum(axis=1)))
80 if tonic != 0:
81 chroma = np.roll(chroma, -tonic, axis=0)
82
83 # 5. 时间归一化到 12×128
84 if chroma.shape[1] != TARGET_FRAMES:
85 chroma = resample(chroma, TARGET_FRAMES, axis=1)
86
87 # 6. 展开为 1536 维向量
88 feature = chroma.flatten().astype(np.float32)
89
90 assert feature.shape == (VECTOR_DIM,), (
91 f"特征维度错误: 期望 {VECTOR_DIM}, 实际 {feature.shape}"
92 )
93
94 return feature
95 finally:
96 # 清理临时文件
97 if os.path.exists(tmp_wav):
98 os.remove(tmp_wav)
99
100
101 def extract_chroma_matrix(audio_path: str) -> np.ndarray:
102 """从音频文件提取 12×128 Chromagram 矩阵(未展平,供 DTW 精排使用)。
103
104 Returns:
105 shape 为 (12, 128) 的 numpy 数组,已做主音对齐。
106 """
107 feature = extract_chroma_feature(audio_path)
108 return feature.reshape(12, TARGET_FRAMES)
1 """Cosine 相似度计算与去重判定。"""
2
3 from enum import Enum
4
5 import numpy as np
6
7 DUPLICATE_THRESHOLD = 0.95
8 SUSPECTED_THRESHOLD = 0.85
9
10
11 class SimilarityDecision(Enum):
12 DUPLICATE = "duplicate"
13 SUSPECTED = "suspected"
14 NEW = "new"
15
16
17 class CompositionSimilarity:
18 @staticmethod
19 def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
20 norm_a = np.linalg.norm(a)
21 norm_b = np.linalg.norm(b)
22 if norm_a == 0.0 or norm_b == 0.0:
23 return 0.0
24 return float(np.dot(a, b) / (norm_a * norm_b))
25
26 @staticmethod
27 def classify_similarity(similarity: float) -> SimilarityDecision:
28 if similarity >= DUPLICATE_THRESHOLD:
29 return SimilarityDecision.DUPLICATE
30 if similarity >= SUSPECTED_THRESHOLD:
31 return SimilarityDecision.SUSPECTED
32 return SimilarityDecision.NEW
33
34 @staticmethod
35 def compare(a: np.ndarray, b: np.ndarray) -> tuple[float, SimilarityDecision]:
36 sim = CompositionSimilarity.cosine_similarity(a, b)
37 return sim, CompositionSimilarity.classify_similarity(sim)
1 """批量导入音频文件到 composition_feature 表。
2
3 用法:
4 python scripts/import_audio_composition.py \
5 --dsn "postgresql:///lyric_dedup" \
6 --audio-dir /Volumes/移动硬盘/composition_test \
7 --ext .wav
8
9 支持通过 --file-list 指定一个包含音频路径的文本文件(每行一个路径)。
10 """
11
12 import argparse
13 import logging
14 import sys
15 from pathlib import Path
16
17 sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
18
19 from tqdm import tqdm
20
21 from composition_dedup.service import CompositionConfig, CompositionDedupService
22
23 logger = logging.getLogger(__name__)
24
25 SUPPORTED_EXTENSIONS = {".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac", ".wma"}
26
27
28 def discover_audio_files(audio_dir: str | None, file_list: str | None, ext: str) -> list[tuple[str, str]]:
29 """发现音频文件,返回 [(song_id, audio_path), ...] 列表。
30
31 优先使用 --file-list,否则扫描 --audio-dir 目录。
32 song_id 使用文件名的数字部分或路径的哈希值。
33 """
34 results = []
35
36 if file_list:
37 with open(file_list, "r", encoding="utf-8") as f:
38 for line in f:
39 path = line.strip()
40 if not path:
41 continue
42 song_id = _extract_song_id(path)
43 results.append((song_id, path))
44 elif audio_dir:
45 audio_dir_path = Path(audio_dir)
46 for audio_file in sorted(audio_dir_path.rglob(f"*{ext}")):
47 if audio_file.is_file() and not audio_file.name.startswith("._"):
48 song_id = _extract_song_id(str(audio_file))
49 results.append((song_id, str(audio_file)))
50 else:
51 print("错误: 请指定 --audio-dir 或 --file-list")
52 sys.exit(1)
53
54 return results
55
56
57 def _extract_song_id(path: str) -> str:
58 """从路径中提取 song_id。
59 优先取文件名第一段(下划线前),若为纯数字则使用,否则用路径哈希。
60 """
61 name = Path(path).stem
62 prefix = name.split("_")[0]
63 if prefix.isdigit():
64 return prefix
65 import hashlib
66 return str(int(hashlib.md5(path.encode()).hexdigest()[:8], 16))
67
68
69 def main() -> None:
70 logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
71
72 parser = argparse.ArgumentParser(description="批量导入音频文件到 composition_feature 表")
73 parser.add_argument("--dsn", required=True, help="PostgreSQL DSN 连接串")
74 parser.add_argument("--audio-dir", help="音频文件目录")
75 parser.add_argument("--file-list", help="音频文件路径列表文件")
76 parser.add_argument("--ext", default=".wav", help="音频文件扩展名(默认 .wav)")
77 parser.add_argument("--batch-size", type=int, default=10, help="批次大小(默认 10)")
78 parser.add_argument("--clear", action="store_true", help="导入前清空 composition_feature 和 dejavu_fingerprints 表数据(保留表结构)")
79 args = parser.parse_args()
80
81 config = CompositionConfig(dsn=args.dsn)
82 service = CompositionDedupService(config=config)
83
84 if args.clear:
85 import psycopg
86 with psycopg.connect(args.dsn) as conn:
87 with conn.cursor() as cur:
88 cur.execute("TRUNCATE TABLE composition_feature, dejavu_fingerprints")
89 conn.commit()
90 logger.info("已清空 composition_feature 和 dejavu_fingerprints 表")
91
92 audio_files = discover_audio_files(args.audio_dir, args.file_list, args.ext)
93 logger.info("发现 %d 个音频文件", len(audio_files))
94
95 success_count = 0
96 fail_count = 0
97
98 for start in tqdm(range(0, len(audio_files), args.batch_size), desc="导入进度"):
99 batch = audio_files[start:start + args.batch_size]
100 for song_id, audio_path in batch:
101 try:
102 service.ingest(song_id=int(song_id), audio_path=audio_path)
103 success_count += 1
104 except Exception as e:
105 logger.error("导入失败: song_id=%s, path=%s, error=%s", song_id, audio_path, e)
106 fail_count += 1
107
108 logger.info("导入完成: 成功 %d, 失败 %d", success_count, fail_count)
109
110
111 if __name__ == "__main__":
112 main()
...@@ -40,3 +40,33 @@ on lyric_lines (line_hash); ...@@ -40,3 +40,33 @@ on lyric_lines (line_hash);
40 40
41 create index if not exists lyric_lines_lyric_id_idx 41 create index if not exists lyric_lines_lyric_id_idx
42 on lyric_lines (lyric_id); 42 on lyric_lines (lyric_id);
43
44 create extension if not exists vector;
45
46 create table if not exists composition_feature (
47 id bigserial primary key,
48 song_id bigint not null unique,
49 feature_vector vector(1536) not null,
50 created_at timestamptz not null default now()
51 );
52
53 create index if not exists composition_feature_hnsw_idx
54 on composition_feature
55 using hnsw (feature_vector vector_cosine_ops)
56 with (m = 16, ef_construction = 64);
57
58 create table if not exists dejavu_fingerprints (
59 id bigserial primary key,
60 song_id bigint not null references composition_feature(song_id) on delete cascade,
61 hash bytea not null,
62 "offset" int not null
63 );
64
65 create index if not exists idx_fingerprints_hash
66 on dejavu_fingerprints (hash);
67
68 create index if not exists idx_fingerprints_hash_song_offset
69 on dejavu_fingerprints (hash, song_id, "offset");
70
71 create index if not exists idx_fingerprints_song_id
72 on dejavu_fingerprints (song_id);
......
1 """作曲去重模块测试。
2
3 测试覆盖:
4 - Chromagram 提取
5 - 时间归一化输出维度
6 - Cosine 相似度计算
7 - 向量展开维度为 1536
8 """
9
10 import os
11 import tempfile
12 import wave
13
14 import numpy as np
15 import pytest
16 from scipy.signal import resample
17
18 from composition_dedup.extractor import extract_chroma_feature, _normalize_audio_ffmpeg
19 from composition_dedup.similarity import (
20 CompositionSimilarity,
21 SimilarityDecision,
22 DUPLICATE_THRESHOLD,
23 SUSPECTED_THRESHOLD,
24 )
25
26
27 def _generate_test_wav(duration_sec: float = 1.0, sample_rate: int = 22050, frequency: float = 440.0) -> str:
28 """生成测试用的 WAV 文件(正弦波)。
29
30 Args:
31 duration_sec: 持续时间(秒)。
32 sample_rate: 采样率。
33 frequency: 频率(Hz)。
34
35 Returns:
36 临时 WAV 文件路径。
37 """
38 t = np.linspace(0, duration_sec, int(sample_rate * duration_sec), endpoint=False)
39 audio_data = (0.5 * np.sin(2 * np.pi * frequency * t)).astype(np.float32)
40
41 tmp_path = tempfile.mktemp(suffix=".wav")
42 with wave.open(tmp_path, "wb") as wf:
43 wf.setnchannels(1)
44 wf.setsampwidth(2) # 16-bit
45 wf.setframerate(sample_rate)
46 wf.writeframes((audio_data * 32767).astype(np.int16).tobytes())
47
48 return tmp_path
49
50
51 class TestChromaExtraction:
52 """Chromagram 提取测试。"""
53
54 def test_extract_chroma_returns_1536_dim(self):
55 """测试 Chromagram 提取返回 1536 维向量。"""
56 wav_path = _generate_test_wav(duration_sec=2.0, frequency=440.0)
57 try:
58 feature = extract_chroma_feature(wav_path)
59 assert isinstance(feature, np.ndarray)
60 assert feature.shape == (1536,), f"期望 (1536,), 实际 {feature.shape}"
61 assert feature.dtype == np.float32
62 finally:
63 if os.path.exists(wav_path):
64 os.remove(wav_path)
65
66 def test_extract_chroma_file_not_found(self):
67 """测试不存在的音频文件抛出 FileNotFoundError。"""
68 with pytest.raises(FileNotFoundError):
69 extract_chroma_feature("/nonexistent/path/audio.mp3")
70
71 def test_extract_chroma_different_frequencies(self):
72 """测试不同频率的音频产生不同特征。"""
73 wav_a = _generate_test_wav(duration_sec=2.0, frequency=440.0)
74 wav_b = _generate_test_wav(duration_sec=2.0, frequency=880.0)
75 try:
76 feature_a = extract_chroma_feature(wav_a)
77 feature_b = extract_chroma_feature(wav_b)
78 # 不同频率的音频特征不应完全相同
79 assert not np.allclose(feature_a, feature_b)
80 finally:
81 for path in [wav_a, wav_b]:
82 if os.path.exists(path):
83 os.remove(path)
84
85 def test_extract_chroma_same_audio_consistent(self):
86 """测试同一音频多次提取结果一致。"""
87 wav_path = _generate_test_wav(duration_sec=1.0, frequency=440.0)
88 try:
89 feature_1 = extract_chroma_feature(wav_path)
90 feature_2 = extract_chroma_feature(wav_path)
91 np.testing.assert_array_almost_equal(feature_1, feature_2, decimal=5)
92 finally:
93 if os.path.exists(wav_path):
94 os.remove(wav_path)
95
96
97 class TestTimeNormalization:
98 """时间归一化测试。"""
99
100 def test_resample_chroma_to_128_frames(self):
101 """测试 Chromagram 时间归一化到 128 帧。"""
102 # 模拟不同长度的 Chromagram
103 for num_frames in [100, 256, 512, 1000, 2000]:
104 chroma = np.random.rand(12, num_frames).astype(np.float32)
105 if chroma.shape[1] != 128:
106 chroma = resample(chroma, 128, axis=1)
107 assert chroma.shape == (12, 128), f"帧数归一化失败: {chroma.shape}"
108
109 def test_flatten_to_1536(self):
110 """测试展平后维度为 1536。"""
111 chroma = np.random.rand(12, 128).astype(np.float32)
112 feature = chroma.flatten()
113 assert feature.shape[0] == 12 * 128 == 1536
114
115
116 class TestCosineSimilarity:
117 """Cosine 相似度计算测试。"""
118
119 def test_identical_vectors(self):
120 """测试相同向量相似度为 1。"""
121 vec = np.random.rand(1536).astype(np.float32)
122 sim = CompositionSimilarity.cosine_similarity(vec, vec)
123 assert abs(sim - 1.0) < 1e-6
124
125 def test_orthogonal_vectors(self):
126 """测试正交向量相似度接近 0。"""
127 vec_a = np.zeros(1536)
128 vec_a[0] = 1.0
129 vec_b = np.zeros(1536)
130 vec_b[1] = 1.0
131 sim = CompositionSimilarity.cosine_similarity(vec_a, vec_b)
132 assert abs(sim) < 1e-6
133
134 def test_zero_vector(self):
135 """测试零向量返回 0 相似度。"""
136 vec_a = np.random.rand(1536).astype(np.float32)
137 vec_b = np.zeros(1536)
138 sim = CompositionSimilarity.cosine_similarity(vec_a, vec_b)
139 assert sim == 0.0
140
141 def test_similarity_range(self):
142 """测试相似度值在 [0, 1] 范围内。"""
143 vec_a = np.random.rand(1536).astype(np.float32)
144 vec_b = np.random.rand(1536).astype(np.float32)
145 sim = CompositionSimilarity.cosine_similarity(vec_a, vec_b)
146 assert 0.0 <= sim <= 1.0
147
148 def test_classify_duplicate(self):
149 """测试重复判定。"""
150 assert CompositionSimilarity.classify_similarity(0.96) == SimilarityDecision.DUPLICATE
151 assert CompositionSimilarity.classify_similarity(0.95) == SimilarityDecision.DUPLICATE
152
153 def test_classify_suspected(self):
154 """测试疑似判定。"""
155 assert CompositionSimilarity.classify_similarity(0.94) == SimilarityDecision.SUSPECTED
156 assert CompositionSimilarity.classify_similarity(0.85) == SimilarityDecision.SUSPECTED
157
158 def test_classify_new(self):
159 """测试非重复判定。"""
160 assert CompositionSimilarity.classify_similarity(0.84) == SimilarityDecision.NEW
161 assert CompositionSimilarity.classify_similarity(0.5) == SimilarityDecision.NEW
162
163 def test_compare_method(self):
164 """测试 compare 方法同时返回相似度和判定。"""
165 vec = np.random.rand(1536).astype(np.float32)
166 sim, decision = CompositionSimilarity.compare(vec, vec)
167 assert abs(sim - 1.0) < 1e-6
168 assert decision == SimilarityDecision.DUPLICATE
169
170
171 class TestThresholds:
172 """阈值常量测试。"""
173
174 def test_threshold_order(self):
175 """测试阈值顺序正确。"""
176 assert DUPLICATE_THRESHOLD > SUSPECTED_THRESHOLD
177
178 def test_threshold_values(self):
179 """测试阈值符合设计值。"""
180 assert DUPLICATE_THRESHOLD == 0.95
181 assert SUSPECTED_THRESHOLD == 0.85