Browse Source

Merge pull request #70 from mwmbl/reduce-new-batch-contention

Reduce new batch contention
Daoud Clarke 2 years ago
parent
commit
aa5878fd2f
6 changed files with 95 additions and 39 deletions
  1. 8 1
      mwmbl/background.py
  2. 13 4
      mwmbl/crawler/app.py
  3. 24 31
      mwmbl/crawler/urls.py
  4. 4 3
      mwmbl/main.py
  5. 39 0
      mwmbl/url_queue.py
  6. 7 0
      mwmbl/utils.py

+ 8 - 1
mwmbl/background.py

@@ -2,21 +2,28 @@
 Script that updates data in a background process.
 """
 from logging import getLogger
+from multiprocessing import Queue
 from pathlib import Path
 from time import sleep
 
 from mwmbl.indexer import index_batches, historical
 from mwmbl.indexer.batch_cache import BatchCache
 from mwmbl.indexer.paths import BATCH_DIR_NAME, INDEX_NAME
+from mwmbl.url_queue import update_url_queue, initialize_url_queue
 
 logger = getLogger(__name__)
 
 
-def run(data_path: str):
+def run(data_path: str, url_queue: Queue):
+    initialize_url_queue(url_queue)
     historical.run()
     index_path = Path(data_path) / INDEX_NAME
     batch_cache = BatchCache(Path(data_path) / BATCH_DIR_NAME)
     while True:
+        try:
+            update_url_queue(url_queue)
+        except Exception:
+            logger.exception("Error updating URL queue")
         try:
             batch_cache.retrieve_batches(num_batches=10000)
         except Exception:

+ 13 - 4
mwmbl/crawler/app.py

@@ -2,6 +2,8 @@ import gzip
 import hashlib
 import json
 from datetime import datetime, timezone, date
+from multiprocessing import Queue
+from queue import Empty
 from typing import Union
 from uuid import uuid4
 
@@ -10,7 +12,7 @@ import requests
 from fastapi import HTTPException, APIRouter
 
 from mwmbl.crawler.batch import Batch, NewBatchRequest, HashedBatch
-from mwmbl.crawler.urls import URLDatabase
+from mwmbl.crawler.urls import URLDatabase, FoundURL, URLStatus
 from mwmbl.database import Database
 from mwmbl.indexer.batch_cache import BatchCache
 from mwmbl.indexer.indexdb import IndexDatabase, BatchInfo, BatchStatus
@@ -43,7 +45,7 @@ def upload(data: bytes, name: str):
 last_batch = None
 
 
-def get_router(batch_cache: BatchCache):
+def get_router(batch_cache: BatchCache, url_queue: Queue):
     router = APIRouter(prefix="/crawler", tags=["crawler"])
 
     @router.on_event("startup")
@@ -104,12 +106,19 @@ def get_router(batch_cache: BatchCache):
         }
 
     @router.post('/batches/new')
-    def request_new_batch(batch_request: NewBatchRequest):
+    def request_new_batch(batch_request: NewBatchRequest) -> list[str]:
         user_id_hash = _get_user_id_hash(batch_request)
+        try:
+            urls = url_queue.get(block=False)
+        except Empty:
+            return []
 
+        found_urls = [FoundURL(url, user_id_hash, 0.0, URLStatus.ASSIGNED, datetime.utcnow()) for url in urls]
         with Database() as db:
             url_db = URLDatabase(db.connection)
-            return url_db.get_new_batch_for_user(user_id_hash)
+            url_db.update_found_urls(found_urls)
+
+        return urls
 
     @router.get('/batches/{date_str}/users/{public_user_id}')
     def get_batches_for_date_and_user(date_str, public_user_id):

+ 24 - 31
mwmbl/crawler/urls.py

@@ -1,20 +1,17 @@
 """
 Database storing info on URLs
 """
-import os
 from dataclasses import dataclass
 from datetime import datetime, timedelta
 from enum import Enum
-from typing import Iterable
 
-from psycopg2 import connect
 from psycopg2.extras import execute_values
 
 
 # Client has one hour to crawl a URL that has been assigned to them, or it will be reassigned
-from mwmbl.database import Database
+from mwmbl.utils import batch
 
-REASSIGN_MIN_HOURS = 1
+REASSIGN_MIN_HOURS = 5
 BATCH_SIZE = 100
 
 
@@ -23,6 +20,7 @@ class URLStatus(Enum):
     URL state update is idempotent and can only progress forwards.
     """
     NEW = 0                   # One user has identified this URL
+    QUEUED = 5                # The URL has been queued for crawling
     ASSIGNED = 10             # The crawler has given the URL to a user to crawl
     ERROR_TIMEOUT = 20        # Timeout while retrieving
     ERROR_404 = 30            # 404 response
@@ -31,15 +29,6 @@ class URLStatus(Enum):
     CRAWLED = 100             # At least one user has crawled the URL
 
 
-def batch(items: list, batch_size):
-    """
-    Adapted from https://stackoverflow.com/a/8290508
-    """
-    length = len(items)
-    for ndx in range(0, length, batch_size):
-        yield items[ndx:min(ndx + batch_size, length)]
-
-
 @dataclass
 class FoundURL:
     url: str
@@ -119,16 +108,16 @@ class URLDatabase:
 
                 execute_values(cursor, insert_sql, data)
 
-    def get_new_batch_for_user(self, user_id_hash: str):
+    def get_urls_for_crawling(self, num_urls: int):
         sql = f"""
-        UPDATE urls SET status = {URLStatus.ASSIGNED.value}, user_id_hash = %(user_id_hash)s, updated = %(now)s
+        UPDATE urls SET status = {URLStatus.QUEUED.value}, updated = %(now)s
         WHERE url IN (
           SELECT url FROM urls
-          WHERE status = {URLStatus.NEW.value} OR (
+          WHERE status IN ({URLStatus.NEW.value}) OR (
             status = {URLStatus.ASSIGNED.value} AND updated < %(min_updated_date)s
           )
           ORDER BY score DESC
-          LIMIT {BATCH_SIZE}
+          LIMIT %(num_urls)s
           FOR UPDATE SKIP LOCKED
         )
         RETURNING url
@@ -137,11 +126,27 @@ class URLDatabase:
         now = datetime.utcnow()
         min_updated_date = now - timedelta(hours=REASSIGN_MIN_HOURS)
         with self.connection.cursor() as cursor:
-            cursor.execute(sql, {'user_id_hash': user_id_hash, 'min_updated_date': min_updated_date, 'now': now})
+            cursor.execute(sql, {'min_updated_date': min_updated_date, 'now': now, 'num_urls': num_urls})
+            results = cursor.fetchall()
+
+        return [result[0] for result in results]
+
+    def get_urls(self, status: URLStatus, num_urls: int):
+        sql = f"""
+        SELECT url FROM urls
+        WHERE status = %(status)s
+        ORDER BY score DESC
+        LIMIT %(num_urls)s
+        """
+
+        with self.connection.cursor() as cursor:
+            cursor.execute(sql, {'status': status.value, 'num_urls': num_urls})
             results = cursor.fetchall()
 
         return [result[0] for result in results]
 
+
+
     def get_url_scores(self, urls: list[str]) -> dict[str, float]:
         sql = f"""
         SELECT url, score FROM urls WHERE url IN %(urls)s
@@ -155,15 +160,3 @@ class URLDatabase:
                 url_scores.update({result[0]: result[1] for result in results})
 
         return url_scores
-
-
-if __name__ == "__main__":
-    with Database() as db:
-        url_db = URLDatabase(db.connection)
-        url_db.create_tables()
-        # update_url_status(conn, [URLStatus("https://mwmbl.org", URLState.NEW, "test-user", datetime.now())])
-        # url_db.user_found_urls("Test user", ["a", "b", "c"], datetime.utcnow())
-        # url_db.user_found_urls("Another user", ["b", "c", "d"], datetime.utcnow())
-        # url_db.user_crawled_urls("Test user", ["c"], datetime.utcnow())
-        batch = url_db.get_new_batch_for_user('test user 4')
-        print("Batch", len(batch), batch)

+ 4 - 3
mwmbl/main.py

@@ -2,7 +2,7 @@ import argparse
 import logging
 import os
 import sys
-from multiprocessing import Process
+from multiprocessing import Process, Queue
 from pathlib import Path
 
 import uvicorn
@@ -45,7 +45,8 @@ def run():
         print("Creating a new index")
         TinyIndex.create(item_factory=Document, index_path=index_path, num_pages=NUM_PAGES, page_size=PAGE_SIZE)
 
-    Process(target=background.run, args=(args.data,)).start()
+    url_queue = Queue()
+    Process(target=background.run, args=(args.data, url_queue)).start()
 
     completer = Completer()
 
@@ -59,7 +60,7 @@ def run():
         app.include_router(search_router)
 
         batch_cache = BatchCache(Path(args.data) / BATCH_DIR_NAME)
-        crawler_router = crawler.get_router(batch_cache)
+        crawler_router = crawler.get_router(batch_cache, url_queue)
         app.include_router(crawler_router)
 
         # Initialize uvicorn server using global app instance and server config params

+ 39 - 0
mwmbl/url_queue.py

@@ -0,0 +1,39 @@
+from logging import getLogger
+from multiprocessing import Queue
+
+from mwmbl.crawler.urls import BATCH_SIZE, URLDatabase, URLStatus
+from mwmbl.database import Database
+from mwmbl.utils import batch
+
+
+logger = getLogger(__name__)
+
+
+MAX_QUEUE_SIZE = 5000
+
+
+def update_url_queue(url_queue: Queue):
+    current_size = url_queue.qsize()
+    if current_size >= MAX_QUEUE_SIZE:
+        logger.info(f"Skipping queue update, current size {current_size}")
+        return
+
+    num_urls_to_fetch = (MAX_QUEUE_SIZE - current_size) * BATCH_SIZE
+    with Database() as db:
+        url_db = URLDatabase(db.connection)
+        urls = url_db.get_urls_for_crawling(num_urls_to_fetch)
+        queue_batches(url_queue, urls)
+        logger.info(f"Queued {len(urls)} urls, current queue size: {url_queue.qsize()}")
+
+
+def initialize_url_queue(url_queue: Queue):
+    with Database() as db:
+        url_db = URLDatabase(db.connection)
+        urls = url_db.get_urls(URLStatus.QUEUED, MAX_QUEUE_SIZE * BATCH_SIZE)
+        queue_batches(url_queue, urls)
+        logger.info(f"Initialized URL queue with {len(urls)} urls, current queue size: {url_queue.qsize()}")
+
+
+def queue_batches(url_queue, urls):
+    for url_batch in batch(urls, BATCH_SIZE):
+        url_queue.put(url_batch, block=False)

+ 7 - 0
mwmbl/utils.py

@@ -0,0 +1,7 @@
+def batch(items: list, batch_size):
+    """
+    Adapted from https://stackoverflow.com/a/8290508
+    """
+    length = len(items)
+    for ndx in range(0, length, batch_size):
+        yield items[ndx:min(ndx + batch_size, length)]