create_app.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import re
  2. from logging import getLogger
  3. from typing import List
  4. import Levenshtein
  5. from fastapi import FastAPI
  6. from starlette.responses import RedirectResponse, FileResponse, HTMLResponse
  7. from starlette.staticfiles import StaticFiles
  8. from index import TinyIndex, Document
  9. logger = getLogger(__name__)
  10. def create(tiny_index: TinyIndex):
  11. app = FastAPI()
  12. @app.get("/search")
  13. def search(s: str):
  14. results, terms = get_results(s)
  15. formatted_results = []
  16. for result in results:
  17. pattern = get_query_regex(terms)
  18. title = result.title
  19. matches = re.finditer(pattern, title, re.IGNORECASE)
  20. all_spans = [0] + sum((list(m.span()) for m in matches), []) + [len(title)]
  21. formatted_result = []
  22. for i in range(len(all_spans) - 1):
  23. is_bold = i % 2 == 1
  24. start = all_spans[i]
  25. end = all_spans[i + 1]
  26. formatted_result.append({'value': title[start:end], 'is_bold': is_bold})
  27. formatted_results.append({'title': formatted_result, 'url': result.url})
  28. logger.info("Return results: %r", formatted_results)
  29. return formatted_results
  30. def get_query_regex(terms):
  31. term_patterns = [rf'\b{term}\b' for term in terms]
  32. pattern = '|'.join(term_patterns)
  33. return pattern
  34. def score_result(terms, r):
  35. query_regex = get_query_regex(terms)
  36. matches = re.findall(query_regex, r, flags=re.IGNORECASE)
  37. match_strings = {x.lower() for x in matches}
  38. match_length = sum(len(x) for x in match_strings)
  39. num_words = len(re.findall(r'\b\w+\b', r))
  40. return match_length + 1./num_words
  41. def order_results(terms: list[str], results: list[Document]):
  42. ordered_results = sorted(results, key=lambda result: score_result(terms, result.title), reverse=True)
  43. # print("Order results", query, ordered_results, sep='\n')
  44. return ordered_results
  45. @app.get("/complete")
  46. def complete(q: str):
  47. ordered_results, terms = get_results(q)
  48. results = [item.title.replace("\n", "") + ' — ' +
  49. item.url.replace("\n", "") for item in ordered_results]
  50. if len(results) == 0:
  51. # print("No results")
  52. return []
  53. # print("Results", results)
  54. return [q, results]
  55. def get_results(q):
  56. terms = [x.lower() for x in q.replace('.', ' ').split()]
  57. # completed = complete_term(terms[-1])
  58. # terms = terms[:-1] + [completed]
  59. pages = []
  60. for term in terms:
  61. items = tiny_index.retrieve(term)
  62. if items is not None:
  63. pages += [item for item in items if term in item.title.lower()]
  64. ordered_results = order_results(terms, pages)
  65. return ordered_results, terms
  66. @app.get('/')
  67. def index():
  68. return FileResponse('static/index.html')
  69. app.mount('/', StaticFiles(directory="static"), name="static")
  70. return app