Implement new indexing approach

This commit is contained in:
Daoud Clarke 2022-07-23 23:19:36 +01:00
parent a8a6c67239
commit 1bceeae3df
7 changed files with 131 additions and 25 deletions

View file

@ -5,7 +5,7 @@ from logging import getLogger
from pathlib import Path
from time import sleep
from mwmbl.indexer import historical
from mwmbl.indexer import historical, index_batches
from mwmbl.indexer.batch_cache import BatchCache
from mwmbl.indexer.paths import INDEX_PATH, BATCH_DIR_NAME
from mwmbl.indexer.preprocess import run_preprocessing
@ -20,15 +20,11 @@ def run(data_path: str):
batch_cache = BatchCache(Path(data_path) / BATCH_DIR_NAME)
while True:
try:
batch_cache.retrieve_batches()
batch_cache.retrieve_batches(1)
except Exception:
logger.exception("Error retrieving batches")
try:
run_preprocessing(index_path)
index_batches.run(batch_cache, index_path)
except Exception:
logger.exception("Error preprocessing")
try:
run_update(index_path)
except Exception:
logger.exception("Error running index update")
logger.exception("Error indexing batches")
sleep(10)

View file

@ -128,6 +128,17 @@ class URLDatabase:
return [result[0] for result in results]
def get_url_scores(self, urls: list[str]) -> dict[str, float]:
sql = f"""
SELECT url, score FROM urls WHERE url IN %(urls)s
"""
with self.connection.cursor() as cursor:
cursor.execute(sql, {'urls': tuple(urls)})
results = cursor.fetchall()
return {result[0]: result[1] for result in results}
if __name__ == "__main__":
with Database() as db:

View file

@ -26,12 +26,13 @@ class BatchCache:
os.makedirs(repo_path, exist_ok=True)
self.path = repo_path
def get(self, num_batches) -> dict[str, HashedBatch]:
def get_cached(self, batch_urls: list[str]) -> dict[str, HashedBatch]:
batches = {}
for path in os.listdir(self.path):
batch = HashedBatch.parse_file(path)
while len(batches) < num_batches:
batches[path] = batch
for url in batch_urls:
path = self.get_path_from_url(url)
data = gzip.GzipFile(path).read()
batch = HashedBatch.parse_raw(data)
batches[url] = batch
return batches
def retrieve_batches(self, num_thousand_batches=10):
@ -43,7 +44,7 @@ class BatchCache:
index_db = IndexDatabase(db.connection)
for i in range(num_thousand_batches):
batches = index_db.get_batches_by_status(BatchStatus.REMOTE)
batches = index_db.get_batches_by_status(BatchStatus.REMOTE, 100)
print(f"Found {len(batches)} remote batches")
if len(batches) == 0:
return

View file

@ -0,0 +1,81 @@
"""
Index batches that are stored locally.
"""
from collections import defaultdict
from logging import getLogger
from typing import Iterable
import spacy
from mwmbl.crawler.batch import HashedBatch
from mwmbl.crawler.urls import URLDatabase
from mwmbl.database import Database
from mwmbl.indexer.batch_cache import BatchCache
from mwmbl.indexer.index import tokenize_document
from mwmbl.indexer.indexdb import BatchStatus, IndexDatabase
from mwmbl.tinysearchengine.indexer import Document, TinyIndex
logger = getLogger(__name__)
def get_documents_from_batches(batches: Iterable[HashedBatch]) -> Iterable[tuple[str, str, str]]:
for batch in batches:
for item in batch.items:
if item.content is not None:
yield item.content.title, item.url, item.content.extract
def run(batch_cache: BatchCache, index_path: str):
nlp = spacy.load("en_core_web_sm")
with Database() as db:
index_db = IndexDatabase(db.connection)
logger.info("Getting local batches")
batches = index_db.get_batches_by_status(BatchStatus.LOCAL)
logger.info(f"Got {len(batches)} batch urls")
batch_data = batch_cache.get_cached([batch.url for batch in batches])
logger.info(f"Got {len(batch_data)} cached batches")
document_tuples = list(get_documents_from_batches(batch_data.values()))
urls = [url for title, url, extract in document_tuples]
print(f"Got {len(urls)} document tuples")
url_db = URLDatabase(db.connection)
url_scores = url_db.get_url_scores(urls)
print(f"Got {len(url_scores)} scores")
documents = [Document(title, url, extract, url_scores.get(url, 1.0)) for title, url, extract in document_tuples]
page_documents = preprocess_documents(documents, index_path, nlp)
index_pages(index_path, page_documents)
def index_pages(index_path, page_documents):
with TinyIndex(Document, index_path, 'w') as indexer:
for page, documents in page_documents.items():
new_documents = []
existing_documents = indexer.get_page(page)
seen_urls = set()
seen_titles = set()
sorted_documents = sorted(documents + existing_documents, key=lambda x: x.score)
for document in sorted_documents:
if document.title in seen_titles or document.url in seen_urls:
continue
new_documents.append(document)
seen_urls.add(document.url)
seen_titles.add(document.title)
indexer.store_in_page(page, new_documents)
logger.debug(f"Wrote page {page} with {len(new_documents)} documents")
def preprocess_documents(documents, index_path, nlp):
page_documents = defaultdict(list)
with TinyIndex(Document, index_path, 'w') as indexer:
for document in documents:
tokenized = tokenize_document(document.url, document.title, document.extract, document.score, nlp)
# logger.debug(f"Tokenized: {tokenized}")
page_indexes = [indexer.get_key_page_index(token) for token in tokenized.tokens]
for page in page_indexes:
page_documents[page].append(document)
print(f"Preprocessed for {len(page_documents)} pages")
return page_documents

View file

@ -77,13 +77,13 @@ class IndexDatabase:
with self.connection.cursor() as cursor:
execute_values(cursor, sql, data)
def get_batches_by_status(self, status: BatchStatus) -> list[BatchInfo]:
def get_batches_by_status(self, status: BatchStatus, num_batches=1000) -> list[BatchInfo]:
sql = """
SELECT * FROM batches WHERE status = %(status)s LIMIT 1000
SELECT * FROM batches WHERE status = %(status)s LIMIT %(num_batches)s
"""
with self.connection.cursor() as cursor:
cursor.execute(sql, {'status': status.value})
cursor.execute(sql, {'status': status.value, 'num_batches': num_batches})
results = cursor.fetchall()
return [BatchInfo(url, user_id_hash, status) for url, user_id_hash, status in results]

View file

@ -16,7 +16,7 @@ from mwmbl.tinysearchengine.completer import Completer
from mwmbl.tinysearchengine.indexer import TinyIndex, Document, NUM_PAGES, PAGE_SIZE
from mwmbl.tinysearchengine.rank import HeuristicRanker
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
def setup_args():
@ -42,7 +42,7 @@ def run():
if existing_index is None:
print("Creating a new index")
TinyIndex.create(item_factory=Document, index_path=args.index, num_pages=NUM_PAGES, page_size=PAGE_SIZE)
TinyIndex.create(item_factory=Document, index_path=index_path, num_pages=NUM_PAGES, page_size=PAGE_SIZE)
Process(target=background.run, args=(args.data,)).start()

View file

@ -1,7 +1,7 @@
import json
import os
from dataclasses import astuple, dataclass, asdict
from io import UnsupportedOperation
from io import UnsupportedOperation, BytesIO
from logging import getLogger
from mmap import mmap, PROT_READ, PROT_WRITE
from typing import TypeVar, Generic, Callable, List
@ -63,9 +63,21 @@ class TinyIndexMetadata:
return TinyIndexMetadata(**values)
def _get_page_data(compressor, page_size, data):
serialised_data = json.dumps(data)
compressed_data = compressor.compress(serialised_data.encode('utf8'))
def _get_page_data(compressor: ZstdCompressor, page_size: int, items: list[T]):
bytes_io = BytesIO()
stream_writer = compressor.stream_writer(bytes_io, write_size=128)
num_fitting = 0
for i, item in enumerate(items):
serialised_data = json.dumps(item) + '\n'
stream_writer.write(serialised_data.encode('utf8'))
stream_writer.flush()
if len(bytes_io.getvalue()) > page_size:
break
num_fitting = i + 1
compressed_data = compressor.compress(json.dumps(items[:num_fitting]).encode('utf8'))
assert len(compressed_data) < page_size, "The data shouldn't get bigger"
return _pad_to_page_size(compressed_data, page_size)
@ -157,15 +169,20 @@ class TinyIndex(Generic[T]):
current_page += value_tuples
self._write_page(current_page, page_index)
def _write_page(self, data, i):
def store_in_page(self, page_index: int, values: list[T]):
value_tuples = [astuple(value) for value in values]
self._write_page(value_tuples, page_index)
def _write_page(self, data, i: int):
"""
Serialise the data using JSON, compress it and store it at index i.
If the data is too big, it will raise a ValueError and not store anything
If the data is too big, it will store the first items in the list and discard the rest.
"""
if self.mode != 'w':
raise UnsupportedOperation("The file is open in read mode, you cannot write")
page_data = _get_page_data(self.compressor, self.page_size, data)
logger.debug(f"Got page data of length {len(page_data)}")
self.mmap[i * self.page_size:(i+1) * self.page_size] = page_data
@staticmethod