فهرست منبع

Refactor to allow LTR ranker

Daoud Clarke 3 سال پیش
والد
کامیت
229819e57e

+ 2 - 2
mwmbl/tinysearchengine/app.py

@@ -7,7 +7,7 @@ import uvicorn
 from mwmbl.tinysearchengine import create_app
 from mwmbl.tinysearchengine import create_app
 from mwmbl.tinysearchengine.completer import Completer
 from mwmbl.tinysearchengine.completer import Completer
 from mwmbl.tinysearchengine.indexer import TinyIndex, Document
 from mwmbl.tinysearchengine.indexer import TinyIndex, Document
-from mwmbl.tinysearchengine.rank import Ranker
+from mwmbl.tinysearchengine.rank import HeuristicRanker
 
 
 logging.basicConfig()
 logging.basicConfig()
 
 
@@ -37,7 +37,7 @@ def main():
     completer = Completer(terms)
     completer = Completer(terms)
 
 
     with TinyIndex(item_factory=Document, index_path=args.index) as tiny_index:
     with TinyIndex(item_factory=Document, index_path=args.index) as tiny_index:
-        ranker = Ranker(tiny_index, completer)
+        ranker = HeuristicRanker(tiny_index, completer)
 
 
         # Initialize FastApi instance
         # Initialize FastApi instance
         app = create_app.create(ranker)
         app = create_app.create(ranker)

+ 2 - 2
mwmbl/tinysearchengine/create_app.py

@@ -10,7 +10,7 @@ from starlette.middleware.cors import CORSMiddleware
 from mwmbl.tinysearchengine.completer import Completer
 from mwmbl.tinysearchengine.completer import Completer
 from mwmbl.tinysearchengine.hn_top_domains_filtered import DOMAINS
 from mwmbl.tinysearchengine.hn_top_domains_filtered import DOMAINS
 from mwmbl.tinysearchengine.indexer import TinyIndex, Document
 from mwmbl.tinysearchengine.indexer import TinyIndex, Document
-from mwmbl.tinysearchengine.rank import Ranker
+from mwmbl.tinysearchengine.rank import HeuristicRanker
 
 
 logger = getLogger(__name__)
 logger = getLogger(__name__)
 
 
@@ -18,7 +18,7 @@ logger = getLogger(__name__)
 SCORE_THRESHOLD = 0.25
 SCORE_THRESHOLD = 0.25
 
 
 
 
-def create(ranker: Ranker):
+def create(ranker: HeuristicRanker):
     app = FastAPI()
     app = FastAPI()
     
     
     # Allow CORS requests from any site
     # Allow CORS requests from any site

+ 1 - 1
mwmbl/tinysearchengine/ltr.py

@@ -31,7 +31,7 @@ def get_match_features_as_series(item: Series):
     last_match_char_extract, match_length_extract, total_possible_match_length_extract = get_match_features(
     last_match_char_extract, match_length_extract, total_possible_match_length_extract = get_match_features(
         terms, item['extract'], True, False)
         terms, item['extract'], True, False)
     last_match_char_url, match_length_url, total_possible_match_length_url = get_match_features(
     last_match_char_url, match_length_url, total_possible_match_length_url = get_match_features(
-        terms, item['title'], True, False)
+        terms, item['url'], True, False)
     domain_score = get_domain_score(item['url'])
     domain_score = get_domain_score(item['url'])
     return Series({
     return Series({
         'last_match_char_title': last_match_char_title,
         'last_match_char_title': last_match_char_title,

+ 32 - 0
mwmbl/tinysearchengine/ltr_rank.py

@@ -0,0 +1,32 @@
+import numpy as np
+from pandas import DataFrame
+from sklearn.base import BaseEstimator
+
+from mwmbl.tinysearchengine.completer import Completer
+from mwmbl.tinysearchengine.indexer import Document, TinyIndex
+from mwmbl.tinysearchengine.rank import Ranker
+
+
+class LTRRanker(Ranker):
+    def __init__(self, model: BaseEstimator, tiny_index: TinyIndex, completer: Completer):
+        super().__init__(tiny_index, completer)
+        self.model = model
+
+    def order_results(self, terms, pages: list[Document], is_complete):
+        if len(pages) == 0:
+            return []
+
+        query = ' '.join(terms)
+        data = {
+            'query': [query] * len(pages),
+            'url': [page.url for page in pages],
+            'title': [page.title for page in pages],
+            'extract': [page.extract for page in pages],
+            'score': [page.score for page in pages],
+        }
+
+        dataframe = DataFrame(data)
+        print("Ordering results", dataframe)
+        predictions = self.model.predict(dataframe)
+        indexes = np.argsort(predictions)[::-1]
+        return [pages[i] for i in indexes]

+ 14 - 3
mwmbl/tinysearchengine/rank.py

@@ -1,4 +1,5 @@
 import re
 import re
+from abc import abstractmethod
 from logging import getLogger
 from logging import getLogger
 from operator import itemgetter
 from operator import itemgetter
 from pathlib import Path
 from pathlib import Path
@@ -70,7 +71,7 @@ def get_match_features(terms, result_string, is_complete, is_url):
     return last_match_char, match_length, total_possible_match_length
     return last_match_char, match_length, total_possible_match_length
 
 
 
 
-def _order_results(terms: list[str], results: list[Document], is_complete: bool):
+def _order_results(terms: list[str], results: list[Document], is_complete: bool) -> list[Document]:
     if len(results) == 0:
     if len(results) == 0:
         return []
         return []
 
 
@@ -86,11 +87,15 @@ class Ranker:
         self.tiny_index = tiny_index
         self.tiny_index = tiny_index
         self.completer = completer
         self.completer = completer
 
 
+    @abstractmethod
+    def order_results(self, terms, pages, is_complete):
+        pass
+
     def search(self, s: str):
     def search(self, s: str):
         results, terms = self.get_results(s)
         results, terms = self.get_results(s)
 
 
         is_complete = s.endswith(' ')
         is_complete = s.endswith(' ')
-        pattern = _get_query_regex(terms, is_complete)
+        pattern = _get_query_regex(terms, is_complete, False)
         formatted_results = []
         formatted_results = []
         for result in results:
         for result in results:
             formatted_result = {}
             formatted_result = {}
@@ -137,5 +142,11 @@ class Ranker:
                             pages.append(item)
                             pages.append(item)
                             seen_items.add(item.title)
                             seen_items.add(item.title)
 
 
-        ordered_results = _order_results(terms, pages, is_complete)
+        ordered_results = self.order_results(terms, pages, is_complete)
         return ordered_results, terms
         return ordered_results, terms
+
+
+class HeuristicRanker(Ranker):
+    def order_results(self, terms, pages, is_complete):
+        return _order_results(terms, pages, is_complete)
+