import_audio_composition.py
4.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""批量导入音频文件到 composition_feature 表。
用法:
python scripts/import_audio_composition.py \
--dsn "postgresql:///lyric_dedup" \
--audio-dir /Volumes/移动硬盘/composition_test \
--ext .wav
支持通过 --file-list 指定一个包含音频路径的文本文件(每行一个路径)。
"""
import argparse
import logging
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from dotenv import load_dotenv
load_dotenv(Path(__file__).resolve().parent.parent / ".env")
from tqdm import tqdm
from composition_dedup.service import CompositionConfig, CompositionDedupService
logger = logging.getLogger(__name__)
SUPPORTED_EXTENSIONS = {".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac", ".wma"}
def discover_audio_files(audio_dir: str | None, file_list: str | None, ext: str) -> list[tuple[str, str]]:
"""发现音频文件,返回 [(song_id, audio_path), ...] 列表。
优先使用 --file-list,否则扫描 --audio-dir 目录。
song_id 使用文件名的数字部分或路径的哈希值。
"""
results = []
if file_list:
with open(file_list, "r", encoding="utf-8") as f:
for line in f:
path = line.strip()
if not path:
continue
song_id = _extract_song_id(path)
results.append((song_id, path))
elif audio_dir:
audio_dir_path = Path(audio_dir)
for audio_file in sorted(audio_dir_path.rglob(f"*{ext}")):
if audio_file.is_file() and not audio_file.name.startswith("._"):
song_id = _extract_song_id(str(audio_file))
results.append((song_id, str(audio_file)))
else:
print("错误: 请指定 --audio-dir 或 --file-list")
sys.exit(1)
return results
def _extract_song_id(path: str) -> str:
"""从路径中提取 song_id。
优先取文件名第一段(下划线前),若为纯数字则使用,否则用路径哈希。
"""
name = Path(path).stem
prefix = name.split("_")[0]
if prefix.isdigit():
return prefix
import hashlib
return str(int(hashlib.md5(path.encode()).hexdigest()[:8], 16))
def main() -> None:
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
parser = argparse.ArgumentParser(description="批量导入音频文件到 composition_feature 表")
parser.add_argument("--dsn", required=True, help="PostgreSQL DSN 连接串")
parser.add_argument("--audio-dir", help="音频文件目录")
parser.add_argument("--file-list", help="音频文件路径列表文件")
parser.add_argument("--ext", default=".wav", help="音频文件扩展名(默认 .wav)")
parser.add_argument("--batch-size", type=int, default=10, help="批次大小(默认 10)")
parser.add_argument("--clear", action="store_true", help="导入前清空 composition_feature 和 dejavu_fingerprints 表数据(保留表结构)")
args = parser.parse_args()
config = CompositionConfig(dsn=args.dsn)
service = CompositionDedupService(config=config)
if args.clear:
import psycopg
with psycopg.connect(args.dsn) as conn:
with conn.cursor() as cur:
cur.execute("TRUNCATE TABLE composition_feature, dejavu_fingerprints")
conn.commit()
logger.info("已清空 composition_feature 和 dejavu_fingerprints 表")
audio_files = discover_audio_files(args.audio_dir, args.file_list, args.ext)
logger.info("发现 %d 个音频文件", len(audio_files))
success_count = 0
fail_count = 0
for start in tqdm(range(0, len(audio_files), args.batch_size), desc="导入进度"):
batch = audio_files[start:start + args.batch_size]
for song_id, audio_path in batch:
try:
service.ingest(song_id=int(song_id), audio_path=audio_path)
success_count += 1
except Exception as e:
logger.error("导入失败: song_id=%s, path=%s, error=%s", song_id, audio_path, e)
fail_count += 1
logger.info("导入完成: 成功 %d, 失败 %d", success_count, fail_count)
if __name__ == "__main__":
main()