Make order_results public
This commit is contained in:
parent
229819e57e
commit
87d8b40cad
2 changed files with 12 additions and 9 deletions
|
@ -4,29 +4,32 @@ from sklearn.base import BaseEstimator
|
|||
|
||||
from mwmbl.tinysearchengine.completer import Completer
|
||||
from mwmbl.tinysearchengine.indexer import Document, TinyIndex
|
||||
from mwmbl.tinysearchengine.rank import Ranker
|
||||
from mwmbl.tinysearchengine.rank import Ranker, order_results
|
||||
|
||||
|
||||
class LTRRanker(Ranker):
|
||||
def __init__(self, model: BaseEstimator, tiny_index: TinyIndex, completer: Completer):
|
||||
super().__init__(tiny_index, completer)
|
||||
self.model = model
|
||||
self.top_n = 20
|
||||
|
||||
def order_results(self, terms, pages: list[Document], is_complete):
|
||||
if len(pages) == 0:
|
||||
return []
|
||||
|
||||
top_pages = order_results(terms, pages, is_complete)[:self.top_n]
|
||||
|
||||
query = ' '.join(terms)
|
||||
data = {
|
||||
'query': [query] * len(pages),
|
||||
'url': [page.url for page in pages],
|
||||
'title': [page.title for page in pages],
|
||||
'extract': [page.extract for page in pages],
|
||||
'score': [page.score for page in pages],
|
||||
'query': [query] * len(top_pages),
|
||||
'url': [page.url for page in top_pages],
|
||||
'title': [page.title for page in top_pages],
|
||||
'extract': [page.extract for page in top_pages],
|
||||
'score': [page.score for page in top_pages],
|
||||
}
|
||||
|
||||
dataframe = DataFrame(data)
|
||||
print("Ordering results", dataframe)
|
||||
predictions = self.model.predict(dataframe)
|
||||
indexes = np.argsort(predictions)[::-1]
|
||||
return [pages[i] for i in indexes]
|
||||
return [top_pages[i] for i in indexes]
|
||||
|
|
|
@ -71,7 +71,7 @@ def get_match_features(terms, result_string, is_complete, is_url):
|
|||
return last_match_char, match_length, total_possible_match_length
|
||||
|
||||
|
||||
def _order_results(terms: list[str], results: list[Document], is_complete: bool) -> list[Document]:
|
||||
def order_results(terms: list[str], results: list[Document], is_complete: bool) -> list[Document]:
|
||||
if len(results) == 0:
|
||||
return []
|
||||
|
||||
|
@ -148,5 +148,5 @@ class Ranker:
|
|||
|
||||
class HeuristicRanker(Ranker):
|
||||
def order_results(self, terms, pages, is_complete):
|
||||
return _order_results(terms, pages, is_complete)
|
||||
return order_results(terms, pages, is_complete)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue