minhash_lsh.py 2.06 KB
"""Small in-memory MinHash LSH index for incremental lyric lookup."""

from __future__ import annotations

import hashlib
from collections import defaultdict
from dataclasses import dataclass


_MAX_HASH = (1 << 64) - 1


@dataclass(frozen=True)
class MinHashConfig:
    num_perm: int = 96
    bands: int = 24
    seed: int = 17

    @property
    def rows_per_band(self) -> int:
        if self.num_perm % self.bands != 0:
            raise ValueError("num_perm must be divisible by bands")
        return self.num_perm // self.bands


class MinHashLSH:
    def __init__(self, config: MinHashConfig | None = None) -> None:
        self.config = config or MinHashConfig()
        self._buckets: dict[tuple[int, tuple[int, ...]], set[str]] = defaultdict(set)

    def signature(self, tokens: set[str]) -> tuple[int, ...]:
        if not tokens:
            return tuple([_MAX_HASH] * self.config.num_perm)

        signature = [_MAX_HASH] * self.config.num_perm
        for token in tokens:
            encoded = token.encode("utf-8")
            for idx in range(self.config.num_perm):
                digest = hashlib.blake2b(
                    encoded,
                    digest_size=8,
                    person=f"lyr{self.config.seed + idx:05d}".encode("ascii")[:16],
                ).digest()
                value = int.from_bytes(digest, "big")
                if value < signature[idx]:
                    signature[idx] = value
        return tuple(signature)

    def add(self, record_id: str, signature: tuple[int, ...]) -> None:
        for key in self._band_keys(signature):
            self._buckets[key].add(record_id)

    def query(self, signature: tuple[int, ...]) -> set[str]:
        candidates: set[str] = set()
        for key in self._band_keys(signature):
            candidates.update(self._buckets.get(key, set()))
        return candidates

    def _band_keys(self, signature: tuple[int, ...]) -> list[tuple[int, tuple[int, ...]]]:
        rows = self.config.rows_per_band
        return [(band, signature[band * rows : (band + 1) * rows]) for band in range(self.config.bands)]