index.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. """
  2. Create a search index
  3. """
  4. import json
  5. import os
  6. from abc import ABC, abstractmethod
  7. from collections import Counter
  8. from dataclasses import dataclass, fields, asdict, astuple
  9. from itertools import islice
  10. from mmap import mmap, PROT_READ
  11. from typing import List, Iterator, TypeVar, Generic, Iterable, Callable
  12. from urllib.parse import unquote
  13. import justext
  14. import mmh3
  15. import pandas as pd
  16. from zstandard import ZstdCompressor, ZstdDecompressor, ZstdError
  17. # NUM_PAGES = 8192
  18. # PAGE_SIZE = 512
  19. NUM_PAGES = 25600
  20. PAGE_SIZE = 4096
  21. NUM_INITIAL_TOKENS = 50
  22. HTTP_START = 'http://'
  23. HTTPS_START = 'https://'
  24. BATCH_SIZE = 100
  25. def is_content_token(nlp, token):
  26. lexeme = nlp.vocab[token.orth]
  27. return (lexeme.is_alpha or lexeme.is_digit) and not token.is_stop
  28. def tokenize(nlp, cleaned_text):
  29. tokens = nlp.tokenizer(cleaned_text)
  30. content_tokens = [token for token in tokens[:NUM_INITIAL_TOKENS]
  31. if is_content_token(nlp, token)]
  32. lowered = {nlp.vocab[token.orth].text.lower() for token in content_tokens}
  33. return lowered
  34. def clean(content):
  35. text = justext.justext(content, justext.get_stoplist("English"))
  36. pars = [par.text for par in text if not par.is_boilerplate]
  37. cleaned_text = ' '.join(pars)
  38. return cleaned_text
  39. @dataclass
  40. class Document:
  41. title: str
  42. url: str
  43. extract: str
  44. @dataclass
  45. class TokenizedDocument(Document):
  46. tokens: List[str]
  47. T = TypeVar('T')
  48. class TinyIndexBase(Generic[T]):
  49. def __init__(self, item_factory: Callable[..., T], num_pages: int, page_size: int):
  50. self.item_factory = item_factory
  51. self.num_pages = num_pages
  52. self.page_size = page_size
  53. self.decompressor = ZstdDecompressor()
  54. self.mmap = None
  55. def retrieve(self, key: str) -> List[T]:
  56. index = self._get_key_page_index(key)
  57. page = self.get_page(index)
  58. if page is None:
  59. return []
  60. # print("REtrieve", self.index_path, page)
  61. return self.convert_items(page)
  62. def _get_key_page_index(self, key):
  63. key_hash = mmh3.hash(key, signed=False)
  64. return key_hash % self.num_pages
  65. def get_page(self, i):
  66. """
  67. Get the page at index i, decompress and deserialise it using JSON
  68. """
  69. page_data = self.mmap[i * self.page_size:(i + 1) * self.page_size]
  70. zeros = page_data.count(b'\x00\x00\x00\x00') * 4
  71. try:
  72. decompressed_data = self.decompressor.decompress(page_data)
  73. except ZstdError:
  74. return None
  75. results = json.loads(decompressed_data.decode('utf8'))
  76. # print(f"Num results: {len(results)}, num zeros: {zeros}")
  77. return results
  78. def convert_items(self, items) -> List[T]:
  79. converted = [self.item_factory(*item) for item in items]
  80. # print("Converted", items, converted)
  81. return converted
  82. class TinyIndex(TinyIndexBase[T]):
  83. def __init__(self, item_factory: Callable[..., T], index_path, num_pages, page_size):
  84. super().__init__(item_factory, num_pages, page_size)
  85. # print("REtrieve path", index_path)
  86. self.index_path = index_path
  87. self.index_file = open(self.index_path, 'rb')
  88. self.mmap = mmap(self.index_file.fileno(), 0, prot=PROT_READ)
  89. class TinyIndexer(TinyIndexBase[T]):
  90. def __init__(self, item_factory: Callable[..., T], index_path: str, num_pages: int, page_size: int):
  91. super().__init__(item_factory, num_pages, page_size)
  92. self.index_path = index_path
  93. self.compressor = ZstdCompressor()
  94. self.decompressor = ZstdDecompressor()
  95. self.index_file = None
  96. self.mmap = None
  97. def __enter__(self):
  98. self.create_if_not_exists()
  99. self.index_file = open(self.index_path, 'r+b')
  100. self.mmap = mmap(self.index_file.fileno(), 0)
  101. return self
  102. def __exit__(self, exc_type, exc_val, exc_tb):
  103. self.mmap.close()
  104. self.index_file.close()
  105. def index(self, key: str, value: T):
  106. # print("Index", value)
  107. assert type(value) == self.item_factory, f"Can only index the specified type" \
  108. f" ({self.item_factory.__name__})"
  109. page_index = self._get_key_page_index(key)
  110. current_page = self.get_page(page_index)
  111. if current_page is None:
  112. current_page = []
  113. value_tuple = astuple(value)
  114. # print("Value tuple", value_tuple)
  115. current_page.append(value_tuple)
  116. try:
  117. # print("Page", current_page)
  118. self._write_page(current_page, page_index)
  119. except ValueError:
  120. pass
  121. def _write_page(self, data, i):
  122. """
  123. Serialise the data using JSON, compress it and store it at index i.
  124. If the data is too big, it will raise a ValueError and not store anything
  125. """
  126. serialised_data = json.dumps(data)
  127. compressed_data = self.compressor.compress(serialised_data.encode('utf8'))
  128. page_length = len(compressed_data)
  129. if page_length > self.page_size:
  130. raise ValueError(f"Data is too big ({page_length}) for page size ({self.page_size})")
  131. padding = b'\x00' * (self.page_size - page_length)
  132. self.mmap[i * self.page_size:(i+1) * self.page_size] = compressed_data + padding
  133. def create_if_not_exists(self):
  134. if not os.path.isfile(self.index_path):
  135. file_length = self.num_pages * self.page_size
  136. with open(self.index_path, 'wb') as index_file:
  137. index_file.write(b'\x00' * file_length)
  138. def prepare_url_for_tokenizing(url: str):
  139. if url.startswith(HTTP_START):
  140. url = url[len(HTTP_START):]
  141. elif url.startswith(HTTPS_START):
  142. url = url[len(HTTPS_START):]
  143. for c in '/._':
  144. if c in url:
  145. url = url.replace(c, ' ')
  146. return url
  147. def get_pages(nlp, titles_urls_and_extracts) -> Iterable[TokenizedDocument]:
  148. for i, (title_cleaned, url, extract) in enumerate(titles_urls_and_extracts):
  149. title_tokens = tokenize(nlp, title_cleaned)
  150. prepared_url = prepare_url_for_tokenizing(unquote(url))
  151. url_tokens = tokenize(nlp, prepared_url)
  152. extract_tokens = tokenize(nlp, extract)
  153. print("Extract tokens", extract_tokens)
  154. tokens = title_tokens | url_tokens | extract_tokens
  155. yield TokenizedDocument(tokens=list(tokens), url=url, title=title_cleaned, extract=extract)
  156. if i % 1000 == 0:
  157. print("Processed", i)
  158. def grouper(n: int, iterator: Iterator):
  159. while True:
  160. chunk = tuple(islice(iterator, n))
  161. if not chunk:
  162. return
  163. yield chunk
  164. def index_titles_urls_and_extracts(indexer: TinyIndexer, nlp, titles_urls_and_extracts, terms_path):
  165. indexer.create_if_not_exists()
  166. terms = Counter()
  167. pages = get_pages(nlp, titles_urls_and_extracts)
  168. for page in pages:
  169. for token in page.tokens:
  170. indexer.index(token, Document(url=page.url, title=page.title, extract=page.extract))
  171. terms.update([t.lower() for t in page.tokens])
  172. term_df = pd.DataFrame({
  173. 'term': terms.keys(),
  174. 'count': terms.values(),
  175. })
  176. term_df.to_csv(terms_path)