service.py 18 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
"""作曲去重服务(入库 + 查询)。

查询流程:
1. Chromagram 12路 + DTW 精排(覆盖绝大多数场景,约 1s)
   - 命中(≥ 阈值)→ 直接返回结果
2. Chromagram 未命中 → Dejavu 指纹兜底(处理 chorus_only / trim_intro 等片段场景)
   - 命中(≥ 阈值)→ 返回 duplicate(similarity=1.0)
"""

import logging
import os
import time
from dataclasses import dataclass, field

import numba
import numpy as np
import psycopg
from scipy.spatial.distance import cdist

from .extractor import (
    TARGET_FRAMES,
    extract_chroma_feature_from_samples,
    extract_chroma_matrix,
)
from .dejavu_fingerprinter import fingerprint_audio, fingerprint_from_samples, load_audio, QUERY_MAX_DURATION_SEC

logger = logging.getLogger(__name__)


def _env_bool(name: str, default: bool) -> bool:
    value = os.getenv(name)
    if value is None:
        return default
    return value.strip().lower() in {"1", "true", "yes", "y", "on"}


def _env_int(name: str, default: int) -> int:
    value = os.getenv(name)
    if value is None:
        return default
    try:
        return int(value)
    except ValueError:
        logger.warning("环境变量 %s=%r 不是整数,使用默认值 %d", name, value, default)
        return default


def _env_float(name: str, default: float) -> float:
    value = os.getenv(name)
    if value is None:
        return default
    try:
        return float(value)
    except ValueError:
        logger.warning("环境变量 %s=%r 不是数字,使用默认值 %.4f", name, value, default)
        return default


@dataclass
class CompositionCandidate:
    """去重候选结果。"""
    song_id: int
    similarity: float
    source: str = "chromagram"
    dejavu_aligned_count: int | None = None  # 仅 source=dejavu 或 dejavu fallback 未命中时记录


@dataclass
class _DejavuMatch:
    """Dejavu offset 对齐后的命中结果。"""
    song_id: int
    aligned_count: int
    total_collisions: int


@dataclass
class CompositionConfig:
    """作曲去重服务配置。"""
    dsn: str = "postgresql:///lyric_dedup"
    statement_timeout_ms: int = 30000
    dtw_rerank_top_k: int = 20  # Cosine 召回后做 DTW 精排的候选数量
    duplicate_threshold: float = _env_float("COMPOSITION_DUPLICATE_THRESHOLD", 0.85)
    # Chromagram 提取配置
    chroma_hop_length: int = _env_int("COMPOSITION_CHROMA_HOP_LENGTH", 512)
    chroma_win_len_smooth: int = _env_int("COMPOSITION_CHROMA_WIN_LEN_SMOOTH", 0)
    # Dejavu 指纹匹配配置
    dejavu_enabled: bool = _env_bool("COMPOSITION_DEJAVU_ENABLED", True)
    dejavu_match_threshold: int = _env_int("COMPOSITION_DEJAVU_MATCH_THRESHOLD", 20)

    def __post_init__(self) -> None:
        # 0 表示自动:按 hop_length 等比缩小,保持平滑窗覆盖时长约 1 秒
        if self.chroma_win_len_smooth == 0:
            self.chroma_win_len_smooth = max(1, round(41 * 512 / self.chroma_hop_length))


@dataclass
class CompositionDedupService:
    """作曲去重服务:特征入库 + 相似度查询。"""
    config: CompositionConfig
    _logger: logging.Logger = field(default_factory=lambda: logger, repr=False)

    def ingest(self, song_id: int, audio_path: str) -> np.ndarray:
        """提取音频特征并写入数据库。

        Args:
            song_id: 歌曲 ID。
            audio_path: 音频文件路径。

        Returns:
            提取的特征向量。
        """
        # 共用一次解码(44100Hz),chromagram 路径在内存中重采样,无需二次 ffmpeg
        samples, sr = load_audio(audio_path)
        self._logger.info("音频解码完成: song_id=%s, audio=%s", song_id, audio_path)

        feature = extract_chroma_feature_from_samples(samples, sr, hop_length=self.config.chroma_hop_length, win_len_smooth=self.config.chroma_win_len_smooth)
        self._logger.info("提取 Chromagram 特征完成: song_id=%s", song_id)

        with psycopg.connect(self.config.dsn) as conn:
            with conn.cursor() as cursor:
                cursor.execute(
                    """
                    INSERT INTO composition_feature (song_id, feature_vector)
                    VALUES (%s, %s)
                    ON CONFLICT DO NOTHING
                    """,
                    (song_id, feature.tolist()),
                )
            conn.commit()

        self._logger.info("Chromagram 特征入库完成: song_id=%s", song_id)

        if self.config.dejavu_enabled:
            self._dejavu_ingest(song_id, audio_path, samples=samples, sr=sr)

        return feature

    def _dejavu_ingest(
        self,
        song_id: int,
        audio_path: str,
        *,
        samples: np.ndarray | None = None,
        sr: int | None = None,
    ) -> None:
        """提取 Dejavu 指纹并写入 dejavu_fingerprints 表。

        若提供了已解码的 samples/sr,直接使用,跳过 ffmpeg;否则从文件重新加载。
        """
        if samples is not None and sr is not None:
            _, fingerprints = fingerprint_from_samples(samples, sr, compute_sha1=False)
            self._logger.info("Dejavu 指纹提取完成(共用解码): song_id=%s", song_id)
        else:
            _, fingerprints = fingerprint_audio(audio_path)

        if not fingerprints:
            self._logger.warning("Dejavu 指纹为空: song_id=%s, audio=%s", song_id, audio_path)
            return

        with psycopg.connect(self.config.dsn) as conn:
            with conn.cursor() as cursor:
                # 先清理可能残留的旧指纹(幂等写入)
                cursor.execute(
                    "DELETE FROM dejavu_fingerprints WHERE song_id = %s",
                    (song_id,),
                )
                # 批量写入
                records = [(song_id, h, o) for h, o in fingerprints]
                cursor.executemany(
                    """
                    INSERT INTO dejavu_fingerprints (song_id, hash, "offset")
                    VALUES (%s, %s, %s)
                    """,
                    records,
                )
            conn.commit()

        self._logger.info("Dejavu 指纹入库完成: song_id=%s, 指纹数=%d", song_id, len(fingerprints))

    def query(
        self,
        audio_path: str,
        top_k: int = 100,
        timings: dict | None = None,
    ) -> list[CompositionCandidate]:
        """提取音频特征并查询相似结果。

        流程:Chromagram 12路 + DTW 精排 → Dejavu 指纹兜底(片段场景,仅在 chroma 未命中时解码)。

        Args:
            timings: 若传入非 None 的 dict,方法执行完毕后会在其中写入各阶段耗时(单位 ms):
                chroma_extract_ms、db_cosine_ms、db_fetch_ms、dtw_ms、
                dejavu_decode_ms、dejavu_fingerprint_ms、dejavu_db_ms。
                Dejavu 路径未执行时,对应键不会写入。
        """
        # 1. Chromagram 12路 + DTW 精排(覆盖绝大多数场景)
        # 使用与入库一致的 22050Hz 解码路径,保证 chroma 向量对齐
        candidates = self._query_chroma(audio_path, top_k, timings=timings)

        # 2. Chromagram 未命中时,用 Dejavu 兜底(处理 chorus_only / trim_intro 等片段场景)
        # 只有未命中才解码 44100Hz,大多数情况下无额外 I/O
        if self.config.dejavu_enabled and not self.candidates_indicate_duplicate(candidates):
            _t = time.perf_counter()
            samples, sr = load_audio(audio_path, max_duration=QUERY_MAX_DURATION_SEC)
            if timings is not None:
                timings["dejavu_decode_ms"] = round((time.perf_counter() - _t) * 1000, 1)
            match = self._dejavu_query(samples, sr, timings=timings)
            if match is not None:
                if match.aligned_count >= self.config.dejavu_match_threshold:
                    self._logger.info(
                        "Dejavu 命中: song_id=%s, aligned_count=%d, total_collisions=%d, decision=duplicate",
                        match.song_id,
                        match.aligned_count,
                        match.total_collisions,
                    )
                    return [CompositionCandidate(
                        song_id=match.song_id,
                        similarity=1.0,
                        source="dejavu",
                        dejavu_aligned_count=match.aligned_count,
                    )]
                else:
                    # 未达阈值:把 aligned_count 附加到 chromagram top1 上供评估脚本记录
                    if candidates:
                        candidates[0].dejavu_aligned_count = match.aligned_count

        return candidates

    def check(self, audio_path: str, top_k: int = 100) -> bool:
        """按最终接口语义返回是否重复。"""
        return self.candidates_indicate_duplicate(self.query(audio_path, top_k=top_k))

    def candidates_indicate_duplicate(self, candidates: list[CompositionCandidate]) -> bool:
        """将候选结果转换为最终 duplicate bool。

        最终接口只返回 true/false,因此判定只看当前查询的最佳候选是否超过阈值,
        不依赖评测集里的 expected_song_id 是否出现在 top-k。
        """
        if not candidates:
            return False
        return candidates[0].similarity >= self.config.duplicate_threshold

    def _query_chroma(
        self,
        audio_path: str,
        top_k: int = 100,
        timings: dict | None = None,
    ) -> list[CompositionCandidate]:
        """Chromagram 12 路循环对齐 + DTW 精排查询。"""
        _t = time.perf_counter()
        chroma = extract_chroma_matrix(audio_path, hop_length=self.config.chroma_hop_length, win_len_smooth=self.config.chroma_win_len_smooth)
        if timings is not None:
            timings["chroma_extract_ms"] = round((time.perf_counter() - _t) * 1000, 1)
        self._logger.info("提取 Chromagram 查询特征完成: audio=%s", audio_path)

        # 1. 12 路循环对齐:穷举 12 种半音偏移,单条 SQL 内部展开,按 song_id 取最高 Cosine 相似度
        shift_vecs = [
            np.roll(chroma, -shift, axis=0).flatten().astype(np.float32).tolist()
            for shift in range(12)
        ]
        # 用 VALUES 展开 12 个偏移向量,LATERAL 子查询对每个偏移各触发一次 HNSW 扫描
        values_clause = ", ".join(f"({i}, %s::vector)" for i in range(12))
        sql = f"""
            WITH shifts(shift_id, vec) AS (
                VALUES {values_clause}
            ),
            candidates AS (
                SELECT
                    cf.song_id,
                    1 - (cf.feature_vector <=> s.vec) AS sim
                FROM shifts s
                CROSS JOIN LATERAL (
                    SELECT song_id, feature_vector
                    FROM composition_feature
                    ORDER BY feature_vector <=> s.vec
                    LIMIT %s
                ) cf
            )
            SELECT song_id, MAX(sim) AS similarity
            FROM candidates
            GROUP BY song_id
            ORDER BY similarity DESC
            LIMIT %s
        """
        best: dict[int, float] = {}
        _t = time.perf_counter()
        with psycopg.connect(self.config.dsn) as conn:
            with conn.cursor() as cursor:
                cursor.execute(
                    f"SET statement_timeout = {int(self.config.statement_timeout_ms)}"
                )
                cursor.execute(sql, (*shift_vecs, top_k, top_k))
                for song_id, sim in cursor.fetchall():
                    best[int(song_id)] = float(sim)
            if timings is not None:
                timings["db_cosine_ms"] = round((time.perf_counter() - _t) * 1000, 1)

            # 2. 取 Top dtw_rerank_top_k,从库中取原始向量做 DTW 精排
            top = sorted(best.items(), key=lambda x: x[1], reverse=True)
            rerank_ids = [sid for sid, _ in top[:self.config.dtw_rerank_top_k]]

            _t = time.perf_counter()
            with conn.cursor() as cursor:
                cursor.execute(
                    "SELECT song_id, feature_vector::float4[] FROM composition_feature WHERE song_id = ANY(%s)",
                    (rerank_ids,),
                )
                db_rows = cursor.fetchall()
            if timings is not None:
                timings["db_fetch_ms"] = round((time.perf_counter() - _t) * 1000, 1)

        _t = time.perf_counter()
        reranked = []
        for song_id, fv in db_rows:
            cand_chroma = np.array(fv, dtype=np.float32).reshape(12, TARGET_FRAMES)
            dtw_sim = _best_shifted_dtw_similarity(
                chroma, cand_chroma, early_exit_threshold=self.config.duplicate_threshold
            )
            reranked.append(CompositionCandidate(song_id=int(song_id), similarity=dtw_sim))
        reranked.sort(key=lambda c: c.similarity, reverse=True)
        if timings is not None:
            timings["dtw_ms"] = round((time.perf_counter() - _t) * 1000, 1)

        rerank_id_set = {c.song_id for c in reranked}
        # top 已按 cosine 降序排列;直接从中剔除已精排的候选,剩余保留 cosine 分
        non_reranked = [
            CompositionCandidate(song_id=sid, similarity=sim)
            for sid, sim in top
            if sid not in rerank_id_set
        ]

        result = reranked + non_reranked
        top_summary = ", ".join(
            f"{candidate.song_id}:{candidate.similarity:.4f}"
            for candidate in result[:5]
        )
        self._logger.info(
            "Chromagram 查询完成: 返回 %d 个候选, top=%s",
            len(result),
            top_summary or "[]",
        )
        return result

    def _dejavu_query(
        self,
        samples: np.ndarray,
        sr: int,
        timings: dict | None = None,
    ) -> _DejavuMatch | None:
        """Dejavu 指纹查询,返回 offset 对齐后碰撞数最多的 song_id。

        只统计 hash 总碰撞数会让常见频谱峰值、噪声片段或大库随机碰撞直接短路成
        similarity=1.0。Dejavu 的关键判据是同一首候选歌里,多个 hash 碰撞的
        db_offset - query_offset 落在同一个时间偏移上。

        Returns:
            最佳匹配结果(不做阈值过滤),无任何碰撞时返回 None。
        """
        _t = time.perf_counter()
        _, fingerprints = fingerprint_from_samples(samples, sr, compute_sha1=False)
        if timings is not None:
            timings["dejavu_fingerprint_ms"] = round((time.perf_counter() - _t) * 1000, 1)
        if not fingerprints:
            return None

        hashes = [h for h, _ in fingerprints]
        offsets = [int(o) for _, o in fingerprints]

        _t = time.perf_counter()
        with psycopg.connect(self.config.dsn) as conn:
            with conn.cursor() as cursor:
                # 先按 hash 找碰撞,再按每个 song_id 的 offset delta 聚类。
                cursor.execute(
                    """
                    WITH query_fp(hash, query_offset) AS (
                        SELECT *
                        FROM unnest(%s::bytea[], %s::int[])
                    ),
                    aligned AS (
                        SELECT
                            db.song_id,
                            db."offset" - query_fp.query_offset AS offset_delta,
                            COUNT(*) AS aligned_count
                        FROM query_fp
                        JOIN dejavu_fingerprints db
                          ON db.hash = query_fp.hash
                        GROUP BY db.song_id, offset_delta
                    )
                    SELECT
                        song_id,
                        MAX(aligned_count) AS best_aligned_count,
                        SUM(aligned_count) AS total_collisions
                    FROM aligned
                    GROUP BY song_id
                    ORDER BY best_aligned_count DESC, total_collisions DESC
                    LIMIT 1
                    """,
                    (hashes, offsets),
                )
                row = cursor.fetchone()
        if timings is not None:
            timings["dejavu_db_ms"] = round((time.perf_counter() - _t) * 1000, 1)
        if row is None:
            return None
        sid, aligned_count, total_collisions = row
        return _DejavuMatch(
            song_id=int(sid),
            aligned_count=int(aligned_count),
            total_collisions=int(total_collisions),
        )


@numba.njit(cache=True)
def _dtw_dp(cost: np.ndarray) -> float:
    """DTW DP 填表(numba JIT 编译,数值结果与纯 Python 实现完全一致)。"""
    n, m = cost.shape
    dp = np.full((n, m), np.inf)
    dp[0, 0] = cost[0, 0]
    for i in range(1, n):
        dp[i, 0] = dp[i - 1, 0] + cost[i, 0]
    for j in range(1, m):
        dp[0, j] = dp[0, j - 1] + cost[0, j]
    for i in range(1, n):
        for j in range(1, m):
            dp[i, j] = cost[i, j] + min(dp[i - 1, j], dp[i, j - 1], dp[i - 1, j - 1])
    return dp[n - 1, m - 1]


def _dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float:
    """计算两个 12×T Chromagram 矩阵之间的 DTW 相似度(映射到 [0,1])。"""
    # 帧间欧氏距离矩阵
    cost = cdist(query.T, candidate.T, metric="euclidean")
    n, m = cost.shape
    dtw_dist = _dtw_dp(cost) / (n + m)
    # 转换为相似度:距离越小相似度越高
    return float(1.0 / (1.0 + dtw_dist))


def _best_shifted_dtw_similarity(
    query: np.ndarray,
    candidate: np.ndarray,
    early_exit_threshold: float = 1.1,
) -> float:
    """计算 12 路音高循环位移下的最佳 DTW 相似度。

    Args:
        early_exit_threshold: 某个 shift 的相似度达到此值时立即返回,跳过剩余 shift。
            传入 duplicate_threshold 即可:对已确认重复的候选不再浪费算力;
            返回值可能略低于理论最大值,但不影响 duplicate/non-duplicate 二元判定。
            默认 1.1(> 1 的不可达值,等价于不启用早退)。
    """
    best = 0.0
    for shift in range(12):
        sim = _dtw_similarity(np.roll(query, -shift, axis=0), candidate)
        if sim > best:
            best = sim
        if best >= early_exit_threshold:
            break
    return best