_job_common.py 7.67 KB
from __future__ import annotations

import json
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import psycopg

SCHEMA_RE = re.compile(r'^[A-Za-z_][A-Za-z0-9_]*$')


@dataclass
class JobContext:
    extraction_job_id: int
    feature_set_id: int
    target_scope: str
    job_status: str
    shard_key: str | None
    job_metadata: dict[str, Any]
    feature_name: str
    feature_level: str
    extraction_granularity: str
    window_sec: float | None
    hop_sec: float | None
    embedding_dim: int | None
    distance_metric: str
    feature_config: dict[str, Any]
    model_id: int
    model_name: str
    model_version: str
    model_family: str
    input_sample_rate: int | None
    output_embedding_dim: int | None
    model_metadata: dict[str, Any]


def require_env(name: str, default: str | None = None) -> str:
    value = os.environ.get(name, default)
    if value is None or value == '':
        raise SystemExit(f'missing required env: {name}')
    return value


def validate_schema(schema: str) -> str:
    if not SCHEMA_RE.match(schema):
        raise SystemExit(f'invalid schema name: {schema}')
    return schema


def ensure_output_parent(path: str | None) -> Path | None:
    if not path:
        return None
    output = Path(path)
    output.parent.mkdir(parents=True, exist_ok=True)
    return output


def connect(dsn: str, schema: str, *, autocommit: bool = True) -> psycopg.Connection:
    conn = psycopg.connect(dsn, autocommit=autocommit)
    conn.execute(f'SET search_path TO {validate_schema(schema)}, public;')
    return conn


def fetch_job_context(conn: psycopg.Connection, extraction_job_id: int) -> JobContext:
    row = conn.execute(
        """
        SELECT
            fej.extraction_job_id,
            fej.feature_set_id,
            fej.target_scope,
            fej.job_status,
            fej.shard_key,
            fej.metadata_json,
            fs.feature_name,
            fs.feature_level,
            fs.extraction_granularity,
            fs.window_sec,
            fs.hop_sec,
            fs.embedding_dim,
            fs.distance_metric,
            fs.config_json,
            mr.model_id,
            mr.model_name,
            mr.model_version,
            mr.model_family,
            mr.input_sample_rate,
            mr.output_embedding_dim,
            mr.metadata_json
        FROM feature_extraction_job fej
        JOIN feature_set_registry fs ON fs.feature_set_id = fej.feature_set_id
        JOIN model_registry mr ON mr.model_id = fs.model_id
        WHERE fej.extraction_job_id = %s
        LIMIT 1;
        """,
        (extraction_job_id,),
    ).fetchone()
    if not row:
        raise SystemExit(f'feature_extraction_job not found: {extraction_job_id}')
    return JobContext(
        extraction_job_id=int(row[0]),
        feature_set_id=int(row[1]),
        target_scope=row[2],
        job_status=row[3],
        shard_key=row[4],
        job_metadata=row[5] or {},
        feature_name=row[6],
        feature_level=row[7],
        extraction_granularity=row[8],
        window_sec=float(row[9]) if row[9] is not None else None,
        hop_sec=float(row[10]) if row[10] is not None else None,
        embedding_dim=int(row[11]) if row[11] is not None else None,
        distance_metric=row[12],
        feature_config=row[13] or {},
        model_id=int(row[14]),
        model_name=row[15],
        model_version=row[16],
        model_family=row[17],
        input_sample_rate=int(row[18]) if row[18] is not None else None,
        output_embedding_dim=int(row[19]) if row[19] is not None else None,
        model_metadata=row[20] or {},
    )


def parse_target_scope(target_scope: str) -> tuple[str, str]:
    if ':' in target_scope:
        scope_type, scope_value = target_scope.split(':', 1)
        return scope_type, scope_value
    return 'unknown', target_scope


def resolve_scope_summary(conn: psycopg.Connection, target_scope: str) -> dict[str, Any]:
    scope_type, scope_value = parse_target_scope(target_scope)
    if scope_type == 'reference_set':
        row = conn.execute(
            """
            SELECT
                rs.reference_set_id,
                rs.set_name,
                count(DISTINCT rsm.recording_id) AS recording_count,
                count(DISTINCT ra.asset_id) FILTER (WHERE ra.ingest_status = 'ready') AS ready_asset_count,
                count(DISTINCT aw.window_id) FILTER (WHERE aw.active_for_index) AS active_window_count
            FROM reference_set_registry rs
            LEFT JOIN reference_set_member rsm ON rsm.reference_set_id = rs.reference_set_id
            LEFT JOIN recording_asset ra ON ra.recording_id = rsm.recording_id
            LEFT JOIN audio_window aw ON aw.recording_id = rsm.recording_id
            WHERE rs.set_name = %s
            GROUP BY rs.reference_set_id, rs.set_name
            LIMIT 1;
            """,
            (scope_value,),
        ).fetchone()
        if not row:
            raise SystemExit(f'reference set not found for target_scope={target_scope}')
        return {
            'scope_type': scope_type,
            'scope_value': scope_value,
            'reference_set_id': int(row[0]),
            'reference_set_name': row[1],
            'recording_count': int(row[2]),
            'ready_asset_count': int(row[3]),
            'active_window_count': int(row[4]),
        }
    return {
        'scope_type': scope_type,
        'scope_value': scope_value,
        'recording_count': 0,
        'ready_asset_count': 0,
        'active_window_count': 0,
    }


def update_job_status(
    conn: psycopg.Connection,
    extraction_job_id: int,
    *,
    status: str,
    input_count: int | None = None,
    output_count: int | None = None,
    log_uri: str | None = None,
    metadata_patch: dict[str, Any] | None = None,
    set_started_at: bool = False,
    set_finished_at: bool = False,
) -> dict[str, Any]:
    patch = json.dumps(metadata_patch or {}, ensure_ascii=False)
    row = conn.execute(
        """
        UPDATE feature_extraction_job
        SET job_status = %s,
            input_count = COALESCE(%s, input_count),
            output_count = COALESCE(%s, output_count),
            log_uri = COALESCE(%s, log_uri),
            started_at = CASE
                WHEN %s THEN COALESCE(started_at, NOW())
                ELSE started_at
            END,
            finished_at = CASE
                WHEN %s THEN NOW()
                ELSE finished_at
            END,
            metadata_json = COALESCE(metadata_json, '{}'::jsonb) || %s::jsonb
        WHERE extraction_job_id = %s
        RETURNING extraction_job_id, job_status, input_count, output_count, started_at, finished_at, log_uri, metadata_json;
        """,
        (
            status,
            input_count,
            output_count,
            log_uri,
            set_started_at,
            set_finished_at,
            patch,
            extraction_job_id,
        ),
    ).fetchone()
    if not row:
        raise SystemExit(f'failed to update feature_extraction_job={extraction_job_id}')
    return {
        'extraction_job_id': int(row[0]),
        'job_status': row[1],
        'input_count': int(row[2]) if row[2] is not None else None,
        'output_count': int(row[3]) if row[3] is not None else None,
        'started_at': row[4].isoformat() if row[4] is not None else None,
        'finished_at': row[5].isoformat() if row[5] is not None else None,
        'log_uri': row[6],
        'metadata_json': row[7] or {},
    }


def emit_payload(payload: dict[str, Any], output: str | None) -> None:
    text = json.dumps(payload, ensure_ascii=False, indent=2)
    if output:
        target = ensure_output_parent(output)
        assert target is not None
        target.write_text(text, encoding='utf-8')
    print(text)