mwmbl/index.py

220 lines
7 KiB
Python
Raw Normal View History

2021-03-13 20:54:15 +00:00
"""
Create a search index
"""
2021-04-12 17:37:33 +00:00
import json
import os
from abc import ABC, abstractmethod
2021-05-30 20:30:34 +00:00
from collections import Counter
from dataclasses import dataclass, fields, asdict, astuple
2021-05-19 20:48:03 +00:00
from itertools import islice
2021-04-12 20:26:41 +00:00
from mmap import mmap, PROT_READ
2021-06-13 20:41:19 +00:00
from typing import List, Iterator, TypeVar, Generic, Iterable, Callable
2021-03-23 22:03:48 +00:00
from urllib.parse import unquote
2021-03-13 20:54:15 +00:00
import justext
2021-04-12 17:37:33 +00:00
import mmh3
2021-05-30 20:30:34 +00:00
import pandas as pd
2021-04-12 17:37:33 +00:00
from zstandard import ZstdCompressor, ZstdDecompressor, ZstdError
2021-03-13 20:54:15 +00:00
2021-12-14 19:59:31 +00:00
# NUM_PAGES = 8192
# PAGE_SIZE = 512
NUM_PAGES = 25600
PAGE_SIZE = 4096
2021-04-12 17:37:33 +00:00
2021-03-13 22:21:50 +00:00
NUM_INITIAL_TOKENS = 50
2021-03-13 20:54:15 +00:00
2021-03-23 22:03:48 +00:00
HTTP_START = 'http://'
HTTPS_START = 'https://'
2021-04-12 17:37:33 +00:00
BATCH_SIZE = 100
2021-03-23 22:03:48 +00:00
2021-03-13 20:54:15 +00:00
def is_content_token(nlp, token):
lexeme = nlp.vocab[token.orth]
2021-03-23 22:03:48 +00:00
return (lexeme.is_alpha or lexeme.is_digit) and not token.is_stop
2021-03-13 20:54:15 +00:00
def tokenize(nlp, cleaned_text):
tokens = nlp.tokenizer(cleaned_text)
2021-03-13 22:21:50 +00:00
content_tokens = [token for token in tokens[:NUM_INITIAL_TOKENS]
if is_content_token(nlp, token)]
2021-03-13 20:54:15 +00:00
lowered = {nlp.vocab[token.orth].text.lower() for token in content_tokens}
return lowered
def clean(content):
text = justext.justext(content, justext.get_stoplist("English"))
pars = [par.text for par in text if not par.is_boilerplate]
cleaned_text = ' '.join(pars)
return cleaned_text
2021-03-24 21:55:35 +00:00
@dataclass
2021-04-12 17:37:33 +00:00
class Document:
2021-03-24 21:55:35 +00:00
title: str
url: str
2021-12-18 22:56:39 +00:00
extract: str
2021-03-24 21:55:35 +00:00
2021-04-12 17:37:33 +00:00
@dataclass
class TokenizedDocument(Document):
tokens: List[str]
2021-06-13 20:41:19 +00:00
T = TypeVar('T')
class TinyIndexBase(Generic[T]):
def __init__(self, item_factory: Callable[..., T], num_pages: int, page_size: int):
self.item_factory = item_factory
2021-04-12 17:37:33 +00:00
self.num_pages = num_pages
self.page_size = page_size
self.decompressor = ZstdDecompressor()
self.mmap = None
2021-06-13 20:41:19 +00:00
def retrieve(self, key: str) -> List[T]:
index = self._get_key_page_index(key)
page = self.get_page(index)
if page is None:
return []
2021-06-11 20:43:12 +00:00
# print("REtrieve", self.index_path, page)
return self.convert_items(page)
2021-04-12 20:26:41 +00:00
def _get_key_page_index(self, key):
key_hash = mmh3.hash(key, signed=False)
return key_hash % self.num_pages
2021-04-12 20:26:41 +00:00
2021-04-16 04:28:51 +00:00
def get_page(self, i):
2021-04-12 17:37:33 +00:00
"""
Get the page at index i, decompress and deserialise it using JSON
"""
page_data = self.mmap[i * self.page_size:(i + 1) * self.page_size]
zeros = page_data.count(b'\x00\x00\x00\x00') * 4
2021-04-12 17:37:33 +00:00
try:
decompressed_data = self.decompressor.decompress(page_data)
except ZstdError:
return None
results = json.loads(decompressed_data.decode('utf8'))
# print(f"Num results: {len(results)}, num zeros: {zeros}")
return results
2021-04-12 17:37:33 +00:00
2021-06-13 20:41:19 +00:00
def convert_items(self, items) -> List[T]:
converted = [self.item_factory(*item) for item in items]
# print("Converted", items, converted)
return converted
2021-04-12 17:37:33 +00:00
2021-06-13 20:41:19 +00:00
class TinyIndex(TinyIndexBase[T]):
def __init__(self, item_factory: Callable[..., T], index_path, num_pages, page_size):
super().__init__(item_factory, num_pages, page_size)
# print("REtrieve path", index_path)
2021-03-23 22:03:48 +00:00
self.index_path = index_path
2021-04-12 20:26:41 +00:00
self.index_file = open(self.index_path, 'rb')
self.mmap = mmap(self.index_file.fileno(), 0, prot=PROT_READ)
2021-03-23 22:03:48 +00:00
2021-04-12 17:37:33 +00:00
2021-06-13 20:41:19 +00:00
class TinyIndexer(TinyIndexBase[T]):
def __init__(self, item_factory: Callable[..., T], index_path: str, num_pages: int, page_size: int):
super().__init__(item_factory, num_pages, page_size)
2021-04-12 17:37:33 +00:00
self.index_path = index_path
self.compressor = ZstdCompressor()
self.decompressor = ZstdDecompressor()
self.index_file = None
2021-04-12 20:26:41 +00:00
self.mmap = None
2021-04-12 17:37:33 +00:00
def __enter__(self):
self.create_if_not_exists()
self.index_file = open(self.index_path, 'r+b')
self.mmap = mmap(self.index_file.fileno(), 0)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.mmap.close()
self.index_file.close()
2021-06-13 20:41:19 +00:00
def index(self, key: str, value: T):
2021-06-11 20:43:12 +00:00
# print("Index", value)
2021-06-13 20:41:19 +00:00
assert type(value) == self.item_factory, f"Can only index the specified type" \
f" ({self.item_factory.__name__})"
page_index = self._get_key_page_index(key)
2021-04-16 04:28:51 +00:00
current_page = self.get_page(page_index)
2021-04-12 17:37:33 +00:00
if current_page is None:
current_page = []
value_tuple = astuple(value)
2021-06-11 20:43:12 +00:00
# print("Value tuple", value_tuple)
current_page.append(value_tuple)
2021-04-12 17:37:33 +00:00
try:
# print("Page", current_page)
2021-04-12 17:37:33 +00:00
self._write_page(current_page, page_index)
except ValueError:
pass
def _write_page(self, data, i):
"""
Serialise the data using JSON, compress it and store it at index i.
If the data is too big, it will raise a ValueError and not store anything
"""
serialised_data = json.dumps(data)
compressed_data = self.compressor.compress(serialised_data.encode('utf8'))
page_length = len(compressed_data)
if page_length > self.page_size:
raise ValueError(f"Data is too big ({page_length}) for page size ({self.page_size})")
padding = b'\x00' * (self.page_size - page_length)
self.mmap[i * self.page_size:(i+1) * self.page_size] = compressed_data + padding
2021-03-23 22:03:48 +00:00
def create_if_not_exists(self):
2021-04-12 17:37:33 +00:00
if not os.path.isfile(self.index_path):
file_length = self.num_pages * self.page_size
with open(self.index_path, 'wb') as index_file:
index_file.write(b'\x00' * file_length)
2021-03-13 22:21:50 +00:00
2021-03-23 22:03:48 +00:00
def prepare_url_for_tokenizing(url: str):
if url.startswith(HTTP_START):
url = url[len(HTTP_START):]
elif url.startswith(HTTPS_START):
url = url[len(HTTPS_START):]
for c in '/._':
if c in url:
url = url.replace(c, ' ')
return url
2021-12-18 22:56:39 +00:00
def get_pages(nlp, titles_urls_and_extracts) -> Iterable[TokenizedDocument]:
for i, (title_cleaned, url, extract) in enumerate(titles_urls_and_extracts):
2021-03-23 22:03:48 +00:00
title_tokens = tokenize(nlp, title_cleaned)
prepared_url = prepare_url_for_tokenizing(unquote(url))
url_tokens = tokenize(nlp, prepared_url)
2021-12-18 22:56:39 +00:00
extract_tokens = tokenize(nlp, extract)
2021-12-19 20:48:28 +00:00
print("Extract tokens", extract_tokens)
2021-12-18 22:56:39 +00:00
tokens = title_tokens | url_tokens | extract_tokens
yield TokenizedDocument(tokens=list(tokens), url=url, title=title_cleaned, extract=extract)
2021-03-23 22:03:48 +00:00
if i % 1000 == 0:
print("Processed", i)
2021-03-13 20:54:15 +00:00
2021-03-24 21:55:35 +00:00
def grouper(n: int, iterator: Iterator):
while True:
chunk = tuple(islice(iterator, n))
if not chunk:
return
yield chunk
2021-12-18 22:56:39 +00:00
def index_titles_urls_and_extracts(indexer: TinyIndexer, nlp, titles_urls_and_extracts, terms_path):
2021-03-24 21:55:35 +00:00
indexer.create_if_not_exists()
2021-05-30 20:30:34 +00:00
terms = Counter()
2021-12-18 22:56:39 +00:00
pages = get_pages(nlp, titles_urls_and_extracts)
for page in pages:
for token in page.tokens:
2021-12-18 22:56:39 +00:00
indexer.index(token, Document(url=page.url, title=page.title, extract=page.extract))
terms.update([t.lower() for t in page.tokens])
2021-05-30 20:30:34 +00:00
term_df = pd.DataFrame({
'term': terms.keys(),
'count': terms.values(),
})
term_df.to_csv(terms_path)