normalization.py 13 KB
"""Lyric-specific normalization and feature extraction."""

from __future__ import annotations

import re
import string
import unicodedata
from collections import Counter
from dataclasses import dataclass


_TRADITIONAL_TO_SIMPLIFIED = str.maketrans(
    {
        "愛": "爱",
        "會": "会",
        "個": "个",
        "妳": "你",
        "們": "们",
        "麼": "么",
        "夢": "梦",
        "憶": "忆",
        "風": "风",
        "無": "无",
        "與": "与",
        "聽": "听",
        "說": "说",
        "見": "见",
        "話": "话",
        "還": "还",
        "這": "这",
        "那": "那",
        "裡": "里",
        "裏": "里",
        "過": "过",
        "來": "来",
        "進": "进",
        "去": "去",
        "給": "给",
        "讓": "让",
        "嗎": "吗",
        "為": "为",
        "誰": "谁",
        "對": "对",
        "錯": "错",
        "淚": "泪",
        "寫": "写",
        "雲": "云",
        "藍": "蓝",
        "紅": "红",
        "綠": "绿",
        "黃": "黄",
        "長": "长",
        "遠": "远",
        "燈": "灯",
        "臺": "台",
        "台": "台",
        "後": "后",
        "從": "从",
        "時": "时",
        "間": "间",
        "葉": "叶",
        "歲": "岁",
        "聲": "声",
        "邊": "边",
        "歡": "欢",
        "繼": "继",
        "續": "续",
        "難": "难",
        "雙": "双",
        "舊": "旧",
        "離": "离",
    }
)

_TIMESTAMP_RE = re.compile(r"\[((?:\d{1,2}:)?\d{1,2}:\d{2}(?:[.:]\d{1,3})?)\]")
_BRACKET_RE = re.compile(r"[\[((【<《].{0,40}?[\]))】>》]")
_ROLE_PREFIX_RE = re.compile(r"^\s*(?:男|女|合|主歌|副歌|verse|chorus|bridge|rap)\s*[::]\s*", re.IGNORECASE)
_CREDIT_PREFIX_RE = re.compile(
    r"^\s*(?:作词|作詞|作曲|编曲|編曲|制作|製作|监制|監製|录音|錄音|混音|母带|"
    r"出品|发行|發行|歌词|歌詞|lyric(?:s)?|composer|writer|producer|arranger|"
    r"copyright|未经|未經|qq音乐|酷狗|网易云|網易雲|lrc)",
    re.IGNORECASE,
)
_WATERMARK_RE = re.compile(
    r"(?:qq音乐|酷狗音乐|网易云音乐|網易雲音樂|虾米音乐|歌词网|歌詞網|"
    r"music\.163\.com|www\.|http[s]?://|\blrc\b)",
    re.IGNORECASE,
)
_CJK_RE = re.compile(r"[\u4e00-\u9fff]")
_LATIN_RE = re.compile(r"[a-zA-Z]")
_KANA_RE = re.compile(r"[\u3040-\u30ff]")
_HANGUL_RE = re.compile(r"[\uac00-\ud7af]")
_WORD_RE = re.compile(r"[a-z0-9]+|[\u4e00-\u9fff]", re.IGNORECASE)
_INLINE_SPLIT_RE = re.compile(r"\s+(?:/|\|||)\s+|(?<=[A-Za-z])\s*[-—]\s*(?=[\u4e00-\u9fff])")


@dataclass(frozen=True)
class _LineEntry:
    text: str
    timestamp: str | None
    language: str
    source_index: int


@dataclass(frozen=True)
class NormalizedLyrics:
    raw_text: str
    normalized_full_text: str
    normalized_lines: tuple[str, ...]
    unique_lines: tuple[str, ...]
    line_counts: dict[str, int]
    content_line_count: int
    primary_lines: tuple[str, ...]
    translation_lines: tuple[str, ...]
    unknown_lines: tuple[str, ...]
    line_roles: tuple[str, ...]
    split_confidence: str
    split_reason: str


def normalize_lyrics(text: str) -> NormalizedLyrics:
    """Normalize lyrics while preserving line-level structure for ranking."""
    entries: list[_LineEntry] = []
    for index, raw_line in enumerate(unicodedata.normalize("NFKC", text).splitlines()):
        entries.extend(_clean_line_entries(raw_line, index))

    cleaned_lines = [entry.text for entry in entries]
    roles, confidence, reason = _assign_line_roles(entries)
    primary_lines = tuple(entry.text for entry, role in zip(entries, roles, strict=False) if role == "primary")
    translation_lines = tuple(entry.text for entry, role in zip(entries, roles, strict=False) if role == "translation")
    unknown_lines = tuple(entry.text for entry, role in zip(entries, roles, strict=False) if role == "unknown")
    if not primary_lines:
        primary_lines = tuple(cleaned_lines)
        roles = tuple("primary" for _ in cleaned_lines)
        if cleaned_lines and confidence == "none":
            reason = "未检测到可分离的翻译结构,全部有效行按原文处理"

    counts = Counter(cleaned_lines)
    unique_lines = tuple(dict.fromkeys(cleaned_lines))
    return NormalizedLyrics(
        raw_text=text,
        normalized_full_text="\n".join(cleaned_lines),
        normalized_lines=tuple(cleaned_lines),
        unique_lines=unique_lines,
        line_counts=dict(counts),
        content_line_count=len(cleaned_lines),
        primary_lines=tuple(dict.fromkeys(primary_lines)),
        translation_lines=tuple(dict.fromkeys(translation_lines)),
        unknown_lines=tuple(dict.fromkeys(unknown_lines)),
        line_roles=tuple(roles),
        split_confidence=confidence,
        split_reason=reason,
    )


def fingerprint_text(normalized: NormalizedLyrics) -> str:
    """Return a text form suitable for exact hashing.

    Repeated adjacent or non-adjacent lyric lines are collapsed so different chorus
    repeat counts do not prevent exact duplicate detection.
    """
    return "\n".join(normalized.primary_lines or normalized.unique_lines)


def lyric_tokens(
    normalized: NormalizedLyrics,
    ngram_size: int = 3,
    *,
    lines: tuple[str, ...] | None = None,
) -> set[str]:
    """Build mixed CJK/Latin n-grams with repeated lines down-weighted."""
    tokens: set[str] = set()
    selected_lines = lines if lines is not None else normalized.unique_lines
    for line in selected_lines:
        units = _token_units(line)
        if len(units) < ngram_size:
            if units:
                tokens.add(" ".join(units))
            continue
        for start in range(len(units) - ngram_size + 1):
            tokens.add(" ".join(units[start : start + ngram_size]))
    return tokens


def _clean_line_entries(raw_line: str, source_index: int) -> list[_LineEntry]:
    timestamp_match = _TIMESTAMP_RE.search(raw_line)
    timestamp = timestamp_match.group(1) if timestamp_match else None
    line = _TIMESTAMP_RE.sub("", raw_line)
    line = _ROLE_PREFIX_RE.sub("", line).strip()
    inline_entries = _split_inline_translation(line, timestamp, source_index)
    if inline_entries:
        return inline_entries
    return _entry_from_text(line, timestamp, source_index)


def _split_inline_translation(line: str, timestamp: str | None, source_index: int) -> list[_LineEntry]:
    parts = [part.strip() for part in _INLINE_SPLIT_RE.split(line, maxsplit=1)]
    if len(parts) != 2:
        return []
    left_entries = _entry_from_text(parts[0], timestamp, source_index)
    right_entries = _entry_from_text(parts[1], timestamp, source_index)
    if not left_entries or not right_entries:
        return []
    left_lang = left_entries[0].language
    right_lang = right_entries[0].language
    if _is_foreign_language(left_lang) and right_lang == "zh":
        return [left_entries[0], right_entries[0]]
    if left_lang == "zh" and _is_foreign_language(right_lang):
        return [right_entries[0], left_entries[0]]
    return []


def _entry_from_text(text: str, timestamp: str | None, source_index: int) -> list[_LineEntry]:
    line = _BRACKET_RE.sub("", text)
    line = line.strip().lower().translate(_TRADITIONAL_TO_SIMPLIFIED)
    if not line or _is_noise_line(line):
        return []
    line = _strip_symbols(line)
    if not line:
        return []
    return [_LineEntry(text=line, timestamp=timestamp, language=_detect_language(line), source_index=source_index)]


def _assign_line_roles(entries: list[_LineEntry]) -> tuple[tuple[str, ...], str, str]:
    if not entries:
        return (), "none", "没有有效歌词行"

    timestamp_roles = _roles_by_same_timestamp(entries)
    if timestamp_roles is not None:
        return timestamp_roles, "high", "同时间戳下检测到外文行和中文行配对"

    inline_roles = _roles_by_inline_translation(entries)
    if inline_roles is not None:
        return inline_roles, "medium", "同一原始行内检测到明显的外文和中文翻译"

    alternating_roles = _roles_by_alternating_translation(entries)
    if alternating_roles is not None:
        return alternating_roles, "high", "检测到稳定的外文行和中文翻译行交替结构"

    block_roles = _roles_by_translation_block(entries)
    if block_roles is not None:
        return block_roles, "low", "检测到疑似原文段落加中文翻译段落,置信度较低"

    return tuple("primary" for _ in entries), "none", "未检测到可分离的翻译结构,全部有效行按原文处理"


def _roles_by_same_timestamp(entries: list[_LineEntry]) -> tuple[str, ...] | None:
    roles = ["unknown"] * len(entries)
    groups: dict[str, list[int]] = {}
    for idx, entry in enumerate(entries):
        if entry.timestamp:
            groups.setdefault(entry.timestamp, []).append(idx)

    paired = 0
    for indexes in groups.values():
        if len(indexes) < 2:
            continue
        foreign = [idx for idx in indexes if _is_foreign_language(entries[idx].language)]
        chinese = [idx for idx in indexes if entries[idx].language == "zh"]
        if not foreign or not chinese:
            continue
        for idx in foreign:
            roles[idx] = "primary"
        for idx in chinese:
            roles[idx] = "translation"
        paired += 1

    if paired == 0:
        return None
    for idx, role in enumerate(roles):
        if role == "unknown":
            roles[idx] = "primary"
    return tuple(roles)


def _roles_by_alternating_translation(entries: list[_LineEntry]) -> tuple[str, ...] | None:
    roles = ["unknown"] * len(entries)
    pairs = 0
    idx = 0
    while idx < len(entries) - 1:
        current = entries[idx]
        nxt = entries[idx + 1]
        if _is_foreign_language(current.language) and nxt.language == "zh":
            roles[idx] = "primary"
            roles[idx + 1] = "translation"
            pairs += 1
            idx += 2
            continue
        idx += 1

    if pairs < 2:
        return None
    assigned = sum(1 for role in roles if role != "unknown")
    if assigned / len(entries) < 0.65:
        return None
    for idx, role in enumerate(roles):
        if role == "unknown":
            roles[idx] = "primary"
    return tuple(roles)


def _roles_by_inline_translation(entries: list[_LineEntry]) -> tuple[str, ...] | None:
    roles = ["primary"] * len(entries)
    pairs = 0
    by_source: dict[int, list[int]] = {}
    for idx, entry in enumerate(entries):
        by_source.setdefault(entry.source_index, []).append(idx)
    for indexes in by_source.values():
        if len(indexes) != 2:
            continue
        first, second = indexes
        if _is_foreign_language(entries[first].language) and entries[second].language == "zh":
            roles[first] = "primary"
            roles[second] = "translation"
            pairs += 1
        elif entries[first].language == "zh" and _is_foreign_language(entries[second].language):
            roles[first] = "translation"
            roles[second] = "primary"
            pairs += 1
    return tuple(roles) if pairs else None


def _roles_by_translation_block(entries: list[_LineEntry]) -> tuple[str, ...] | None:
    if len(entries) < 4:
        return None
    midpoint = len(entries) // 2
    first = entries[:midpoint]
    second = entries[midpoint:]
    first_foreign = sum(1 for entry in first if _is_foreign_language(entry.language))
    second_zh = sum(1 for entry in second if entry.language == "zh")
    if first_foreign / len(first) >= 0.75 and second_zh / len(second) >= 0.75:
        return tuple("primary" if idx < midpoint else "translation" for idx in range(len(entries)))
    return None


def _detect_language(line: str) -> str:
    cjk = len(_CJK_RE.findall(line))
    latin = len(_LATIN_RE.findall(line))
    kana = len(_KANA_RE.findall(line))
    hangul = len(_HANGUL_RE.findall(line))
    if hangul:
        return "kr"
    if kana:
        return "jp"
    if cjk and latin:
        return "mixed"
    if cjk:
        return "zh"
    if latin:
        return "latin"
    return "other"


def _is_foreign_language(language: str) -> bool:
    return language in {"latin", "jp", "kr", "other"}


def _is_noise_line(line: str) -> bool:
    if _CREDIT_PREFIX_RE.search(line) or _WATERMARK_RE.search(line):
        return True
    has_cjk_or_latin = bool(_CJK_RE.search(line) or _LATIN_RE.search(line))
    if not has_cjk_or_latin:
        return True
    compact = _strip_symbols(line)
    return len(compact) <= 1


def _strip_symbols(line: str) -> str:
    punctuation = string.punctuation + ",。!?;:、“”‘’·…—~!¥()【】《》〈〉「」『』﹏"
    line = "".join(" " if char in punctuation else char for char in line)
    line = re.sub(r"\s+", " ", line)
    line = re.sub(r"(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])", "", line)
    return line.strip()


def _token_units(line: str) -> list[str]:
    units: list[str] = []
    for match in _WORD_RE.finditer(line):
        token = match.group(0).lower()
        if _CJK_RE.fullmatch(token):
            units.append(token)
        else:
            units.append(token)
    return units