Improve handling of incomplete words:
- Correctly generate regex for incomplete vs complete words - Return more than one top word from completer - Correctly handle no terms
This commit is contained in:
parent
7d829bc319
commit
fe6ace93e6
2 changed files with 34 additions and 19 deletions
|
@ -6,21 +6,26 @@ from pandas import DataFrame
|
|||
|
||||
|
||||
class Completer:
|
||||
def __init__(self, terms: DataFrame):
|
||||
def __init__(self, terms: DataFrame, num_matches: int = 3):
|
||||
terms_dict = terms.sort_values('term').set_index('term')['count'].to_dict()
|
||||
self.terms = list(terms_dict.keys())
|
||||
self.counts = list(terms_dict.values())
|
||||
self.num_matches = num_matches
|
||||
print("Terms", self.terms[:100], self.counts[:100])
|
||||
|
||||
def complete(self, term):
|
||||
def complete(self, term) -> list[str]:
|
||||
term_length = len(term)
|
||||
start = bisect_left(self.terms, term, key=lambda x: x[:term_length])
|
||||
end = bisect_right(self.terms, term, key=lambda x: x[:term_length])
|
||||
start_index = bisect_left(self.terms, term, key=lambda x: x[:term_length])
|
||||
end_index = bisect_right(self.terms, term, key=lambda x: x[:term_length])
|
||||
|
||||
matching_terms = zip(self.counts[start:end], self.terms[start:end])
|
||||
top_count, top_term = max(matching_terms)
|
||||
print("Top term", top_term, top_count)
|
||||
return top_term
|
||||
matching_terms = zip(self.counts[start_index:end_index], self.terms[start_index:end_index])
|
||||
top_terms = sorted(matching_terms, reverse=True)[:self.num_matches]
|
||||
print("Top terms, counts", top_terms)
|
||||
if not top_terms:
|
||||
return []
|
||||
|
||||
counts, terms = zip(*top_terms)
|
||||
return list(terms)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -31,7 +31,8 @@ def create(tiny_index: TinyIndex, completer: Completer):
|
|||
def search(s: str):
|
||||
results, terms = get_results(s)
|
||||
|
||||
pattern = get_query_regex(terms)
|
||||
is_complete = s.endswith(' ')
|
||||
pattern = get_query_regex(terms, is_complete)
|
||||
formatted_results = []
|
||||
for result in results:
|
||||
formatted_result = {}
|
||||
|
@ -51,17 +52,23 @@ def create(tiny_index: TinyIndex, completer: Completer):
|
|||
logger.info("Return results: %r", formatted_results)
|
||||
return formatted_results
|
||||
|
||||
def get_query_regex(terms):
|
||||
term_patterns = [rf'\b{term}\b' for term in terms]
|
||||
def get_query_regex(terms, is_complete):
|
||||
if not terms:
|
||||
return ''
|
||||
|
||||
if is_complete:
|
||||
term_patterns = [rf'\b{term}\b' for term in terms]
|
||||
else:
|
||||
term_patterns = [rf'\b{term}\b' for term in terms[:-1]] + [rf'\b{terms[-1]}']
|
||||
pattern = '|'.join(term_patterns)
|
||||
return pattern
|
||||
|
||||
def score_result(terms, result: Document):
|
||||
def score_result(terms, result: Document, is_complete: bool):
|
||||
domain = urlparse(result.url).netloc
|
||||
domain_score = DOMAINS.get(domain, 0.0)
|
||||
|
||||
result_string = f"{result.title.strip()} {result.extract.strip()}"
|
||||
query_regex = get_query_regex(terms)
|
||||
query_regex = get_query_regex(terms, is_complete)
|
||||
matches = list(re.finditer(query_regex, result_string, flags=re.IGNORECASE))
|
||||
match_strings = {x.group(0).lower() for x in matches}
|
||||
match_length = sum(len(x) for x in match_strings)
|
||||
|
@ -78,8 +85,8 @@ def create(tiny_index: TinyIndex, completer: Completer):
|
|||
score = 0.1*domain_score + 0.9*(match_length + 1./last_match_char) / (total_possible_match_length + 1)
|
||||
return score
|
||||
|
||||
def order_results(terms: list[str], results: list[Document]):
|
||||
results_and_scores = [(score_result(terms, result), result) for result in results]
|
||||
def order_results(terms: list[str], results: list[Document], is_complete: bool):
|
||||
results_and_scores = [(score_result(terms, result, is_complete), result) for result in results]
|
||||
ordered_results = sorted(results_and_scores, key=itemgetter(0), reverse=True)
|
||||
filtered_results = [result for score, result in ordered_results if score > SCORE_THRESHOLD]
|
||||
return filtered_results
|
||||
|
@ -95,12 +102,15 @@ def create(tiny_index: TinyIndex, completer: Completer):
|
|||
|
||||
def get_results(q):
|
||||
terms = [x.lower() for x in q.replace('.', ' ').split()]
|
||||
if not q.endswith(' '):
|
||||
terms[-1] = completer.complete(terms[-1])
|
||||
is_complete = q.endswith(' ')
|
||||
if len(terms) > 0 and not is_complete:
|
||||
retrieval_terms = terms[:-1] + completer.complete(terms[-1])
|
||||
else:
|
||||
retrieval_terms = terms
|
||||
|
||||
pages = []
|
||||
seen_items = set()
|
||||
for term in terms:
|
||||
for term in retrieval_terms:
|
||||
items = tiny_index.retrieve(term)
|
||||
if items is not None:
|
||||
for item in items:
|
||||
|
@ -109,6 +119,6 @@ def create(tiny_index: TinyIndex, completer: Completer):
|
|||
pages.append(item)
|
||||
seen_items.add(item.title)
|
||||
|
||||
ordered_results = order_results(terms, pages)
|
||||
ordered_results = order_results(terms, pages, is_complete)
|
||||
return ordered_results, terms
|
||||
return app
|
||||
|
|
Loading…
Add table
Reference in a new issue