diff --git a/mwmbl/crawler/urls.py b/mwmbl/crawler/urls.py index bad7616..3e50ff8 100644 --- a/mwmbl/crawler/urls.py +++ b/mwmbl/crawler/urls.py @@ -128,9 +128,9 @@ class URLDatabase: updated = [FoundURL(*result) for result in results] return updated - def get_urls(self, status: URLStatus, num_urls: int): + def get_urls(self, status: URLStatus, num_urls: int) -> list[FoundURL]: sql = f""" - SELECT url FROM urls + SELECT url, status, user_id_hash, score, updated FROM urls WHERE status = %(status)s ORDER BY score DESC LIMIT %(num_urls)s @@ -140,7 +140,7 @@ class URLDatabase: cursor.execute(sql, {'status': status.value, 'num_urls': num_urls}) results = cursor.fetchall() - return [result[0] for result in results] + return [FoundURL(url, user_id_hash, score, status, updated) for url, status, user_id_hash, score, updated in results] def get_url_scores(self, urls: list[str]) -> dict[str, float]: sql = f""" diff --git a/mwmbl/url_queue.py b/mwmbl/url_queue.py index 714cb10..ab0f1bc 100644 --- a/mwmbl/url_queue.py +++ b/mwmbl/url_queue.py @@ -25,6 +25,7 @@ MAX_URLS_PER_CORE_DOMAIN = 1000 MAX_URLS_PER_TOP_DOMAIN = 100 MAX_URLS_PER_OTHER_DOMAIN = 5 MAX_OTHER_DOMAINS = 10000 +INITIALIZE_URLS = 10000 random = Random(1) @@ -41,13 +42,15 @@ class URLQueue: self._other_urls = defaultdict(dict) self._top_urls = defaultdict(dict) self._min_top_domains = min_top_domains + assert min_top_domains > 0, "Need a minimum greater than 0 to prevent a never-ending loop" def initialize(self): + logger.info(f"Initializing URL queue") with Database() as db: url_db = URLDatabase(db.connection) - urls = url_db.get_urls(URLStatus.NEW, MAX_QUEUE_SIZE * BATCH_SIZE) - self._queue_urls(urls) - logger.info(f"Initialized URL queue with {len(urls)} urls, current queue size: {self.num_queued_batches}") + found_urls = url_db.get_urls(URLStatus.NEW, INITIALIZE_URLS) + self._process_found_urls(found_urls) + logger.info(f"Initialized URL queue with {len(found_urls)} urls, current queue size: {self.num_queued_batches}") def update(self): num_processed = 0 @@ -70,7 +73,7 @@ class URLQueue: self._sort_urls(valid_urls) logger.info(f"Queue size: {self.num_queued_batches}") - while self.num_queued_batches < MAX_QUEUE_SIZE and len(self._top_urls) > self._min_top_domains: + while self.num_queued_batches < MAX_QUEUE_SIZE and len(self._top_urls) >= self._min_top_domains: total_top_urls = sum(len(urls) for urls in self._top_urls.values()) logger.info(f"Total top URLs stored: {total_top_urls}") @@ -95,8 +98,8 @@ class URLQueue: _sort_and_limit_urls(self._other_urls, MAX_OTHER_URLS) # Keep only the top "other" domains, ranked by the top item for that domain - top_other_urls = sorted(self._other_urls.items(), key=lambda x: x[1][0].score, reverse=True)[:MAX_OTHER_DOMAINS] - self._other_urls = defaultdict(list, dict(top_other_urls)) + top_other_urls = sorted(self._other_urls.items(), key=lambda x: next(iter(x[1].values())), reverse=True)[:MAX_OTHER_DOMAINS] + self._other_urls = defaultdict(dict, dict(top_other_urls)) def _batch_urls(self): urls = [] diff --git a/test/test_url_queue.py b/test/test_url_queue.py index d6c15a2..0829c16 100644 --- a/test/test_url_queue.py +++ b/test/test_url_queue.py @@ -9,7 +9,7 @@ def test_url_queue_empties(): new_item_queue = Queue() queued_batches = Queue() - url_queue = URLQueue(new_item_queue, queued_batches, min_top_domains=0) + url_queue = URLQueue(new_item_queue, queued_batches, min_top_domains=1) new_item_queue.put([FoundURL("https://google.com", "123", 10.0, URLStatus.NEW.value, datetime(2023, 1, 19))]) url_queue.update() @@ -17,3 +17,21 @@ def test_url_queue_empties(): items = queued_batches.get(block=False) assert items == ["https://google.com"] + + +def test_url_queue_multiple_puts(): + new_item_queue = Queue() + queued_batches = Queue() + + url_queue = URLQueue(new_item_queue, queued_batches, min_top_domains=1) + new_item_queue.put([FoundURL("https://google.com", "123", 10.0, URLStatus.NEW.value, datetime(2023, 1, 19))]) + url_queue.update() + + new_item_queue.put([FoundURL("https://www.supermemo.com", "124", 10.0, URLStatus.NEW.value, datetime(2023, 1, 20))]) + url_queue.update() + + items = queued_batches.get(block=False) + assert items == ["https://google.com"] + + items_2 = queued_batches.get(block=False) + assert items_2 == ["https://www.supermemo.com"]