Merge branch 'completion' into local

This commit is contained in:
Daoud Clarke 2022-08-13 10:08:37 +01:00
commit 6022d867a3
3 changed files with 40 additions and 24 deletions

View file

@ -1,5 +1,6 @@
import logging
import sys
from itertools import islice
from mwmbl.indexer.paths import INDEX_PATH
from mwmbl.tinysearchengine.completer import Completer
@ -9,14 +10,22 @@ from mwmbl.tinysearchengine.rank import HeuristicRanker
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
def clean(sequence):
return ''.join(x['value'] for x in sequence)
def run():
with TinyIndex(Document, INDEX_PATH) as tiny_index:
completer = Completer()
ranker = HeuristicRanker(tiny_index, completer)
items = ranker.search('jasper fforde')
print()
if items:
for item in items:
print("Items", item)
for i, item in enumerate(islice(items, 10)):
print(f"{i + 1}. {item['url']}")
print(clean(item['title']))
print(clean(item['extract']))
print()
if __name__ == '__main__':

View file

@ -1,7 +1,6 @@
import argparse
import logging
import os
import pickle
import sys
from multiprocessing import Process, Queue
from pathlib import Path
@ -16,7 +15,7 @@ from mwmbl.indexer.paths import INDEX_NAME, BATCH_DIR_NAME
from mwmbl.tinysearchengine import search
from mwmbl.tinysearchengine.completer import Completer
from mwmbl.tinysearchengine.indexer import TinyIndex, Document, NUM_PAGES, PAGE_SIZE
from mwmbl.tinysearchengine.ltr_rank import LTRRanker
from mwmbl.tinysearchengine.rank import HeuristicRanker
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
@ -59,9 +58,9 @@ def run():
completer = Completer()
with TinyIndex(item_factory=Document, index_path=index_path) as tiny_index:
# ranker = HeuristicRanker(tiny_index, completer)
model = pickle.load(open(MODEL_PATH, 'rb'))
ranker = LTRRanker(model, tiny_index, completer)
ranker = HeuristicRanker(tiny_index, completer)
# model = pickle.load(open(MODEL_PATH, 'rb'))
# ranker = LTRRanker(model, tiny_index, completer)
# Initialize FastApi instance
app = FastAPI()

View file

@ -12,10 +12,12 @@ from mwmbl.tinysearchengine.indexer import TinyIndex, Document
logger = getLogger(__name__)
MATCH_SCORE_THRESHOLD = 0.0
SCORE_THRESHOLD = 0.0
LENGTH_PENALTY = 0.04
MATCH_EXPONENT = 2
DOMAIN_SCORE_SMOOTHING = 50
HTTPS_STRING = 'https://'
def _get_query_regex(terms, is_complete, is_url):
@ -36,15 +38,20 @@ def _score_result(terms: list[str], result: Document, is_complete: bool):
features = get_features(terms, result.title, result.url, result.extract, result.score, is_complete)
length_penalty = math.e ** (-LENGTH_PENALTY * len(result.url))
score = (
4 * features['match_score_title']
+ features['match_score_extract'] +
2 * features['match_score_domain'] +
2 * features['match_score_domain_tokenized']
+ features['match_score_path']) * length_penalty * (features['domain_score'] + DOMAIN_SCORE_SMOOTHING) / 10
match_score = (4 * features['match_score_title'] + features['match_score_extract'] + 2 * features[
'match_score_domain'] + 2 * features['match_score_domain_tokenized'] + features['match_score_path'])
max_match_terms = max(features[f'match_terms_{name}']
for name in ['title', 'extract', 'domain', 'domain_tokenized', 'path'])
if max_match_terms <= len(terms) / 2:
return 0.0
if match_score > MATCH_SCORE_THRESHOLD:
return match_score * length_penalty * (features['domain_score'] + DOMAIN_SCORE_SMOOTHING) / 10
# best_match_score = max(features[f'match_score_{name}'] for name in ['title', 'extract', 'domain', 'domain_tokenized'])
# score = best_match_score * length_penalty * (features['domain_score'] + DOMAIN_SCORE_SMOOTHING)
return score
return 0.0
def score_match(last_match_char, match_length, total_possible_match_length):
@ -108,7 +115,6 @@ def order_results(terms: list[str], results: list[Document], is_complete: bool)
if len(results) == 0:
return []
max_score = max(result.score for result in results)
results_and_scores = [(_score_result(terms, result, is_complete), result) for result in results]
ordered_results = sorted(results_and_scores, key=itemgetter(0), reverse=True)
filtered_results = [result for score, result in ordered_results if score > SCORE_THRESHOLD]
@ -125,7 +131,7 @@ class Ranker:
pass
def search(self, s: str):
results, terms = self.get_results(s)
results, terms, _ = self.get_results(s)
is_complete = s.endswith(' ')
pattern = _get_query_regex(terms, is_complete, False)
@ -149,19 +155,21 @@ class Ranker:
return formatted_results
def complete(self, q: str):
ordered_results, terms = self.get_results(q)
results = [item.title.replace("\n", "") + '' +
item.url.replace("\n", "") for item in ordered_results]
if len(results) == 0:
return []
return [q, results]
ordered_results, terms, completions = self.get_results(q)
filtered_completions = [c for c in completions if c != terms[-1]]
urls = [item.url[len(HTTPS_STRING):].rstrip('/') for item in ordered_results[:5]
if item.url.startswith(HTTPS_STRING) and all(term in item.url for term in terms)][:1]
completed = [' '.join(terms[:-1] + [t]) for t in filtered_completions]
return [q, urls + completed]
def get_results(self, q):
terms = [x.lower() for x in q.replace('.', ' ').split()]
is_complete = q.endswith(' ')
if len(terms) > 0 and not is_complete:
retrieval_terms = set(terms + self.completer.complete(terms[-1]))
completions = self.completer.complete(terms[-1])
retrieval_terms = set(terms + completions)
else:
completions = []
retrieval_terms = set(terms)
pages = []
@ -176,7 +184,7 @@ class Ranker:
seen_items.add(item.title)
ordered_results = self.order_results(terms, pages, is_complete)
return ordered_results, terms
return ordered_results, terms, completions
class HeuristicRanker(Ranker):