Quellcode durchsuchen

Merge pull request #69 from mwmbl/reduce-contention-for-client-queries

Reduce contention for client queries
Daoud Clarke vor 3 Jahren
Ursprung
Commit
a54e093cf1

+ 10 - 12
mwmbl/background.py

@@ -2,29 +2,27 @@
 Script that updates data in a background process.
 """
 from logging import getLogger
+from pathlib import Path
 from time import sleep
 
-from mwmbl.indexer import historical
-from mwmbl.indexer.preprocess import run_preprocessing
-from mwmbl.indexer.retrieve import retrieve_batches
-from mwmbl.indexer.update_pages import run_update
+from mwmbl.indexer import index_batches, historical
+from mwmbl.indexer.batch_cache import BatchCache
+from mwmbl.indexer.paths import BATCH_DIR_NAME, INDEX_NAME
 
 logger = getLogger(__name__)
 
 
-def run(index_path: str):
+def run(data_path: str):
     historical.run()
+    index_path = Path(data_path) / INDEX_NAME
+    batch_cache = BatchCache(Path(data_path) / BATCH_DIR_NAME)
     while True:
         try:
-            retrieve_batches()
+            batch_cache.retrieve_batches(num_batches=10000)
         except Exception:
             logger.exception("Error retrieving batches")
         try:
-            run_preprocessing(index_path)
+            index_batches.run(batch_cache, index_path)
         except Exception:
-            logger.exception("Error preprocessing")
-        try:
-            run_update(index_path)
-        except Exception:
-            logger.exception("Error running index update")
+            logger.exception("Error indexing batches")
         sleep(10)

+ 106 - 170
mwmbl/crawler/app.py

@@ -1,10 +1,8 @@
 import gzip
 import hashlib
 import json
-from collections import defaultdict
-from datetime import datetime, timezone, timedelta, date
+from datetime import datetime, timezone, date
 from typing import Union
-from urllib.parse import urlparse
 from uuid import uuid4
 
 import boto3
@@ -12,9 +10,9 @@ import requests
 from fastapi import HTTPException, APIRouter
 
 from mwmbl.crawler.batch import Batch, NewBatchRequest, HashedBatch
-from mwmbl.crawler.urls import URLDatabase, FoundURL, URLStatus
+from mwmbl.crawler.urls import URLDatabase
 from mwmbl.database import Database
-from mwmbl.hn_top_domains_filtered import DOMAINS
+from mwmbl.indexer.batch_cache import BatchCache
 from mwmbl.indexer.indexdb import IndexDatabase, BatchInfo, BatchStatus
 from mwmbl.settings import (
     ENDPOINT_URL,
@@ -25,16 +23,9 @@ from mwmbl.settings import (
     USER_ID_LENGTH,
     VERSION,
     PUBLIC_URL_PREFIX,
-    UNKNOWN_DOMAIN_MULTIPLIER,
-    SCORE_FOR_SAME_DOMAIN,
-    SCORE_FOR_DIFFERENT_DOMAIN,
-    SCORE_FOR_ROOT_PATH,
     PUBLIC_USER_ID_LENGTH,
     FILE_NAME_SUFFIX,
     DATE_REGEX)
-from mwmbl.tinysearchengine.indexer import Document
-
-router = APIRouter(prefix="/crawler", tags=["crawler"])
 
 
 def get_bucket(name):
@@ -52,150 +43,111 @@ def upload(data: bytes, name: str):
 last_batch = None
 
 
-@router.on_event("startup")
-async def on_startup():
-    with Database() as db:
-        url_db = URLDatabase(db.connection)
-        return url_db.create_tables()
-
-
-@router.post('/batches/')
-def create_batch(batch: Batch):
-    if len(batch.items) > MAX_BATCH_SIZE:
-        raise HTTPException(400, f"Batch size too large (maximum {MAX_BATCH_SIZE}), got {len(batch.items)}")
-
-    if len(batch.user_id) != USER_ID_LENGTH:
-        raise HTTPException(400, f"User ID length is incorrect, should be {USER_ID_LENGTH} characters")
-
-    if len(batch.items) == 0:
-        return {
-            'status': 'ok',
-        }
-
-    user_id_hash = _get_user_id_hash(batch)
-
-    now = datetime.now(timezone.utc)
-    seconds = (now - datetime(now.year, now.month, now.day, tzinfo=timezone.utc)).seconds
+def get_router(batch_cache: BatchCache):
+    router = APIRouter(prefix="/crawler", tags=["crawler"])
 
-    # How to pad a string with zeros: https://stackoverflow.com/a/39402910
-    # Maximum seconds in a day is 60*60*24 = 86400, so 5 digits is enough
-    padded_seconds = str(seconds).zfill(5)
+    @router.on_event("startup")
+    async def on_startup():
+        with Database() as db:
+            url_db = URLDatabase(db.connection)
+            return url_db.create_tables()
 
-    # See discussion here: https://stackoverflow.com/a/13484764
-    uid = str(uuid4())[:8]
-    filename = f'1/{VERSION}/{now.date()}/1/{user_id_hash}/{padded_seconds}__{uid}.json.gz'
+    @router.post('/batches/')
+    def create_batch(batch: Batch):
+        if len(batch.items) > MAX_BATCH_SIZE:
+            raise HTTPException(400, f"Batch size too large (maximum {MAX_BATCH_SIZE}), got {len(batch.items)}")
 
-    # Using an approach from https://stackoverflow.com/a/30476450
-    epoch_time = (now - datetime(1970, 1, 1, tzinfo=timezone.utc)).total_seconds()
-    hashed_batch = HashedBatch(user_id_hash=user_id_hash, timestamp=epoch_time, items=batch.items)
-    data = gzip.compress(hashed_batch.json().encode('utf8'))
-    upload(data, filename)
+        if len(batch.user_id) != USER_ID_LENGTH:
+            raise HTTPException(400, f"User ID length is incorrect, should be {USER_ID_LENGTH} characters")
 
-    record_urls_in_database(batch, user_id_hash, now)
-    queue_batch(hashed_batch)
+        if len(batch.items) == 0:
+            return {
+                'status': 'ok',
+            }
 
-    global last_batch
-    last_batch = hashed_batch
+        user_id_hash = _get_user_id_hash(batch)
 
-    # Record the batch as being local so that we don't retrieve it again when the server restarts
-    batch_url = f'{PUBLIC_URL_PREFIX}{filename}'
-    infos = [BatchInfo(batch_url, user_id_hash, BatchStatus.LOCAL)]
+        now = datetime.now(timezone.utc)
+        seconds = (now - datetime(now.year, now.month, now.day, tzinfo=timezone.utc)).seconds
 
-    with Database() as db:
-        index_db = IndexDatabase(db.connection)
-        index_db.record_batches(infos)
+        # How to pad a string with zeros: https://stackoverflow.com/a/39402910
+        # Maximum seconds in a day is 60*60*24 = 86400, so 5 digits is enough
+        padded_seconds = str(seconds).zfill(5)
 
-    return {
-        'status': 'ok',
-        'public_user_id': user_id_hash,
-        'url': batch_url,
-    }
+        # See discussion here: https://stackoverflow.com/a/13484764
+        uid = str(uuid4())[:8]
+        filename = f'1/{VERSION}/{now.date()}/1/{user_id_hash}/{padded_seconds}__{uid}.json.gz'
 
+        # Using an approach from https://stackoverflow.com/a/30476450
+        epoch_time = (now - datetime(1970, 1, 1, tzinfo=timezone.utc)).total_seconds()
+        hashed_batch = HashedBatch(user_id_hash=user_id_hash, timestamp=epoch_time, items=batch.items)
+        data = gzip.compress(hashed_batch.json().encode('utf8'))
+        upload(data, filename)
 
-def _get_user_id_hash(batch: Union[Batch, NewBatchRequest]):
-    return hashlib.sha3_256(batch.user_id.encode('utf8')).hexdigest()
-
-
-@router.post('/batches/new')
-def request_new_batch(batch_request: NewBatchRequest):
-    user_id_hash = _get_user_id_hash(batch_request)
-
-    with Database() as db:
-        url_db = URLDatabase(db.connection)
-        return url_db.get_new_batch_for_user(user_id_hash)
+        global last_batch
+        last_batch = hashed_batch
 
+        batch_url = f'{PUBLIC_URL_PREFIX}{filename}'
+        batch_cache.store(hashed_batch, batch_url)
 
-@router.post('/batches/historical')
-def create_historical_batch(batch: HashedBatch):
-    """
-    Update the database state of URL crawling for old data
-    """
-    user_id_hash = batch.user_id_hash
-    batch_datetime = get_datetime_from_timestamp(batch.timestamp)
-    record_urls_in_database(batch, user_id_hash, batch_datetime)
+        # Record the batch as being local so that we don't retrieve it again when the server restarts
+        infos = [BatchInfo(batch_url, user_id_hash, BatchStatus.LOCAL)]
 
+        with Database() as db:
+            index_db = IndexDatabase(db.connection)
+            index_db.record_batches(infos)
 
-def get_datetime_from_timestamp(timestamp: int) -> datetime:
-    batch_datetime = datetime(1970, 1, 1, tzinfo=timezone.utc) + timedelta(seconds=timestamp)
-    return batch_datetime
-
-
-def record_urls_in_database(batch: Union[Batch, HashedBatch], user_id_hash: str, timestamp: datetime):
-    with Database() as db:
-        url_db = URLDatabase(db.connection)
-        url_scores = defaultdict(float)
-        for item in batch.items:
-            if item.content is not None:
-                crawled_page_domain = urlparse(item.url).netloc
-                score_multiplier = 1 if crawled_page_domain in DOMAINS else UNKNOWN_DOMAIN_MULTIPLIER
-                for link in item.content.links:
-                    parsed_link = urlparse(link)
-                    score = SCORE_FOR_SAME_DOMAIN if parsed_link.netloc == crawled_page_domain else SCORE_FOR_DIFFERENT_DOMAIN
-                    url_scores[link] += score * score_multiplier
-                    domain = f'{parsed_link.scheme}://{parsed_link.netloc}/'
-                    url_scores[domain] += SCORE_FOR_ROOT_PATH * score_multiplier
-
-        found_urls = [FoundURL(url, user_id_hash, score, URLStatus.NEW, timestamp) for url, score in url_scores.items()]
-        if len(found_urls) > 0:
-            url_db.update_found_urls(found_urls)
+        return {
+            'status': 'ok',
+            'public_user_id': user_id_hash,
+            'url': batch_url,
+        }
 
-        crawled_urls = [FoundURL(item.url, user_id_hash, 0.0, URLStatus.CRAWLED, timestamp)
-                        for item in batch.items]
-        url_db.update_found_urls(crawled_urls)
+    @router.post('/batches/new')
+    def request_new_batch(batch_request: NewBatchRequest):
+        user_id_hash = _get_user_id_hash(batch_request)
+
+        with Database() as db:
+            url_db = URLDatabase(db.connection)
+            return url_db.get_new_batch_for_user(user_id_hash)
+
+    @router.get('/batches/{date_str}/users/{public_user_id}')
+    def get_batches_for_date_and_user(date_str, public_user_id):
+        check_date_str(date_str)
+        check_public_user_id(public_user_id)
+        prefix = f'1/{VERSION}/{date_str}/1/{public_user_id}/'
+        return get_batch_ids_for_prefix(prefix)
+
+    @router.get('/batches/{date_str}/users/{public_user_id}/batch/{batch_id}')
+    def get_batch_from_id(date_str, public_user_id, batch_id):
+        url = get_batch_url(batch_id, date_str, public_user_id)
+        data = json.loads(gzip.decompress(requests.get(url).content))
+        return {
+            'url': url,
+            'batch': data,
+        }
 
+    @router.get('/latest-batch', response_model=list[HashedBatch])
+    def get_latest_batch():
+        return [] if last_batch is None else [last_batch]
 
-def get_batches_for_date(date_str):
-    check_date_str(date_str)
-    prefix = f'1/{VERSION}/{date_str}/1/'
-    cache_filename = prefix + 'batches.json.gz'
-    cache_url = PUBLIC_URL_PREFIX + cache_filename
-    try:
-        cached_batches = json.loads(gzip.decompress(requests.get(cache_url).content))
-        print(f"Got cached batches for {date_str}")
-        return cached_batches
-    except gzip.BadGzipFile:
-        pass
-
-    batches = get_batches_for_prefix(prefix)
-    result = {'batch_urls': [f'{PUBLIC_URL_PREFIX}{batch}' for batch in sorted(batches)]}
-    if date_str != str(date.today()):
-        # Don't cache data from today since it may change
-        data = gzip.compress(json.dumps(result).encode('utf8'))
-        upload(data, cache_filename)
-    print(f"Cached batches for {date_str} in {PUBLIC_URL_PREFIX}{cache_filename}")
-    return result
+    @router.get('/batches/{date_str}/users')
+    def get_user_id_hashes_for_date(date_str: str):
+        check_date_str(date_str)
+        prefix = f'1/{VERSION}/{date_str}/1/'
+        return get_subfolders(prefix)
 
+    @router.get('/')
+    def status():
+        return {
+            'status': 'ok'
+        }
 
-def get_user_id_hash_from_url(url):
-    return url.split('/')[9]
+    return router
 
 
-@router.get('/batches/{date_str}/users/{public_user_id}')
-def get_batches_for_date_and_user(date_str, public_user_id):
-    check_date_str(date_str)
-    check_public_user_id(public_user_id)
-    prefix = f'1/{VERSION}/{date_str}/1/{public_user_id}/'
-    return get_batch_ids_for_prefix(prefix)
+def _get_user_id_hash(batch: Union[Batch, NewBatchRequest]):
+    return hashlib.sha3_256(batch.user_id.encode('utf8')).hexdigest()
 
 
 def check_public_user_id(public_user_id):
@@ -203,16 +155,6 @@ def check_public_user_id(public_user_id):
         raise HTTPException(400, f"Incorrect public user ID length, should be {PUBLIC_USER_ID_LENGTH}")
 
 
-@router.get('/batches/{date_str}/users/{public_user_id}/batch/{batch_id}')
-def get_batch_from_id(date_str, public_user_id, batch_id):
-    url = get_batch_url(batch_id, date_str, public_user_id)
-    data = json.loads(gzip.decompress(requests.get(url).content))
-    return {
-        'url': url,
-        'batch': data,
-    }
-
-
 def get_batch_url(batch_id, date_str, public_user_id):
     check_date_str(date_str)
     check_public_user_id(public_user_id)
@@ -220,11 +162,6 @@ def get_batch_url(batch_id, date_str, public_user_id):
     return url
 
 
-@router.get('/latest-batch', response_model=list[HashedBatch])
-def get_latest_batch():
-    return [] if last_batch is None else [last_batch]
-
-
 def get_batch_id_from_file_name(file_name: str):
     assert file_name.endswith(FILE_NAME_SUFFIX)
     return file_name[:-len(FILE_NAME_SUFFIX)]
@@ -246,13 +183,6 @@ def get_batches_for_prefix(prefix):
     return filenames
 
 
-@router.get('/batches/{date_str}/users')
-def get_user_id_hashes_for_date(date_str: str):
-    check_date_str(date_str)
-    prefix = f'1/{VERSION}/{date_str}/1/'
-    return get_subfolders(prefix)
-
-
 def check_date_str(date_str):
     if not DATE_REGEX.match(date_str):
         raise HTTPException(400, f"Incorrect date format, should be YYYY-MM-DD")
@@ -268,17 +198,23 @@ def get_subfolders(prefix):
     return item_keys
 
 
-@router.get('/')
-def status():
-    return {
-        'status': 'ok'
-    }
-
+def get_batches_for_date(date_str):
+    check_date_str(date_str)
+    prefix = f'1/{VERSION}/{date_str}/1/'
+    cache_filename = prefix + 'batches.json.gz'
+    cache_url = PUBLIC_URL_PREFIX + cache_filename
+    try:
+        cached_batches = json.loads(gzip.decompress(requests.get(cache_url).content))
+        print(f"Got cached batches for {date_str}")
+        return cached_batches
+    except gzip.BadGzipFile:
+        pass
 
-def queue_batch(batch: HashedBatch):
-    # TODO: get the score from the URLs database
-    documents = [Document(item.content.title, item.url, item.content.extract, 1)
-                 for item in batch.items if item.content is not None]
-    with Database() as db:
-        index_db = IndexDatabase(db.connection)
-        index_db.queue_documents(documents)
+    batches = get_batches_for_prefix(prefix)
+    result = {'batch_urls': [f'{PUBLIC_URL_PREFIX}{batch}' for batch in sorted(batches)]}
+    if date_str != str(date.today()):
+        # Don't cache data from today since it may change
+        data = gzip.compress(json.dumps(result).encode('utf8'))
+        upload(data, cache_filename)
+    print(f"Cached batches for {date_str} in {PUBLIC_URL_PREFIX}{cache_filename}")
+    return result

+ 31 - 3
mwmbl/crawler/urls.py

@@ -5,6 +5,7 @@ import os
 from dataclasses import dataclass
 from datetime import datetime, timedelta
 from enum import Enum
+from typing import Iterable
 
 from psycopg2 import connect
 from psycopg2.extras import execute_values
@@ -21,9 +22,22 @@ class URLStatus(Enum):
     """
     URL state update is idempotent and can only progress forwards.
     """
-    NEW = 0         # One user has identified this URL
-    ASSIGNED = 2    # The crawler has given the URL to a user to crawl
-    CRAWLED = 3     # At least one user has crawled the URL
+    NEW = 0                   # One user has identified this URL
+    ASSIGNED = 10             # The crawler has given the URL to a user to crawl
+    ERROR_TIMEOUT = 20        # Timeout while retrieving
+    ERROR_404 = 30            # 404 response
+    ERROR_OTHER = 40          # Some other error
+    ERROR_ROBOTS_DENIED = 50  # Robots disallow this page
+    CRAWLED = 100             # At least one user has crawled the URL
+
+
+def batch(items: list, batch_size):
+    """
+    Adapted from https://stackoverflow.com/a/8290508
+    """
+    length = len(items)
+    for ndx in range(0, length, batch_size):
+        yield items[ndx:min(ndx + batch_size, length)]
 
 
 @dataclass
@@ -128,6 +142,20 @@ class URLDatabase:
 
         return [result[0] for result in results]
 
+    def get_url_scores(self, urls: list[str]) -> dict[str, float]:
+        sql = f"""
+        SELECT url, score FROM urls WHERE url IN %(urls)s
+        """
+
+        url_scores = {}
+        for url_batch in batch(urls, 10000):
+            with self.connection.cursor() as cursor:
+                cursor.execute(sql, {'urls': tuple(url_batch)})
+                results = cursor.fetchall()
+                url_scores.update({result[0]: result[1] for result in results})
+
+        return url_scores
+
 
 if __name__ == "__main__":
     with Database() as db:

+ 79 - 0
mwmbl/indexer/batch_cache.py

@@ -0,0 +1,79 @@
+"""
+Store for local batches.
+
+We store them in a directory on the local machine.
+"""
+import gzip
+import json
+import os
+from multiprocessing.pool import ThreadPool
+from pathlib import Path
+from tempfile import NamedTemporaryFile
+from urllib.parse import urlparse
+
+from pydantic import ValidationError
+
+from mwmbl.crawler.batch import HashedBatch
+from mwmbl.database import Database
+from mwmbl.indexer.indexdb import IndexDatabase, BatchStatus
+from mwmbl.retry import retry_requests
+
+
+class BatchCache:
+    num_threads = 20
+
+    def __init__(self, repo_path):
+        os.makedirs(repo_path, exist_ok=True)
+        self.path = repo_path
+
+    def get_cached(self, batch_urls: list[str]) -> dict[str, HashedBatch]:
+        batches = {}
+        for url in batch_urls:
+            path = self.get_path_from_url(url)
+            data = gzip.GzipFile(path).read()
+            batch = HashedBatch.parse_raw(data)
+            batches[url] = batch
+        return batches
+
+    def retrieve_batches(self, num_batches):
+        with Database() as db:
+            index_db = IndexDatabase(db.connection)
+            index_db.create_tables()
+
+        with Database() as db:
+            index_db = IndexDatabase(db.connection)
+            batches = index_db.get_batches_by_status(BatchStatus.REMOTE, num_batches)
+            print(f"Found {len(batches)} remote batches")
+            if len(batches) == 0:
+                return
+            urls = [batch.url for batch in batches]
+            pool = ThreadPool(self.num_threads)
+            results = pool.imap_unordered(self.retrieve_batch, urls)
+            total_processed = 0
+            for result in results:
+                total_processed += result
+            print("Processed batches with items:", total_processed)
+            index_db.update_batch_status(urls, BatchStatus.LOCAL)
+
+    def retrieve_batch(self, url):
+        data = json.loads(gzip.decompress(retry_requests.get(url).content))
+        try:
+            batch = HashedBatch.parse_obj(data)
+        except ValidationError:
+            print("Failed to validate batch", data)
+            return 0
+        if len(batch.items) > 0:
+            self.store(batch, url)
+        return len(batch.items)
+
+    def store(self, batch, url):
+        path = self.get_path_from_url(url)
+        print(f"Storing local batch at {path}")
+        os.makedirs(path.parent, exist_ok=True)
+        with open(path, 'wb') as output_file:
+            data = gzip.compress(batch.json().encode('utf8'))
+            output_file.write(data)
+
+    def get_path_from_url(self, url) -> Path:
+        url_path = urlparse(url).path
+        return Path(self.path) / url_path.lstrip('/')

+ 6 - 2
mwmbl/indexer/historical.py

@@ -1,10 +1,10 @@
 from datetime import date, timedelta
 
-from mwmbl.crawler.app import get_batches_for_date, get_user_id_hash_from_url
+from mwmbl.crawler.app import get_batches_for_date
 from mwmbl.database import Database
 from mwmbl.indexer.indexdb import BatchInfo, BatchStatus, IndexDatabase
 
-DAYS = 10
+DAYS = 20
 
 
 def run():
@@ -20,5 +20,9 @@ def run():
             index_db.record_batches(infos)
 
 
+def get_user_id_hash_from_url(url):
+    return url.split('/')[9]
+
+
 if __name__ == '__main__':
     run()

+ 150 - 0
mwmbl/indexer/index_batches.py

@@ -0,0 +1,150 @@
+"""
+Index batches that are stored locally.
+"""
+from collections import defaultdict
+from datetime import datetime, timezone, timedelta
+from logging import getLogger
+from typing import Iterable
+from urllib.parse import urlparse
+
+import spacy
+
+from mwmbl.crawler.batch import HashedBatch, Item
+from mwmbl.crawler.urls import URLDatabase, URLStatus, FoundURL
+from mwmbl.database import Database
+from mwmbl.hn_top_domains_filtered import DOMAINS
+from mwmbl.indexer.batch_cache import BatchCache
+from mwmbl.indexer.index import tokenize_document
+from mwmbl.indexer.indexdb import BatchStatus, IndexDatabase
+from mwmbl.settings import UNKNOWN_DOMAIN_MULTIPLIER, SCORE_FOR_SAME_DOMAIN, SCORE_FOR_DIFFERENT_DOMAIN, \
+    SCORE_FOR_ROOT_PATH
+from mwmbl.tinysearchengine.indexer import Document, TinyIndex
+
+logger = getLogger(__name__)
+
+
+def get_documents_from_batches(batches: Iterable[HashedBatch]) -> Iterable[tuple[str, str, str]]:
+    for batch in batches:
+        for item in batch.items:
+            if item.content is not None:
+                yield item.content.title, item.url, item.content.extract
+
+
+def run(batch_cache: BatchCache, index_path: str):
+    nlp = spacy.load("en_core_web_sm")
+    with Database() as db:
+        index_db = IndexDatabase(db.connection)
+
+        logger.info("Getting local batches")
+        batches = index_db.get_batches_by_status(BatchStatus.LOCAL, 10000)
+        logger.info(f"Got {len(batches)} batch urls")
+        if len(batches) == 0:
+            return
+
+        batch_data = batch_cache.get_cached([batch.url for batch in batches])
+        logger.info(f"Got {len(batch_data)} cached batches")
+
+        record_urls_in_database(batch_data.values())
+
+        document_tuples = list(get_documents_from_batches(batch_data.values()))
+        urls = [url for title, url, extract in document_tuples]
+
+        logger.info(f"Got {len(urls)} document tuples")
+
+        url_db = URLDatabase(db.connection)
+        url_scores = url_db.get_url_scores(urls)
+
+        logger.info(f"Got {len(url_scores)} scores")
+        documents = [Document(title, url, extract, url_scores.get(url, 1.0)) for title, url, extract in document_tuples]
+
+        page_documents = preprocess_documents(documents, index_path, nlp)
+        index_pages(index_path, page_documents)
+        logger.info("Indexed pages")
+        index_db.update_batch_status([batch.url for batch in batches], BatchStatus.INDEXED)
+
+
+def index_pages(index_path, page_documents):
+    with TinyIndex(Document, index_path, 'w') as indexer:
+        for page, documents in page_documents.items():
+            new_documents = []
+            existing_documents = indexer.get_page(page)
+            seen_urls = set()
+            seen_titles = set()
+            sorted_documents = sorted(documents + existing_documents, key=lambda x: x.score)
+            for document in sorted_documents:
+                if document.title in seen_titles or document.url in seen_urls:
+                    continue
+                new_documents.append(document)
+                seen_urls.add(document.url)
+                seen_titles.add(document.title)
+            indexer.store_in_page(page, new_documents)
+            logger.debug(f"Wrote page {page} with {len(new_documents)} documents")
+
+
+def preprocess_documents(documents, index_path, nlp):
+    page_documents = defaultdict(list)
+    with TinyIndex(Document, index_path, 'w') as indexer:
+        for document in documents:
+            tokenized = tokenize_document(document.url, document.title, document.extract, document.score, nlp)
+            # logger.debug(f"Tokenized: {tokenized}")
+            page_indexes = [indexer.get_key_page_index(token) for token in tokenized.tokens]
+            for page in page_indexes:
+                page_documents[page].append(document)
+    print(f"Preprocessed for {len(page_documents)} pages")
+    return page_documents
+
+
+def get_url_error_status(item: Item):
+    if item.status == 404:
+        return URLStatus.ERROR_404
+    if item.error is not None:
+        if item.error.name == 'AbortError':
+            return URLStatus.ERROR_TIMEOUT
+        elif item.error.name == 'RobotsDenied':
+            return URLStatus.ERROR_ROBOTS_DENIED
+    return URLStatus.ERROR_OTHER
+
+
+def record_urls_in_database(batches: Iterable[HashedBatch]):
+    with Database() as db:
+        url_db = URLDatabase(db.connection)
+        url_scores = defaultdict(float)
+        url_users = {}
+        url_timestamps = {}
+        url_statuses = defaultdict(lambda: URLStatus.NEW)
+        for batch in batches:
+            for item in batch.items:
+                timestamp = get_datetime_from_timestamp(item.timestamp / 1000.0)
+                url_timestamps[item.url] = timestamp
+                url_users[item.url] = batch.user_id_hash
+                if item.content is None:
+                    url_statuses[item.url] = get_url_error_status(item)
+                else:
+                    url_statuses[item.url] = URLStatus.CRAWLED
+                    crawled_page_domain = urlparse(item.url).netloc
+                    score_multiplier = 1 if crawled_page_domain in DOMAINS else UNKNOWN_DOMAIN_MULTIPLIER
+                    for link in item.content.links:
+                        parsed_link = urlparse(link)
+                        score = SCORE_FOR_SAME_DOMAIN if parsed_link.netloc == crawled_page_domain else SCORE_FOR_DIFFERENT_DOMAIN
+                        url_scores[link] += score * score_multiplier
+                        url_users[link] = batch.user_id_hash
+                        url_timestamps[link] = timestamp
+                        domain = f'{parsed_link.scheme}://{parsed_link.netloc}/'
+                        url_scores[domain] += SCORE_FOR_ROOT_PATH * score_multiplier
+                        url_users[domain] = batch.user_id_hash
+                        url_timestamps[domain] = timestamp
+
+        found_urls = [FoundURL(url, url_users[url], url_scores[url], url_statuses[url], url_timestamps[url])
+                      for url in url_scores.keys() | url_statuses.keys()]
+
+        url_db.update_found_urls(found_urls)
+
+
+def get_datetime_from_timestamp(timestamp: float) -> datetime:
+    batch_datetime = datetime(1970, 1, 1, tzinfo=timezone.utc) + timedelta(seconds=timestamp)
+    return batch_datetime
+
+
+# TODO: clean unicode at some point
+def clean_unicode(s: str) -> str:
+    return s.encode('utf-8', 'ignore').decode('utf-8')

+ 4 - 120
mwmbl/indexer/indexdb.py

@@ -6,17 +6,11 @@ from enum import Enum
 
 from psycopg2.extras import execute_values
 
-from mwmbl.tinysearchengine.indexer import Document
-
 
 class BatchStatus(Enum):
     REMOTE = 0    # The batch only exists in long term storage
     LOCAL = 1     # We have a copy of the batch locally in Postgresql
-
-
-class DocumentStatus(Enum):
-    NEW = 0
-    PREPROCESSING = 1
+    INDEXED = 2
 
 
 @dataclass
@@ -39,32 +33,8 @@ class IndexDatabase:
         )
         """
 
-        documents_sql = """
-        CREATE TABLE IF NOT EXISTS documents (
-            url VARCHAR PRIMARY KEY,
-            title VARCHAR NOT NULL,
-            extract VARCHAR NOT NULL,
-            score FLOAT NOT NULL,
-            status INT NOT NULL
-        )
-        """
-
-        document_pages_sql = """
-        CREATE TABLE IF NOT EXISTS document_pages (
-            url VARCHAR NOT NULL,
-            page INT NOT NULL
-        ) 
-        """
-
-        document_pages_index_sql = """
-        CREATE INDEX IF NOT EXISTS document_pages_page_index ON document_pages (page)
-        """
-
         with self.connection.cursor() as cursor:
             cursor.execute(batches_sql)
-            cursor.execute(documents_sql)
-            cursor.execute(document_pages_sql)
-            cursor.execute(document_pages_index_sql)
 
     def record_batches(self, batch_infos: list[BatchInfo]):
         sql = """
@@ -77,13 +47,13 @@ class IndexDatabase:
         with self.connection.cursor() as cursor:
             execute_values(cursor, sql, data)
 
-    def get_batches_by_status(self, status: BatchStatus) -> list[BatchInfo]:
+    def get_batches_by_status(self, status: BatchStatus, num_batches=1000) -> list[BatchInfo]:
         sql = """
-        SELECT * FROM batches WHERE status = %(status)s LIMIT 1000
+        SELECT * FROM batches WHERE status = %(status)s LIMIT %(num_batches)s
         """
 
         with self.connection.cursor() as cursor:
-            cursor.execute(sql, {'status': status.value})
+            cursor.execute(sql, {'status': status.value, 'num_batches': num_batches})
             results = cursor.fetchall()
             return [BatchInfo(url, user_id_hash, status) for url, user_id_hash, status in results]
 
@@ -95,89 +65,3 @@ class IndexDatabase:
 
         with self.connection.cursor() as cursor:
             cursor.execute(sql, {'status': status.value, 'urls': tuple(batch_urls)})
-
-    def queue_documents(self, documents: list[Document]):
-        sql = """
-        INSERT INTO documents (url, title, extract, score, status)
-        VALUES %s
-        ON CONFLICT (url) DO NOTHING
-        """
-
-        sorted_documents = sorted(documents, key=lambda x: x.url)
-        data = [(document.url, clean_unicode(document.title), clean_unicode(document.extract),
-                 document.score, DocumentStatus.NEW.value)
-                for document in sorted_documents]
-
-        print("Queueing documents", len(data))
-        with self.connection.cursor() as cursor:
-            execute_values(cursor, sql, data)
-
-    def get_documents_for_preprocessing(self):
-        sql = f"""
-        UPDATE documents SET status = {DocumentStatus.PREPROCESSING.value}
-        WHERE url IN (
-            SELECT url FROM documents
-            WHERE status = {DocumentStatus.NEW.value}
-            LIMIT 1000
-            FOR UPDATE SKIP LOCKED
-        )
-        RETURNING url, title, extract, score
-        """
-
-        with self.connection.cursor() as cursor:
-            cursor.execute(sql)
-            results = cursor.fetchall()
-            return [Document(title, url, extract, score) for url, title, extract, score in results]
-
-    def clear_documents_for_preprocessing(self) -> int:
-        sql = f"""
-        DELETE FROM documents WHERE status = {DocumentStatus.PREPROCESSING.value}
-        """
-
-        with self.connection.cursor() as cursor:
-            cursor.execute(sql)
-            return cursor.rowcount
-
-    def queue_documents_for_page(self, urls_and_page_indexes: list[tuple[str, int]]):
-        sql = """
-        INSERT INTO document_pages (url, page) values %s
-        """
-
-        print(f"Queuing {len(urls_and_page_indexes)} urls and page indexes")
-        with self.connection.cursor() as cursor:
-            execute_values(cursor, sql, urls_and_page_indexes)
-
-    def get_queued_documents_for_page(self, page_index: int) -> list[Document]:
-        sql = """
-        SELECT d.url, title, extract, score
-        FROM document_pages p INNER JOIN documents d ON p.url = d.url
-        WHERE p.page = %(page_index)s
-        """
-
-        with self.connection.cursor() as cursor:
-            cursor.execute(sql, {'page_index': page_index})
-            results = cursor.fetchall()
-            return [Document(title, url, extract, score) for url, title, extract, score in results]
-
-    def get_queued_pages(self) -> list[int]:
-        sql = """
-        SELECT DISTINCT page FROM document_pages ORDER BY page
-        """
-
-        with self.connection.cursor() as cursor:
-            cursor.execute(sql)
-            results = cursor.fetchall()
-            return [x[0] for x in results]
-
-    def clear_queued_documents_for_page(self, page_index: int):
-        sql = """
-        DELETE FROM document_pages
-        WHERE page = %(page_index)s
-        """
-
-        with self.connection.cursor() as cursor:
-            cursor.execute(sql, {'page_index': page_index})
-
-
-def clean_unicode(s: str) -> str:
-    return s.encode('utf-8', 'ignore').decode('utf-8')

+ 3 - 0
mwmbl/indexer/paths.py

@@ -26,3 +26,6 @@ TOP_DOMAINS_JSON_PATH = TINYSEARCH_DATA_DIR / 'hn-top-domains.json'
 MWMBL_DATA_DIR = DATA_DIR / "mwmbl"
 CRAWL_GLOB = str(MWMBL_DATA_DIR / "b2") + "/*/*/*/*/*/*.json.gz"
 LINK_COUNT_PATH = MWMBL_DATA_DIR / 'crawl-counts.json'
+
+INDEX_NAME = 'index.tinysearch'
+BATCH_DIR_NAME = 'batches'

+ 0 - 48
mwmbl/indexer/preprocess.py

@@ -1,48 +0,0 @@
-"""
-Preprocess local documents for indexing.
-"""
-import traceback
-from logging import getLogger
-from time import sleep
-
-import spacy
-
-from mwmbl.database import Database
-from mwmbl.indexer.indexdb import IndexDatabase
-from mwmbl.indexer.index import tokenize_document
-from mwmbl.tinysearchengine.indexer import TinyIndex, Document
-
-
-logger = getLogger(__name__)
-
-
-def run(index_path):
-    while True:
-        try:
-            run_preprocessing(index_path)
-        except Exception as e:
-            print("Exception preprocessing")
-            traceback.print_exception(type(e), e, e.__traceback__)
-            sleep(10)
-
-
-def run_preprocessing(index_path):
-    nlp = spacy.load("en_core_web_sm")
-    with Database() as db:
-        index_db = IndexDatabase(db.connection)
-        for i in range(100):
-            documents = index_db.get_documents_for_preprocessing()
-            print(f"Got {len(documents)} documents for preprocessing")
-            if len(documents) == 0:
-                break
-            with TinyIndex(Document, index_path, 'w') as indexer:
-                for document in documents:
-                    tokenized = tokenize_document(document.url, document.title, document.extract, 1, nlp)
-                    logger.debug(f"Tokenized: {tokenized}")
-                    page_indexes = [indexer.get_key_page_index(token) for token in tokenized.tokens]
-                    logger.debug(f"Page indexes: {page_indexes}")
-                    index_db.queue_documents_for_page([(tokenized.url, i) for i in page_indexes])
-
-
-if __name__ == '__main__':
-    run('data/index.tinysearch')

+ 0 - 68
mwmbl/indexer/retrieve.py

@@ -1,68 +0,0 @@
-"""
-Retrieve remote batches and store them in Postgres locally
-"""
-import gzip
-import json
-import traceback
-from multiprocessing.pool import ThreadPool
-from time import sleep
-
-from pydantic import ValidationError
-
-from mwmbl.crawler.app import create_historical_batch, queue_batch
-from mwmbl.crawler.batch import HashedBatch
-from mwmbl.database import Database
-from mwmbl.indexer.indexdb import IndexDatabase, BatchStatus
-from mwmbl.retry import retry_requests
-
-NUM_THREADS = 5
-
-
-def retrieve_batches():
-    with Database() as db:
-        index_db = IndexDatabase(db.connection)
-        index_db.create_tables()
-
-    with Database() as db:
-        index_db = IndexDatabase(db.connection)
-
-        for i in range(100):
-            batches = index_db.get_batches_by_status(BatchStatus.REMOTE)
-            print(f"Found {len(batches)} remote batches")
-            if len(batches) == 0:
-                return
-            urls = [batch.url for batch in batches]
-            pool = ThreadPool(NUM_THREADS)
-            results = pool.imap_unordered(retrieve_batch, urls)
-            for result in results:
-                if result > 0:
-                    print("Processed batch with items:", result)
-            index_db.update_batch_status(urls, BatchStatus.LOCAL)
-
-
-def retrieve_batch(url):
-    data = json.loads(gzip.decompress(retry_requests.get(url).content))
-    try:
-        batch = HashedBatch.parse_obj(data)
-    except ValidationError:
-        print("Failed to validate batch", data)
-        return 0
-    if len(batch.items) > 0:
-        print(f"Retrieved batch with {len(batch.items)} items")
-        create_historical_batch(batch)
-        queue_batch(batch)
-    return len(batch.items)
-
-
-def run():
-    while True:
-        try:
-            retrieve_batches()
-        except Exception as e:
-            print("Exception retrieving batch")
-            traceback.print_exception(type(e), e, e.__traceback__)
-        sleep(10)
-
-
-if __name__ == '__main__':
-    retrieve_batches()

+ 0 - 54
mwmbl/indexer/update_pages.py

@@ -1,54 +0,0 @@
-"""
-Iterate over each page in the index and update it based on what is in the index database.
-"""
-import traceback
-from time import sleep
-
-from mwmbl.database import Database
-from mwmbl.indexer.indexdb import IndexDatabase
-from mwmbl.tinysearchengine.indexer import TinyIndex, Document, PageError
-
-
-def run_update(index_path):
-    with Database() as db:
-        index_db = IndexDatabase(db.connection)
-        index_db.create_tables()
-
-    with TinyIndex(Document, index_path, 'w') as indexer:
-        with Database() as db:
-            index_db = IndexDatabase(db.connection)
-            pages_to_process = index_db.get_queued_pages()
-            print(f"Got {len(pages_to_process)} pages to process")
-            for i in pages_to_process:
-                documents = index_db.get_queued_documents_for_page(i)
-                print(f"Documents queued for page {i}: {len(documents)}")
-                if len(documents) > 0:
-                    for j in range(20):
-                        try:
-                            indexer.add_to_page(i, documents)
-                            break
-                        except PageError:
-                            documents = documents[:len(documents)//2]
-                            if len(documents) == 0:
-                                print("No more space")
-                                break
-                            print(f"Not enough space, adding {len(documents)}")
-                index_db.clear_queued_documents_for_page(i)
-            # All preprocessed documents should now have been indexed
-            # Clear documents that have now been preprocessed and indexed
-            num_cleared = index_db.clear_documents_for_preprocessing()
-            print(f"Indexed {num_cleared} documents")
-
-
-def run(index_path):
-    while True:
-        try:
-            run_update(index_path)
-        except Exception as e:
-            print("Exception updating pages in index")
-            traceback.print_exception(type(e), e, e.__traceback__)
-            sleep(10)
-
-
-if __name__ == '__main__':
-    run_update('data/index.tinysearch')

+ 14 - 12
mwmbl/main.py

@@ -3,13 +3,15 @@ import logging
 import os
 import sys
 from multiprocessing import Process
+from pathlib import Path
 
 import uvicorn
 from fastapi import FastAPI
 
 from mwmbl import background
-from mwmbl.indexer import historical, retrieve, preprocess, update_pages
-from mwmbl.crawler.app import router as crawler_router
+from mwmbl.crawler import app as crawler
+from mwmbl.indexer.batch_cache import BatchCache
+from mwmbl.indexer.paths import INDEX_NAME, BATCH_DIR_NAME
 from mwmbl.tinysearchengine import search
 from mwmbl.tinysearchengine.completer import Completer
 from mwmbl.tinysearchengine.indexer import TinyIndex, Document, NUM_PAGES, PAGE_SIZE
@@ -20,7 +22,7 @@ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
 
 def setup_args():
     parser = argparse.ArgumentParser(description="mwmbl-tinysearchengine")
-    parser.add_argument("--index", help="Path to the tinysearchengine index file", default="/app/storage/index.tinysearch")
+    parser.add_argument("--data", help="Path to the tinysearchengine index file", default="/app/storage/")
     args = parser.parse_args()
     return args
 
@@ -28,29 +30,26 @@ def setup_args():
 def run():
     args = setup_args()
 
+    index_path = Path(args.data) / INDEX_NAME
     try:
-        existing_index = TinyIndex(item_factory=Document, index_path=args.index)
+        existing_index = TinyIndex(item_factory=Document, index_path=index_path)
         if existing_index.page_size != PAGE_SIZE or existing_index.num_pages != NUM_PAGES:
             print(f"Existing index page sizes ({existing_index.page_size}) and number of pages "
                   f"({existing_index.num_pages}) does not match - removing.")
-            os.remove(args.index)
+            os.remove(index_path)
             existing_index = None
     except FileNotFoundError:
         existing_index = None
 
     if existing_index is None:
         print("Creating a new index")
-        TinyIndex.create(item_factory=Document, index_path=args.index, num_pages=NUM_PAGES, page_size=PAGE_SIZE)
+        TinyIndex.create(item_factory=Document, index_path=index_path, num_pages=NUM_PAGES, page_size=PAGE_SIZE)
 
-    Process(target=background.run, args=(args.index,)).start()
-    # Process(target=historical.run).start()
-    # Process(target=retrieve.run).start()
-    # Process(target=preprocess.run, args=(args.index,)).start()
-    # Process(target=update_pages.run, args=(args.index,)).start()
+    Process(target=background.run, args=(args.data,)).start()
 
     completer = Completer()
 
-    with TinyIndex(item_factory=Document, index_path=args.index) as tiny_index:
+    with TinyIndex(item_factory=Document, index_path=index_path) as tiny_index:
         ranker = HeuristicRanker(tiny_index, completer)
 
         # Initialize FastApi instance
@@ -58,6 +57,9 @@ def run():
 
         search_router = search.create_router(ranker)
         app.include_router(search_router)
+
+        batch_cache = BatchCache(Path(args.data) / BATCH_DIR_NAME)
+        crawler_router = crawler.get_router(batch_cache)
         app.include_router(crawler_router)
 
         # Initialize uvicorn server using global app instance and server config params

+ 24 - 7
mwmbl/tinysearchengine/indexer.py

@@ -1,7 +1,7 @@
 import json
 import os
 from dataclasses import astuple, dataclass, asdict
-from io import UnsupportedOperation
+from io import UnsupportedOperation, BytesIO
 from logging import getLogger
 from mmap import mmap, PROT_READ, PROT_WRITE
 from typing import TypeVar, Generic, Callable, List
@@ -13,7 +13,7 @@ VERSION = 1
 METADATA_CONSTANT = b'mwmbl-tiny-search'
 METADATA_SIZE = 4096
 
-NUM_PAGES = 5_120_000
+NUM_PAGES = 10_240_000
 PAGE_SIZE = 4096
 
 
@@ -63,9 +63,21 @@ class TinyIndexMetadata:
         return TinyIndexMetadata(**values)
 
 
-def _get_page_data(compressor, page_size, data):
-    serialised_data = json.dumps(data)
-    compressed_data = compressor.compress(serialised_data.encode('utf8'))
+def _get_page_data(compressor: ZstdCompressor, page_size: int, items: list[T]):
+    bytes_io = BytesIO()
+    stream_writer = compressor.stream_writer(bytes_io, write_size=128)
+
+    num_fitting = 0
+    for i, item in enumerate(items):
+        serialised_data = json.dumps(item) + '\n'
+        stream_writer.write(serialised_data.encode('utf8'))
+        stream_writer.flush()
+        if len(bytes_io.getvalue()) > page_size:
+            break
+        num_fitting = i + 1
+
+    compressed_data = compressor.compress(json.dumps(items[:num_fitting]).encode('utf8'))
+    assert len(compressed_data) < page_size, "The data shouldn't get bigger"
     return _pad_to_page_size(compressed_data, page_size)
 
 
@@ -157,15 +169,20 @@ class TinyIndex(Generic[T]):
         current_page += value_tuples
         self._write_page(current_page, page_index)
 
-    def _write_page(self, data, i):
+    def store_in_page(self, page_index: int, values: list[T]):
+        value_tuples = [astuple(value) for value in values]
+        self._write_page(value_tuples, page_index)
+
+    def _write_page(self, data, i: int):
         """
         Serialise the data using JSON, compress it and store it at index i.
-        If the data is too big, it will raise a ValueError and not store anything
+        If the data is too big, it will store the first items in the list and discard the rest.
         """
         if self.mode != 'w':
             raise UnsupportedOperation("The file is open in read mode, you cannot write")
 
         page_data = _get_page_data(self.compressor, self.page_size, data)
+        logger.debug(f"Got page data of length {len(page_data)}")
         self.mmap[i * self.page_size:(i+1) * self.page_size] = page_data
 
     @staticmethod

+ 1 - 10
test/test_indexdb.py

@@ -1,13 +1,4 @@
-from mwmbl.database import Database
-from mwmbl.indexer.indexdb import IndexDatabase, clean_unicode
-from mwmbl.tinysearchengine.indexer import Document
-
-
-def test_bad_unicode_encoding():
-    bad_doc = Document('Good title', 'https://goodurl.com', 'Bad extract text \ud83c', 1.0)
-    with Database() as db:
-        index_db = IndexDatabase(db.connection)
-        index_db.queue_documents([bad_doc])
+from mwmbl.indexer.index_batches import clean_unicode
 
 
 def test_clean_unicode():