service.py
13.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
"""作曲去重服务(入库 + 查询)。
查询流程:
1. Dejavu 指纹匹配(毫秒级,子序列匹配,支持 chorus_only / trim_intro)
- 命中(≥ 阈值)→ 直接返回 duplicate(短路)
2. 未命中 → Chromagram 12路 + DTW(百毫秒级)
- 返回结果
"""
import logging
import os
from dataclasses import dataclass, field
import numpy as np
import psycopg
from scipy.spatial.distance import cdist
from .extractor import TARGET_FRAMES, extract_chroma_feature, extract_chroma_matrix
from .dejavu_fingerprinter import fingerprint_audio
logger = logging.getLogger(__name__)
def _env_bool(name: str, default: bool) -> bool:
value = os.getenv(name)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "y", "on"}
def _env_int(name: str, default: int) -> int:
value = os.getenv(name)
if value is None:
return default
try:
return int(value)
except ValueError:
logger.warning("环境变量 %s=%r 不是整数,使用默认值 %d", name, value, default)
return default
def _env_float(name: str, default: float) -> float:
value = os.getenv(name)
if value is None:
return default
try:
return float(value)
except ValueError:
logger.warning("环境变量 %s=%r 不是数字,使用默认值 %.4f", name, value, default)
return default
@dataclass
class CompositionCandidate:
"""去重候选结果。"""
song_id: int
similarity: float
source: str = "chromagram"
@dataclass
class _DejavuMatch:
"""Dejavu offset 对齐后的命中结果。"""
song_id: int
aligned_count: int
total_collisions: int
@dataclass
class CompositionConfig:
"""作曲去重服务配置。"""
dsn: str = "postgresql:///lyric_dedup"
statement_timeout_ms: int = 30000
dtw_rerank_top_k: int = 20 # Cosine 召回后做 DTW 精排的候选数量
duplicate_threshold: float = _env_float("COMPOSITION_DUPLICATE_THRESHOLD", 0.85)
# Dejavu 指纹匹配配置
dejavu_enabled: bool = _env_bool("COMPOSITION_DEJAVU_ENABLED", True)
dejavu_match_threshold: int = _env_int("COMPOSITION_DEJAVU_MATCH_THRESHOLD", 20)
@dataclass
class CompositionDedupService:
"""作曲去重服务:特征入库 + 相似度查询。"""
config: CompositionConfig
_logger: logging.Logger = field(default_factory=lambda: logger, repr=False)
def ingest(self, song_id: int, audio_path: str) -> np.ndarray:
"""提取音频特征并写入数据库。
Args:
song_id: 歌曲 ID。
audio_path: 音频文件路径。
Returns:
提取的特征向量。
"""
feature = extract_chroma_feature(audio_path)
self._logger.info("提取 Chromagram 特征完成: song_id=%s, audio=%s", song_id, audio_path)
with psycopg.connect(self.config.dsn) as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
INSERT INTO composition_feature (song_id, feature_vector)
VALUES (%s, %s)
ON CONFLICT DO NOTHING
""",
(song_id, feature.tolist()),
)
conn.commit()
self._logger.info("Chromagram 特征入库完成: song_id=%s", song_id)
# Dejavu 指纹同时入库
if self.config.dejavu_enabled:
self._dejavu_ingest(song_id, audio_path)
return feature
def _dejavu_ingest(self, song_id: int, audio_path: str) -> None:
"""提取 Dejavu 指纹并写入 dejavu_fingerprints 表。"""
file_sha1, fingerprints = fingerprint_audio(audio_path)
if not fingerprints:
self._logger.warning("Dejavu 指纹为空: song_id=%s, audio=%s", song_id, audio_path)
return
with psycopg.connect(self.config.dsn) as conn:
with conn.cursor() as cursor:
# 先清理可能残留的旧指纹(幂等写入)
cursor.execute(
"DELETE FROM dejavu_fingerprints WHERE song_id = %s",
(song_id,),
)
# 批量写入
records = [(song_id, h, o) for h, o in fingerprints]
cursor.executemany(
"""
INSERT INTO dejavu_fingerprints (song_id, hash, "offset")
VALUES (%s, %s, %s)
""",
records,
)
conn.commit()
self._logger.info("Dejavu 指纹入库完成: song_id=%s, 指纹数=%d", song_id, len(fingerprints))
def query(self, audio_path: str, top_k: int = 100) -> list[CompositionCandidate]:
"""提取音频特征并查询相似结果。
流程:Dejavu 指纹短路匹配 → 12 路循环对齐 Cosine 召回 → DTW 精排。
"""
# 1. 优先尝试 Dejavu 指纹匹配(短路)
if self.config.dejavu_enabled:
match = self._dejavu_query(audio_path)
if match is not None:
self._logger.info(
"Dejavu 命中: song_id=%s, aligned_count=%d, total_collisions=%d, decision=duplicate",
match.song_id,
match.aligned_count,
match.total_collisions,
)
return [CompositionCandidate(song_id=match.song_id, similarity=1.0, source="dejavu")]
# 2. Dejavu 未命中或禁用,走现有 Chromagram 12路 + DTW 流程
return self._query_chroma(audio_path, top_k)
def check(self, audio_path: str, top_k: int = 100) -> bool:
"""按最终接口语义返回是否重复。"""
return self.candidates_indicate_duplicate(self.query(audio_path, top_k=top_k))
def candidates_indicate_duplicate(self, candidates: list[CompositionCandidate]) -> bool:
"""将候选结果转换为最终 duplicate bool。
最终接口只返回 true/false,因此判定只看当前查询的最佳候选是否超过阈值,
不依赖评测集里的 expected_song_id 是否出现在 top-k。
"""
if not candidates:
return False
return candidates[0].similarity >= self.config.duplicate_threshold
def _query_chroma(self, audio_path: str, top_k: int = 100) -> list[CompositionCandidate]:
"""Chromagram 12 路循环对齐 + DTW 精排查询。"""
chroma = extract_chroma_matrix(audio_path)
self._logger.info("提取 Chromagram 查询特征完成: audio=%s", audio_path)
# 1. 12 路循环对齐:穷举 12 种半音偏移,单条 SQL 内部展开,按 song_id 取最高 Cosine 相似度
shift_vecs = [
np.roll(chroma, -shift, axis=0).flatten().astype(np.float32).tolist()
for shift in range(12)
]
# 用 VALUES 展开 12 个偏移向量,LATERAL 子查询对每个偏移各触发一次 HNSW 扫描
values_clause = ", ".join(f"({i}, %s::vector)" for i in range(12))
sql = f"""
WITH shifts(shift_id, vec) AS (
VALUES {values_clause}
),
candidates AS (
SELECT
cf.song_id,
1 - (cf.feature_vector <=> s.vec) AS sim
FROM shifts s
CROSS JOIN LATERAL (
SELECT song_id, feature_vector
FROM composition_feature
ORDER BY feature_vector <=> s.vec
LIMIT %s
) cf
)
SELECT song_id, MAX(sim) AS similarity
FROM candidates
GROUP BY song_id
ORDER BY similarity DESC
LIMIT %s
"""
best: dict[int, float] = {}
with psycopg.connect(self.config.dsn) as conn:
with conn.cursor() as cursor:
cursor.execute(
f"SET statement_timeout = {int(self.config.statement_timeout_ms)}"
)
cursor.execute(sql, (*shift_vecs, top_k, top_k))
for song_id, sim in cursor.fetchall():
best[int(song_id)] = float(sim)
# 2. 取 Top dtw_rerank_top_k,从库中取原始向量做 DTW 精排
top = sorted(best.items(), key=lambda x: x[1], reverse=True)
rerank_ids = [sid for sid, _ in top[:self.config.dtw_rerank_top_k]]
with conn.cursor() as cursor:
cursor.execute(
"SELECT song_id, feature_vector::float4[] FROM composition_feature WHERE song_id = ANY(%s)",
(rerank_ids,),
)
db_rows = cursor.fetchall()
reranked = []
for song_id, fv in db_rows:
cand_chroma = np.array(fv, dtype=np.float32).reshape(12, TARGET_FRAMES)
dtw_sim = _best_shifted_dtw_similarity(chroma, cand_chroma)
reranked.append(CompositionCandidate(song_id=int(song_id), similarity=dtw_sim))
reranked.sort(key=lambda c: c.similarity, reverse=True)
rerank_id_set = {c.song_id for c in reranked}
rest = [
CompositionCandidate(song_id=sid, similarity=sim)
for sid, sim in top[self.config.dtw_rerank_top_k:]
if sid not in rerank_id_set
]
result = reranked + rest
top_summary = ", ".join(
f"{candidate.song_id}:{candidate.similarity:.4f}"
for candidate in result[:5]
)
self._logger.info(
"Chromagram 查询完成: 返回 %d 个候选, top=%s",
len(result),
top_summary or "[]",
)
return result
def _dejavu_query(self, audio_path: str) -> _DejavuMatch | None:
"""Dejavu 指纹查询,返回 offset 对齐后碰撞数最多的 song_id。
只统计 hash 总碰撞数会让常见频谱峰值、噪声片段或大库随机碰撞直接短路成
similarity=1.0。Dejavu 的关键判据是同一首候选歌里,多个 hash 碰撞的
db_offset - query_offset 落在同一个时间偏移上。
Returns:
命中结果,未命中返回 None。
"""
file_sha1, fingerprints = fingerprint_audio(audio_path)
if not fingerprints:
return None
hashes = [h for h, _ in fingerprints]
offsets = [int(o) for _, o in fingerprints]
with psycopg.connect(self.config.dsn) as conn:
with conn.cursor() as cursor:
# 先按 hash 找碰撞,再按每个 song_id 的 offset delta 聚类。
cursor.execute(
"""
WITH query_fp(hash, query_offset) AS (
SELECT *
FROM unnest(%s::bytea[], %s::int[])
),
aligned AS (
SELECT
db.song_id,
db."offset" - query_fp.query_offset AS offset_delta,
COUNT(*) AS aligned_count
FROM query_fp
JOIN dejavu_fingerprints db
ON db.hash = query_fp.hash
GROUP BY db.song_id, offset_delta
)
SELECT
song_id,
MAX(aligned_count) AS best_aligned_count,
SUM(aligned_count) AS total_collisions
FROM aligned
GROUP BY song_id
ORDER BY best_aligned_count DESC, total_collisions DESC
LIMIT 1
""",
(hashes, offsets),
)
row = cursor.fetchone()
if row is None:
return None
sid, aligned_count, total_collisions = row
aligned_count = int(aligned_count)
if aligned_count >= self.config.dejavu_match_threshold:
return _DejavuMatch(
song_id=int(sid),
aligned_count=aligned_count,
total_collisions=int(total_collisions),
)
return None
def _dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float:
"""计算两个 12×T Chromagram 矩阵之间的 DTW 相似度(映射到 [0,1])。"""
# 帧间欧氏距离矩阵
cost = cdist(query.T, candidate.T, metric="euclidean")
n, m = cost.shape
dp = np.full((n, m), np.inf)
dp[0, 0] = cost[0, 0]
for i in range(1, n):
dp[i, 0] = dp[i - 1, 0] + cost[i, 0]
for j in range(1, m):
dp[0, j] = dp[0, j - 1] + cost[0, j]
for i in range(1, n):
for j in range(1, m):
dp[i, j] = cost[i, j] + min(dp[i - 1, j], dp[i, j - 1], dp[i - 1, j - 1])
dtw_dist = dp[n - 1, m - 1] / (n + m)
# 转换为相似度:距离越小相似度越高
return float(1.0 / (1.0 + dtw_dist))
def _best_shifted_dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float:
"""计算 12 路音高循环位移下的最佳 DTW 相似度。"""
return max(
_dtw_similarity(np.roll(query, -shift, axis=0), candidate)
for shift in range(12)
)