enrich_songcentric_manifest_with_local_features.py 7.52 KB
#!/usr/bin/env /usr/local/miniconda3/bin/python
from __future__ import annotations

import argparse
import hashlib
import importlib
import json
import wave
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
import sys
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from src.engines.chromaprint_matcher import ChromaprintMatcher, load_audio_mono


def load_jsonl(path: Path):
    for line in path.read_text().splitlines():
        line = line.strip()
        if line:
            yield json.loads(line)


def module_available(name: str) -> bool:
    try:
        importlib.import_module(name)
        return True
    except Exception:
        return False


def semantic_runtime_available() -> tuple[bool, list[str]]:
    required = ['torch', 'torchaudio', 'transformers']
    missing = [m for m in required if not module_available(m)]
    return (len(missing) == 0, missing)


def read_wav_stats(path: Path, start_ms: int, end_ms: int) -> dict:
    with wave.open(str(path), 'rb') as wf:
        rate = wf.getframerate()
        sampwidth = wf.getsampwidth()
        n_channels = wf.getnchannels()
        start_frame = int(start_ms * rate / 1000)
        end_frame = int(end_ms * rate / 1000)
        wf.setpos(min(start_frame, wf.getnframes()))
        frames = wf.readframes(max(end_frame - start_frame, 0))
    digest = hashlib.sha256(frames).hexdigest()
    if sampwidth == 1:
        energy = sum(abs(b - 128) for b in frames[: min(len(frames), 4000)])
    else:
        energy = sum(abs(int.from_bytes(frames[i:i+2], 'little', signed=True)) for i in range(0, min(len(frames), 8000), 2))
    return {'digest': digest, 'energy': energy, 'rate': rate, 'channels': n_channels, 'bytes_read': len(frames)}


def extract_matcher_fingerprint(path: Path, start_ms: int, end_ms: int) -> dict | None:
    try:
        matcher = ChromaprintMatcher(sr=16000)
        y, _ = load_audio_mono(str(path), sr=matcher.sr)
        start = int(start_ms * matcher.sr / 1000)
        end = int(end_ms * matcher.sr / 1000)
        segment = y[start:end]
        hashes = matcher.extract_hashes(segment)
        digest = hashlib.sha256(json.dumps(hashes[:128]).encode('utf-8')).hexdigest()
        return {
            'fingerprint_value': digest[:32],
            'checksum': f'chromaprint:{digest[:16]}',
            'metadata_json': {'hash_count': len(hashes), 'hash_sample': hashes[:8]},
        }
    except Exception:
        return None


def build_semantic_feature(stats: dict, start_ms: int, end_ms: int, runtime_ok: bool, missing: list[str]) -> dict:
    if runtime_ok:
        return {
            'feature_type': 'embedding',
            'model_name': 'semantic_runtime_ready_placeholder',
            'model_version': 'awaiting_real_adapter',
            'feature_set_name': 'semantic_runtime_ready_5s',
            'feature_schema_ver': 'v1',
            'embedding_dim': 8,
            'embedding_uri': f"runtime-ready://{stats['digest'][:16]}:{start_ms}:{end_ms}",
            'vector_table_name': 'audio_embedding_vector_8_placeholder',
            'checksum': f"emb:{stats['digest'][:16]}",
            'metadata_json': {'semantic_backend': 'runtime_ready_placeholder'},
        }
    return {
        'feature_type': 'embedding',
        'model_name': 'local_wavehash_embed',
        'model_version': 'v1',
        'feature_set_name': 'wavehash_embed_5s',
        'feature_schema_ver': 'v1',
        'embedding_dim': 8,
        'embedding_uri': f"inline://{stats['digest'][:16]}:{start_ms}:{end_ms}",
        'vector_table_name': 'audio_embedding_vector_8_placeholder',
        'checksum': f"emb:{stats['digest'][:16]}",
        'metadata_json': {
            'energy': stats['energy'],
            'rate': stats['rate'],
            'channels': stats['channels'],
            'semantic_backend': 'local_fallback',
            'runtime_missing': missing,
        },
    }


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument('--input-manifest', required=True)
    parser.add_argument('--output-manifest', required=True)
    parser.add_argument('--report-output')
    args = parser.parse_args()

    in_path = Path(args.input_manifest).resolve()
    out_path = Path(args.output_manifest).resolve()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    report_path = Path(args.report_output).resolve() if args.report_output else None
    if report_path:
        report_path.parent.mkdir(parents=True, exist_ok=True)

    runtime_ok, missing_runtime = semantic_runtime_available()

    rows = []
    feature_count = 0
    wav_windows_seen = 0
    matcher_fp_count = 0
    fallback_fp_count = 0
    semantic_runtime_ready_count = 0
    semantic_fallback_count = 0

    for row in load_jsonl(in_path):
        asset = row['asset']
        asset_path = Path(asset['storage_uri'])
        for window in row.get('windows', []):
            features = window.setdefault('features', [])
            if asset_path.suffix.lower() == '.wav' and asset_path.exists():
                wav_windows_seen += 1
                stats = read_wav_stats(asset_path, window['start_ms'], window['end_ms'])
                matcher_fp = extract_matcher_fingerprint(asset_path, window['start_ms'], window['end_ms'])
                if matcher_fp is not None:
                    fp = {
                        'feature_type': 'fingerprint',
                        'model_name': 'chromaprint_matcher',
                        'model_version': 'phase1_local',
                        'feature_set_name': 'chromaprint_matcher_5s',
                        'fingerprint_value': matcher_fp['fingerprint_value'],
                        'checksum': matcher_fp['checksum'],
                        'metadata_json': matcher_fp['metadata_json'],
                    }
                    matcher_fp_count += 1
                else:
                    fp = {
                        'feature_type': 'fingerprint',
                        'model_name': 'local_wavehash',
                        'model_version': 'v1',
                        'feature_set_name': 'wavehash_5s',
                        'fingerprint_value': stats['digest'][:32],
                        'checksum': f"fp:{stats['digest'][:16]}",
                        'metadata_json': {'energy': stats['energy'], 'bytes_read': stats['bytes_read']},
                    }
                    fallback_fp_count += 1

                emb = build_semantic_feature(stats, window['start_ms'], window['end_ms'], runtime_ok, missing_runtime)
                if runtime_ok:
                    semantic_runtime_ready_count += 1
                else:
                    semantic_fallback_count += 1

                features.extend([fp, emb])
                feature_count += 2
        rows.append(row)

    out_path.write_text('\n'.join(json.dumps(r, ensure_ascii=False) for r in rows) + ('\n' if rows else ''))
    report = {
        'input_manifest': str(in_path),
        'output_manifest': str(out_path),
        'rows': len(rows),
        'wav_windows_seen': wav_windows_seen,
        'features_added': feature_count,
        'matcher_fingerprint_count': matcher_fp_count,
        'fallback_fingerprint_count': fallback_fp_count,
        'semantic_runtime_available': runtime_ok,
        'semantic_runtime_missing': missing_runtime,
        'semantic_runtime_ready_count': semantic_runtime_ready_count,
        'semantic_fallback_count': semantic_fallback_count,
    }
    if report_path:
        report_path.write_text(json.dumps(report, ensure_ascii=False, indent=2))
    print(json.dumps(report, ensure_ascii=False, indent=2))
    return 0


if __name__ == '__main__':
    raise SystemExit(main())