create_app.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import re
  2. from logging import getLogger
  3. from operator import itemgetter
  4. from fastapi import FastAPI
  5. from starlette.responses import FileResponse
  6. from starlette.staticfiles import StaticFiles
  7. from index import TinyIndex, Document
  8. logger = getLogger(__name__)
  9. SCORE_THRESHOLD = 0.25
  10. def create(tiny_index: TinyIndex):
  11. app = FastAPI()
  12. @app.get("/search")
  13. def search(s: str):
  14. results, terms = get_results(s)
  15. formatted_results = []
  16. for result in results:
  17. pattern = get_query_regex(terms)
  18. title_and_extract = f"{result.title} - {result.extract}"
  19. matches = re.finditer(pattern, title_and_extract, re.IGNORECASE)
  20. all_spans = [0] + sum((list(m.span()) for m in matches), []) + [len(title_and_extract)]
  21. formatted_result = []
  22. title_length = len(result.title)
  23. for i in range(len(all_spans) - 1):
  24. is_bold = i % 2 == 1
  25. start = all_spans[i]
  26. end = all_spans[i + 1]
  27. formatted_result.append({'value': title_and_extract[start:end], 'is_bold': is_bold})
  28. formatted_results.append({'title': formatted_result, 'url': result.url})
  29. logger.info("Return results: %r", formatted_results)
  30. return formatted_results
  31. def get_query_regex(terms):
  32. term_patterns = [rf'\b{term}\b' for term in terms]
  33. pattern = '|'.join(term_patterns)
  34. return pattern
  35. def score_result(terms, result: Document):
  36. print("Score result", result)
  37. result_string = f"{result.title} {result.extract}"
  38. query_regex = get_query_regex(terms)
  39. matches = re.findall(query_regex, result_string, flags=re.IGNORECASE)
  40. match_strings = {x.lower() for x in matches}
  41. match_length = sum(len(x) for x in match_strings)
  42. num_words = len(re.findall(r'\b\w+\b', result_string))
  43. total_possible_match_length = sum(len(x) for x in terms)
  44. return (match_length + 1./num_words) / (total_possible_match_length + 1)
  45. def order_results(terms: list[str], results: list[Document]):
  46. results_and_scores = [(score_result(terms, result), result) for result in results]
  47. ordered_results = sorted(results_and_scores, key=itemgetter(0), reverse=True)
  48. # print("Ordered results", ordered_results)
  49. filtered_results = [result for score, result in ordered_results if score > SCORE_THRESHOLD]
  50. return filtered_results
  51. @app.get("/complete")
  52. def complete(q: str):
  53. ordered_results, terms = get_results(q)
  54. results = [item.title.replace("\n", "") + ' — ' +
  55. item.url.replace("\n", "") for item in ordered_results]
  56. if len(results) == 0:
  57. # print("No results")
  58. return []
  59. # print("Results", results)
  60. return [q, results]
  61. # TODO: why does 'leek and potato soup' result not get returned for 'potato soup' query?
  62. def get_results(q):
  63. terms = [x.lower() for x in q.replace('.', ' ').split()]
  64. # completed = complete_term(terms[-1])
  65. # terms = terms[:-1] + [completed]
  66. pages = []
  67. seen_items = set()
  68. for term in terms:
  69. items = tiny_index.retrieve(term)
  70. print("Items", items)
  71. if items is not None:
  72. for item in items:
  73. if term in item.title.lower() or term in item.extract.lower():
  74. if item.title not in seen_items:
  75. pages.append(item)
  76. seen_items.add(item.title)
  77. ordered_results = order_results(terms, pages)
  78. return ordered_results, terms
  79. @app.get('/')
  80. def index():
  81. return FileResponse('static/index.html')
  82. app.mount('/', StaticFiles(directory="static"), name="static")
  83. return app