dejavu_fingerprinter.py 7.78 KB
"""Dejavu 风格的音频指纹生成。

基于 worldveil/dejavu 的指纹算法提取实现,不依赖 Dejavu 的数据库层。
使用 scipy.signal.spectrogram 替代已废弃的 matplotlib.mlab.specgram。

流程:
1. 音频标准化:ffmpeg 转 44100Hz / Mono / WAV
2. librosa 加载音频
3. 短时傅里叶变换(STFT)→ 对数频谱图
4. 2D 峰值检测:在频谱图中找局部极大值
5. 指纹哈希:对峰值对 (freq1, freq2, time_delta) 做 SHA1,取前 20 位
"""

import hashlib
import logging
import os
import subprocess
import tempfile
from operator import itemgetter
from pathlib import Path

import librosa
import numpy as np
from scipy.ndimage import (
    generate_binary_structure,
    iterate_structure,
    maximum_filter,
)
from scipy.signal import spectrogram

logger = logging.getLogger(__name__)


def _load_env_file() -> None:
    """加载项目根目录 .env,不覆盖已存在的真实环境变量。"""
    env_path = Path(__file__).resolve().parent.parent / ".env"
    if not env_path.exists():
        return
    with env_path.open(encoding="utf-8") as file:
        for raw_line in file:
            line = raw_line.strip()
            if not line or line.startswith("#") or "=" not in line:
                continue
            key, value = line.split("=", 1)
            os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'"))


_load_env_file()

# ===== 常量(可通过环境变量覆盖)=====

DEFAULT_FS = 44100
DEFAULT_WINDOW_SIZE = 4096
DEFAULT_OVERLAP_RATIO = float(os.environ.get("COMPOSITION_DEJAVU_OVERLAP_RATIO", "0.3"))
DEFAULT_FAN_VALUE = int(os.environ.get("COMPOSITION_DEJAVU_FAN_VALUE", "10"))
DEFAULT_AMP_MIN = float(os.environ.get("COMPOSITION_DEJAVU_AMP_MIN", "20"))
PEAK_NEIGHBORHOOD_SIZE = 20
MIN_HASH_TIME_DELTA = 0
MAX_HASH_TIME_DELTA = 200
PEAK_SORT = True
FINGERPRINT_REDUCTION = 20
QUERY_MAX_DURATION_SEC = float(os.environ.get("COMPOSITION_DEJAVU_QUERY_MAX_DURATION", "120"))  # 0=不限制


def _normalize_audio(audio_path: str, max_duration: float = 0) -> tuple[np.ndarray, int]:
    """将音频标准化为单声道 WAV 并加载为 numpy 数组。

    使用 ffmpeg 先做重采样,再用 librosa 读取。
    可选限制音频长度,超长音频只取前 N 秒。
    """
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
        tmp_wav = tmp.name

    try:
        cmd = [
            "ffmpeg",
            "-y",
            "-i", audio_path,
            "-ar", str(DEFAULT_FS),
            "-ac", "1",
            "-f", "wav",
        ]
        if max_duration > 0:
            cmd += ["-t", str(max_duration)]
        cmd.append(tmp_wav)

        result = subprocess.run(cmd, capture_output=True, text=True)
        if result.returncode != 0:
            raise RuntimeError(f"ffmpeg 转换失败: {result.stderr}")

        y, sr = librosa.load(tmp_wav, sr=DEFAULT_FS, mono=True)
        return y, sr
    finally:
        if os.path.exists(tmp_wav):
            os.remove(tmp_wav)


def _specgram(samples: np.ndarray, fs: int, window_size: int, overlap_ratio: float):
    """计算对数频谱图,替代 matplotlib.mlab.specgram。

    Returns:
        arr2D: shape (n_freq, n_time) 的对数频谱矩阵(dBFS 刻度)
    """
    noverlap = int(window_size * overlap_ratio)
    window = np.hanning(window_size)

    freqs, times, Sxx = spectrogram(
        samples,
        fs=fs,
        window=window,
        nperseg=window_size,
        noverlap=noverlap,
    )

    # 转为对数尺度(dBFS,0 dB 为峰值参考)
    # scipy.signal.spectrogram 返回 PSD,mlab.specgram 返回功率,两者量纲不同
    # 统一转为相对于峰值的 dBFS 刻度,使强信号峰值落在 20~80 dB 范围
    arr2D = 10 * np.log10(Sxx + 1e-10)
    arr2D = arr2D - arr2D.max()  # 归一化到峰值为 0 dBFS
    arr2D = arr2D + 80  # 偏移使典型峰值落在 20~80 dB(与 mlab.specgram 一致)
    arr2D[arr2D < -100] = -100  # 限幅
    return arr2D


def _get_2d_peaks(arr2D: np.ndarray, amp_min: float = DEFAULT_AMP_MIN):
    """在频谱图中检测 2D 局部极大值。

    Returns:
        (frequency_idx, time_idx): 峰值的频率和时间索引列表
    """
    struct = generate_binary_structure(2, 1)
    neighborhood = iterate_structure(struct, PEAK_NEIGHBORHOOD_SIZE)

    # 找局部极大值
    detected_peaks = maximum_filter(arr2D, footprint=neighborhood) == arr2D

    # 提取峰值
    amps = arr2D[detected_peaks]
    j, i = np.where(detected_peaks)

    # 过滤低于阈值的峰值
    peaks = list(zip(i, j, amps))
    peaks_filtered = [x for x in peaks if x[2] > amp_min]

    frequency_idx = [x[1] for x in peaks_filtered]
    time_idx = [x[0] for x in peaks_filtered]

    return frequency_idx, time_idx


def _generate_hashes(peaks: list[tuple[int, int]], fan_value: int = DEFAULT_FAN_VALUE):
    """根据峰值对生成 SHA1 指纹哈希。

    Args:
        peaks: [(freq_idx, time_idx), ...] 列表
        fan_value: 每个峰值与后续多少个峰值配对

    Yields:
        (hash_bytes, time_offset) 元组
    """
    if PEAK_SORT:
        peaks.sort(key=itemgetter(1))

    for i in range(len(peaks)):
        for j in range(1, fan_value):
            if i + j < len(peaks):
                freq1 = peaks[i][0]
                freq2 = peaks[i + j][0]
                t1 = peaks[i][1]
                t2 = peaks[i + j][1]
                t_delta = t2 - t1

                if MIN_HASH_TIME_DELTA <= t_delta <= MAX_HASH_TIME_DELTA:
                    h = hashlib.sha1(f"{freq1}|{freq2}|{t_delta}".encode())
                    yield (h.hexdigest()[:FINGERPRINT_REDUCTION].encode(), t1)


def load_audio(audio_path: str, max_duration: float = 0) -> tuple[np.ndarray, int]:
    """加载并标准化音频为 44100Hz 单声道(供多路径共用,避免重复解码)。

    Args:
        audio_path: 音频文件路径。
        max_duration: 最大截取时长(秒),0 表示不限制。

    Returns:
        (samples, sr) 元组。
    """
    return _normalize_audio(audio_path, max_duration)


def fingerprint_from_samples(
    samples: np.ndarray, sr: int, *, compute_sha1: bool = True
) -> tuple[str, list[tuple[bytes, int]]]:
    """对已加载的音频样本生成 Dejavu 风格指纹(不做 I/O)。

    Args:
        samples: 单声道音频样本(应为 DEFAULT_FS=44100Hz)。
        sr: 采样率。
        compute_sha1: 是否计算 file_sha1。service 内部调用时传 False 可跳过
            对 samples.tobytes() 的 21MB 哈希运算(返回值在那些路径中未被使用)。

    Returns:
        (file_sha1, fingerprints) 元组,
        其中 fingerprints 是 [(hash_bytes, offset), ...] 列表。
        compute_sha1=False 时 file_sha1 返回空字符串。
    """
    file_sha1 = hashlib.sha1(samples.tobytes()).hexdigest()[:16] if compute_sha1 else ""
    arr2D = _specgram(samples, sr, DEFAULT_WINDOW_SIZE, DEFAULT_OVERLAP_RATIO)
    freq_idx, time_idx = _get_2d_peaks(arr2D)
    peaks = list(zip(freq_idx, time_idx))
    fingerprints = list(_generate_hashes(peaks))
    return file_sha1, fingerprints


def fingerprint_audio(audio_path: str) -> tuple[str, list[tuple[bytes, int]]]:
    """对音频文件生成 Dejavu 风格指纹。

    Args:
        audio_path: 音频文件路径。

    Returns:
        (file_sha1, fingerprints) 元组,
        其中 fingerprints 是 [(hash_bytes, offset), ...] 列表。

    Raises:
        FileNotFoundError: 音频文件不存在。
        RuntimeError: ffmpeg 转换失败。
    """
    if not os.path.isfile(audio_path):
        raise FileNotFoundError(f"音频文件不存在: {audio_path}")

    samples, fs = _normalize_audio(audio_path)
    file_sha1, fingerprints = fingerprint_from_samples(samples, fs)
    logger.info("指纹生成完成: audio=%s, 指纹数=%d", audio_path, len(fingerprints))
    return file_sha1, fingerprints