import_audio_composition.py 4.09 KB
"""批量导入音频文件到 composition_feature 表。

用法:
python scripts/import_audio_composition.py \
    --dsn "postgresql:///lyric_dedup" \
    --audio-dir /Volumes/移动硬盘/composition_test \
    --ext .wav

支持通过 --file-list 指定一个包含音频路径的文本文件(每行一个路径)。
"""

import argparse
import logging
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from tqdm import tqdm

from composition_dedup.service import CompositionConfig, CompositionDedupService

logger = logging.getLogger(__name__)

SUPPORTED_EXTENSIONS = {".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac", ".wma"}


def discover_audio_files(audio_dir: str | None, file_list: str | None, ext: str) -> list[tuple[str, str]]:
    """发现音频文件,返回 [(song_id, audio_path), ...] 列表。

    优先使用 --file-list,否则扫描 --audio-dir 目录。
    song_id 使用文件名的数字部分或路径的哈希值。
    """
    results = []

    if file_list:
        with open(file_list, "r", encoding="utf-8") as f:
            for line in f:
                path = line.strip()
                if not path:
                    continue
                song_id = _extract_song_id(path)
                results.append((song_id, path))
    elif audio_dir:
        audio_dir_path = Path(audio_dir)
        for audio_file in sorted(audio_dir_path.rglob(f"*{ext}")):
            if audio_file.is_file() and not audio_file.name.startswith("._"):
                song_id = _extract_song_id(str(audio_file))
                results.append((song_id, str(audio_file)))
    else:
        print("错误: 请指定 --audio-dir 或 --file-list")
        sys.exit(1)

    return results


def _extract_song_id(path: str) -> str:
    """从路径中提取 song_id。
    优先取文件名第一段(下划线前),若为纯数字则使用,否则用路径哈希。
    """
    name = Path(path).stem
    prefix = name.split("_")[0]
    if prefix.isdigit():
        return prefix
    import hashlib
    return str(int(hashlib.md5(path.encode()).hexdigest()[:8], 16))


def main() -> None:
    logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")

    parser = argparse.ArgumentParser(description="批量导入音频文件到 composition_feature 表")
    parser.add_argument("--dsn", required=True, help="PostgreSQL DSN 连接串")
    parser.add_argument("--audio-dir", help="音频文件目录")
    parser.add_argument("--file-list", help="音频文件路径列表文件")
    parser.add_argument("--ext", default=".wav", help="音频文件扩展名(默认 .wav)")
    parser.add_argument("--batch-size", type=int, default=10, help="批次大小(默认 10)")
    parser.add_argument("--clear", action="store_true", help="导入前清空 composition_feature 和 dejavu_fingerprints 表数据(保留表结构)")
    args = parser.parse_args()

    config = CompositionConfig(dsn=args.dsn)
    service = CompositionDedupService(config=config)

    if args.clear:
        import psycopg
        with psycopg.connect(args.dsn) as conn:
            with conn.cursor() as cur:
                cur.execute("TRUNCATE TABLE composition_feature, dejavu_fingerprints")
            conn.commit()
        logger.info("已清空 composition_feature 和 dejavu_fingerprints 表")

    audio_files = discover_audio_files(args.audio_dir, args.file_list, args.ext)
    logger.info("发现 %d 个音频文件", len(audio_files))

    success_count = 0
    fail_count = 0

    for start in tqdm(range(0, len(audio_files), args.batch_size), desc="导入进度"):
        batch = audio_files[start:start + args.batch_size]
        for song_id, audio_path in batch:
            try:
                service.ingest(song_id=int(song_id), audio_path=audio_path)
                success_count += 1
            except Exception as e:
                logger.error("导入失败: song_id=%s, path=%s, error=%s", song_id, audio_path, e)
                fail_count += 1

    logger.info("导入完成: 成功 %d, 失败 %d", success_count, fail_count)


if __name__ == "__main__":
    main()