|
@@ -5,6 +5,7 @@ import os
|
|
|
from dataclasses import dataclass
|
|
|
from datetime import datetime, timedelta
|
|
|
from enum import Enum
|
|
|
+from typing import Iterable
|
|
|
|
|
|
from psycopg2 import connect
|
|
|
from psycopg2.extras import execute_values
|
|
@@ -30,6 +31,15 @@ class URLStatus(Enum):
|
|
|
CRAWLED = 100 # At least one user has crawled the URL
|
|
|
|
|
|
|
|
|
+def batch(items: list, batch_size):
|
|
|
+ """
|
|
|
+ Adapted from https://stackoverflow.com/a/8290508
|
|
|
+ """
|
|
|
+ length = len(items)
|
|
|
+ for ndx in range(0, length, batch_size):
|
|
|
+ yield items[ndx:min(ndx + batch_size, length)]
|
|
|
+
|
|
|
+
|
|
|
@dataclass
|
|
|
class FoundURL:
|
|
|
url: str
|
|
@@ -137,11 +147,14 @@ class URLDatabase:
|
|
|
SELECT url, score FROM urls WHERE url IN %(urls)s
|
|
|
"""
|
|
|
|
|
|
- with self.connection.cursor() as cursor:
|
|
|
- cursor.execute(sql, {'urls': tuple(urls)})
|
|
|
- results = cursor.fetchall()
|
|
|
+ url_scores = {}
|
|
|
+ for url_batch in batch(urls, 10000):
|
|
|
+ with self.connection.cursor() as cursor:
|
|
|
+ cursor.execute(sql, {'urls': tuple(url_batch)})
|
|
|
+ results = cursor.fetchall()
|
|
|
+ url_scores.update({result[0]: result[1] for result in results})
|
|
|
|
|
|
- return {result[0]: result[1] for result in results}
|
|
|
+ return url_scores
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|