indexer.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import json
  2. import os
  3. from dataclasses import astuple, dataclass
  4. from mmap import mmap, PROT_READ
  5. from pathlib import Path
  6. from typing import TypeVar, Generic, Callable, List
  7. import mmh3
  8. from zstandard import ZstdDecompressor
  9. NUM_PAGES = 25600
  10. PAGE_SIZE = 4096
  11. @dataclass
  12. class Document:
  13. title: str
  14. url: str
  15. extract: str
  16. @dataclass
  17. class TokenizedDocument(Document):
  18. tokens: List[str]
  19. T = TypeVar('T')
  20. class TinyIndexBase(Generic[T]):
  21. def __init__(self, item_factory: Callable[..., T], num_pages: int, page_size: int):
  22. self.item_factory = item_factory
  23. self.num_pages = num_pages
  24. self.page_size = page_size
  25. self.decompressor = ZstdDecompressor()
  26. self.mmap = None
  27. def retrieve(self, key: str) -> List[T]:
  28. index = self._get_key_page_index(key)
  29. page = self.get_page(index)
  30. if page is None:
  31. return []
  32. # print("REtrieve", self.index_path, page)
  33. return self.convert_items(page)
  34. def _get_key_page_index(self, key):
  35. key_hash = mmh3.hash(key, signed=False)
  36. return key_hash % self.num_pages
  37. def get_page(self, i):
  38. """
  39. Get the page at index i, decompress and deserialise it using JSON
  40. """
  41. page_data = self.mmap[i * self.page_size:(i + 1) * self.page_size]
  42. zeros = page_data.count(b'\x00\x00\x00\x00') * 4
  43. try:
  44. decompressed_data = self.decompressor.decompress(page_data)
  45. except ZstdError:
  46. return None
  47. results = json.loads(decompressed_data.decode('utf8'))
  48. # print(f"Num results: {len(results)}, num zeros: {zeros}")
  49. return results
  50. def convert_items(self, items) -> List[T]:
  51. converted = [self.item_factory(*item) for item in items]
  52. # print("Converted", items, converted)
  53. return converted
  54. class TinyIndex(TinyIndexBase[T]):
  55. def __init__(self, item_factory: Callable[..., T], index_path, num_pages, page_size):
  56. super().__init__(item_factory, num_pages, page_size)
  57. # print("REtrieve path", index_path)
  58. self.index_path = index_path
  59. self.index_file = open(self.index_path, 'rb')
  60. self.mmap = mmap(self.index_file.fileno(), 0, prot=PROT_READ)
  61. class TinyIndexer(TinyIndexBase[T]):
  62. def __init__(self, item_factory: Callable[..., T], index_path: str, num_pages: int, page_size: int):
  63. super().__init__(item_factory, num_pages, page_size)
  64. self.index_path = index_path
  65. self.compressor = ZstdCompressor()
  66. self.decompressor = ZstdDecompressor()
  67. self.index_file = None
  68. self.mmap = None
  69. def __enter__(self):
  70. self.create_if_not_exists()
  71. self.index_file = open(self.index_path, 'r+b')
  72. self.mmap = mmap(self.index_file.fileno(), 0)
  73. return self
  74. def __exit__(self, exc_type, exc_val, exc_tb):
  75. self.mmap.close()
  76. self.index_file.close()
  77. def index(self, key: str, value: T):
  78. # print("Index", value)
  79. assert type(value) == self.item_factory, f"Can only index the specified type" \
  80. f" ({self.item_factory.__name__})"
  81. page_index = self._get_key_page_index(key)
  82. current_page = self.get_page(page_index)
  83. if current_page is None:
  84. current_page = []
  85. value_tuple = astuple(value)
  86. # print("Value tuple", value_tuple)
  87. current_page.append(value_tuple)
  88. try:
  89. # print("Page", current_page)
  90. self._write_page(current_page, page_index)
  91. except ValueError:
  92. pass
  93. def _write_page(self, data, i):
  94. """
  95. Serialise the data using JSON, compress it and store it at index i.
  96. If the data is too big, it will raise a ValueError and not store anything
  97. """
  98. serialised_data = json.dumps(data)
  99. compressed_data = self.compressor.compress(serialised_data.encode('utf8'))
  100. page_length = len(compressed_data)
  101. if page_length > self.page_size:
  102. raise ValueError(f"Data is too big ({page_length}) for page size ({self.page_size})")
  103. padding = b'\x00' * (self.page_size - page_length)
  104. self.mmap[i * self.page_size:(i+1) * self.page_size] = compressed_data + padding
  105. def create_if_not_exists(self):
  106. if not os.path.isfile(self.index_path):
  107. file_length = self.num_pages * self.page_size
  108. with open(self.index_path, 'wb') as index_file:
  109. index_file.write(b'\x00' * file_length)