normalize_business_export.py 4.11 KB
#!/usr/bin/env python3
from __future__ import annotations

import argparse
import csv
import json
from pathlib import Path
from typing import Iterable


def load_rows(path: Path) -> list[dict]:
    suffix = path.suffix.lower()
    if suffix == '.csv':
        with path.open(newline='') as f:
            return list(csv.DictReader(f))
    if suffix == '.jsonl':
        return [json.loads(line) for line in path.read_text().splitlines() if line.strip()]
    raise ValueError(f'unsupported input format: {path}')


def load_mapping(path: Path) -> dict[int, dict]:
    data = json.loads(path.read_text())
    return {int(item['type']): item for item in data['mappings']}


def parse_bool(value):
    if isinstance(value, bool):
        return value
    if value is None:
        return None
    s = str(value).strip().lower()
    if s in {'true', '1', 'yes'}:
        return True
    if s in {'false', '0', 'no'}:
        return False
    return None


def parse_float(value):
    if value in (None, ''):
        return None
    try:
        return float(value)
    except ValueError:
        return None


def normalize_row(row: dict, mapping: dict[int, dict], source_dataset: str, default_split: str) -> dict:
    row = dict(row)
    asset_type = int(row['type'])
    rule = mapping.get(asset_type, {'role': 'excluded', 'default_bucket': 'unknown', 'trainable': False})
    normalized = {
        'song_id': row['song_id'],
        'asset_id': row['asset_id'],
        'type': asset_type,
        'role': row.get('role') or rule['role'],
        'split': row.get('split') or default_split,
        'audio_path': row['audio_path'],
        'source_dataset': row.get('source_dataset') or source_dataset,
        'title': row.get('title'),
        'artist': row.get('artist'),
        'album_id': row.get('album_id'),
        'bucket': row.get('bucket') or rule.get('default_bucket'),
        'offset_sec': parse_float(row.get('offset_sec')),
        'duration_sec': parse_float(row.get('duration_sec')),
        'sample_rate': int(row['sample_rate']) if row.get('sample_rate') not in (None, '') else None,
        'bitrate': int(row['bitrate']) if row.get('bitrate') not in (None, '') else None,
        'license': row.get('license'),
        'is_lossless': parse_bool(row.get('is_lossless')),
        'trainable': bool(rule.get('trainable', False)),
    }
    return normalized


def emit_jsonl(rows: Iterable[dict], output: Path) -> None:
    output.parent.mkdir(parents=True, exist_ok=True)
    with output.open('w') as f:
        for row in rows:
            f.write(json.dumps(row, ensure_ascii=False) + '\n')


def main() -> None:
    parser = argparse.ArgumentParser(description='Normalize business CSV/JSONL export into manifest-ready JSONL rows')
    parser.add_argument('--input', required=True, help='Input CSV or JSONL export')
    parser.add_argument('--mapping', default='configs/manifests/business_type_role_mapping.json')
    parser.add_argument('--source-dataset', default='internal_catalog')
    parser.add_argument('--default-split', default='holdout')
    parser.add_argument('--output', required=True, help='Output JSONL path')
    args = parser.parse_args()

    repo = Path(__file__).resolve().parents[1]
    input_path = Path(args.input)
    if not input_path.is_absolute():
        input_path = (repo / input_path).resolve()
    mapping_path = Path(args.mapping)
    if not mapping_path.is_absolute():
        mapping_path = (repo / mapping_path).resolve()
    output_path = Path(args.output)
    if not output_path.is_absolute():
        output_path = (repo / output_path).resolve()

    rows = load_rows(input_path)
    mapping = load_mapping(mapping_path)
    normalized = [normalize_row(row, mapping, args.source_dataset, args.default_split) for row in rows]
    emit_jsonl(normalized, output_path)
    summary = {
        'input_rows': len(rows),
        'output_rows': len(normalized),
        'output': str(output_path),
        'roles': sorted({row['role'] for row in normalized}),
        'buckets': sorted({row['bucket'] for row in normalized if row.get('bucket')}),
    }
    print(json.dumps(summary, ensure_ascii=False, indent=2))


if __name__ == '__main__':
    main()