Browse Source

New LTR model trained on more data

Daoud Clarke 3 năm trước cách đây
mục cha
commit
c1d361c0a0

+ 23 - 0
analyse/search.py

@@ -0,0 +1,23 @@
+import logging
+import sys
+
+from mwmbl.indexer.paths import INDEX_PATH
+from mwmbl.tinysearchengine.completer import Completer
+from mwmbl.tinysearchengine.indexer import TinyIndex, Document
+from mwmbl.tinysearchengine.rank import HeuristicRanker
+
+logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
+
+
+def run():
+    with TinyIndex(Document, INDEX_PATH) as tiny_index:
+        completer = Completer()
+        ranker = HeuristicRanker(tiny_index, completer)
+        items = ranker.search('jasper fforde')
+        if items:
+            for item in items:
+                print("Items", item)
+
+
+if __name__ == '__main__':
+    run()

+ 8 - 3
mwmbl/main.py

@@ -1,13 +1,13 @@
 import argparse
 import argparse
 import logging
 import logging
 import os
 import os
+import pickle
 import sys
 import sys
 from multiprocessing import Process, Queue
 from multiprocessing import Process, Queue
 from pathlib import Path
 from pathlib import Path
 
 
 import uvicorn
 import uvicorn
 from fastapi import FastAPI
 from fastapi import FastAPI
-
 from mwmbl import background
 from mwmbl import background
 from mwmbl.crawler import app as crawler
 from mwmbl.crawler import app as crawler
 from mwmbl.indexer.batch_cache import BatchCache
 from mwmbl.indexer.batch_cache import BatchCache
@@ -15,11 +15,14 @@ from mwmbl.indexer.paths import INDEX_NAME, BATCH_DIR_NAME
 from mwmbl.tinysearchengine import search
 from mwmbl.tinysearchengine import search
 from mwmbl.tinysearchengine.completer import Completer
 from mwmbl.tinysearchengine.completer import Completer
 from mwmbl.tinysearchengine.indexer import TinyIndex, Document, NUM_PAGES, PAGE_SIZE
 from mwmbl.tinysearchengine.indexer import TinyIndex, Document, NUM_PAGES, PAGE_SIZE
-from mwmbl.tinysearchengine.rank import HeuristicRanker
+from mwmbl.tinysearchengine.ltr_rank import LTRRanker
 
 
 logging.basicConfig(stream=sys.stdout, level=logging.INFO)
 logging.basicConfig(stream=sys.stdout, level=logging.INFO)
 
 
 
 
+MODEL_PATH = Path(__file__).parent / 'resources' / 'model.pickle'
+
+
 def setup_args():
 def setup_args():
     parser = argparse.ArgumentParser(description="mwmbl-tinysearchengine")
     parser = argparse.ArgumentParser(description="mwmbl-tinysearchengine")
     parser.add_argument("--data", help="Path to the tinysearchengine index file", default="/app/storage/")
     parser.add_argument("--data", help="Path to the tinysearchengine index file", default="/app/storage/")
@@ -55,7 +58,9 @@ def run():
     completer = Completer()
     completer = Completer()
 
 
     with TinyIndex(item_factory=Document, index_path=index_path) as tiny_index:
     with TinyIndex(item_factory=Document, index_path=index_path) as tiny_index:
-        ranker = HeuristicRanker(tiny_index, completer)
+        # ranker = HeuristicRanker(tiny_index, completer)
+        model = pickle.load(open(MODEL_PATH, 'rb'))
+        ranker = LTRRanker(model, tiny_index, completer)
 
 
         # Initialize FastApi instance
         # Initialize FastApi instance
         app = FastAPI()
         app = FastAPI()

BIN
mwmbl/resources/model.pickle


+ 1 - 0
mwmbl/tinysearchengine/ltr.py

@@ -37,6 +37,7 @@ def get_match_features_as_series(item: Series):
     features['num_terms'] = len(terms)
     features['num_terms'] = len(terms)
     features['num_chars'] = len(' '.join(terms))
     features['num_chars'] = len(' '.join(terms))
     features['domain_score'] = get_domain_score(item['url'])
     features['domain_score'] = get_domain_score(item['url'])
+    features['url_length'] = len(item['url'])
     features['item_score'] = item['score']
     features['item_score'] = item['score']
     return Series(features)
     return Series(features)
 
 

+ 9 - 4
mwmbl/tinysearchengine/rank.py

@@ -13,7 +13,8 @@ logger = getLogger(__name__)
 
 
 
 
 SCORE_THRESHOLD = 0.0
 SCORE_THRESHOLD = 0.0
-LENGTH_PENALTY=0.01
+LENGTH_PENALTY = 0.01
+MATCH_EXPONENT = 1.5
 
 
 
 
 def _get_query_regex(terms, is_complete, is_url):
 def _get_query_regex(terms, is_complete, is_url):
@@ -37,11 +38,13 @@ def _score_result(terms, result: Document, is_complete: bool):
     domain = parsed_url.netloc
     domain = parsed_url.netloc
     path = parsed_url.path
     path = parsed_url.path
     string_scores = []
     string_scores = []
+    logger.debug(f"Item: {result}")
     for result_string, is_url in [(result.title, False), (result.extract, False), (domain, True), (domain, False), (path, True)]:
     for result_string, is_url in [(result.title, False), (result.extract, False), (domain, True), (domain, False), (path, True)]:
         last_match_char, match_length, total_possible_match_length = get_match_features(
         last_match_char, match_length, total_possible_match_length = get_match_features(
             terms, result_string, is_complete, is_url)
             terms, result_string, is_complete, is_url)
 
 
         new_score = score_match(last_match_char, match_length, total_possible_match_length)
         new_score = score_match(last_match_char, match_length, total_possible_match_length)
+        logger.debug(f"Item score: {new_score}, result {result_string}")
         string_scores.append(new_score)
         string_scores.append(new_score)
     title_score, extract_score, domain_score, domain_split_score, path_score = string_scores
     title_score, extract_score, domain_score, domain_split_score, path_score = string_scores
 
 
@@ -55,7 +58,8 @@ def _score_result(terms, result: Document, is_complete: bool):
 
 
 
 
 def score_match(last_match_char, match_length, total_possible_match_length):
 def score_match(last_match_char, match_length, total_possible_match_length):
-    return (match_length + 1. / last_match_char) / (total_possible_match_length + 1)
+    # return (match_length + 1. / last_match_char) / (total_possible_match_length + 1)
+    return MATCH_EXPONENT ** (match_length - total_possible_match_length) / last_match_char
 
 
 
 
 def get_domain_score(url):
 def get_domain_score(url):
@@ -66,6 +70,7 @@ def get_domain_score(url):
 
 
 def get_match_features(terms, result_string, is_complete, is_url):
 def get_match_features(terms, result_string, is_complete, is_url):
     query_regex = _get_query_regex(terms, is_complete, is_url)
     query_regex = _get_query_regex(terms, is_complete, is_url)
+    print("Result string", result_string)
     matches = list(re.finditer(query_regex, result_string, flags=re.IGNORECASE))
     matches = list(re.finditer(query_regex, result_string, flags=re.IGNORECASE))
     match_strings = {x.group(0).lower() for x in matches}
     match_strings = {x.group(0).lower() for x in matches}
     match_length = sum(len(x) for x in match_strings)
     match_length = sum(len(x) for x in match_strings)
@@ -138,9 +143,9 @@ class Ranker:
         terms = [x.lower() for x in q.replace('.', ' ').split()]
         terms = [x.lower() for x in q.replace('.', ' ').split()]
         is_complete = q.endswith(' ')
         is_complete = q.endswith(' ')
         if len(terms) > 0 and not is_complete:
         if len(terms) > 0 and not is_complete:
-            retrieval_terms = terms + self.completer.complete(terms[-1])
+            retrieval_terms = set(terms + self.completer.complete(terms[-1]))
         else:
         else:
-            retrieval_terms = terms
+            retrieval_terms = set(terms)
 
 
         pages = []
         pages = []
         seen_items = set()
         seen_items = set()