Refactor to allow LTR ranker
This commit is contained in:
parent
94287cec01
commit
229819e57e
5 changed files with 51 additions and 8 deletions
|
@ -7,7 +7,7 @@ import uvicorn
|
|||
from mwmbl.tinysearchengine import create_app
|
||||
from mwmbl.tinysearchengine.completer import Completer
|
||||
from mwmbl.tinysearchengine.indexer import TinyIndex, Document
|
||||
from mwmbl.tinysearchengine.rank import Ranker
|
||||
from mwmbl.tinysearchengine.rank import HeuristicRanker
|
||||
|
||||
logging.basicConfig()
|
||||
|
||||
|
@ -37,7 +37,7 @@ def main():
|
|||
completer = Completer(terms)
|
||||
|
||||
with TinyIndex(item_factory=Document, index_path=args.index) as tiny_index:
|
||||
ranker = Ranker(tiny_index, completer)
|
||||
ranker = HeuristicRanker(tiny_index, completer)
|
||||
|
||||
# Initialize FastApi instance
|
||||
app = create_app.create(ranker)
|
||||
|
|
|
@ -10,7 +10,7 @@ from starlette.middleware.cors import CORSMiddleware
|
|||
from mwmbl.tinysearchengine.completer import Completer
|
||||
from mwmbl.tinysearchengine.hn_top_domains_filtered import DOMAINS
|
||||
from mwmbl.tinysearchengine.indexer import TinyIndex, Document
|
||||
from mwmbl.tinysearchengine.rank import Ranker
|
||||
from mwmbl.tinysearchengine.rank import HeuristicRanker
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -18,7 +18,7 @@ logger = getLogger(__name__)
|
|||
SCORE_THRESHOLD = 0.25
|
||||
|
||||
|
||||
def create(ranker: Ranker):
|
||||
def create(ranker: HeuristicRanker):
|
||||
app = FastAPI()
|
||||
|
||||
# Allow CORS requests from any site
|
||||
|
|
|
@ -31,7 +31,7 @@ def get_match_features_as_series(item: Series):
|
|||
last_match_char_extract, match_length_extract, total_possible_match_length_extract = get_match_features(
|
||||
terms, item['extract'], True, False)
|
||||
last_match_char_url, match_length_url, total_possible_match_length_url = get_match_features(
|
||||
terms, item['title'], True, False)
|
||||
terms, item['url'], True, False)
|
||||
domain_score = get_domain_score(item['url'])
|
||||
return Series({
|
||||
'last_match_char_title': last_match_char_title,
|
||||
|
|
32
mwmbl/tinysearchengine/ltr_rank.py
Normal file
32
mwmbl/tinysearchengine/ltr_rank.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
import numpy as np
|
||||
from pandas import DataFrame
|
||||
from sklearn.base import BaseEstimator
|
||||
|
||||
from mwmbl.tinysearchengine.completer import Completer
|
||||
from mwmbl.tinysearchengine.indexer import Document, TinyIndex
|
||||
from mwmbl.tinysearchengine.rank import Ranker
|
||||
|
||||
|
||||
class LTRRanker(Ranker):
|
||||
def __init__(self, model: BaseEstimator, tiny_index: TinyIndex, completer: Completer):
|
||||
super().__init__(tiny_index, completer)
|
||||
self.model = model
|
||||
|
||||
def order_results(self, terms, pages: list[Document], is_complete):
|
||||
if len(pages) == 0:
|
||||
return []
|
||||
|
||||
query = ' '.join(terms)
|
||||
data = {
|
||||
'query': [query] * len(pages),
|
||||
'url': [page.url for page in pages],
|
||||
'title': [page.title for page in pages],
|
||||
'extract': [page.extract for page in pages],
|
||||
'score': [page.score for page in pages],
|
||||
}
|
||||
|
||||
dataframe = DataFrame(data)
|
||||
print("Ordering results", dataframe)
|
||||
predictions = self.model.predict(dataframe)
|
||||
indexes = np.argsort(predictions)[::-1]
|
||||
return [pages[i] for i in indexes]
|
|
@ -1,4 +1,5 @@
|
|||
import re
|
||||
from abc import abstractmethod
|
||||
from logging import getLogger
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
|
@ -70,7 +71,7 @@ def get_match_features(terms, result_string, is_complete, is_url):
|
|||
return last_match_char, match_length, total_possible_match_length
|
||||
|
||||
|
||||
def _order_results(terms: list[str], results: list[Document], is_complete: bool):
|
||||
def _order_results(terms: list[str], results: list[Document], is_complete: bool) -> list[Document]:
|
||||
if len(results) == 0:
|
||||
return []
|
||||
|
||||
|
@ -86,11 +87,15 @@ class Ranker:
|
|||
self.tiny_index = tiny_index
|
||||
self.completer = completer
|
||||
|
||||
@abstractmethod
|
||||
def order_results(self, terms, pages, is_complete):
|
||||
pass
|
||||
|
||||
def search(self, s: str):
|
||||
results, terms = self.get_results(s)
|
||||
|
||||
is_complete = s.endswith(' ')
|
||||
pattern = _get_query_regex(terms, is_complete)
|
||||
pattern = _get_query_regex(terms, is_complete, False)
|
||||
formatted_results = []
|
||||
for result in results:
|
||||
formatted_result = {}
|
||||
|
@ -137,5 +142,11 @@ class Ranker:
|
|||
pages.append(item)
|
||||
seen_items.add(item.title)
|
||||
|
||||
ordered_results = _order_results(terms, pages, is_complete)
|
||||
ordered_results = self.order_results(terms, pages, is_complete)
|
||||
return ordered_results, terms
|
||||
|
||||
|
||||
class HeuristicRanker(Ranker):
|
||||
def order_results(self, terms, pages, is_complete):
|
||||
return _order_results(terms, pages, is_complete)
|
||||
|
||||
|
|
Loading…
Reference in a new issue