From 0d1e7d841c57c24dc6f4469b13ed4c8c81940198 Mon Sep 17 00:00:00 2001
From: Daoud Clarke <daoud.clarke@gmail.com>
Date: Tue, 19 Jul 2022 21:18:43 +0100
Subject: [PATCH] Implement a batch cache to store files locally before
 preprocessing

---
 mwmbl/background.py          | 10 ++++--
 mwmbl/indexer/batch_cache.py | 69 ++++++++++++++++++++++++++++++++++++
 mwmbl/indexer/paths.py       |  3 ++
 mwmbl/indexer/retrieve.py    | 68 -----------------------------------
 mwmbl/main.py                | 15 ++++----
 5 files changed, 86 insertions(+), 79 deletions(-)
 create mode 100644 mwmbl/indexer/batch_cache.py
 delete mode 100644 mwmbl/indexer/retrieve.py

diff --git a/mwmbl/background.py b/mwmbl/background.py
index 89f4829..3e6df0d 100644
--- a/mwmbl/background.py
+++ b/mwmbl/background.py
@@ -2,21 +2,25 @@
 Script that updates data in a background process.
 """
 from logging import getLogger
+from pathlib import Path
 from time import sleep
 
 from mwmbl.indexer import historical
+from mwmbl.indexer.batch_repo import BatchCache
+from mwmbl.indexer.paths import INDEX_PATH, BATCH_DIR_NAME
 from mwmbl.indexer.preprocess import run_preprocessing
-from mwmbl.indexer.retrieve import retrieve_batches
 from mwmbl.indexer.update_pages import run_update
 
 logger = getLogger(__name__)
 
 
-def run(index_path: str):
+def run(data_path: str):
     historical.run()
+    index_path = Path(data_path) / INDEX_PATH
+    batch_cache = BatchCache(Path(data_path) / BATCH_DIR_NAME)
     while True:
         try:
-            retrieve_batches()
+            batch_cache.retrieve_batches()
         except Exception:
             logger.exception("Error retrieving batches")
         try:
diff --git a/mwmbl/indexer/batch_cache.py b/mwmbl/indexer/batch_cache.py
new file mode 100644
index 0000000..c012212
--- /dev/null
+++ b/mwmbl/indexer/batch_cache.py
@@ -0,0 +1,69 @@
+"""
+Store for local batches.
+
+We store them in a directory on the local machine.
+"""
+import gzip
+import json
+import os
+from multiprocessing.pool import ThreadPool
+from tempfile import NamedTemporaryFile
+
+from pydantic import ValidationError
+
+from mwmbl.crawler.batch import HashedBatch
+from mwmbl.database import Database
+from mwmbl.indexer.indexdb import IndexDatabase, BatchStatus
+from mwmbl.retry import retry_requests
+
+
+class BatchCache:
+    num_threads = 8
+
+    def __init__(self, repo_path):
+        os.makedirs(repo_path, exist_ok=True)
+        self.path = repo_path
+
+    def store(self, batch: HashedBatch):
+        with NamedTemporaryFile(mode='w', dir=self.path, prefix='batch_', suffix='.json', delete=False) as output_file:
+            output_file.write(batch.json())
+
+    def get(self, num_batches) -> dict[str, HashedBatch]:
+        batches = {}
+        for path in os.listdir(self.path):
+            batch = HashedBatch.parse_file(path)
+            while len(batches) < num_batches:
+                batches[path] = batch
+        return batches
+
+    def retrieve_batches(self, num_thousand_batches=10):
+        with Database() as db:
+            index_db = IndexDatabase(db.connection)
+            index_db.create_tables()
+
+        with Database() as db:
+            index_db = IndexDatabase(db.connection)
+
+            for i in range(num_thousand_batches):
+                batches = index_db.get_batches_by_status(BatchStatus.REMOTE)
+                print(f"Found {len(batches)} remote batches")
+                if len(batches) == 0:
+                    return
+                urls = [batch.url for batch in batches]
+                pool = ThreadPool(self.num_threads)
+                results = pool.imap_unordered(self.retrieve_batch, urls)
+                for result in results:
+                    if result > 0:
+                        print("Processed batch with items:", result)
+                index_db.update_batch_status(urls, BatchStatus.LOCAL)
+
+    def retrieve_batch(self, url):
+        data = json.loads(gzip.decompress(retry_requests.get(url).content))
+        try:
+            batch = HashedBatch.parse_obj(data)
+        except ValidationError:
+            print("Failed to validate batch", data)
+            return 0
+        if len(batch.items) > 0:
+            self.store(batch)
+        return len(batch.items)
diff --git a/mwmbl/indexer/paths.py b/mwmbl/indexer/paths.py
index 5b02c41..b23998e 100644
--- a/mwmbl/indexer/paths.py
+++ b/mwmbl/indexer/paths.py
@@ -26,3 +26,6 @@ TOP_DOMAINS_JSON_PATH = TINYSEARCH_DATA_DIR / 'hn-top-domains.json'
 MWMBL_DATA_DIR = DATA_DIR / "mwmbl"
 CRAWL_GLOB = str(MWMBL_DATA_DIR / "b2") + "/*/*/*/*/*/*.json.gz"
 LINK_COUNT_PATH = MWMBL_DATA_DIR / 'crawl-counts.json'
+
+INDEX_NAME = 'index.tinysearch'
+BATCH_DIR_NAME = 'batches'
\ No newline at end of file
diff --git a/mwmbl/indexer/retrieve.py b/mwmbl/indexer/retrieve.py
deleted file mode 100644
index 6dda806..0000000
--- a/mwmbl/indexer/retrieve.py
+++ /dev/null
@@ -1,68 +0,0 @@
-"""
-Retrieve remote batches and store them in Postgres locally
-"""
-import gzip
-import json
-import traceback
-from multiprocessing.pool import ThreadPool
-from time import sleep
-
-from pydantic import ValidationError
-
-from mwmbl.crawler.app import create_historical_batch, queue_batch
-from mwmbl.crawler.batch import HashedBatch
-from mwmbl.database import Database
-from mwmbl.indexer.indexdb import IndexDatabase, BatchStatus
-from mwmbl.retry import retry_requests
-
-NUM_THREADS = 5
-
-
-def retrieve_batches():
-    with Database() as db:
-        index_db = IndexDatabase(db.connection)
-        index_db.create_tables()
-
-    with Database() as db:
-        index_db = IndexDatabase(db.connection)
-
-        for i in range(100):
-            batches = index_db.get_batches_by_status(BatchStatus.REMOTE)
-            print(f"Found {len(batches)} remote batches")
-            if len(batches) == 0:
-                return
-            urls = [batch.url for batch in batches]
-            pool = ThreadPool(NUM_THREADS)
-            results = pool.imap_unordered(retrieve_batch, urls)
-            for result in results:
-                if result > 0:
-                    print("Processed batch with items:", result)
-            index_db.update_batch_status(urls, BatchStatus.LOCAL)
-
-
-def retrieve_batch(url):
-    data = json.loads(gzip.decompress(retry_requests.get(url).content))
-    try:
-        batch = HashedBatch.parse_obj(data)
-    except ValidationError:
-        print("Failed to validate batch", data)
-        return 0
-    if len(batch.items) > 0:
-        print(f"Retrieved batch with {len(batch.items)} items")
-        create_historical_batch(batch)
-        queue_batch(batch)
-    return len(batch.items)
-
-
-def run():
-    while True:
-        try:
-            retrieve_batches()
-        except Exception as e:
-            print("Exception retrieving batch")
-            traceback.print_exception(type(e), e, e.__traceback__)
-        sleep(10)
-
-
-if __name__ == '__main__':
-    retrieve_batches()
diff --git a/mwmbl/main.py b/mwmbl/main.py
index baa8e9c..ac2fade 100644
--- a/mwmbl/main.py
+++ b/mwmbl/main.py
@@ -3,6 +3,7 @@ import logging
 import os
 import sys
 from multiprocessing import Process
+from pathlib import Path
 
 import uvicorn
 from fastapi import FastAPI
@@ -10,6 +11,7 @@ from fastapi import FastAPI
 from mwmbl import background
 from mwmbl.indexer import historical, retrieve, preprocess, update_pages
 from mwmbl.crawler.app import router as crawler_router
+from mwmbl.indexer.paths import INDEX_NAME
 from mwmbl.tinysearchengine import search
 from mwmbl.tinysearchengine.completer import Completer
 from mwmbl.tinysearchengine.indexer import TinyIndex, Document, NUM_PAGES, PAGE_SIZE
@@ -20,7 +22,7 @@ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
 
 def setup_args():
     parser = argparse.ArgumentParser(description="mwmbl-tinysearchengine")
-    parser.add_argument("--index", help="Path to the tinysearchengine index file", default="/app/storage/index.tinysearch")
+    parser.add_argument("--data", help="Path to the tinysearchengine index file", default="/app/storage/")
     args = parser.parse_args()
     return args
 
@@ -28,8 +30,9 @@ def setup_args():
 def run():
     args = setup_args()
 
+    index_path = Path(args.data) / INDEX_NAME
     try:
-        existing_index = TinyIndex(item_factory=Document, index_path=args.index)
+        existing_index = TinyIndex(item_factory=Document, index_path=index_path)
         if existing_index.page_size != PAGE_SIZE or existing_index.num_pages != NUM_PAGES:
             print(f"Existing index page sizes ({existing_index.page_size}) and number of pages "
                   f"({existing_index.num_pages}) does not match - removing.")
@@ -42,15 +45,11 @@ def run():
         print("Creating a new index")
         TinyIndex.create(item_factory=Document, index_path=args.index, num_pages=NUM_PAGES, page_size=PAGE_SIZE)
 
-    Process(target=background.run, args=(args.index,)).start()
-    # Process(target=historical.run).start()
-    # Process(target=retrieve.run).start()
-    # Process(target=preprocess.run, args=(args.index,)).start()
-    # Process(target=update_pages.run, args=(args.index,)).start()
+    Process(target=background.run, args=(args.data,)).start()
 
     completer = Completer()
 
-    with TinyIndex(item_factory=Document, index_path=args.index) as tiny_index:
+    with TinyIndex(item_factory=Document, index_path=index_path) as tiny_index:
         ranker = HeuristicRanker(tiny_index, completer)
 
         # Initialize FastApi instance