indexer.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import json
  2. import os
  3. from dataclasses import dataclass, asdict, field
  4. from enum import IntEnum
  5. from io import UnsupportedOperation
  6. from logging import getLogger
  7. from mmap import mmap, PROT_READ, PROT_WRITE
  8. from typing import TypeVar, Generic, Callable, List, Optional
  9. import mmh3
  10. from zstandard import ZstdDecompressor, ZstdCompressor, ZstdError
  11. VERSION = 1
  12. METADATA_CONSTANT = b'mwmbl-tiny-search'
  13. METADATA_SIZE = 4096
  14. PAGE_SIZE = 4096
  15. logger = getLogger(__name__)
  16. def astuple(dc):
  17. """
  18. Convert a type to a tuple - values at the end that are None can be truncated.
  19. """
  20. value = tuple(dc.__dict__.values())
  21. while value[-1] is None:
  22. value = value[:-1]
  23. return value
  24. class DocumentState(IntEnum):
  25. CURATED = 1
  26. @dataclass
  27. class Document:
  28. title: str
  29. url: str
  30. extract: str
  31. score: float
  32. term: Optional[str] = None
  33. state: Optional[int] = None
  34. @dataclass
  35. class TokenizedDocument(Document):
  36. tokens: List[str] = field(default_factory=list)
  37. T = TypeVar('T')
  38. class PageError(Exception):
  39. pass
  40. @dataclass
  41. class TinyIndexMetadata:
  42. version: int
  43. page_size: int
  44. num_pages: int
  45. item_factory: str
  46. def to_bytes(self) -> bytes:
  47. metadata_bytes = METADATA_CONSTANT + json.dumps(asdict(self)).encode('utf8')
  48. assert len(metadata_bytes) <= METADATA_SIZE
  49. return metadata_bytes
  50. @staticmethod
  51. def from_bytes(data: bytes):
  52. constant_length = len(METADATA_CONSTANT)
  53. metadata_constant = data[:constant_length]
  54. if metadata_constant != METADATA_CONSTANT:
  55. raise ValueError("This doesn't seem to be an index file")
  56. values = json.loads(data[constant_length:].decode('utf8'))
  57. return TinyIndexMetadata(**values)
  58. # Find the optimal amount of data that fits onto a page
  59. # We do this by leveraging binary search to quickly find the index where:
  60. # - index+1 cannot fit onto a page
  61. # - <=index can fit on a page
  62. def _binary_search_fitting_size(compressor: ZstdCompressor, page_size: int, items:list[T], lo:int, hi:int):
  63. # Base case: our binary search has gone too far
  64. if lo > hi:
  65. return -1, None
  66. # Check the midpoint to see if it will fit onto a page
  67. mid = (lo+hi)//2
  68. compressed_data = compressor.compress(json.dumps(items[:mid]).encode('utf8'))
  69. size = len(compressed_data)
  70. if size > page_size:
  71. # We cannot fit this much data into a page
  72. # Reduce the hi boundary, and try again
  73. return _binary_search_fitting_size(compressor, page_size, items, lo, mid-1)
  74. else:
  75. # We can fit this data into a page, but maybe we can fit more data
  76. # Try to see if we have a better match
  77. potential_target, potential_data = _binary_search_fitting_size(compressor, page_size, items, mid+1, hi)
  78. if potential_target != -1:
  79. # We found a larger index that can still fit onto a page, so use that
  80. return potential_target, potential_data
  81. else:
  82. # No better match, use our index
  83. return mid, compressed_data
  84. def _trim_items_to_page(compressor: ZstdCompressor, page_size: int, items:list[T]):
  85. # Find max number of items that fit on a page
  86. return _binary_search_fitting_size(compressor, page_size, items, 0, len(items))
  87. def _get_page_data(compressor: ZstdCompressor, page_size: int, items: list[T]):
  88. num_fitting, serialised_data = _trim_items_to_page(compressor, page_size, items)
  89. compressed_data = compressor.compress(json.dumps(items[:num_fitting]).encode('utf8'))
  90. assert len(compressed_data) <= page_size, "The data shouldn't get bigger"
  91. return _pad_to_page_size(compressed_data, page_size)
  92. def _pad_to_page_size(data: bytes, page_size: int):
  93. page_length = len(data)
  94. if page_length > page_size:
  95. raise PageError(f"Data is too big ({page_length}) for page size ({page_size})")
  96. padding = b'\x00' * (page_size - page_length)
  97. page_data = data + padding
  98. return page_data
  99. class TinyIndex(Generic[T]):
  100. def __init__(self, item_factory: Callable[..., T], index_path, mode='r'):
  101. if mode not in {'r', 'w'}:
  102. raise ValueError(f"Mode should be one of 'r' or 'w', got {mode}")
  103. with open(index_path, 'rb') as index_file:
  104. metadata_page = index_file.read(METADATA_SIZE)
  105. metadata_bytes = metadata_page.rstrip(b'\x00')
  106. metadata = TinyIndexMetadata.from_bytes(metadata_bytes)
  107. if metadata.item_factory != item_factory.__name__:
  108. raise ValueError(f"Metadata item factory '{metadata.item_factory}' in the index "
  109. f"does not match the passed item factory: '{item_factory.__name__}'")
  110. self.item_factory = item_factory
  111. self.index_path = index_path
  112. self.mode = mode
  113. self.num_pages = metadata.num_pages
  114. self.page_size = metadata.page_size
  115. self.compressor = ZstdCompressor()
  116. self.decompressor = ZstdDecompressor()
  117. logger.info(f"Loaded index with {self.num_pages} pages and {self.page_size} page size")
  118. self.index_file = None
  119. self.mmap = None
  120. def __enter__(self):
  121. self.index_file = open(self.index_path, 'r+b')
  122. prot = PROT_READ if self.mode == 'r' else PROT_READ | PROT_WRITE
  123. self.mmap = mmap(self.index_file.fileno(), 0, prot=prot)
  124. return self
  125. def __exit__(self, exc_type, exc_val, exc_tb):
  126. self.mmap.close()
  127. self.index_file.close()
  128. def retrieve(self, key: str) -> List[T]:
  129. index = self.get_key_page_index(key)
  130. logger.debug(f"Retrieving index {index}")
  131. return self.get_page(index)
  132. def get_key_page_index(self, key) -> int:
  133. key_hash = mmh3.hash(key, signed=False)
  134. return key_hash % self.num_pages
  135. def get_page(self, i) -> list[T]:
  136. """
  137. Get the page at index i, decompress and deserialise it using JSON
  138. """
  139. results = self._get_page_tuples(i)
  140. return [self.item_factory(*item) for item in results]
  141. def _get_page_tuples(self, i):
  142. page_data = self.mmap[i * self.page_size + METADATA_SIZE:(i + 1) * self.page_size + METADATA_SIZE]
  143. try:
  144. decompressed_data = self.decompressor.decompress(page_data)
  145. except ZstdError:
  146. logger.exception(f"Error decompressing page data, content: {page_data}")
  147. return []
  148. # logger.debug(f"Decompressed data: {decompressed_data}")
  149. return json.loads(decompressed_data.decode('utf8'))
  150. def store_in_page(self, page_index: int, values: list[T]):
  151. value_tuples = [astuple(value) for value in values]
  152. self._write_page(value_tuples, page_index)
  153. def _write_page(self, data, i: int):
  154. """
  155. Serialise the data using JSON, compress it and store it at index i.
  156. If the data is too big, it will store the first items in the list and discard the rest.
  157. """
  158. if self.mode != 'w':
  159. raise UnsupportedOperation("The file is open in read mode, you cannot write")
  160. page_data = _get_page_data(self.compressor, self.page_size, data)
  161. logger.debug(f"Got page data of length {len(page_data)}")
  162. self.mmap[i * self.page_size + METADATA_SIZE:(i+1) * self.page_size + METADATA_SIZE] = page_data
  163. @staticmethod
  164. def create(item_factory: Callable[..., T], index_path: str, num_pages: int, page_size: int):
  165. if os.path.isfile(index_path):
  166. raise FileExistsError(f"Index file '{index_path}' already exists")
  167. metadata = TinyIndexMetadata(VERSION, page_size, num_pages, item_factory.__name__)
  168. metadata_bytes = metadata.to_bytes()
  169. metadata_padded = _pad_to_page_size(metadata_bytes, METADATA_SIZE)
  170. compressor = ZstdCompressor()
  171. page_bytes = _get_page_data(compressor, page_size, [])
  172. with open(index_path, 'wb') as index_file:
  173. index_file.write(metadata_padded)
  174. for i in range(num_pages):
  175. index_file.write(page_bytes)
  176. return TinyIndex(item_factory, index_path=index_path)