From e03e379ccf08f30860085d51a8ffa3bf4e4f3d15 Mon Sep 17 00:00:00 2001 From: Daoud Clarke Date: Wed, 9 Feb 2022 22:43:47 +0000 Subject: [PATCH] Refactor to enable easier evaluation --- analyse/performance.py | 2 +- mwmbl/tinysearchengine/app.py | 5 +- mwmbl/tinysearchengine/create_app.py | 93 +-------------------- mwmbl/tinysearchengine/rank.py | 119 +++++++++++++++++++++++++++ 4 files changed, 128 insertions(+), 91 deletions(-) create mode 100644 mwmbl/tinysearchengine/rank.py diff --git a/analyse/performance.py b/analyse/performance.py index 4a675d4..0bac7f9 100644 --- a/analyse/performance.py +++ b/analyse/performance.py @@ -36,7 +36,7 @@ def query_test(): print(f"Got {len(titles_and_urls)} titles and URLs") tiny_index = TinyIndex(Document, TEST_INDEX_PATH, TEST_NUM_PAGES, TEST_PAGE_SIZE) - app = create_app.create(tiny_index) + app = create_app.create() client = TestClient(app) start = datetime.now() diff --git a/mwmbl/tinysearchengine/app.py b/mwmbl/tinysearchengine/app.py index d0cde72..dda7702 100644 --- a/mwmbl/tinysearchengine/app.py +++ b/mwmbl/tinysearchengine/app.py @@ -8,6 +8,7 @@ from mwmbl.tinysearchengine import create_app from mwmbl.tinysearchengine.completer import Completer from mwmbl.tinysearchengine.indexer import TinyIndex, NUM_PAGES, PAGE_SIZE, Document from mwmbl.tinysearchengine.config import parse_config_file +from mwmbl.tinysearchengine.rank import Ranker logging.basicConfig() @@ -35,8 +36,10 @@ def main(): terms = pd.read_csv(config.terms_path) completer = Completer(terms) + ranker = Ranker(tiny_index, completer) + # Initialize FastApi instance - app = create_app.create(tiny_index, completer) + app = create_app.create(ranker) # Initialize uvicorn server using global app instance and server config params uvicorn.run(app, **config.server_config.dict()) diff --git a/mwmbl/tinysearchengine/create_app.py b/mwmbl/tinysearchengine/create_app.py index 8e54df1..5d08e30 100644 --- a/mwmbl/tinysearchengine/create_app.py +++ b/mwmbl/tinysearchengine/create_app.py @@ -10,6 +10,7 @@ from starlette.middleware.cors import CORSMiddleware from mwmbl.tinysearchengine.completer import Completer from mwmbl.tinysearchengine.hn_top_domains_filtered import DOMAINS from mwmbl.tinysearchengine.indexer import TinyIndex, Document +from mwmbl.tinysearchengine.rank import Ranker logger = getLogger(__name__) @@ -17,7 +18,7 @@ logger = getLogger(__name__) SCORE_THRESHOLD = 0.25 -def create(tiny_index: TinyIndex, completer: Completer): +def create(ranker: Ranker): app = FastAPI() # Allow CORS requests from any site @@ -29,96 +30,10 @@ def create(tiny_index: TinyIndex, completer: Completer): @app.get("/search") def search(s: str): - results, terms = get_results(s) - - is_complete = s.endswith(' ') - pattern = get_query_regex(terms, is_complete) - formatted_results = [] - for result in results: - formatted_result = {} - for content_type, content in [('title', result.title), ('extract', result.extract)]: - matches = re.finditer(pattern, content, re.IGNORECASE) - all_spans = [0] + sum((list(m.span()) for m in matches), []) + [len(content)] - content_result = [] - for i in range(len(all_spans) - 1): - is_bold = i % 2 == 1 - start = all_spans[i] - end = all_spans[i + 1] - content_result.append({'value': content[start:end], 'is_bold': is_bold}) - formatted_result[content_type] = content_result - formatted_result['url'] = result.url - formatted_results.append(formatted_result) - - logger.info("Return results: %r", formatted_results) - return formatted_results - - def get_query_regex(terms, is_complete): - if not terms: - return '' - - if is_complete: - term_patterns = [rf'\b{term}\b' for term in terms] - else: - term_patterns = [rf'\b{term}\b' for term in terms[:-1]] + [rf'\b{terms[-1]}'] - pattern = '|'.join(term_patterns) - return pattern - - def score_result(terms, result: Document, is_complete: bool): - domain = urlparse(result.url).netloc - domain_score = DOMAINS.get(domain, 0.0) - - result_string = f"{result.title.strip()} {result.extract.strip()}" - query_regex = get_query_regex(terms, is_complete) - matches = list(re.finditer(query_regex, result_string, flags=re.IGNORECASE)) - match_strings = {x.group(0).lower() for x in matches} - match_length = sum(len(x) for x in match_strings) - - last_match_char = 1 - seen_matches = set() - for match in matches: - value = match.group(0).lower() - if value not in seen_matches: - last_match_char = match.span()[1] - seen_matches.add(value) - - total_possible_match_length = sum(len(x) for x in terms) - score = 0.1*domain_score + 0.9*(match_length + 1./last_match_char) / (total_possible_match_length + 1) - return score - - def order_results(terms: list[str], results: list[Document], is_complete: bool): - results_and_scores = [(score_result(terms, result, is_complete), result) for result in results] - ordered_results = sorted(results_and_scores, key=itemgetter(0), reverse=True) - filtered_results = [result for score, result in ordered_results if score > SCORE_THRESHOLD] - return filtered_results + return ranker.search(s) @app.get("/complete") def complete(q: str): - ordered_results, terms = get_results(q) - results = [item.title.replace("\n", "") + ' — ' + - item.url.replace("\n", "") for item in ordered_results] - if len(results) == 0: - return [] - return [q, results] + return ranker.complete(q) - def get_results(q): - terms = [x.lower() for x in q.replace('.', ' ').split()] - is_complete = q.endswith(' ') - if len(terms) > 0 and not is_complete: - retrieval_terms = terms[:-1] + completer.complete(terms[-1]) - else: - retrieval_terms = terms - - pages = [] - seen_items = set() - for term in retrieval_terms: - items = tiny_index.retrieve(term) - if items is not None: - for item in items: - if term in item.title.lower() or term in item.extract.lower(): - if item.title not in seen_items: - pages.append(item) - seen_items.add(item.title) - - ordered_results = order_results(terms, pages, is_complete) - return ordered_results, terms return app diff --git a/mwmbl/tinysearchengine/rank.py b/mwmbl/tinysearchengine/rank.py new file mode 100644 index 0000000..8ef4942 --- /dev/null +++ b/mwmbl/tinysearchengine/rank.py @@ -0,0 +1,119 @@ +import re +from logging import getLogger +from operator import itemgetter +from pathlib import Path +from urllib.parse import urlparse + +from fastapi import FastAPI +from starlette.middleware.cors import CORSMiddleware + +from mwmbl.tinysearchengine.completer import Completer +from mwmbl.tinysearchengine.hn_top_domains_filtered import DOMAINS +from mwmbl.tinysearchengine.indexer import TinyIndex, Document + +logger = getLogger(__name__) + + +SCORE_THRESHOLD = 0.25 + + +def _get_query_regex(terms, is_complete): + if not terms: + return '' + + if is_complete: + term_patterns = [rf'\b{term}\b' for term in terms] + else: + term_patterns = [rf'\b{term}\b' for term in terms[:-1]] + [rf'\b{terms[-1]}'] + pattern = '|'.join(term_patterns) + return pattern + + +def _score_result(terms, result: Document, is_complete: bool): + domain = urlparse(result.url).netloc + domain_score = DOMAINS.get(domain, 0.0) + + result_string = f"{result.title.strip()} {result.extract.strip()}" + query_regex = _get_query_regex(terms, is_complete) + matches = list(re.finditer(query_regex, result_string, flags=re.IGNORECASE)) + match_strings = {x.group(0).lower() for x in matches} + match_length = sum(len(x) for x in match_strings) + + last_match_char = 1 + seen_matches = set() + for match in matches: + value = match.group(0).lower() + if value not in seen_matches: + last_match_char = match.span()[1] + seen_matches.add(value) + + total_possible_match_length = sum(len(x) for x in terms) + score = 0.1*domain_score + 0.9*(match_length + 1./last_match_char) / (total_possible_match_length + 1) + return score + + +def _order_results(terms: list[str], results: list[Document], is_complete: bool): + results_and_scores = [(_score_result(terms, result, is_complete), result) for result in results] + ordered_results = sorted(results_and_scores, key=itemgetter(0), reverse=True) + filtered_results = [result for score, result in ordered_results if score > SCORE_THRESHOLD] + return filtered_results + + +class Ranker: + def __init__(self, tiny_index: TinyIndex, completer: Completer): + self.tiny_index = tiny_index + self.completer = completer + + def search(self, s: str): + results, terms = self._get_results(s) + + is_complete = s.endswith(' ') + pattern = _get_query_regex(terms, is_complete) + formatted_results = [] + for result in results: + formatted_result = {} + for content_type, content in [('title', result.title), ('extract', result.extract)]: + matches = re.finditer(pattern, content, re.IGNORECASE) + all_spans = [0] + sum((list(m.span()) for m in matches), []) + [len(content)] + content_result = [] + for i in range(len(all_spans) - 1): + is_bold = i % 2 == 1 + start = all_spans[i] + end = all_spans[i + 1] + content_result.append({'value': content[start:end], 'is_bold': is_bold}) + formatted_result[content_type] = content_result + formatted_result['url'] = result.url + formatted_results.append(formatted_result) + + logger.info("Return results: %r", formatted_results) + return formatted_results + + def complete(self, q: str): + ordered_results, terms = self._get_results(q) + results = [item.title.replace("\n", "") + ' — ' + + item.url.replace("\n", "") for item in ordered_results] + if len(results) == 0: + return [] + return [q, results] + + def _get_results(self, q): + terms = [x.lower() for x in q.replace('.', ' ').split()] + is_complete = q.endswith(' ') + if len(terms) > 0 and not is_complete: + retrieval_terms = terms[:-1] + self.completer.complete(terms[-1]) + else: + retrieval_terms = terms + + pages = [] + seen_items = set() + for term in retrieval_terms: + items = self.tiny_index.retrieve(term) + if items is not None: + for item in items: + if term in item.title.lower() or term in item.extract.lower(): + if item.title not in seen_items: + pages.append(item) + seen_items.add(item.title) + + ordered_results = _order_results(terms, pages, is_complete) + return ordered_results, terms