Implement new indexing approach
This commit is contained in:
parent
a8a6c67239
commit
1bceeae3df
7 changed files with 131 additions and 25 deletions
|
@ -5,7 +5,7 @@ from logging import getLogger
|
|||
from pathlib import Path
|
||||
from time import sleep
|
||||
|
||||
from mwmbl.indexer import historical
|
||||
from mwmbl.indexer import historical, index_batches
|
||||
from mwmbl.indexer.batch_cache import BatchCache
|
||||
from mwmbl.indexer.paths import INDEX_PATH, BATCH_DIR_NAME
|
||||
from mwmbl.indexer.preprocess import run_preprocessing
|
||||
|
@ -20,15 +20,11 @@ def run(data_path: str):
|
|||
batch_cache = BatchCache(Path(data_path) / BATCH_DIR_NAME)
|
||||
while True:
|
||||
try:
|
||||
batch_cache.retrieve_batches()
|
||||
batch_cache.retrieve_batches(1)
|
||||
except Exception:
|
||||
logger.exception("Error retrieving batches")
|
||||
try:
|
||||
run_preprocessing(index_path)
|
||||
index_batches.run(batch_cache, index_path)
|
||||
except Exception:
|
||||
logger.exception("Error preprocessing")
|
||||
try:
|
||||
run_update(index_path)
|
||||
except Exception:
|
||||
logger.exception("Error running index update")
|
||||
logger.exception("Error indexing batches")
|
||||
sleep(10)
|
||||
|
|
|
@ -128,6 +128,17 @@ class URLDatabase:
|
|||
|
||||
return [result[0] for result in results]
|
||||
|
||||
def get_url_scores(self, urls: list[str]) -> dict[str, float]:
|
||||
sql = f"""
|
||||
SELECT url, score FROM urls WHERE url IN %(urls)s
|
||||
"""
|
||||
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(sql, {'urls': tuple(urls)})
|
||||
results = cursor.fetchall()
|
||||
|
||||
return {result[0]: result[1] for result in results}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with Database() as db:
|
||||
|
|
|
@ -26,12 +26,13 @@ class BatchCache:
|
|||
os.makedirs(repo_path, exist_ok=True)
|
||||
self.path = repo_path
|
||||
|
||||
def get(self, num_batches) -> dict[str, HashedBatch]:
|
||||
def get_cached(self, batch_urls: list[str]) -> dict[str, HashedBatch]:
|
||||
batches = {}
|
||||
for path in os.listdir(self.path):
|
||||
batch = HashedBatch.parse_file(path)
|
||||
while len(batches) < num_batches:
|
||||
batches[path] = batch
|
||||
for url in batch_urls:
|
||||
path = self.get_path_from_url(url)
|
||||
data = gzip.GzipFile(path).read()
|
||||
batch = HashedBatch.parse_raw(data)
|
||||
batches[url] = batch
|
||||
return batches
|
||||
|
||||
def retrieve_batches(self, num_thousand_batches=10):
|
||||
|
@ -43,7 +44,7 @@ class BatchCache:
|
|||
index_db = IndexDatabase(db.connection)
|
||||
|
||||
for i in range(num_thousand_batches):
|
||||
batches = index_db.get_batches_by_status(BatchStatus.REMOTE)
|
||||
batches = index_db.get_batches_by_status(BatchStatus.REMOTE, 100)
|
||||
print(f"Found {len(batches)} remote batches")
|
||||
if len(batches) == 0:
|
||||
return
|
||||
|
|
81
mwmbl/indexer/index_batches.py
Normal file
81
mwmbl/indexer/index_batches.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
"""
|
||||
Index batches that are stored locally.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
from logging import getLogger
|
||||
from typing import Iterable
|
||||
|
||||
import spacy
|
||||
|
||||
from mwmbl.crawler.batch import HashedBatch
|
||||
from mwmbl.crawler.urls import URLDatabase
|
||||
from mwmbl.database import Database
|
||||
from mwmbl.indexer.batch_cache import BatchCache
|
||||
from mwmbl.indexer.index import tokenize_document
|
||||
from mwmbl.indexer.indexdb import BatchStatus, IndexDatabase
|
||||
from mwmbl.tinysearchengine.indexer import Document, TinyIndex
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def get_documents_from_batches(batches: Iterable[HashedBatch]) -> Iterable[tuple[str, str, str]]:
|
||||
for batch in batches:
|
||||
for item in batch.items:
|
||||
if item.content is not None:
|
||||
yield item.content.title, item.url, item.content.extract
|
||||
|
||||
|
||||
def run(batch_cache: BatchCache, index_path: str):
|
||||
nlp = spacy.load("en_core_web_sm")
|
||||
with Database() as db:
|
||||
index_db = IndexDatabase(db.connection)
|
||||
|
||||
logger.info("Getting local batches")
|
||||
batches = index_db.get_batches_by_status(BatchStatus.LOCAL)
|
||||
logger.info(f"Got {len(batches)} batch urls")
|
||||
batch_data = batch_cache.get_cached([batch.url for batch in batches])
|
||||
logger.info(f"Got {len(batch_data)} cached batches")
|
||||
|
||||
document_tuples = list(get_documents_from_batches(batch_data.values()))
|
||||
urls = [url for title, url, extract in document_tuples]
|
||||
|
||||
print(f"Got {len(urls)} document tuples")
|
||||
url_db = URLDatabase(db.connection)
|
||||
url_scores = url_db.get_url_scores(urls)
|
||||
|
||||
print(f"Got {len(url_scores)} scores")
|
||||
documents = [Document(title, url, extract, url_scores.get(url, 1.0)) for title, url, extract in document_tuples]
|
||||
|
||||
page_documents = preprocess_documents(documents, index_path, nlp)
|
||||
index_pages(index_path, page_documents)
|
||||
|
||||
|
||||
def index_pages(index_path, page_documents):
|
||||
with TinyIndex(Document, index_path, 'w') as indexer:
|
||||
for page, documents in page_documents.items():
|
||||
new_documents = []
|
||||
existing_documents = indexer.get_page(page)
|
||||
seen_urls = set()
|
||||
seen_titles = set()
|
||||
sorted_documents = sorted(documents + existing_documents, key=lambda x: x.score)
|
||||
for document in sorted_documents:
|
||||
if document.title in seen_titles or document.url in seen_urls:
|
||||
continue
|
||||
new_documents.append(document)
|
||||
seen_urls.add(document.url)
|
||||
seen_titles.add(document.title)
|
||||
indexer.store_in_page(page, new_documents)
|
||||
logger.debug(f"Wrote page {page} with {len(new_documents)} documents")
|
||||
|
||||
|
||||
def preprocess_documents(documents, index_path, nlp):
|
||||
page_documents = defaultdict(list)
|
||||
with TinyIndex(Document, index_path, 'w') as indexer:
|
||||
for document in documents:
|
||||
tokenized = tokenize_document(document.url, document.title, document.extract, document.score, nlp)
|
||||
# logger.debug(f"Tokenized: {tokenized}")
|
||||
page_indexes = [indexer.get_key_page_index(token) for token in tokenized.tokens]
|
||||
for page in page_indexes:
|
||||
page_documents[page].append(document)
|
||||
print(f"Preprocessed for {len(page_documents)} pages")
|
||||
return page_documents
|
|
@ -77,13 +77,13 @@ class IndexDatabase:
|
|||
with self.connection.cursor() as cursor:
|
||||
execute_values(cursor, sql, data)
|
||||
|
||||
def get_batches_by_status(self, status: BatchStatus) -> list[BatchInfo]:
|
||||
def get_batches_by_status(self, status: BatchStatus, num_batches=1000) -> list[BatchInfo]:
|
||||
sql = """
|
||||
SELECT * FROM batches WHERE status = %(status)s LIMIT 1000
|
||||
SELECT * FROM batches WHERE status = %(status)s LIMIT %(num_batches)s
|
||||
"""
|
||||
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(sql, {'status': status.value})
|
||||
cursor.execute(sql, {'status': status.value, 'num_batches': num_batches})
|
||||
results = cursor.fetchall()
|
||||
return [BatchInfo(url, user_id_hash, status) for url, user_id_hash, status in results]
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from mwmbl.tinysearchengine.completer import Completer
|
|||
from mwmbl.tinysearchengine.indexer import TinyIndex, Document, NUM_PAGES, PAGE_SIZE
|
||||
from mwmbl.tinysearchengine.rank import HeuristicRanker
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
||||
|
||||
|
||||
def setup_args():
|
||||
|
@ -42,7 +42,7 @@ def run():
|
|||
|
||||
if existing_index is None:
|
||||
print("Creating a new index")
|
||||
TinyIndex.create(item_factory=Document, index_path=args.index, num_pages=NUM_PAGES, page_size=PAGE_SIZE)
|
||||
TinyIndex.create(item_factory=Document, index_path=index_path, num_pages=NUM_PAGES, page_size=PAGE_SIZE)
|
||||
|
||||
Process(target=background.run, args=(args.data,)).start()
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import os
|
||||
from dataclasses import astuple, dataclass, asdict
|
||||
from io import UnsupportedOperation
|
||||
from io import UnsupportedOperation, BytesIO
|
||||
from logging import getLogger
|
||||
from mmap import mmap, PROT_READ, PROT_WRITE
|
||||
from typing import TypeVar, Generic, Callable, List
|
||||
|
@ -63,9 +63,21 @@ class TinyIndexMetadata:
|
|||
return TinyIndexMetadata(**values)
|
||||
|
||||
|
||||
def _get_page_data(compressor, page_size, data):
|
||||
serialised_data = json.dumps(data)
|
||||
compressed_data = compressor.compress(serialised_data.encode('utf8'))
|
||||
def _get_page_data(compressor: ZstdCompressor, page_size: int, items: list[T]):
|
||||
bytes_io = BytesIO()
|
||||
stream_writer = compressor.stream_writer(bytes_io, write_size=128)
|
||||
|
||||
num_fitting = 0
|
||||
for i, item in enumerate(items):
|
||||
serialised_data = json.dumps(item) + '\n'
|
||||
stream_writer.write(serialised_data.encode('utf8'))
|
||||
stream_writer.flush()
|
||||
if len(bytes_io.getvalue()) > page_size:
|
||||
break
|
||||
num_fitting = i + 1
|
||||
|
||||
compressed_data = compressor.compress(json.dumps(items[:num_fitting]).encode('utf8'))
|
||||
assert len(compressed_data) < page_size, "The data shouldn't get bigger"
|
||||
return _pad_to_page_size(compressed_data, page_size)
|
||||
|
||||
|
||||
|
@ -157,15 +169,20 @@ class TinyIndex(Generic[T]):
|
|||
current_page += value_tuples
|
||||
self._write_page(current_page, page_index)
|
||||
|
||||
def _write_page(self, data, i):
|
||||
def store_in_page(self, page_index: int, values: list[T]):
|
||||
value_tuples = [astuple(value) for value in values]
|
||||
self._write_page(value_tuples, page_index)
|
||||
|
||||
def _write_page(self, data, i: int):
|
||||
"""
|
||||
Serialise the data using JSON, compress it and store it at index i.
|
||||
If the data is too big, it will raise a ValueError and not store anything
|
||||
If the data is too big, it will store the first items in the list and discard the rest.
|
||||
"""
|
||||
if self.mode != 'w':
|
||||
raise UnsupportedOperation("The file is open in read mode, you cannot write")
|
||||
|
||||
page_data = _get_page_data(self.compressor, self.page_size, data)
|
||||
logger.debug(f"Got page data of length {len(page_data)}")
|
||||
self.mmap[i * self.page_size:(i+1) * self.page_size] = page_data
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Add table
Reference in a new issue