Make get_results() public for learning to rank
This commit is contained in:
parent
ee5ca6bcf6
commit
2d334074af
1 changed files with 6 additions and 5 deletions
|
@ -22,9 +22,9 @@ def _get_query_regex(terms, is_complete):
|
|||
return ''
|
||||
|
||||
if is_complete:
|
||||
term_patterns = [rf'\b{term}\b' for term in terms]
|
||||
term_patterns = [rf'\b{re.escape(term)}\b' for term in terms]
|
||||
else:
|
||||
term_patterns = [rf'\b{term}\b' for term in terms[:-1]] + [rf'\b{terms[-1]}']
|
||||
term_patterns = [rf'\b{re.escape(term)}\b' for term in terms[:-1]] + [rf'\b{re.escape(terms[-1])}']
|
||||
pattern = '|'.join(term_patterns)
|
||||
return pattern
|
||||
|
||||
|
@ -35,6 +35,7 @@ def _score_result(terms, result: Document, is_complete: bool, max_score: float):
|
|||
|
||||
result_string = f"{result.title.strip()} {result.extract.strip()}"
|
||||
query_regex = _get_query_regex(terms, is_complete)
|
||||
print("Query regex", query_regex)
|
||||
matches = list(re.finditer(query_regex, result_string, flags=re.IGNORECASE))
|
||||
match_strings = {x.group(0).lower() for x in matches}
|
||||
match_length = sum(len(x) for x in match_strings)
|
||||
|
@ -72,7 +73,7 @@ class Ranker:
|
|||
self.completer = completer
|
||||
|
||||
def search(self, s: str):
|
||||
results, terms = self._get_results(s)
|
||||
results, terms = self.get_results(s)
|
||||
|
||||
is_complete = s.endswith(' ')
|
||||
pattern = _get_query_regex(terms, is_complete)
|
||||
|
@ -96,14 +97,14 @@ class Ranker:
|
|||
return formatted_results
|
||||
|
||||
def complete(self, q: str):
|
||||
ordered_results, terms = self._get_results(q)
|
||||
ordered_results, terms = self.get_results(q)
|
||||
results = [item.title.replace("\n", "") + ' — ' +
|
||||
item.url.replace("\n", "") for item in ordered_results]
|
||||
if len(results) == 0:
|
||||
return []
|
||||
return [q, results]
|
||||
|
||||
def _get_results(self, q):
|
||||
def get_results(self, q):
|
||||
terms = [x.lower() for x in q.replace('.', ' ').split()]
|
||||
is_complete = q.endswith(' ')
|
||||
if len(terms) > 0 and not is_complete:
|
||||
|
|
Loading…
Add table
Reference in a new issue