diff --git a/mwmbl/background.py b/mwmbl/background.py index 7ff2edb..8e085d3 100644 --- a/mwmbl/background.py +++ b/mwmbl/background.py @@ -5,7 +5,7 @@ from logging import getLogger from pathlib import Path from time import sleep -from mwmbl.indexer import historical +from mwmbl.indexer import historical, index_batches from mwmbl.indexer.batch_cache import BatchCache from mwmbl.indexer.paths import INDEX_PATH, BATCH_DIR_NAME from mwmbl.indexer.preprocess import run_preprocessing @@ -20,15 +20,11 @@ def run(data_path: str): batch_cache = BatchCache(Path(data_path) / BATCH_DIR_NAME) while True: try: - batch_cache.retrieve_batches() + batch_cache.retrieve_batches(1) except Exception: logger.exception("Error retrieving batches") try: - run_preprocessing(index_path) + index_batches.run(batch_cache, index_path) except Exception: - logger.exception("Error preprocessing") - try: - run_update(index_path) - except Exception: - logger.exception("Error running index update") + logger.exception("Error indexing batches") sleep(10) diff --git a/mwmbl/crawler/urls.py b/mwmbl/crawler/urls.py index 291d246..d97809b 100644 --- a/mwmbl/crawler/urls.py +++ b/mwmbl/crawler/urls.py @@ -128,6 +128,17 @@ class URLDatabase: return [result[0] for result in results] + def get_url_scores(self, urls: list[str]) -> dict[str, float]: + sql = f""" + SELECT url, score FROM urls WHERE url IN %(urls)s + """ + + with self.connection.cursor() as cursor: + cursor.execute(sql, {'urls': tuple(urls)}) + results = cursor.fetchall() + + return {result[0]: result[1] for result in results} + if __name__ == "__main__": with Database() as db: diff --git a/mwmbl/indexer/batch_cache.py b/mwmbl/indexer/batch_cache.py index f811578..ed1d811 100644 --- a/mwmbl/indexer/batch_cache.py +++ b/mwmbl/indexer/batch_cache.py @@ -26,12 +26,13 @@ class BatchCache: os.makedirs(repo_path, exist_ok=True) self.path = repo_path - def get(self, num_batches) -> dict[str, HashedBatch]: + def get_cached(self, batch_urls: list[str]) -> dict[str, HashedBatch]: batches = {} - for path in os.listdir(self.path): - batch = HashedBatch.parse_file(path) - while len(batches) < num_batches: - batches[path] = batch + for url in batch_urls: + path = self.get_path_from_url(url) + data = gzip.GzipFile(path).read() + batch = HashedBatch.parse_raw(data) + batches[url] = batch return batches def retrieve_batches(self, num_thousand_batches=10): @@ -43,7 +44,7 @@ class BatchCache: index_db = IndexDatabase(db.connection) for i in range(num_thousand_batches): - batches = index_db.get_batches_by_status(BatchStatus.REMOTE) + batches = index_db.get_batches_by_status(BatchStatus.REMOTE, 100) print(f"Found {len(batches)} remote batches") if len(batches) == 0: return diff --git a/mwmbl/indexer/index_batches.py b/mwmbl/indexer/index_batches.py new file mode 100644 index 0000000..a1991e2 --- /dev/null +++ b/mwmbl/indexer/index_batches.py @@ -0,0 +1,81 @@ +""" +Index batches that are stored locally. +""" +from collections import defaultdict +from logging import getLogger +from typing import Iterable + +import spacy + +from mwmbl.crawler.batch import HashedBatch +from mwmbl.crawler.urls import URLDatabase +from mwmbl.database import Database +from mwmbl.indexer.batch_cache import BatchCache +from mwmbl.indexer.index import tokenize_document +from mwmbl.indexer.indexdb import BatchStatus, IndexDatabase +from mwmbl.tinysearchengine.indexer import Document, TinyIndex + +logger = getLogger(__name__) + + +def get_documents_from_batches(batches: Iterable[HashedBatch]) -> Iterable[tuple[str, str, str]]: + for batch in batches: + for item in batch.items: + if item.content is not None: + yield item.content.title, item.url, item.content.extract + + +def run(batch_cache: BatchCache, index_path: str): + nlp = spacy.load("en_core_web_sm") + with Database() as db: + index_db = IndexDatabase(db.connection) + + logger.info("Getting local batches") + batches = index_db.get_batches_by_status(BatchStatus.LOCAL) + logger.info(f"Got {len(batches)} batch urls") + batch_data = batch_cache.get_cached([batch.url for batch in batches]) + logger.info(f"Got {len(batch_data)} cached batches") + + document_tuples = list(get_documents_from_batches(batch_data.values())) + urls = [url for title, url, extract in document_tuples] + + print(f"Got {len(urls)} document tuples") + url_db = URLDatabase(db.connection) + url_scores = url_db.get_url_scores(urls) + + print(f"Got {len(url_scores)} scores") + documents = [Document(title, url, extract, url_scores.get(url, 1.0)) for title, url, extract in document_tuples] + + page_documents = preprocess_documents(documents, index_path, nlp) + index_pages(index_path, page_documents) + + +def index_pages(index_path, page_documents): + with TinyIndex(Document, index_path, 'w') as indexer: + for page, documents in page_documents.items(): + new_documents = [] + existing_documents = indexer.get_page(page) + seen_urls = set() + seen_titles = set() + sorted_documents = sorted(documents + existing_documents, key=lambda x: x.score) + for document in sorted_documents: + if document.title in seen_titles or document.url in seen_urls: + continue + new_documents.append(document) + seen_urls.add(document.url) + seen_titles.add(document.title) + indexer.store_in_page(page, new_documents) + logger.debug(f"Wrote page {page} with {len(new_documents)} documents") + + +def preprocess_documents(documents, index_path, nlp): + page_documents = defaultdict(list) + with TinyIndex(Document, index_path, 'w') as indexer: + for document in documents: + tokenized = tokenize_document(document.url, document.title, document.extract, document.score, nlp) + # logger.debug(f"Tokenized: {tokenized}") + page_indexes = [indexer.get_key_page_index(token) for token in tokenized.tokens] + for page in page_indexes: + page_documents[page].append(document) + print(f"Preprocessed for {len(page_documents)} pages") + return page_documents diff --git a/mwmbl/indexer/indexdb.py b/mwmbl/indexer/indexdb.py index efc4fc1..8dd43e8 100644 --- a/mwmbl/indexer/indexdb.py +++ b/mwmbl/indexer/indexdb.py @@ -77,13 +77,13 @@ class IndexDatabase: with self.connection.cursor() as cursor: execute_values(cursor, sql, data) - def get_batches_by_status(self, status: BatchStatus) -> list[BatchInfo]: + def get_batches_by_status(self, status: BatchStatus, num_batches=1000) -> list[BatchInfo]: sql = """ - SELECT * FROM batches WHERE status = %(status)s LIMIT 1000 + SELECT * FROM batches WHERE status = %(status)s LIMIT %(num_batches)s """ with self.connection.cursor() as cursor: - cursor.execute(sql, {'status': status.value}) + cursor.execute(sql, {'status': status.value, 'num_batches': num_batches}) results = cursor.fetchall() return [BatchInfo(url, user_id_hash, status) for url, user_id_hash, status in results] diff --git a/mwmbl/main.py b/mwmbl/main.py index 9896cc9..e07e055 100644 --- a/mwmbl/main.py +++ b/mwmbl/main.py @@ -16,7 +16,7 @@ from mwmbl.tinysearchengine.completer import Completer from mwmbl.tinysearchengine.indexer import TinyIndex, Document, NUM_PAGES, PAGE_SIZE from mwmbl.tinysearchengine.rank import HeuristicRanker -logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) def setup_args(): @@ -42,7 +42,7 @@ def run(): if existing_index is None: print("Creating a new index") - TinyIndex.create(item_factory=Document, index_path=args.index, num_pages=NUM_PAGES, page_size=PAGE_SIZE) + TinyIndex.create(item_factory=Document, index_path=index_path, num_pages=NUM_PAGES, page_size=PAGE_SIZE) Process(target=background.run, args=(args.data,)).start() diff --git a/mwmbl/tinysearchengine/indexer.py b/mwmbl/tinysearchengine/indexer.py index a9aa9f4..df026d0 100644 --- a/mwmbl/tinysearchengine/indexer.py +++ b/mwmbl/tinysearchengine/indexer.py @@ -1,7 +1,7 @@ import json import os from dataclasses import astuple, dataclass, asdict -from io import UnsupportedOperation +from io import UnsupportedOperation, BytesIO from logging import getLogger from mmap import mmap, PROT_READ, PROT_WRITE from typing import TypeVar, Generic, Callable, List @@ -63,9 +63,21 @@ class TinyIndexMetadata: return TinyIndexMetadata(**values) -def _get_page_data(compressor, page_size, data): - serialised_data = json.dumps(data) - compressed_data = compressor.compress(serialised_data.encode('utf8')) +def _get_page_data(compressor: ZstdCompressor, page_size: int, items: list[T]): + bytes_io = BytesIO() + stream_writer = compressor.stream_writer(bytes_io, write_size=128) + + num_fitting = 0 + for i, item in enumerate(items): + serialised_data = json.dumps(item) + '\n' + stream_writer.write(serialised_data.encode('utf8')) + stream_writer.flush() + if len(bytes_io.getvalue()) > page_size: + break + num_fitting = i + 1 + + compressed_data = compressor.compress(json.dumps(items[:num_fitting]).encode('utf8')) + assert len(compressed_data) < page_size, "The data shouldn't get bigger" return _pad_to_page_size(compressed_data, page_size) @@ -157,15 +169,20 @@ class TinyIndex(Generic[T]): current_page += value_tuples self._write_page(current_page, page_index) - def _write_page(self, data, i): + def store_in_page(self, page_index: int, values: list[T]): + value_tuples = [astuple(value) for value in values] + self._write_page(value_tuples, page_index) + + def _write_page(self, data, i: int): """ Serialise the data using JSON, compress it and store it at index i. - If the data is too big, it will raise a ValueError and not store anything + If the data is too big, it will store the first items in the list and discard the rest. """ if self.mode != 'w': raise UnsupportedOperation("The file is open in read mode, you cannot write") page_data = _get_page_data(self.compressor, self.page_size, data) + logger.debug(f"Got page data of length {len(page_data)}") self.mmap[i * self.page_size:(i+1) * self.page_size] = page_data @staticmethod