2021-12-17 21:31:26 +00:00
|
|
|
import re
|
2021-12-16 21:36:01 +00:00
|
|
|
from logging import getLogger
|
2021-12-18 22:35:59 +00:00
|
|
|
from operator import itemgetter
|
2021-06-05 21:22:31 +00:00
|
|
|
|
|
|
|
from fastapi import FastAPI
|
2021-12-19 21:09:00 +00:00
|
|
|
from starlette.responses import FileResponse
|
2021-06-05 21:22:31 +00:00
|
|
|
from starlette.staticfiles import StaticFiles
|
|
|
|
|
|
|
|
from index import TinyIndex, Document
|
|
|
|
|
2021-12-16 21:36:01 +00:00
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
|
|
|
2021-12-18 22:35:59 +00:00
|
|
|
SCORE_THRESHOLD = 0.25
|
|
|
|
|
|
|
|
|
2021-06-05 21:22:31 +00:00
|
|
|
def create(tiny_index: TinyIndex):
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
@app.get("/search")
|
|
|
|
def search(s: str):
|
2021-12-17 21:31:26 +00:00
|
|
|
results, terms = get_results(s)
|
|
|
|
|
|
|
|
formatted_results = []
|
|
|
|
for result in results:
|
2021-12-18 12:42:04 +00:00
|
|
|
pattern = get_query_regex(terms)
|
2021-12-19 20:48:28 +00:00
|
|
|
title_and_extract = f"{result.title} - {result.extract}"
|
|
|
|
matches = re.finditer(pattern, title_and_extract, re.IGNORECASE)
|
|
|
|
all_spans = [0] + sum((list(m.span()) for m in matches), []) + [len(title_and_extract)]
|
2021-12-17 21:31:26 +00:00
|
|
|
formatted_result = []
|
2021-12-19 20:48:28 +00:00
|
|
|
title_length = len(result.title)
|
2021-12-17 21:31:26 +00:00
|
|
|
for i in range(len(all_spans) - 1):
|
|
|
|
is_bold = i % 2 == 1
|
|
|
|
start = all_spans[i]
|
|
|
|
end = all_spans[i + 1]
|
2021-12-19 20:48:28 +00:00
|
|
|
formatted_result.append({'value': title_and_extract[start:end], 'is_bold': is_bold})
|
2021-12-17 21:31:26 +00:00
|
|
|
formatted_results.append({'title': formatted_result, 'url': result.url})
|
|
|
|
|
|
|
|
logger.info("Return results: %r", formatted_results)
|
|
|
|
return formatted_results
|
2021-06-05 21:22:31 +00:00
|
|
|
|
2021-12-18 12:42:04 +00:00
|
|
|
def get_query_regex(terms):
|
|
|
|
term_patterns = [rf'\b{term}\b' for term in terms]
|
|
|
|
pattern = '|'.join(term_patterns)
|
|
|
|
return pattern
|
|
|
|
|
2021-12-19 20:48:28 +00:00
|
|
|
def score_result(terms, result: Document):
|
|
|
|
print("Score result", result)
|
|
|
|
result_string = f"{result.title} {result.extract}"
|
2021-12-18 12:42:04 +00:00
|
|
|
query_regex = get_query_regex(terms)
|
2021-12-19 20:48:28 +00:00
|
|
|
matches = re.findall(query_regex, result_string, flags=re.IGNORECASE)
|
2021-12-18 12:42:04 +00:00
|
|
|
match_strings = {x.lower() for x in matches}
|
|
|
|
match_length = sum(len(x) for x in match_strings)
|
|
|
|
|
2021-12-19 20:48:28 +00:00
|
|
|
num_words = len(re.findall(r'\b\w+\b', result_string))
|
2021-12-18 22:35:59 +00:00
|
|
|
total_possible_match_length = sum(len(x) for x in terms)
|
|
|
|
return (match_length + 1./num_words) / (total_possible_match_length + 1)
|
2021-12-18 12:42:04 +00:00
|
|
|
|
|
|
|
def order_results(terms: list[str], results: list[Document]):
|
2021-12-19 20:48:28 +00:00
|
|
|
results_and_scores = [(score_result(terms, result), result) for result in results]
|
2021-12-18 22:35:59 +00:00
|
|
|
ordered_results = sorted(results_and_scores, key=itemgetter(0), reverse=True)
|
2021-12-19 21:09:00 +00:00
|
|
|
# print("Ordered results", ordered_results)
|
2021-12-18 22:35:59 +00:00
|
|
|
filtered_results = [result for score, result in ordered_results if score > SCORE_THRESHOLD]
|
|
|
|
return filtered_results
|
2021-06-05 21:22:31 +00:00
|
|
|
|
|
|
|
@app.get("/complete")
|
|
|
|
def complete(q: str):
|
2021-12-17 21:31:26 +00:00
|
|
|
ordered_results, terms = get_results(q)
|
2021-12-14 22:01:59 +00:00
|
|
|
results = [item.title.replace("\n", "") + ' — ' +
|
|
|
|
item.url.replace("\n", "") for item in ordered_results]
|
|
|
|
if len(results) == 0:
|
|
|
|
# print("No results")
|
|
|
|
return []
|
|
|
|
# print("Results", results)
|
|
|
|
return [q, results]
|
2021-06-05 21:22:31 +00:00
|
|
|
|
2021-12-19 20:48:28 +00:00
|
|
|
# TODO: why does 'leek and potato soup' result not get returned for 'potato soup' query?
|
2021-12-14 22:01:59 +00:00
|
|
|
def get_results(q):
|
|
|
|
terms = [x.lower() for x in q.replace('.', ' ').split()]
|
2021-06-05 21:22:31 +00:00
|
|
|
# completed = complete_term(terms[-1])
|
|
|
|
# terms = terms[:-1] + [completed]
|
|
|
|
pages = []
|
2021-12-19 20:48:28 +00:00
|
|
|
seen_items = set()
|
2021-06-05 21:22:31 +00:00
|
|
|
for term in terms:
|
|
|
|
items = tiny_index.retrieve(term)
|
2021-12-19 20:48:28 +00:00
|
|
|
print("Items", items)
|
2021-06-05 21:22:31 +00:00
|
|
|
if items is not None:
|
2021-12-19 20:48:28 +00:00
|
|
|
for item in items:
|
|
|
|
if term in item.title.lower() or term in item.extract.lower():
|
|
|
|
if item.title not in seen_items:
|
|
|
|
pages.append(item)
|
|
|
|
seen_items.add(item.title)
|
|
|
|
|
2021-12-18 12:42:04 +00:00
|
|
|
ordered_results = order_results(terms, pages)
|
2021-12-17 21:31:26 +00:00
|
|
|
return ordered_results, terms
|
2021-06-05 21:22:31 +00:00
|
|
|
|
|
|
|
@app.get('/')
|
|
|
|
def index():
|
|
|
|
return FileResponse('static/index.html')
|
|
|
|
|
|
|
|
app.mount('/', StaticFiles(directory="static"), name="static")
|
|
|
|
return app
|