create_app.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import re
  2. from logging import getLogger
  3. from operator import itemgetter
  4. from pathlib import Path
  5. from fastapi import FastAPI
  6. from starlette.responses import FileResponse
  7. from starlette.staticfiles import StaticFiles
  8. from tinysearchengine.indexer import TinyIndex, Document
  9. logger = getLogger(__name__)
  10. STATIC_FILES_PATH = Path(__file__).parent / 'static'
  11. SCORE_THRESHOLD = 0.25
  12. def create(tiny_index: TinyIndex):
  13. app = FastAPI()
  14. @app.get("/search")
  15. def search(s: str):
  16. results, terms = get_results(s)
  17. formatted_results = []
  18. for result in results:
  19. pattern = get_query_regex(terms)
  20. formatted_result = {}
  21. for content_type, content in [('title', result.title), ('extract', result.extract)]:
  22. matches = re.finditer(pattern, content, re.IGNORECASE)
  23. all_spans = [0] + sum((list(m.span()) for m in matches), []) + [len(content)]
  24. content_result = []
  25. for i in range(len(all_spans) - 1):
  26. is_bold = i % 2 == 1
  27. start = all_spans[i]
  28. end = all_spans[i + 1]
  29. content_result.append({'value': content[start:end], 'is_bold': is_bold})
  30. formatted_result[content_type] = content_result
  31. formatted_result['url'] = result.url
  32. formatted_results.append(formatted_result)
  33. logger.info("Return results: %r", formatted_results)
  34. return formatted_results
  35. def get_query_regex(terms):
  36. term_patterns = [rf'\b{term}\b' for term in terms]
  37. pattern = '|'.join(term_patterns)
  38. return pattern
  39. def score_result(terms, result: Document):
  40. result_string = f"{result.title.strip()} {result.extract.strip()}"
  41. query_regex = get_query_regex(terms)
  42. matches = list(re.finditer(query_regex, result_string, flags=re.IGNORECASE))
  43. match_strings = {x.group(0).lower() for x in matches}
  44. match_length = sum(len(x) for x in match_strings)
  45. last_match_char = 1
  46. seen_matches = set()
  47. for match in matches:
  48. value = match.group(0).lower()
  49. if value not in seen_matches:
  50. last_match_char = match.span()[1]
  51. seen_matches.add(value)
  52. # num_words = len(re.findall(r'\b\w+\b', result_string))
  53. total_possible_match_length = sum(len(x) for x in terms)
  54. score = (match_length + 1./last_match_char) / (total_possible_match_length + 1)
  55. # print("Score result", match_length, last_match_char, score, result.title)
  56. return score
  57. def order_results(terms: list[str], results: list[Document]):
  58. results_and_scores = [(score_result(terms, result), result) for result in results]
  59. ordered_results = sorted(results_and_scores, key=itemgetter(0), reverse=True)
  60. # print("Ordered results", ordered_results)
  61. filtered_results = [result for score, result in ordered_results if score > SCORE_THRESHOLD]
  62. return filtered_results
  63. @app.get("/complete")
  64. def complete(q: str):
  65. ordered_results, terms = get_results(q)
  66. results = [item.title.replace("\n", "") + ' — ' +
  67. item.url.replace("\n", "") for item in ordered_results]
  68. if len(results) == 0:
  69. # print("No results")
  70. return []
  71. # print("Results", results)
  72. return [q, results]
  73. # TODO: why does 'leek and potato soup' result not get returned for 'potato soup' query?
  74. def get_results(q):
  75. terms = [x.lower() for x in q.replace('.', ' ').split()]
  76. # completed = complete_term(terms[-1])
  77. # terms = terms[:-1] + [completed]
  78. pages = []
  79. seen_items = set()
  80. for term in terms:
  81. items = tiny_index.retrieve(term)
  82. print("Items", items)
  83. if items is not None:
  84. for item in items:
  85. if term in item.title.lower() or term in item.extract.lower():
  86. if item.title not in seen_items:
  87. pages.append(item)
  88. seen_items.add(item.title)
  89. ordered_results = order_results(terms, pages)
  90. return ordered_results, terms
  91. @app.get('/')
  92. def index():
  93. return FileResponse(STATIC_FILES_PATH / 'index.html')
  94. app.mount('/', StaticFiles(directory=STATIC_FILES_PATH), name="static")
  95. return app