create_app.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import re
  2. from logging import getLogger
  3. from operator import itemgetter
  4. from typing import List
  5. import Levenshtein
  6. from fastapi import FastAPI
  7. from starlette.responses import RedirectResponse, FileResponse, HTMLResponse
  8. from starlette.staticfiles import StaticFiles
  9. from index import TinyIndex, Document
  10. logger = getLogger(__name__)
  11. SCORE_THRESHOLD = 0.25
  12. def create(tiny_index: TinyIndex):
  13. app = FastAPI()
  14. @app.get("/search")
  15. def search(s: str):
  16. results, terms = get_results(s)
  17. formatted_results = []
  18. for result in results:
  19. pattern = get_query_regex(terms)
  20. title = result.title
  21. matches = re.finditer(pattern, title, re.IGNORECASE)
  22. all_spans = [0] + sum((list(m.span()) for m in matches), []) + [len(title)]
  23. formatted_result = []
  24. for i in range(len(all_spans) - 1):
  25. is_bold = i % 2 == 1
  26. start = all_spans[i]
  27. end = all_spans[i + 1]
  28. formatted_result.append({'value': title[start:end], 'is_bold': is_bold})
  29. formatted_results.append({'title': formatted_result, 'url': result.url})
  30. logger.info("Return results: %r", formatted_results)
  31. return formatted_results
  32. def get_query_regex(terms):
  33. term_patterns = [rf'\b{term}\b' for term in terms]
  34. pattern = '|'.join(term_patterns)
  35. return pattern
  36. def score_result(terms, r):
  37. query_regex = get_query_regex(terms)
  38. matches = re.findall(query_regex, r, flags=re.IGNORECASE)
  39. match_strings = {x.lower() for x in matches}
  40. match_length = sum(len(x) for x in match_strings)
  41. num_words = len(re.findall(r'\b\w+\b', r))
  42. total_possible_match_length = sum(len(x) for x in terms)
  43. return (match_length + 1./num_words) / (total_possible_match_length + 1)
  44. def order_results(terms: list[str], results: list[Document]):
  45. results_and_scores = [(score_result(terms, result.title), result) for result in results]
  46. ordered_results = sorted(results_and_scores, key=itemgetter(0), reverse=True)
  47. filtered_results = [result for score, result in ordered_results if score > SCORE_THRESHOLD]
  48. # ordered_results = sorted(results, key=lambda result: score_result(terms, result.title), reverse=True)
  49. # print("Order results", query, ordered_results, sep='\n')
  50. return filtered_results
  51. @app.get("/complete")
  52. def complete(q: str):
  53. ordered_results, terms = get_results(q)
  54. results = [item.title.replace("\n", "") + ' — ' +
  55. item.url.replace("\n", "") for item in ordered_results]
  56. if len(results) == 0:
  57. # print("No results")
  58. return []
  59. # print("Results", results)
  60. return [q, results]
  61. def get_results(q):
  62. terms = [x.lower() for x in q.replace('.', ' ').split()]
  63. # completed = complete_term(terms[-1])
  64. # terms = terms[:-1] + [completed]
  65. pages = []
  66. for term in terms:
  67. items = tiny_index.retrieve(term)
  68. if items is not None:
  69. pages += [item for item in items if term in item.title.lower()]
  70. ordered_results = order_results(terms, pages)
  71. return ordered_results, terms
  72. @app.get('/')
  73. def index():
  74. return FileResponse('static/index.html')
  75. app.mount('/', StaticFiles(directory="static"), name="static")
  76. return app