Order results by Levenshtein distance to improve recall

This commit is contained in:
Daoud Clarke 2021-05-23 22:14:07 +01:00
parent 0e3069fdb3
commit d6cc81278f
2 changed files with 26 additions and 6 deletions

8
app.py
View file

@ -1,6 +1,7 @@
import sqlite3
from functools import lru_cache
import Levenshtein
import pandas as pd
from fastapi import FastAPI
@ -41,6 +42,10 @@ def complete_term(term):
return None
def order_results(query, results):
return sorted(results, key=lambda result: Levenshtein.distance(query, result[0]))
@app.get("/complete")
def complete(q: str):
terms = [x.lower() for x in q.split()]
@ -54,8 +59,9 @@ def complete(q: str):
if page is not None:
pages += page
ordered_results = order_results(q, pages)
results = [title.replace("\n", "") + '' +
url.replace("\n", "") for title, url in pages]
url.replace("\n", "") for title, url in ordered_results]
if len(results) == 0:
# print("No results")
return []

View file

@ -20,33 +20,47 @@ NUM_PAGES_FOR_STATS = 10
TEST_PAGE_SIZE = 512
TEST_NUM_PAGES = 1024
TEST_DATA_PATH = os.path.join(DATA_DIR, 'test-urls.zstd')
RECALL_AT_K = 3
def get_test_pages():
serializer = ZstdJsonSerializer()
with open(TEST_DATA_PATH, 'rb') as data_file:
data = serializer.deserialize(data_file.read())
return ((row['title'], row['url']) for row in data if row['title'] is not None)
return [(row['title'], row['url']) for row in data if row['title'] is not None]
def query_test():
titles_and_urls = get_test_pages()
print(f"Got {len(titles_and_urls)} titles and URLs")
client = TestClient(app)
start = datetime.now()
hits = 0
count = 0
for title, url in titles_and_urls:
print("Title", title, url)
result = client.get('/complete', params={'q': title})
assert result.status_code == 200
data = result.content.decode('utf8')
# print("Data", data, url, sep='\n')
data = result.json()
print("Data", data, url, sep='\n')
if title in data:
hit = False
if data:
for result in data[1][:RECALL_AT_K]:
if url in result:
hit = True
break
if hit:
hits += 1
count += 1
end = datetime.now()
print("Hits:", hits)
print(f"Hits: {hits} out of {count}")
print(f"Recall at {RECALL_AT_K}: {hits/count}")
print("Query time:", (end - start).total_seconds() / NUM_DOCUMENTS)