Refactor to allow LTR ranker

This commit is contained in:
Daoud Clarke 2022-03-27 22:32:44 +01:00
parent 94287cec01
commit 229819e57e
5 changed files with 51 additions and 8 deletions

View file

@ -7,7 +7,7 @@ import uvicorn
from mwmbl.tinysearchengine import create_app
from mwmbl.tinysearchengine.completer import Completer
from mwmbl.tinysearchengine.indexer import TinyIndex, Document
from mwmbl.tinysearchengine.rank import Ranker
from mwmbl.tinysearchengine.rank import HeuristicRanker
logging.basicConfig()
@ -37,7 +37,7 @@ def main():
completer = Completer(terms)
with TinyIndex(item_factory=Document, index_path=args.index) as tiny_index:
ranker = Ranker(tiny_index, completer)
ranker = HeuristicRanker(tiny_index, completer)
# Initialize FastApi instance
app = create_app.create(ranker)

View file

@ -10,7 +10,7 @@ from starlette.middleware.cors import CORSMiddleware
from mwmbl.tinysearchengine.completer import Completer
from mwmbl.tinysearchengine.hn_top_domains_filtered import DOMAINS
from mwmbl.tinysearchengine.indexer import TinyIndex, Document
from mwmbl.tinysearchengine.rank import Ranker
from mwmbl.tinysearchengine.rank import HeuristicRanker
logger = getLogger(__name__)
@ -18,7 +18,7 @@ logger = getLogger(__name__)
SCORE_THRESHOLD = 0.25
def create(ranker: Ranker):
def create(ranker: HeuristicRanker):
app = FastAPI()
# Allow CORS requests from any site

View file

@ -31,7 +31,7 @@ def get_match_features_as_series(item: Series):
last_match_char_extract, match_length_extract, total_possible_match_length_extract = get_match_features(
terms, item['extract'], True, False)
last_match_char_url, match_length_url, total_possible_match_length_url = get_match_features(
terms, item['title'], True, False)
terms, item['url'], True, False)
domain_score = get_domain_score(item['url'])
return Series({
'last_match_char_title': last_match_char_title,

View file

@ -0,0 +1,32 @@
import numpy as np
from pandas import DataFrame
from sklearn.base import BaseEstimator
from mwmbl.tinysearchengine.completer import Completer
from mwmbl.tinysearchengine.indexer import Document, TinyIndex
from mwmbl.tinysearchengine.rank import Ranker
class LTRRanker(Ranker):
def __init__(self, model: BaseEstimator, tiny_index: TinyIndex, completer: Completer):
super().__init__(tiny_index, completer)
self.model = model
def order_results(self, terms, pages: list[Document], is_complete):
if len(pages) == 0:
return []
query = ' '.join(terms)
data = {
'query': [query] * len(pages),
'url': [page.url for page in pages],
'title': [page.title for page in pages],
'extract': [page.extract for page in pages],
'score': [page.score for page in pages],
}
dataframe = DataFrame(data)
print("Ordering results", dataframe)
predictions = self.model.predict(dataframe)
indexes = np.argsort(predictions)[::-1]
return [pages[i] for i in indexes]

View file

@ -1,4 +1,5 @@
import re
from abc import abstractmethod
from logging import getLogger
from operator import itemgetter
from pathlib import Path
@ -70,7 +71,7 @@ def get_match_features(terms, result_string, is_complete, is_url):
return last_match_char, match_length, total_possible_match_length
def _order_results(terms: list[str], results: list[Document], is_complete: bool):
def _order_results(terms: list[str], results: list[Document], is_complete: bool) -> list[Document]:
if len(results) == 0:
return []
@ -86,11 +87,15 @@ class Ranker:
self.tiny_index = tiny_index
self.completer = completer
@abstractmethod
def order_results(self, terms, pages, is_complete):
pass
def search(self, s: str):
results, terms = self.get_results(s)
is_complete = s.endswith(' ')
pattern = _get_query_regex(terms, is_complete)
pattern = _get_query_regex(terms, is_complete, False)
formatted_results = []
for result in results:
formatted_result = {}
@ -137,5 +142,11 @@ class Ranker:
pages.append(item)
seen_items.add(item.title)
ordered_results = _order_results(terms, pages, is_complete)
ordered_results = self.order_results(terms, pages, is_complete)
return ordered_results, terms
class HeuristicRanker(Ranker):
def order_results(self, terms, pages, is_complete):
return _order_results(terms, pages, is_complete)