123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- import json
- import os
- from dataclasses import dataclass, asdict, field
- from enum import IntEnum
- from io import UnsupportedOperation
- from logging import getLogger
- from mmap import mmap, PROT_READ, PROT_WRITE
- from typing import TypeVar, Generic, Callable, List, Optional
- import mmh3
- from zstandard import ZstdDecompressor, ZstdCompressor, ZstdError
- VERSION = 1
- METADATA_CONSTANT = b'mwmbl-tiny-search'
- METADATA_SIZE = 4096
- PAGE_SIZE = 4096
- logger = getLogger(__name__)
- def astuple(dc):
- """
- Convert a type to a tuple - values at the end that are None can be truncated.
- """
- value = tuple(dc.__dict__.values())
- while value[-1] is None:
- value = value[:-1]
- return value
- class DocumentState(IntEnum):
- CURATED = 1
- @dataclass
- class Document:
- title: str
- url: str
- extract: str
- score: float
- term: Optional[str] = None
- state: Optional[int] = None
- @dataclass
- class TokenizedDocument(Document):
- tokens: List[str] = field(default_factory=list)
- T = TypeVar('T')
- class PageError(Exception):
- pass
- @dataclass
- class TinyIndexMetadata:
- version: int
- page_size: int
- num_pages: int
- item_factory: str
- def to_bytes(self) -> bytes:
- metadata_bytes = METADATA_CONSTANT + json.dumps(asdict(self)).encode('utf8')
- assert len(metadata_bytes) <= METADATA_SIZE
- return metadata_bytes
- @staticmethod
- def from_bytes(data: bytes):
- constant_length = len(METADATA_CONSTANT)
- metadata_constant = data[:constant_length]
- if metadata_constant != METADATA_CONSTANT:
- raise ValueError("This doesn't seem to be an index file")
- values = json.loads(data[constant_length:].decode('utf8'))
- return TinyIndexMetadata(**values)
- # Find the optimal amount of data that fits onto a page
- # We do this by leveraging binary search to quickly find the index where:
- # - index+1 cannot fit onto a page
- # - <=index can fit on a page
- def _binary_search_fitting_size(compressor: ZstdCompressor, page_size: int, items:list[T], lo:int, hi:int):
- # Base case: our binary search has gone too far
- if lo > hi:
- return -1, None
- # Check the midpoint to see if it will fit onto a page
- mid = (lo+hi)//2
- compressed_data = compressor.compress(json.dumps(items[:mid]).encode('utf8'))
- size = len(compressed_data)
- if size > page_size:
- # We cannot fit this much data into a page
- # Reduce the hi boundary, and try again
- return _binary_search_fitting_size(compressor, page_size, items, lo, mid-1)
- else:
- # We can fit this data into a page, but maybe we can fit more data
- # Try to see if we have a better match
- potential_target, potential_data = _binary_search_fitting_size(compressor, page_size, items, mid+1, hi)
- if potential_target != -1:
- # We found a larger index that can still fit onto a page, so use that
- return potential_target, potential_data
- else:
- # No better match, use our index
- return mid, compressed_data
- def _trim_items_to_page(compressor: ZstdCompressor, page_size: int, items:list[T]):
- # Find max number of items that fit on a page
- return _binary_search_fitting_size(compressor, page_size, items, 0, len(items))
- def _get_page_data(compressor: ZstdCompressor, page_size: int, items: list[T]):
- num_fitting, serialised_data = _trim_items_to_page(compressor, page_size, items)
- compressed_data = compressor.compress(json.dumps(items[:num_fitting]).encode('utf8'))
- assert len(compressed_data) <= page_size, "The data shouldn't get bigger"
- return _pad_to_page_size(compressed_data, page_size)
- def _pad_to_page_size(data: bytes, page_size: int):
- page_length = len(data)
- if page_length > page_size:
- raise PageError(f"Data is too big ({page_length}) for page size ({page_size})")
- padding = b'\x00' * (page_size - page_length)
- page_data = data + padding
- return page_data
- class TinyIndex(Generic[T]):
- def __init__(self, item_factory: Callable[..., T], index_path, mode='r'):
- if mode not in {'r', 'w'}:
- raise ValueError(f"Mode should be one of 'r' or 'w', got {mode}")
- with open(index_path, 'rb') as index_file:
- metadata_page = index_file.read(METADATA_SIZE)
- metadata_bytes = metadata_page.rstrip(b'\x00')
- metadata = TinyIndexMetadata.from_bytes(metadata_bytes)
- if metadata.item_factory != item_factory.__name__:
- raise ValueError(f"Metadata item factory '{metadata.item_factory}' in the index "
- f"does not match the passed item factory: '{item_factory.__name__}'")
- self.item_factory = item_factory
- self.index_path = index_path
- self.mode = mode
- self.num_pages = metadata.num_pages
- self.page_size = metadata.page_size
- self.compressor = ZstdCompressor()
- self.decompressor = ZstdDecompressor()
- logger.info(f"Loaded index with {self.num_pages} pages and {self.page_size} page size")
- self.index_file = None
- self.mmap = None
- def __enter__(self):
- self.index_file = open(self.index_path, 'r+b')
- prot = PROT_READ if self.mode == 'r' else PROT_READ | PROT_WRITE
- self.mmap = mmap(self.index_file.fileno(), 0, prot=prot)
- return self
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.mmap.close()
- self.index_file.close()
- def retrieve(self, key: str) -> List[T]:
- index = self.get_key_page_index(key)
- logger.debug(f"Retrieving index {index}")
- return self.get_page(index)
- def get_key_page_index(self, key) -> int:
- key_hash = mmh3.hash(key, signed=False)
- return key_hash % self.num_pages
- def get_page(self, i) -> list[T]:
- """
- Get the page at index i, decompress and deserialise it using JSON
- """
- results = self._get_page_tuples(i)
- return [self.item_factory(*item) for item in results]
- def _get_page_tuples(self, i):
- page_data = self.mmap[i * self.page_size + METADATA_SIZE:(i + 1) * self.page_size + METADATA_SIZE]
- try:
- decompressed_data = self.decompressor.decompress(page_data)
- except ZstdError:
- logger.exception(f"Error decompressing page data, content: {page_data}")
- return []
- # logger.debug(f"Decompressed data: {decompressed_data}")
- return json.loads(decompressed_data.decode('utf8'))
- def store_in_page(self, page_index: int, values: list[T]):
- value_tuples = [astuple(value) for value in values]
- self._write_page(value_tuples, page_index)
- def _write_page(self, data, i: int):
- """
- Serialise the data using JSON, compress it and store it at index i.
- If the data is too big, it will store the first items in the list and discard the rest.
- """
- if self.mode != 'w':
- raise UnsupportedOperation("The file is open in read mode, you cannot write")
- page_data = _get_page_data(self.compressor, self.page_size, data)
- logger.debug(f"Got page data of length {len(page_data)}")
- self.mmap[i * self.page_size + METADATA_SIZE:(i+1) * self.page_size + METADATA_SIZE] = page_data
- @staticmethod
- def create(item_factory: Callable[..., T], index_path: str, num_pages: int, page_size: int):
- if os.path.isfile(index_path):
- raise FileExistsError(f"Index file '{index_path}' already exists")
- metadata = TinyIndexMetadata(VERSION, page_size, num_pages, item_factory.__name__)
- metadata_bytes = metadata.to_bytes()
- metadata_padded = _pad_to_page_size(metadata_bytes, METADATA_SIZE)
- compressor = ZstdCompressor()
- page_bytes = _get_page_data(compressor, page_size, [])
- with open(index_path, 'wb') as index_file:
- index_file.write(metadata_padded)
- for i in range(num_pages):
- index_file.write(page_bytes)
- return TinyIndex(item_factory, index_path=index_path)
|