Browse Source

New LTR model trained on more data

Daoud Clarke 3 years ago
parent
commit
c1d361c0a0

+ 23 - 0
analyse/search.py

@@ -0,0 +1,23 @@
+import logging
+import sys
+
+from mwmbl.indexer.paths import INDEX_PATH
+from mwmbl.tinysearchengine.completer import Completer
+from mwmbl.tinysearchengine.indexer import TinyIndex, Document
+from mwmbl.tinysearchengine.rank import HeuristicRanker
+
+logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
+
+
+def run():
+    with TinyIndex(Document, INDEX_PATH) as tiny_index:
+        completer = Completer()
+        ranker = HeuristicRanker(tiny_index, completer)
+        items = ranker.search('jasper fforde')
+        if items:
+            for item in items:
+                print("Items", item)
+
+
+if __name__ == '__main__':
+    run()

+ 8 - 3
mwmbl/main.py

@@ -1,13 +1,13 @@
 import argparse
 import logging
 import os
+import pickle
 import sys
 from multiprocessing import Process, Queue
 from pathlib import Path
 
 import uvicorn
 from fastapi import FastAPI
-
 from mwmbl import background
 from mwmbl.crawler import app as crawler
 from mwmbl.indexer.batch_cache import BatchCache
@@ -15,11 +15,14 @@ from mwmbl.indexer.paths import INDEX_NAME, BATCH_DIR_NAME
 from mwmbl.tinysearchengine import search
 from mwmbl.tinysearchengine.completer import Completer
 from mwmbl.tinysearchengine.indexer import TinyIndex, Document, NUM_PAGES, PAGE_SIZE
-from mwmbl.tinysearchengine.rank import HeuristicRanker
+from mwmbl.tinysearchengine.ltr_rank import LTRRanker
 
 logging.basicConfig(stream=sys.stdout, level=logging.INFO)
 
 
+MODEL_PATH = Path(__file__).parent / 'resources' / 'model.pickle'
+
+
 def setup_args():
     parser = argparse.ArgumentParser(description="mwmbl-tinysearchengine")
     parser.add_argument("--data", help="Path to the tinysearchengine index file", default="/app/storage/")
@@ -55,7 +58,9 @@ def run():
     completer = Completer()
 
     with TinyIndex(item_factory=Document, index_path=index_path) as tiny_index:
-        ranker = HeuristicRanker(tiny_index, completer)
+        # ranker = HeuristicRanker(tiny_index, completer)
+        model = pickle.load(open(MODEL_PATH, 'rb'))
+        ranker = LTRRanker(model, tiny_index, completer)
 
         # Initialize FastApi instance
         app = FastAPI()

BIN
mwmbl/resources/model.pickle


+ 1 - 0
mwmbl/tinysearchengine/ltr.py

@@ -37,6 +37,7 @@ def get_match_features_as_series(item: Series):
     features['num_terms'] = len(terms)
     features['num_chars'] = len(' '.join(terms))
     features['domain_score'] = get_domain_score(item['url'])
+    features['url_length'] = len(item['url'])
     features['item_score'] = item['score']
     return Series(features)
 

+ 9 - 4
mwmbl/tinysearchengine/rank.py

@@ -13,7 +13,8 @@ logger = getLogger(__name__)
 
 
 SCORE_THRESHOLD = 0.0
-LENGTH_PENALTY=0.01
+LENGTH_PENALTY = 0.01
+MATCH_EXPONENT = 1.5
 
 
 def _get_query_regex(terms, is_complete, is_url):
@@ -37,11 +38,13 @@ def _score_result(terms, result: Document, is_complete: bool):
     domain = parsed_url.netloc
     path = parsed_url.path
     string_scores = []
+    logger.debug(f"Item: {result}")
     for result_string, is_url in [(result.title, False), (result.extract, False), (domain, True), (domain, False), (path, True)]:
         last_match_char, match_length, total_possible_match_length = get_match_features(
             terms, result_string, is_complete, is_url)
 
         new_score = score_match(last_match_char, match_length, total_possible_match_length)
+        logger.debug(f"Item score: {new_score}, result {result_string}")
         string_scores.append(new_score)
     title_score, extract_score, domain_score, domain_split_score, path_score = string_scores
 
@@ -55,7 +58,8 @@ def _score_result(terms, result: Document, is_complete: bool):
 
 
 def score_match(last_match_char, match_length, total_possible_match_length):
-    return (match_length + 1. / last_match_char) / (total_possible_match_length + 1)
+    # return (match_length + 1. / last_match_char) / (total_possible_match_length + 1)
+    return MATCH_EXPONENT ** (match_length - total_possible_match_length) / last_match_char
 
 
 def get_domain_score(url):
@@ -66,6 +70,7 @@ def get_domain_score(url):
 
 def get_match_features(terms, result_string, is_complete, is_url):
     query_regex = _get_query_regex(terms, is_complete, is_url)
+    print("Result string", result_string)
     matches = list(re.finditer(query_regex, result_string, flags=re.IGNORECASE))
     match_strings = {x.group(0).lower() for x in matches}
     match_length = sum(len(x) for x in match_strings)
@@ -138,9 +143,9 @@ class Ranker:
         terms = [x.lower() for x in q.replace('.', ' ').split()]
         is_complete = q.endswith(' ')
         if len(terms) > 0 and not is_complete:
-            retrieval_terms = terms + self.completer.complete(terms[-1])
+            retrieval_terms = set(terms + self.completer.complete(terms[-1]))
         else:
-            retrieval_terms = terms
+            retrieval_terms = set(terms)
 
         pages = []
         seen_items = set()