|
@@ -2,11 +2,12 @@ import json
|
|
|
import os
|
|
|
from dataclasses import astuple, dataclass, asdict
|
|
|
from io import UnsupportedOperation
|
|
|
+from logging import getLogger
|
|
|
from mmap import mmap, PROT_READ, PROT_WRITE
|
|
|
from typing import TypeVar, Generic, Callable, List
|
|
|
|
|
|
import mmh3
|
|
|
-from zstandard import ZstdDecompressor, ZstdCompressor
|
|
|
+from zstandard import ZstdDecompressor, ZstdCompressor, ZstdError
|
|
|
|
|
|
VERSION = 1
|
|
|
METADATA_CONSTANT = b'mwmbl-tiny-search'
|
|
@@ -16,6 +17,9 @@ NUM_PAGES = 5_120_000
|
|
|
PAGE_SIZE = 4096
|
|
|
|
|
|
|
|
|
+logger = getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
@dataclass
|
|
|
class Document:
|
|
|
title: str
|
|
@@ -32,6 +36,10 @@ class TokenizedDocument(Document):
|
|
|
T = TypeVar('T')
|
|
|
|
|
|
|
|
|
+class PageError(Exception):
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
@dataclass
|
|
|
class TinyIndexMetadata:
|
|
|
version: int
|
|
@@ -64,7 +72,7 @@ def _get_page_data(compressor, page_size, data):
|
|
|
def _pad_to_page_size(data: bytes, page_size: int):
|
|
|
page_length = len(data)
|
|
|
if page_length > page_size:
|
|
|
- raise ValueError(f"Data is too big ({page_length}) for page size ({page_size})")
|
|
|
+ raise PageError(f"Data is too big ({page_length}) for page size ({page_size})")
|
|
|
padding = b'\x00' * (page_size - page_length)
|
|
|
page_data = data + padding
|
|
|
return page_data
|
|
@@ -92,6 +100,7 @@ class TinyIndex(Generic[T]):
|
|
|
self.page_size = metadata.page_size
|
|
|
self.compressor = ZstdCompressor()
|
|
|
self.decompressor = ZstdDecompressor()
|
|
|
+ logger.info(f"Loaded index with {self.num_pages} pages and {self.page_size} page size")
|
|
|
self.index_file = None
|
|
|
self.mmap = None
|
|
|
|
|
@@ -107,13 +116,14 @@ class TinyIndex(Generic[T]):
|
|
|
|
|
|
def retrieve(self, key: str) -> List[T]:
|
|
|
index = self.get_key_page_index(key)
|
|
|
+ logger.debug(f"Retrieving index {index}")
|
|
|
return self.get_page(index)
|
|
|
|
|
|
def get_key_page_index(self, key) -> int:
|
|
|
key_hash = mmh3.hash(key, signed=False)
|
|
|
return key_hash % self.num_pages
|
|
|
|
|
|
- def get_page(self, i):
|
|
|
+ def get_page(self, i) -> list[T]:
|
|
|
"""
|
|
|
Get the page at index i, decompress and deserialise it using JSON
|
|
|
"""
|
|
@@ -122,7 +132,12 @@ class TinyIndex(Generic[T]):
|
|
|
|
|
|
def _get_page_tuples(self, i):
|
|
|
page_data = self.mmap[i * self.page_size:(i + 1) * self.page_size]
|
|
|
- decompressed_data = self.decompressor.decompress(page_data)
|
|
|
+ try:
|
|
|
+ decompressed_data = self.decompressor.decompress(page_data)
|
|
|
+ except ZstdError:
|
|
|
+ logger.exception(f"Error decompressing page data, content: {page_data}")
|
|
|
+ return []
|
|
|
+ # logger.debug(f"Decompressed data: {decompressed_data}")
|
|
|
return json.loads(decompressed_data.decode('utf8'))
|
|
|
|
|
|
def index(self, key: str, value: T):
|
|
@@ -131,7 +146,7 @@ class TinyIndex(Generic[T]):
|
|
|
page_index = self.get_key_page_index(key)
|
|
|
try:
|
|
|
self.add_to_page(page_index, [value])
|
|
|
- except ValueError:
|
|
|
+ except PageError:
|
|
|
pass
|
|
|
|
|
|
def add_to_page(self, page_index: int, values: list[T]):
|