Selaa lähdekoodia

Implement new indexing approach

Daoud Clarke 3 vuotta sitten
vanhempi
commit
1bceeae3df

+ 4 - 8
mwmbl/background.py

@@ -5,7 +5,7 @@ from logging import getLogger
 from pathlib import Path
 from time import sleep
 
-from mwmbl.indexer import historical
+from mwmbl.indexer import historical, index_batches
 from mwmbl.indexer.batch_cache import BatchCache
 from mwmbl.indexer.paths import INDEX_PATH, BATCH_DIR_NAME
 from mwmbl.indexer.preprocess import run_preprocessing
@@ -20,15 +20,11 @@ def run(data_path: str):
     batch_cache = BatchCache(Path(data_path) / BATCH_DIR_NAME)
     while True:
         try:
-            batch_cache.retrieve_batches()
+            batch_cache.retrieve_batches(1)
         except Exception:
             logger.exception("Error retrieving batches")
         try:
-            run_preprocessing(index_path)
+            index_batches.run(batch_cache, index_path)
         except Exception:
-            logger.exception("Error preprocessing")
-        try:
-            run_update(index_path)
-        except Exception:
-            logger.exception("Error running index update")
+            logger.exception("Error indexing batches")
         sleep(10)

+ 11 - 0
mwmbl/crawler/urls.py

@@ -128,6 +128,17 @@ class URLDatabase:
 
         return [result[0] for result in results]
 
+    def get_url_scores(self, urls: list[str]) -> dict[str, float]:
+        sql = f"""
+        SELECT url, score FROM urls WHERE url IN %(urls)s
+        """
+
+        with self.connection.cursor() as cursor:
+            cursor.execute(sql, {'urls': tuple(urls)})
+            results = cursor.fetchall()
+
+        return {result[0]: result[1] for result in results}
+
 
 if __name__ == "__main__":
     with Database() as db:

+ 7 - 6
mwmbl/indexer/batch_cache.py

@@ -26,12 +26,13 @@ class BatchCache:
         os.makedirs(repo_path, exist_ok=True)
         self.path = repo_path
 
-    def get(self, num_batches) -> dict[str, HashedBatch]:
+    def get_cached(self, batch_urls: list[str]) -> dict[str, HashedBatch]:
         batches = {}
-        for path in os.listdir(self.path):
-            batch = HashedBatch.parse_file(path)
-            while len(batches) < num_batches:
-                batches[path] = batch
+        for url in batch_urls:
+            path = self.get_path_from_url(url)
+            data = gzip.GzipFile(path).read()
+            batch = HashedBatch.parse_raw(data)
+            batches[url] = batch
         return batches
 
     def retrieve_batches(self, num_thousand_batches=10):
@@ -43,7 +44,7 @@ class BatchCache:
             index_db = IndexDatabase(db.connection)
 
             for i in range(num_thousand_batches):
-                batches = index_db.get_batches_by_status(BatchStatus.REMOTE)
+                batches = index_db.get_batches_by_status(BatchStatus.REMOTE, 100)
                 print(f"Found {len(batches)} remote batches")
                 if len(batches) == 0:
                     return

+ 81 - 0
mwmbl/indexer/index_batches.py

@@ -0,0 +1,81 @@
+"""
+Index batches that are stored locally.
+"""
+from collections import defaultdict
+from logging import getLogger
+from typing import Iterable
+
+import spacy
+
+from mwmbl.crawler.batch import HashedBatch
+from mwmbl.crawler.urls import URLDatabase
+from mwmbl.database import Database
+from mwmbl.indexer.batch_cache import BatchCache
+from mwmbl.indexer.index import tokenize_document
+from mwmbl.indexer.indexdb import BatchStatus, IndexDatabase
+from mwmbl.tinysearchengine.indexer import Document, TinyIndex
+
+logger = getLogger(__name__)
+
+
+def get_documents_from_batches(batches: Iterable[HashedBatch]) -> Iterable[tuple[str, str, str]]:
+    for batch in batches:
+        for item in batch.items:
+            if item.content is not None:
+                yield item.content.title, item.url, item.content.extract
+
+
+def run(batch_cache: BatchCache, index_path: str):
+    nlp = spacy.load("en_core_web_sm")
+    with Database() as db:
+        index_db = IndexDatabase(db.connection)
+
+        logger.info("Getting local batches")
+        batches = index_db.get_batches_by_status(BatchStatus.LOCAL)
+        logger.info(f"Got {len(batches)} batch urls")
+        batch_data = batch_cache.get_cached([batch.url for batch in batches])
+        logger.info(f"Got {len(batch_data)} cached batches")
+
+        document_tuples = list(get_documents_from_batches(batch_data.values()))
+        urls = [url for title, url, extract in document_tuples]
+
+        print(f"Got {len(urls)} document tuples")
+        url_db = URLDatabase(db.connection)
+        url_scores = url_db.get_url_scores(urls)
+
+        print(f"Got {len(url_scores)} scores")
+        documents = [Document(title, url, extract, url_scores.get(url, 1.0)) for title, url, extract in document_tuples]
+
+        page_documents = preprocess_documents(documents, index_path, nlp)
+        index_pages(index_path, page_documents)
+
+
+def index_pages(index_path, page_documents):
+    with TinyIndex(Document, index_path, 'w') as indexer:
+        for page, documents in page_documents.items():
+            new_documents = []
+            existing_documents = indexer.get_page(page)
+            seen_urls = set()
+            seen_titles = set()
+            sorted_documents = sorted(documents + existing_documents, key=lambda x: x.score)
+            for document in sorted_documents:
+                if document.title in seen_titles or document.url in seen_urls:
+                    continue
+                new_documents.append(document)
+                seen_urls.add(document.url)
+                seen_titles.add(document.title)
+            indexer.store_in_page(page, new_documents)
+            logger.debug(f"Wrote page {page} with {len(new_documents)} documents")
+
+
+def preprocess_documents(documents, index_path, nlp):
+    page_documents = defaultdict(list)
+    with TinyIndex(Document, index_path, 'w') as indexer:
+        for document in documents:
+            tokenized = tokenize_document(document.url, document.title, document.extract, document.score, nlp)
+            # logger.debug(f"Tokenized: {tokenized}")
+            page_indexes = [indexer.get_key_page_index(token) for token in tokenized.tokens]
+            for page in page_indexes:
+                page_documents[page].append(document)
+    print(f"Preprocessed for {len(page_documents)} pages")
+    return page_documents

+ 3 - 3
mwmbl/indexer/indexdb.py

@@ -77,13 +77,13 @@ class IndexDatabase:
         with self.connection.cursor() as cursor:
             execute_values(cursor, sql, data)
 
-    def get_batches_by_status(self, status: BatchStatus) -> list[BatchInfo]:
+    def get_batches_by_status(self, status: BatchStatus, num_batches=1000) -> list[BatchInfo]:
         sql = """
-        SELECT * FROM batches WHERE status = %(status)s LIMIT 1000
+        SELECT * FROM batches WHERE status = %(status)s LIMIT %(num_batches)s
         """
 
         with self.connection.cursor() as cursor:
-            cursor.execute(sql, {'status': status.value})
+            cursor.execute(sql, {'status': status.value, 'num_batches': num_batches})
             results = cursor.fetchall()
             return [BatchInfo(url, user_id_hash, status) for url, user_id_hash, status in results]
 

+ 2 - 2
mwmbl/main.py

@@ -16,7 +16,7 @@ from mwmbl.tinysearchengine.completer import Completer
 from mwmbl.tinysearchengine.indexer import TinyIndex, Document, NUM_PAGES, PAGE_SIZE
 from mwmbl.tinysearchengine.rank import HeuristicRanker
 
-logging.basicConfig(stream=sys.stdout, level=logging.INFO)
+logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
 
 
 def setup_args():
@@ -42,7 +42,7 @@ def run():
 
     if existing_index is None:
         print("Creating a new index")
-        TinyIndex.create(item_factory=Document, index_path=args.index, num_pages=NUM_PAGES, page_size=PAGE_SIZE)
+        TinyIndex.create(item_factory=Document, index_path=index_path, num_pages=NUM_PAGES, page_size=PAGE_SIZE)
 
     Process(target=background.run, args=(args.data,)).start()
 

+ 23 - 6
mwmbl/tinysearchengine/indexer.py

@@ -1,7 +1,7 @@
 import json
 import os
 from dataclasses import astuple, dataclass, asdict
-from io import UnsupportedOperation
+from io import UnsupportedOperation, BytesIO
 from logging import getLogger
 from mmap import mmap, PROT_READ, PROT_WRITE
 from typing import TypeVar, Generic, Callable, List
@@ -63,9 +63,21 @@ class TinyIndexMetadata:
         return TinyIndexMetadata(**values)
 
 
-def _get_page_data(compressor, page_size, data):
-    serialised_data = json.dumps(data)
-    compressed_data = compressor.compress(serialised_data.encode('utf8'))
+def _get_page_data(compressor: ZstdCompressor, page_size: int, items: list[T]):
+    bytes_io = BytesIO()
+    stream_writer = compressor.stream_writer(bytes_io, write_size=128)
+
+    num_fitting = 0
+    for i, item in enumerate(items):
+        serialised_data = json.dumps(item) + '\n'
+        stream_writer.write(serialised_data.encode('utf8'))
+        stream_writer.flush()
+        if len(bytes_io.getvalue()) > page_size:
+            break
+        num_fitting = i + 1
+
+    compressed_data = compressor.compress(json.dumps(items[:num_fitting]).encode('utf8'))
+    assert len(compressed_data) < page_size, "The data shouldn't get bigger"
     return _pad_to_page_size(compressed_data, page_size)
 
 
@@ -157,15 +169,20 @@ class TinyIndex(Generic[T]):
         current_page += value_tuples
         self._write_page(current_page, page_index)
 
-    def _write_page(self, data, i):
+    def store_in_page(self, page_index: int, values: list[T]):
+        value_tuples = [astuple(value) for value in values]
+        self._write_page(value_tuples, page_index)
+
+    def _write_page(self, data, i: int):
         """
         Serialise the data using JSON, compress it and store it at index i.
-        If the data is too big, it will raise a ValueError and not store anything
+        If the data is too big, it will store the first items in the list and discard the rest.
         """
         if self.mode != 'w':
             raise UnsupportedOperation("The file is open in read mode, you cannot write")
 
         page_data = _get_page_data(self.compressor, self.page_size, data)
+        logger.debug(f"Got page data of length {len(page_data)}")
         self.mmap[i * self.page_size:(i+1) * self.page_size] = page_data
 
     @staticmethod