Browse Source

Make order_results public

Daoud Clarke 3 years ago
parent
commit
87d8b40cad
2 changed files with 12 additions and 9 deletions
  1. 10 7
      mwmbl/tinysearchengine/ltr_rank.py
  2. 2 2
      mwmbl/tinysearchengine/rank.py

+ 10 - 7
mwmbl/tinysearchengine/ltr_rank.py

@@ -4,29 +4,32 @@ from sklearn.base import BaseEstimator
 
 from mwmbl.tinysearchengine.completer import Completer
 from mwmbl.tinysearchengine.indexer import Document, TinyIndex
-from mwmbl.tinysearchengine.rank import Ranker
+from mwmbl.tinysearchengine.rank import Ranker, order_results
 
 
 class LTRRanker(Ranker):
     def __init__(self, model: BaseEstimator, tiny_index: TinyIndex, completer: Completer):
         super().__init__(tiny_index, completer)
         self.model = model
+        self.top_n = 20
 
     def order_results(self, terms, pages: list[Document], is_complete):
         if len(pages) == 0:
             return []
 
+        top_pages = order_results(terms, pages, is_complete)[:self.top_n]
+
         query = ' '.join(terms)
         data = {
-            'query': [query] * len(pages),
-            'url': [page.url for page in pages],
-            'title': [page.title for page in pages],
-            'extract': [page.extract for page in pages],
-            'score': [page.score for page in pages],
+            'query': [query] * len(top_pages),
+            'url': [page.url for page in top_pages],
+            'title': [page.title for page in top_pages],
+            'extract': [page.extract for page in top_pages],
+            'score': [page.score for page in top_pages],
         }
 
         dataframe = DataFrame(data)
         print("Ordering results", dataframe)
         predictions = self.model.predict(dataframe)
         indexes = np.argsort(predictions)[::-1]
-        return [pages[i] for i in indexes]
+        return [top_pages[i] for i in indexes]

+ 2 - 2
mwmbl/tinysearchengine/rank.py

@@ -71,7 +71,7 @@ def get_match_features(terms, result_string, is_complete, is_url):
     return last_match_char, match_length, total_possible_match_length
 
 
-def _order_results(terms: list[str], results: list[Document], is_complete: bool) -> list[Document]:
+def order_results(terms: list[str], results: list[Document], is_complete: bool) -> list[Document]:
     if len(results) == 0:
         return []
 
@@ -148,5 +148,5 @@ class Ranker:
 
 class HeuristicRanker(Ranker):
     def order_results(self, terms, pages, is_complete):
-        return _order_results(terms, pages, is_complete)
+        return order_results(terms, pages, is_complete)