indexer.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import json
  2. import os
  3. from dataclasses import astuple, dataclass
  4. from mmap import mmap, PROT_READ
  5. from typing import TypeVar, Generic, Callable, List
  6. import mmh3
  7. from zstandard import ZstdDecompressor
  8. NUM_PAGES = 25600
  9. PAGE_SIZE = 4096
  10. @dataclass
  11. class Document:
  12. title: str
  13. url: str
  14. extract: str
  15. @dataclass
  16. class TokenizedDocument(Document):
  17. tokens: List[str]
  18. T = TypeVar('T')
  19. class TinyIndexBase(Generic[T]):
  20. def __init__(self, item_factory: Callable[..., T], num_pages: int, page_size: int):
  21. self.item_factory = item_factory
  22. self.num_pages = num_pages
  23. self.page_size = page_size
  24. self.decompressor = ZstdDecompressor()
  25. self.mmap = None
  26. def retrieve(self, key: str) -> List[T]:
  27. index = self._get_key_page_index(key)
  28. page = self.get_page(index)
  29. if page is None:
  30. return []
  31. # print("REtrieve", self.index_path, page)
  32. return self.convert_items(page)
  33. def _get_key_page_index(self, key):
  34. key_hash = mmh3.hash(key, signed=False)
  35. return key_hash % self.num_pages
  36. def get_page(self, i):
  37. """
  38. Get the page at index i, decompress and deserialise it using JSON
  39. """
  40. page_data = self.mmap[i * self.page_size:(i + 1) * self.page_size]
  41. zeros = page_data.count(b'\x00\x00\x00\x00') * 4
  42. try:
  43. decompressed_data = self.decompressor.decompress(page_data)
  44. except ZstdError:
  45. return None
  46. results = json.loads(decompressed_data.decode('utf8'))
  47. # print(f"Num results: {len(results)}, num zeros: {zeros}")
  48. return results
  49. def convert_items(self, items) -> List[T]:
  50. converted = [self.item_factory(*item) for item in items]
  51. # print("Converted", items, converted)
  52. return converted
  53. class TinyIndex(TinyIndexBase[T]):
  54. def __init__(self, item_factory: Callable[..., T], index_path, num_pages, page_size):
  55. super().__init__(item_factory, num_pages, page_size)
  56. # print("REtrieve path", index_path)
  57. self.index_path = index_path
  58. self.index_file = open(self.index_path, 'rb')
  59. self.mmap = mmap(self.index_file.fileno(), 0, prot=PROT_READ)
  60. class TinyIndexer(TinyIndexBase[T]):
  61. def __init__(self, item_factory: Callable[..., T], index_path: str, num_pages: int, page_size: int):
  62. super().__init__(item_factory, num_pages, page_size)
  63. self.index_path = index_path
  64. self.compressor = ZstdCompressor()
  65. self.decompressor = ZstdDecompressor()
  66. self.index_file = None
  67. self.mmap = None
  68. def __enter__(self):
  69. self.create_if_not_exists()
  70. self.index_file = open(self.index_path, 'r+b')
  71. self.mmap = mmap(self.index_file.fileno(), 0)
  72. return self
  73. def __exit__(self, exc_type, exc_val, exc_tb):
  74. self.mmap.close()
  75. self.index_file.close()
  76. def index(self, key: str, value: T):
  77. # print("Index", value)
  78. assert type(value) == self.item_factory, f"Can only index the specified type" \
  79. f" ({self.item_factory.__name__})"
  80. page_index = self._get_key_page_index(key)
  81. current_page = self.get_page(page_index)
  82. if current_page is None:
  83. current_page = []
  84. value_tuple = astuple(value)
  85. # print("Value tuple", value_tuple)
  86. current_page.append(value_tuple)
  87. try:
  88. # print("Page", current_page)
  89. self._write_page(current_page, page_index)
  90. except ValueError:
  91. pass
  92. def _write_page(self, data, i):
  93. """
  94. Serialise the data using JSON, compress it and store it at index i.
  95. If the data is too big, it will raise a ValueError and not store anything
  96. """
  97. serialised_data = json.dumps(data)
  98. compressed_data = self.compressor.compress(serialised_data.encode('utf8'))
  99. page_length = len(compressed_data)
  100. if page_length > self.page_size:
  101. raise ValueError(f"Data is too big ({page_length}) for page size ({self.page_size})")
  102. padding = b'\x00' * (self.page_size - page_length)
  103. self.mmap[i * self.page_size:(i+1) * self.page_size] = compressed_data + padding
  104. def create_if_not_exists(self):
  105. if not os.path.isfile(self.index_path):
  106. file_length = self.num_pages * self.page_size
  107. with open(self.index_path, 'wb') as index_file:
  108. index_file.write(b'\x00' * file_length)