Make order_results public

This commit is contained in:
Daoud Clarke 2022-05-06 23:15:50 +01:00
parent 229819e57e
commit 87d8b40cad
2 changed files with 12 additions and 9 deletions

View file

@ -4,29 +4,32 @@ from sklearn.base import BaseEstimator
from mwmbl.tinysearchengine.completer import Completer
from mwmbl.tinysearchengine.indexer import Document, TinyIndex
from mwmbl.tinysearchengine.rank import Ranker
from mwmbl.tinysearchengine.rank import Ranker, order_results
class LTRRanker(Ranker):
def __init__(self, model: BaseEstimator, tiny_index: TinyIndex, completer: Completer):
super().__init__(tiny_index, completer)
self.model = model
self.top_n = 20
def order_results(self, terms, pages: list[Document], is_complete):
if len(pages) == 0:
return []
top_pages = order_results(terms, pages, is_complete)[:self.top_n]
query = ' '.join(terms)
data = {
'query': [query] * len(pages),
'url': [page.url for page in pages],
'title': [page.title for page in pages],
'extract': [page.extract for page in pages],
'score': [page.score for page in pages],
'query': [query] * len(top_pages),
'url': [page.url for page in top_pages],
'title': [page.title for page in top_pages],
'extract': [page.extract for page in top_pages],
'score': [page.score for page in top_pages],
}
dataframe = DataFrame(data)
print("Ordering results", dataframe)
predictions = self.model.predict(dataframe)
indexes = np.argsort(predictions)[::-1]
return [pages[i] for i in indexes]
return [top_pages[i] for i in indexes]

View file

@ -71,7 +71,7 @@ def get_match_features(terms, result_string, is_complete, is_url):
return last_match_char, match_length, total_possible_match_length
def _order_results(terms: list[str], results: list[Document], is_complete: bool) -> list[Document]:
def order_results(terms: list[str], results: list[Document], is_complete: bool) -> list[Document]:
if len(results) == 0:
return []
@ -148,5 +148,5 @@ class Ranker:
class HeuristicRanker(Ranker):
def order_results(self, terms, pages, is_complete):
return _order_results(terms, pages, is_complete)
return order_results(terms, pages, is_complete)