Prevent deadlock when inserting URLs

This commit is contained in:
Daoud Clarke 2022-06-28 22:34:46 +01:00
parent 1457cba2c2
commit 955d650cf4
4 changed files with 40 additions and 41 deletions

View file

@ -6,5 +6,5 @@ from mwmbl.indexer.retrieve import retrieve_batches
def run():
historical.run()
# historical.run()
retrieve_batches()

View file

@ -14,7 +14,7 @@ import requests
from fastapi import HTTPException, APIRouter
from mwmbl.crawler.batch import Batch, NewBatchRequest, HashedBatch
from mwmbl.crawler.urls import URLDatabase, FoundURL
from mwmbl.crawler.urls import URLDatabase, FoundURL, URLStatus
from mwmbl.database import Database
from mwmbl.hn_top_domains_filtered import DOMAINS
@ -123,10 +123,15 @@ def create_historical_batch(batch: HashedBatch):
Update the database state of URL crawling for old data
"""
user_id_hash = batch.user_id_hash
batch_datetime = datetime(1970, 1, 1, tzinfo=timezone.utc) + timedelta(seconds=batch.timestamp)
batch_datetime = get_datetime_from_timestamp(batch.timestamp)
record_urls_in_database(batch, user_id_hash, batch_datetime)
def get_datetime_from_timestamp(timestamp: int) -> datetime:
batch_datetime = datetime(1970, 1, 1, tzinfo=timezone.utc) + timedelta(seconds=timestamp)
return batch_datetime
def record_urls_in_database(batch: Union[Batch, HashedBatch], user_id_hash: str, timestamp: datetime):
with Database() as db:
url_db = URLDatabase(db.connection)
@ -144,12 +149,14 @@ def record_urls_in_database(batch: Union[Batch, HashedBatch], user_id_hash: str,
domain = f'{parsed_link.scheme}://{parsed_link.netloc}/'
url_scores[domain] += SCORE_FOR_ROOT_PATH
found_urls = [FoundURL(url, user_id_hash, score, item.timestamp) for url, score in url_scores.items()]
batch_datetime = get_datetime_from_timestamp(batch.timestamp)
found_urls = [FoundURL(url, user_id_hash, score, URLStatus.NEW, batch_datetime) for url, score in url_scores.items()]
if len(found_urls) > 0:
url_db.update_found_urls(found_urls)
crawled_urls = [item.url for item in batch.items]
url_db.user_crawled_urls(user_id_hash, crawled_urls, timestamp)
crawled_urls = [FoundURL(item.url, user_id_hash, 0.0, URLStatus.CRAWLED, batch_datetime)
for item in batch.items]
url_db.update_found_urls(crawled_urls)
# TODO:
# - test this code

View file

@ -22,7 +22,6 @@ class URLStatus(Enum):
URL state update is idempotent and can only progress forwards.
"""
NEW = 0 # One user has identified this URL
CONFIRMED = 1 # A different user has identified the same URL
ASSIGNED = 2 # The crawler has given the URL to a user to crawl
CRAWLED = 3 # At least one user has crawled the URL
@ -32,6 +31,7 @@ class FoundURL:
url: str
user_id_hash: str
score: float
status: URLStatus
timestamp: datetime
@ -54,49 +54,43 @@ class URLDatabase:
cursor.execute(sql)
def update_found_urls(self, found_urls: list[FoundURL]):
sql = f"""
if len(found_urls) == 0:
return
get_urls_sql = """
SELECT url FROM urls
WHERE url in %(urls)s
FOR UPDATE SKIP LOCKED
"""
insert_sql = f"""
INSERT INTO urls (url, status, user_id_hash, score, updated) values %s
ON CONFLICT (url) DO UPDATE SET
status = CASE
WHEN excluded.status={URLStatus.NEW.value}
WHEN urls.status={URLStatus.CRAWLED.value} THEN {URLStatus.CRAWLED.value}
WHEN urls.status={URLStatus.CONFIRMED.value} THEN {URLStatus.CONFIRMED.value}
WHEN urls.status={URLStatus.ASSIGNED.value} THEN {URLStatus.ASSIGNED.value}
WHEN urls.status={URLStatus.NEW.value}
AND excluded.user_id_hash != urls.user_id_hash
THEN {URLStatus.CONFIRMED.value}
ELSE {URLStatus.NEW.value}
END,
user_id_hash=excluded.user_id_hash,
status = GREATEST(urls.status, excluded.status),
user_id_hash = CASE
WHEN urls.status={URLStatus.ASSIGNED.value} THEN urls.user_id_hash ELSE excluded.user_id_hash
WHEN urls.status > excluded.status THEN urls.user_id_hash ELSE excluded.user_id_hash
END,
score=urls.score + excluded.score,
updated=excluded.updated
score = urls.score + excluded.score,
updated = CASE
WHEN urls.status={URLStatus.ASSIGNED.value} THEN urls.updated ELSE excluded.updated
WHEN urls.status > excluded.status THEN urls.updated ELSE excluded.updated
END
"""
data = [(found_url.url, URLStatus.NEW.value, found_url.user_id_hash, found_url.score, found_url.timestamp)
for found_url in found_urls]
urls_to_insert = [x.url for x in found_urls]
assert len(urls_to_insert) == len(set(urls_to_insert))
with self.connection.cursor() as cursor:
execute_values(cursor, sql, data)
with self.connection as connection:
with connection.cursor() as cursor:
cursor.execute(get_urls_sql, {'urls': tuple(urls_to_insert)})
locked_urls = {x[0] for x in cursor.fetchall()}
if len(locked_urls) != len(urls_to_insert):
print(f"Only got {len(locked_urls)} instead of {len(urls_to_insert)}")
def user_crawled_urls(self, user_id_hash: str, urls: list[str], timestamp: datetime):
sql = f"""
INSERT INTO urls (url, status, user_id_hash, updated) values %s
ON CONFLICT (url) DO UPDATE SET
status=excluded.status,
user_id_hash=excluded.user_id_hash,
updated=excluded.updated
"""
data = [
(found_url.url, found_url.status.value, found_url.user_id_hash, found_url.score, found_url.timestamp)
for found_url in found_urls if found_url.url in locked_urls]
data = [(url, URLStatus.CRAWLED.value, user_id_hash, timestamp) for url in urls]
with self.connection.cursor() as cursor:
execute_values(cursor, sql, data)
execute_values(cursor, insert_sql, data)
def get_new_batch_for_user(self, user_id_hash: str):
sql = f"""

View file

@ -9,9 +9,7 @@ class Database:
def __enter__(self):
self.connection = connect(os.environ["DATABASE_URL"])
self.connection.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.connection.__exit__(exc_type, exc_val, exc_tb)
self.connection.close()