minhash_lsh.py
2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""Small in-memory MinHash LSH index for incremental lyric lookup."""
from __future__ import annotations
import hashlib
from collections import defaultdict
from dataclasses import dataclass
_MAX_HASH = (1 << 64) - 1
@dataclass(frozen=True)
class MinHashConfig:
num_perm: int = 96
bands: int = 24
seed: int = 17
@property
def rows_per_band(self) -> int:
if self.num_perm % self.bands != 0:
raise ValueError("num_perm must be divisible by bands")
return self.num_perm // self.bands
class MinHashLSH:
def __init__(self, config: MinHashConfig | None = None) -> None:
self.config = config or MinHashConfig()
self._buckets: dict[tuple[int, tuple[int, ...]], set[str]] = defaultdict(set)
def signature(self, tokens: set[str]) -> tuple[int, ...]:
if not tokens:
return tuple([_MAX_HASH] * self.config.num_perm)
signature = [_MAX_HASH] * self.config.num_perm
for token in tokens:
encoded = token.encode("utf-8")
for idx in range(self.config.num_perm):
digest = hashlib.blake2b(
encoded,
digest_size=8,
person=f"lyr{self.config.seed + idx:05d}".encode("ascii")[:16],
).digest()
value = int.from_bytes(digest, "big")
if value < signature[idx]:
signature[idx] = value
return tuple(signature)
def add(self, record_id: str, signature: tuple[int, ...]) -> None:
for key in self._band_keys(signature):
self._buckets[key].add(record_id)
def query(self, signature: tuple[int, ...]) -> set[str]:
candidates: set[str] = set()
for key in self._band_keys(signature):
candidates.update(self._buckets.get(key, set()))
return candidates
def _band_keys(self, signature: tuple[int, ...]) -> list[tuple[int, tuple[int, ...]]]:
rows = self.config.rows_per_band
return [(band, signature[band * rows : (band + 1) * rows]) for band in range(self.config.bands)]