Order results by Levenshtein distance to improve recall
This commit is contained in:
parent
0e3069fdb3
commit
d6cc81278f
2 changed files with 26 additions and 6 deletions
8
app.py
8
app.py
|
@ -1,6 +1,7 @@
|
|||
import sqlite3
|
||||
from functools import lru_cache
|
||||
|
||||
import Levenshtein
|
||||
import pandas as pd
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
@ -41,6 +42,10 @@ def complete_term(term):
|
|||
return None
|
||||
|
||||
|
||||
def order_results(query, results):
|
||||
return sorted(results, key=lambda result: Levenshtein.distance(query, result[0]))
|
||||
|
||||
|
||||
@app.get("/complete")
|
||||
def complete(q: str):
|
||||
terms = [x.lower() for x in q.split()]
|
||||
|
@ -54,8 +59,9 @@ def complete(q: str):
|
|||
if page is not None:
|
||||
pages += page
|
||||
|
||||
ordered_results = order_results(q, pages)
|
||||
results = [title.replace("\n", "") + ' — ' +
|
||||
url.replace("\n", "") for title, url in pages]
|
||||
url.replace("\n", "") for title, url in ordered_results]
|
||||
if len(results) == 0:
|
||||
# print("No results")
|
||||
return []
|
||||
|
|
|
@ -20,33 +20,47 @@ NUM_PAGES_FOR_STATS = 10
|
|||
TEST_PAGE_SIZE = 512
|
||||
TEST_NUM_PAGES = 1024
|
||||
TEST_DATA_PATH = os.path.join(DATA_DIR, 'test-urls.zstd')
|
||||
RECALL_AT_K = 3
|
||||
|
||||
|
||||
def get_test_pages():
|
||||
serializer = ZstdJsonSerializer()
|
||||
with open(TEST_DATA_PATH, 'rb') as data_file:
|
||||
data = serializer.deserialize(data_file.read())
|
||||
return ((row['title'], row['url']) for row in data if row['title'] is not None)
|
||||
return [(row['title'], row['url']) for row in data if row['title'] is not None]
|
||||
|
||||
|
||||
def query_test():
|
||||
titles_and_urls = get_test_pages()
|
||||
print(f"Got {len(titles_and_urls)} titles and URLs")
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
start = datetime.now()
|
||||
hits = 0
|
||||
count = 0
|
||||
for title, url in titles_and_urls:
|
||||
print("Title", title, url)
|
||||
result = client.get('/complete', params={'q': title})
|
||||
assert result.status_code == 200
|
||||
data = result.content.decode('utf8')
|
||||
# print("Data", data, url, sep='\n')
|
||||
data = result.json()
|
||||
print("Data", data, url, sep='\n')
|
||||
|
||||
if title in data:
|
||||
hit = False
|
||||
if data:
|
||||
for result in data[1][:RECALL_AT_K]:
|
||||
if url in result:
|
||||
hit = True
|
||||
break
|
||||
|
||||
if hit:
|
||||
hits += 1
|
||||
|
||||
count += 1
|
||||
|
||||
end = datetime.now()
|
||||
print("Hits:", hits)
|
||||
print(f"Hits: {hits} out of {count}")
|
||||
print(f"Recall at {RECALL_AT_K}: {hits/count}")
|
||||
print("Query time:", (end - start).total_seconds() / NUM_DOCUMENTS)
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue