Prevent deadlock when inserting URLs
This commit is contained in:
parent
1457cba2c2
commit
955d650cf4
4 changed files with 40 additions and 41 deletions
|
@ -6,5 +6,5 @@ from mwmbl.indexer.retrieve import retrieve_batches
|
|||
|
||||
|
||||
def run():
|
||||
historical.run()
|
||||
# historical.run()
|
||||
retrieve_batches()
|
||||
|
|
|
@ -14,7 +14,7 @@ import requests
|
|||
from fastapi import HTTPException, APIRouter
|
||||
|
||||
from mwmbl.crawler.batch import Batch, NewBatchRequest, HashedBatch
|
||||
from mwmbl.crawler.urls import URLDatabase, FoundURL
|
||||
from mwmbl.crawler.urls import URLDatabase, FoundURL, URLStatus
|
||||
from mwmbl.database import Database
|
||||
from mwmbl.hn_top_domains_filtered import DOMAINS
|
||||
|
||||
|
@ -123,10 +123,15 @@ def create_historical_batch(batch: HashedBatch):
|
|||
Update the database state of URL crawling for old data
|
||||
"""
|
||||
user_id_hash = batch.user_id_hash
|
||||
batch_datetime = datetime(1970, 1, 1, tzinfo=timezone.utc) + timedelta(seconds=batch.timestamp)
|
||||
batch_datetime = get_datetime_from_timestamp(batch.timestamp)
|
||||
record_urls_in_database(batch, user_id_hash, batch_datetime)
|
||||
|
||||
|
||||
def get_datetime_from_timestamp(timestamp: int) -> datetime:
|
||||
batch_datetime = datetime(1970, 1, 1, tzinfo=timezone.utc) + timedelta(seconds=timestamp)
|
||||
return batch_datetime
|
||||
|
||||
|
||||
def record_urls_in_database(batch: Union[Batch, HashedBatch], user_id_hash: str, timestamp: datetime):
|
||||
with Database() as db:
|
||||
url_db = URLDatabase(db.connection)
|
||||
|
@ -144,12 +149,14 @@ def record_urls_in_database(batch: Union[Batch, HashedBatch], user_id_hash: str,
|
|||
domain = f'{parsed_link.scheme}://{parsed_link.netloc}/'
|
||||
url_scores[domain] += SCORE_FOR_ROOT_PATH
|
||||
|
||||
found_urls = [FoundURL(url, user_id_hash, score, item.timestamp) for url, score in url_scores.items()]
|
||||
batch_datetime = get_datetime_from_timestamp(batch.timestamp)
|
||||
found_urls = [FoundURL(url, user_id_hash, score, URLStatus.NEW, batch_datetime) for url, score in url_scores.items()]
|
||||
if len(found_urls) > 0:
|
||||
url_db.update_found_urls(found_urls)
|
||||
|
||||
crawled_urls = [item.url for item in batch.items]
|
||||
url_db.user_crawled_urls(user_id_hash, crawled_urls, timestamp)
|
||||
crawled_urls = [FoundURL(item.url, user_id_hash, 0.0, URLStatus.CRAWLED, batch_datetime)
|
||||
for item in batch.items]
|
||||
url_db.update_found_urls(crawled_urls)
|
||||
|
||||
# TODO:
|
||||
# - test this code
|
||||
|
|
|
@ -22,7 +22,6 @@ class URLStatus(Enum):
|
|||
URL state update is idempotent and can only progress forwards.
|
||||
"""
|
||||
NEW = 0 # One user has identified this URL
|
||||
CONFIRMED = 1 # A different user has identified the same URL
|
||||
ASSIGNED = 2 # The crawler has given the URL to a user to crawl
|
||||
CRAWLED = 3 # At least one user has crawled the URL
|
||||
|
||||
|
@ -32,6 +31,7 @@ class FoundURL:
|
|||
url: str
|
||||
user_id_hash: str
|
||||
score: float
|
||||
status: URLStatus
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
|
@ -54,49 +54,43 @@ class URLDatabase:
|
|||
cursor.execute(sql)
|
||||
|
||||
def update_found_urls(self, found_urls: list[FoundURL]):
|
||||
sql = f"""
|
||||
if len(found_urls) == 0:
|
||||
return
|
||||
|
||||
get_urls_sql = """
|
||||
SELECT url FROM urls
|
||||
WHERE url in %(urls)s
|
||||
FOR UPDATE SKIP LOCKED
|
||||
"""
|
||||
|
||||
insert_sql = f"""
|
||||
INSERT INTO urls (url, status, user_id_hash, score, updated) values %s
|
||||
ON CONFLICT (url) DO UPDATE SET
|
||||
status = CASE
|
||||
WHEN excluded.status={URLStatus.NEW.value}
|
||||
WHEN urls.status={URLStatus.CRAWLED.value} THEN {URLStatus.CRAWLED.value}
|
||||
WHEN urls.status={URLStatus.CONFIRMED.value} THEN {URLStatus.CONFIRMED.value}
|
||||
WHEN urls.status={URLStatus.ASSIGNED.value} THEN {URLStatus.ASSIGNED.value}
|
||||
WHEN urls.status={URLStatus.NEW.value}
|
||||
AND excluded.user_id_hash != urls.user_id_hash
|
||||
THEN {URLStatus.CONFIRMED.value}
|
||||
ELSE {URLStatus.NEW.value}
|
||||
END,
|
||||
user_id_hash=excluded.user_id_hash,
|
||||
status = GREATEST(urls.status, excluded.status),
|
||||
user_id_hash = CASE
|
||||
WHEN urls.status={URLStatus.ASSIGNED.value} THEN urls.user_id_hash ELSE excluded.user_id_hash
|
||||
WHEN urls.status > excluded.status THEN urls.user_id_hash ELSE excluded.user_id_hash
|
||||
END,
|
||||
score=urls.score + excluded.score,
|
||||
updated=excluded.updated
|
||||
score = urls.score + excluded.score,
|
||||
updated = CASE
|
||||
WHEN urls.status={URLStatus.ASSIGNED.value} THEN urls.updated ELSE excluded.updated
|
||||
WHEN urls.status > excluded.status THEN urls.updated ELSE excluded.updated
|
||||
END
|
||||
"""
|
||||
|
||||
data = [(found_url.url, URLStatus.NEW.value, found_url.user_id_hash, found_url.score, found_url.timestamp)
|
||||
for found_url in found_urls]
|
||||
urls_to_insert = [x.url for x in found_urls]
|
||||
assert len(urls_to_insert) == len(set(urls_to_insert))
|
||||
|
||||
with self.connection.cursor() as cursor:
|
||||
execute_values(cursor, sql, data)
|
||||
with self.connection as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(get_urls_sql, {'urls': tuple(urls_to_insert)})
|
||||
locked_urls = {x[0] for x in cursor.fetchall()}
|
||||
if len(locked_urls) != len(urls_to_insert):
|
||||
print(f"Only got {len(locked_urls)} instead of {len(urls_to_insert)}")
|
||||
|
||||
def user_crawled_urls(self, user_id_hash: str, urls: list[str], timestamp: datetime):
|
||||
sql = f"""
|
||||
INSERT INTO urls (url, status, user_id_hash, updated) values %s
|
||||
ON CONFLICT (url) DO UPDATE SET
|
||||
status=excluded.status,
|
||||
user_id_hash=excluded.user_id_hash,
|
||||
updated=excluded.updated
|
||||
"""
|
||||
data = [
|
||||
(found_url.url, found_url.status.value, found_url.user_id_hash, found_url.score, found_url.timestamp)
|
||||
for found_url in found_urls if found_url.url in locked_urls]
|
||||
|
||||
data = [(url, URLStatus.CRAWLED.value, user_id_hash, timestamp) for url in urls]
|
||||
|
||||
with self.connection.cursor() as cursor:
|
||||
execute_values(cursor, sql, data)
|
||||
execute_values(cursor, insert_sql, data)
|
||||
|
||||
def get_new_batch_for_user(self, user_id_hash: str):
|
||||
sql = f"""
|
||||
|
|
|
@ -9,9 +9,7 @@ class Database:
|
|||
|
||||
def __enter__(self):
|
||||
self.connection = connect(os.environ["DATABASE_URL"])
|
||||
self.connection.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.connection.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.connection.close()
|
||||
|
|
Loading…
Add table
Reference in a new issue