create_app.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import re
  2. from logging import getLogger
  3. from operator import itemgetter
  4. from fastapi import FastAPI
  5. from starlette.responses import FileResponse
  6. from starlette.staticfiles import StaticFiles
  7. from tinysearchengine.indexer import TinyIndex, Document
  8. logger = getLogger(__name__)
  9. SCORE_THRESHOLD = 0.25
  10. def create(tiny_index: TinyIndex):
  11. app = FastAPI()
  12. @app.get("/search")
  13. def search(s: str):
  14. results, terms = get_results(s)
  15. formatted_results = []
  16. for result in results:
  17. pattern = get_query_regex(terms)
  18. formatted_result = {}
  19. for content_type, content in [('title', result.title), ('extract', result.extract)]:
  20. matches = re.finditer(pattern, content, re.IGNORECASE)
  21. all_spans = [0] + sum((list(m.span()) for m in matches), []) + [len(content)]
  22. content_result = []
  23. for i in range(len(all_spans) - 1):
  24. is_bold = i % 2 == 1
  25. start = all_spans[i]
  26. end = all_spans[i + 1]
  27. content_result.append({'value': content[start:end], 'is_bold': is_bold})
  28. formatted_result[content_type] = content_result
  29. formatted_result['url'] = result.url
  30. formatted_results.append(formatted_result)
  31. logger.info("Return results: %r", formatted_results)
  32. return formatted_results
  33. def get_query_regex(terms):
  34. term_patterns = [rf'\b{term}\b' for term in terms]
  35. pattern = '|'.join(term_patterns)
  36. return pattern
  37. def score_result(terms, result: Document):
  38. result_string = f"{result.title.strip()} {result.extract.strip()}"
  39. query_regex = get_query_regex(terms)
  40. matches = list(re.finditer(query_regex, result_string, flags=re.IGNORECASE))
  41. match_strings = {x.group(0).lower() for x in matches}
  42. match_length = sum(len(x) for x in match_strings)
  43. last_match_char = 1
  44. seen_matches = set()
  45. for match in matches:
  46. value = match.group(0).lower()
  47. if value not in seen_matches:
  48. last_match_char = match.span()[1]
  49. seen_matches.add(value)
  50. # num_words = len(re.findall(r'\b\w+\b', result_string))
  51. total_possible_match_length = sum(len(x) for x in terms)
  52. score = (match_length + 1./last_match_char) / (total_possible_match_length + 1)
  53. # print("Score result", match_length, last_match_char, score, result.title)
  54. return score
  55. def order_results(terms: list[str], results: list[Document]):
  56. results_and_scores = [(score_result(terms, result), result) for result in results]
  57. ordered_results = sorted(results_and_scores, key=itemgetter(0), reverse=True)
  58. # print("Ordered results", ordered_results)
  59. filtered_results = [result for score, result in ordered_results if score > SCORE_THRESHOLD]
  60. return filtered_results
  61. @app.get("/complete")
  62. def complete(q: str):
  63. ordered_results, terms = get_results(q)
  64. results = [item.title.replace("\n", "") + ' — ' +
  65. item.url.replace("\n", "") for item in ordered_results]
  66. if len(results) == 0:
  67. # print("No results")
  68. return []
  69. # print("Results", results)
  70. return [q, results]
  71. # TODO: why does 'leek and potato soup' result not get returned for 'potato soup' query?
  72. def get_results(q):
  73. terms = [x.lower() for x in q.replace('.', ' ').split()]
  74. # completed = complete_term(terms[-1])
  75. # terms = terms[:-1] + [completed]
  76. pages = []
  77. seen_items = set()
  78. for term in terms:
  79. items = tiny_index.retrieve(term)
  80. print("Items", items)
  81. if items is not None:
  82. for item in items:
  83. if term in item.title.lower() or term in item.extract.lower():
  84. if item.title not in seen_items:
  85. pages.append(item)
  86. seen_items.add(item.title)
  87. ordered_results = order_results(terms, pages)
  88. return ordered_results, terms
  89. @app.get('/')
  90. def index():
  91. return FileResponse('tinysearchengine/static/index.html')
  92. app.mount('/', StaticFiles(directory="tinysearchengine/static"), name="static")
  93. return app