service.py 13.1 KB
"""作曲去重服务(入库 + 查询)。

查询流程:
1. Dejavu 指纹匹配(毫秒级,子序列匹配,支持 chorus_only / trim_intro)
   - 命中(≥ 阈值)→ 直接返回 duplicate(短路)
2. 未命中 → Chromagram 12路 + DTW(百毫秒级)
   - 返回结果
"""

import logging
import os
from dataclasses import dataclass, field

import numpy as np
import psycopg
from scipy.spatial.distance import cdist

from .extractor import TARGET_FRAMES, extract_chroma_feature, extract_chroma_matrix
from .dejavu_fingerprinter import fingerprint_audio

logger = logging.getLogger(__name__)


def _env_bool(name: str, default: bool) -> bool:
    value = os.getenv(name)
    if value is None:
        return default
    return value.strip().lower() in {"1", "true", "yes", "y", "on"}


def _env_int(name: str, default: int) -> int:
    value = os.getenv(name)
    if value is None:
        return default
    try:
        return int(value)
    except ValueError:
        logger.warning("环境变量 %s=%r 不是整数,使用默认值 %d", name, value, default)
        return default


def _env_float(name: str, default: float) -> float:
    value = os.getenv(name)
    if value is None:
        return default
    try:
        return float(value)
    except ValueError:
        logger.warning("环境变量 %s=%r 不是数字,使用默认值 %.4f", name, value, default)
        return default


@dataclass
class CompositionCandidate:
    """去重候选结果。"""
    song_id: int
    similarity: float
    source: str = "chromagram"


@dataclass
class _DejavuMatch:
    """Dejavu offset 对齐后的命中结果。"""
    song_id: int
    aligned_count: int
    total_collisions: int


@dataclass
class CompositionConfig:
    """作曲去重服务配置。"""
    dsn: str = "postgresql:///lyric_dedup"
    statement_timeout_ms: int = 30000
    dtw_rerank_top_k: int = 20  # Cosine 召回后做 DTW 精排的候选数量
    duplicate_threshold: float = _env_float("COMPOSITION_DUPLICATE_THRESHOLD", 0.85)
    # Dejavu 指纹匹配配置
    dejavu_enabled: bool = _env_bool("COMPOSITION_DEJAVU_ENABLED", True)
    dejavu_match_threshold: int = _env_int("COMPOSITION_DEJAVU_MATCH_THRESHOLD", 20)


@dataclass
class CompositionDedupService:
    """作曲去重服务:特征入库 + 相似度查询。"""
    config: CompositionConfig
    _logger: logging.Logger = field(default_factory=lambda: logger, repr=False)

    def ingest(self, song_id: int, audio_path: str) -> np.ndarray:
        """提取音频特征并写入数据库。

        Args:
            song_id: 歌曲 ID。
            audio_path: 音频文件路径。

        Returns:
            提取的特征向量。
        """
        feature = extract_chroma_feature(audio_path)
        self._logger.info("提取 Chromagram 特征完成: song_id=%s, audio=%s", song_id, audio_path)

        with psycopg.connect(self.config.dsn) as conn:
            with conn.cursor() as cursor:
                cursor.execute(
                    """
                    INSERT INTO composition_feature (song_id, feature_vector)
                    VALUES (%s, %s)
                    ON CONFLICT DO NOTHING
                    """,
                    (song_id, feature.tolist()),
                )
            conn.commit()

        self._logger.info("Chromagram 特征入库完成: song_id=%s", song_id)

        # Dejavu 指纹同时入库
        if self.config.dejavu_enabled:
            self._dejavu_ingest(song_id, audio_path)

        return feature

    def _dejavu_ingest(self, song_id: int, audio_path: str) -> None:
        """提取 Dejavu 指纹并写入 dejavu_fingerprints 表。"""
        file_sha1, fingerprints = fingerprint_audio(audio_path)
        if not fingerprints:
            self._logger.warning("Dejavu 指纹为空: song_id=%s, audio=%s", song_id, audio_path)
            return

        with psycopg.connect(self.config.dsn) as conn:
            with conn.cursor() as cursor:
                # 先清理可能残留的旧指纹(幂等写入)
                cursor.execute(
                    "DELETE FROM dejavu_fingerprints WHERE song_id = %s",
                    (song_id,),
                )
                # 批量写入
                records = [(song_id, h, o) for h, o in fingerprints]
                cursor.executemany(
                    """
                    INSERT INTO dejavu_fingerprints (song_id, hash, "offset")
                    VALUES (%s, %s, %s)
                    """,
                    records,
                )
            conn.commit()

        self._logger.info("Dejavu 指纹入库完成: song_id=%s, 指纹数=%d", song_id, len(fingerprints))

    def query(self, audio_path: str, top_k: int = 100) -> list[CompositionCandidate]:
        """提取音频特征并查询相似结果。

        流程:Dejavu 指纹短路匹配 → 12 路循环对齐 Cosine 召回 → DTW 精排。
        """
        # 1. 优先尝试 Dejavu 指纹匹配(短路)
        if self.config.dejavu_enabled:
            match = self._dejavu_query(audio_path)
            if match is not None:
                self._logger.info(
                    "Dejavu 命中: song_id=%s, aligned_count=%d, total_collisions=%d, decision=duplicate",
                    match.song_id,
                    match.aligned_count,
                    match.total_collisions,
                )
                return [CompositionCandidate(song_id=match.song_id, similarity=1.0, source="dejavu")]

        # 2. Dejavu 未命中或禁用,走现有 Chromagram 12路 + DTW 流程
        return self._query_chroma(audio_path, top_k)

    def check(self, audio_path: str, top_k: int = 100) -> bool:
        """按最终接口语义返回是否重复。"""
        return self.candidates_indicate_duplicate(self.query(audio_path, top_k=top_k))

    def candidates_indicate_duplicate(self, candidates: list[CompositionCandidate]) -> bool:
        """将候选结果转换为最终 duplicate bool。

        最终接口只返回 true/false,因此判定只看当前查询的最佳候选是否超过阈值,
        不依赖评测集里的 expected_song_id 是否出现在 top-k。
        """
        if not candidates:
            return False
        return candidates[0].similarity >= self.config.duplicate_threshold

    def _query_chroma(self, audio_path: str, top_k: int = 100) -> list[CompositionCandidate]:
        """Chromagram 12 路循环对齐 + DTW 精排查询。"""
        chroma = extract_chroma_matrix(audio_path)
        self._logger.info("提取 Chromagram 查询特征完成: audio=%s", audio_path)

        # 1. 12 路循环对齐:穷举 12 种半音偏移,单条 SQL 内部展开,按 song_id 取最高 Cosine 相似度
        shift_vecs = [
            np.roll(chroma, -shift, axis=0).flatten().astype(np.float32).tolist()
            for shift in range(12)
        ]
        # 用 VALUES 展开 12 个偏移向量,LATERAL 子查询对每个偏移各触发一次 HNSW 扫描
        values_clause = ", ".join(f"({i}, %s::vector)" for i in range(12))
        sql = f"""
            WITH shifts(shift_id, vec) AS (
                VALUES {values_clause}
            ),
            candidates AS (
                SELECT
                    cf.song_id,
                    1 - (cf.feature_vector <=> s.vec) AS sim
                FROM shifts s
                CROSS JOIN LATERAL (
                    SELECT song_id, feature_vector
                    FROM composition_feature
                    ORDER BY feature_vector <=> s.vec
                    LIMIT %s
                ) cf
            )
            SELECT song_id, MAX(sim) AS similarity
            FROM candidates
            GROUP BY song_id
            ORDER BY similarity DESC
            LIMIT %s
        """
        best: dict[int, float] = {}
        with psycopg.connect(self.config.dsn) as conn:
            with conn.cursor() as cursor:
                cursor.execute(
                    f"SET statement_timeout = {int(self.config.statement_timeout_ms)}"
                )
                cursor.execute(sql, (*shift_vecs, top_k, top_k))
                for song_id, sim in cursor.fetchall():
                    best[int(song_id)] = float(sim)

            # 2. 取 Top dtw_rerank_top_k,从库中取原始向量做 DTW 精排
            top = sorted(best.items(), key=lambda x: x[1], reverse=True)
            rerank_ids = [sid for sid, _ in top[:self.config.dtw_rerank_top_k]]

            with conn.cursor() as cursor:
                cursor.execute(
                    "SELECT song_id, feature_vector::float4[] FROM composition_feature WHERE song_id = ANY(%s)",
                    (rerank_ids,),
                )
                db_rows = cursor.fetchall()

        reranked = []
        for song_id, fv in db_rows:
            cand_chroma = np.array(fv, dtype=np.float32).reshape(12, TARGET_FRAMES)
            dtw_sim = _best_shifted_dtw_similarity(chroma, cand_chroma)
            reranked.append(CompositionCandidate(song_id=int(song_id), similarity=dtw_sim))
        reranked.sort(key=lambda c: c.similarity, reverse=True)

        rerank_id_set = {c.song_id for c in reranked}
        rest = [
            CompositionCandidate(song_id=sid, similarity=sim)
            for sid, sim in top[self.config.dtw_rerank_top_k:]
            if sid not in rerank_id_set
        ]

        result = reranked + rest
        top_summary = ", ".join(
            f"{candidate.song_id}:{candidate.similarity:.4f}"
            for candidate in result[:5]
        )
        self._logger.info(
            "Chromagram 查询完成: 返回 %d 个候选, top=%s",
            len(result),
            top_summary or "[]",
        )
        return result

    def _dejavu_query(self, audio_path: str) -> _DejavuMatch | None:
        """Dejavu 指纹查询,返回 offset 对齐后碰撞数最多的 song_id。

        只统计 hash 总碰撞数会让常见频谱峰值、噪声片段或大库随机碰撞直接短路成
        similarity=1.0。Dejavu 的关键判据是同一首候选歌里,多个 hash 碰撞的
        db_offset - query_offset 落在同一个时间偏移上。

        Returns:
            命中结果,未命中返回 None。
        """
        file_sha1, fingerprints = fingerprint_audio(audio_path)
        if not fingerprints:
            return None

        hashes = [h for h, _ in fingerprints]
        offsets = [int(o) for _, o in fingerprints]

        with psycopg.connect(self.config.dsn) as conn:
            with conn.cursor() as cursor:
                # 先按 hash 找碰撞,再按每个 song_id 的 offset delta 聚类。
                cursor.execute(
                    """
                    WITH query_fp(hash, query_offset) AS (
                        SELECT *
                        FROM unnest(%s::bytea[], %s::int[])
                    ),
                    aligned AS (
                        SELECT
                            db.song_id,
                            db."offset" - query_fp.query_offset AS offset_delta,
                            COUNT(*) AS aligned_count
                        FROM query_fp
                        JOIN dejavu_fingerprints db
                          ON db.hash = query_fp.hash
                        GROUP BY db.song_id, offset_delta
                    )
                    SELECT
                        song_id,
                        MAX(aligned_count) AS best_aligned_count,
                        SUM(aligned_count) AS total_collisions
                    FROM aligned
                    GROUP BY song_id
                    ORDER BY best_aligned_count DESC, total_collisions DESC
                    LIMIT 1
                    """,
                    (hashes, offsets),
                )
                row = cursor.fetchone()
                if row is None:
                    return None
                sid, aligned_count, total_collisions = row
                aligned_count = int(aligned_count)
                if aligned_count >= self.config.dejavu_match_threshold:
                    return _DejavuMatch(
                        song_id=int(sid),
                        aligned_count=aligned_count,
                        total_collisions=int(total_collisions),
                    )
                return None


def _dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float:
    """计算两个 12×T Chromagram 矩阵之间的 DTW 相似度(映射到 [0,1])。"""
    # 帧间欧氏距离矩阵
    cost = cdist(query.T, candidate.T, metric="euclidean")
    n, m = cost.shape
    dp = np.full((n, m), np.inf)
    dp[0, 0] = cost[0, 0]
    for i in range(1, n):
        dp[i, 0] = dp[i - 1, 0] + cost[i, 0]
    for j in range(1, m):
        dp[0, j] = dp[0, j - 1] + cost[0, j]
    for i in range(1, n):
        for j in range(1, m):
            dp[i, j] = cost[i, j] + min(dp[i - 1, j], dp[i, j - 1], dp[i - 1, j - 1])
    dtw_dist = dp[n - 1, m - 1] / (n + m)
    # 转换为相似度:距离越小相似度越高
    return float(1.0 / (1.0 + dtw_dist))


def _best_shifted_dtw_similarity(query: np.ndarray, candidate: np.ndarray) -> float:
    """计算 12 路音高循环位移下的最佳 DTW 相似度。"""
    return max(
        _dtw_similarity(np.roll(query, -shift, axis=0), candidate)
        for shift in range(12)
    )