export_workspace_music20_embeddings_jsonl.py
2.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env /usr/local/miniconda3/bin/python
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from scripts.local_music20_acr import REFERENCE_TYPE, SUPPORTED_QUERY_TYPES, embed_chroma, first_file
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument('--downloads-dir', default='/workspace/downloads')
ap.add_argument('--song-limit', type=int, default=20)
ap.add_argument('--duration', type=float, default=8.0)
ap.add_argument('--sr', type=int, default=22050)
ap.add_argument('--out-dir', default='data/pgvector_eval/music20')
return ap.parse_args()
def main():
args = parse_args()
downloads_dir = Path(args.downloads_dir)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
ref_path = out_dir / 'reference_embeddings.jsonl'
qry_path = out_dir / 'query_embeddings.jsonl'
ref_f = ref_path.open('w', encoding='utf-8')
qry_f = qry_path.open('w', encoding='utf-8')
ref_count = qry_count = 0
refs_seen = set()
for song_dir in sorted(p for p in downloads_dir.iterdir() if p.is_dir()):
ref = first_file(song_dir / f'type_{REFERENCE_TYPE}')
if ref and song_dir.name not in refs_seen and len(refs_seen) < args.song_limit:
row = {
'song_id': song_dir.name,
'audio_path': str(ref),
'type': 'reference',
'embedding': embed_chroma(str(ref), args.sr, args.duration).tolist(),
}
ref_f.write(json.dumps(row, ensure_ascii=False) + '\n')
ref_count += 1
refs_seen.add(song_dir.name)
for query_type in SUPPORTED_QUERY_TYPES:
kept = 0
for song_id in sorted(refs_seen):
song_dir = downloads_dir / song_id
qry = first_file(song_dir / f'type_{query_type}')
if not qry:
continue
row = {
'song_id': song_id,
'audio_path': str(qry),
'query_type': query_type,
'embedding': embed_chroma(str(qry), args.sr, args.duration).tolist(),
}
qry_f.write(json.dumps(row, ensure_ascii=False) + '\n')
qry_count += 1
kept += 1
print(f'query_type={query_type} rows={kept}')
ref_f.close()
qry_f.close()
print(json.dumps({'reference_rows': ref_count, 'query_rows': qry_count, 'out_dir': str(out_dir.resolve())}, ensure_ascii=False, indent=2))
if __name__ == '__main__':
main()