diff --git a/app.py b/app.py index 5614b98..7cd4ca7 100644 --- a/app.py +++ b/app.py @@ -1,6 +1,7 @@ import sqlite3 from functools import lru_cache +import Levenshtein import pandas as pd from fastapi import FastAPI @@ -41,6 +42,10 @@ def complete_term(term): return None +def order_results(query, results): + return sorted(results, key=lambda result: Levenshtein.distance(query, result[0])) + + @app.get("/complete") def complete(q: str): terms = [x.lower() for x in q.split()] @@ -54,8 +59,9 @@ def complete(q: str): if page is not None: pages += page + ordered_results = order_results(q, pages) results = [title.replace("\n", "") + ' — ' + - url.replace("\n", "") for title, url in pages] + url.replace("\n", "") for title, url in ordered_results] if len(results) == 0: # print("No results") return [] diff --git a/performance.py b/performance.py index 6113bbb..9a944b7 100644 --- a/performance.py +++ b/performance.py @@ -20,33 +20,47 @@ NUM_PAGES_FOR_STATS = 10 TEST_PAGE_SIZE = 512 TEST_NUM_PAGES = 1024 TEST_DATA_PATH = os.path.join(DATA_DIR, 'test-urls.zstd') +RECALL_AT_K = 3 def get_test_pages(): serializer = ZstdJsonSerializer() with open(TEST_DATA_PATH, 'rb') as data_file: data = serializer.deserialize(data_file.read()) - return ((row['title'], row['url']) for row in data if row['title'] is not None) + return [(row['title'], row['url']) for row in data if row['title'] is not None] def query_test(): titles_and_urls = get_test_pages() + print(f"Got {len(titles_and_urls)} titles and URLs") client = TestClient(app) start = datetime.now() hits = 0 + count = 0 for title, url in titles_and_urls: + print("Title", title, url) result = client.get('/complete', params={'q': title}) assert result.status_code == 200 - data = result.content.decode('utf8') - # print("Data", data, url, sep='\n') + data = result.json() + print("Data", data, url, sep='\n') - if title in data: + hit = False + if data: + for result in data[1][:RECALL_AT_K]: + if url in result: + hit = True + break + + if hit: hits += 1 + count += 1 + end = datetime.now() - print("Hits:", hits) + print(f"Hits: {hits} out of {count}") + print(f"Recall at {RECALL_AT_K}: {hits/count}") print("Query time:", (end - start).total_seconds() / NUM_DOCUMENTS)