smoke_songcentric_schema_live.py
5.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#!/usr/bin/env /usr/local/miniconda3/bin/python
from __future__ import annotations
import argparse
import json
from pathlib import Path
import psycopg
from psycopg.rows import dict_row
def quote_ident(name: str) -> str:
return '"' + name.replace('"', '""') + '"'
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument('--dsn', required=True)
parser.add_argument('--schema', default='acr_songcentric_test')
parser.add_argument('--sql', default='acr-engine/sql/acr_pg_schema_songcentric_v1.sql')
parser.add_argument('--output', default='acr-engine/data/pgvector_eval/music20/songcentric_schema_smoke_report.json')
args = parser.parse_args()
sql_path = Path('/workspace') / args.sql
output_path = Path('/workspace') / args.output
output_path.parent.mkdir(parents=True, exist_ok=True)
schema = args.schema
qschema = quote_ident(schema)
report: dict = {'schema': schema, 'sql_path': str(sql_path.relative_to('/workspace'))}
with psycopg.connect(args.dsn, row_factory=dict_row) as conn:
conn.execute(f'drop schema if exists {qschema} cascade')
conn.execute(f'create schema {qschema}')
conn.execute(f'set search_path to {qschema}, public')
conn.execute(sql_path.read_text())
song_id = conn.execute(
"""
insert into media_entity (entity_type, biz_key, title, artist_name)
values ('song', 'song-9001', 'Smoke Song', 'Smoke Artist')
returning entity_id
"""
).fetchone()['entity_id']
asset_id = conn.execute(
"""
insert into audio_object (
object_type, song_id, source_type, storage_uri, storage_scheme,
checksum, codec, sample_rate, channels, duration_ms
) values (
'asset', %s, 'official', 's3://bucket/smoke-song.wav', 's3',
'sha256:smoke-asset', 'wav', 44100, 2, 180000
) returning object_id
""",
(song_id,),
).fetchone()['object_id']
window_id = conn.execute(
"""
insert into audio_object (
object_type, song_id, parent_object_id, start_ms, end_ms, duration_ms
) values ('window', %s, %s, 30000, 35000, 5000)
returning object_id
""",
(song_id, asset_id),
).fetchone()['object_id']
fingerprint_id = conn.execute(
"""
insert into feature_fact (
feature_type, object_id, song_id, model_name, model_version,
feature_set_name, fingerprint_value
) values (
'fingerprint', %s, %s, 'chromaprint', 'phase1', 'chromaprint_5s', 'fp-smoke'
) returning feature_id
""",
(window_id, song_id),
).fetchone()['feature_id']
embedding_id = conn.execute(
"""
insert into feature_fact (
feature_type, object_id, song_id, model_name, model_version,
feature_set_name, embedding_dim, embedding_uri, vector_table_name
) values (
'embedding', %s, %s, 'mert', 'v1-95m',
'mert_5s_hop2.5_meanpool', 768, 's3://bucket/smoke-song-win.npy', 'audio_embedding_vector_768'
) returning feature_id
""",
(window_id, song_id),
).fetchone()['feature_id']
membership_id = conn.execute(
"""
insert into set_membership (
set_type, set_name, member_type, member_id, song_id, priority
) values (
'reference_set', 'phase1_hot_reference_v1', 'asset', %s, %s, 100
) returning membership_id
""",
(asset_id, song_id),
).fetchone()['membership_id']
lineage = conn.execute(
"""
select ff.feature_id,
ff.feature_type,
ff.model_name,
ff.model_version,
ff.feature_set_name,
win.object_id as window_id,
ast.object_id as asset_id,
song.entity_id as song_id,
song.title,
song.artist_name
from feature_fact ff
join audio_object win
on win.object_id = ff.object_id
and win.object_type = 'window'
join audio_object ast
on ast.object_id = win.parent_object_id
and ast.object_type = 'asset'
join media_entity song
on song.entity_id = ff.song_id
and song.entity_type = 'song'
where ff.feature_id = %s
""",
(embedding_id,),
).fetchone()
counts = {}
for table in ['media_entity', 'audio_object', 'feature_fact', 'set_membership']:
counts[table] = conn.execute(f'select count(*) as c from {table}').fetchone()['c']
report.update(
inserted={
'song_id': song_id,
'asset_id': asset_id,
'window_id': window_id,
'fingerprint_feature_id': fingerprint_id,
'embedding_feature_id': embedding_id,
'membership_id': membership_id,
},
counts=counts,
embedding_lineage=lineage,
)
conn.commit()
output_path.write_text(json.dumps(report, ensure_ascii=False, indent=2))
print(json.dumps(report, ensure_ascii=False, indent=2))
return 0
if __name__ == '__main__':
raise SystemExit(main())