Merge pull request #46 from mwmbl/refactor-for-evaluation
Refactor to enable easier evaluation
This commit is contained in:
commit
82c46b50bc
4 changed files with 128 additions and 91 deletions
|
@ -36,7 +36,7 @@ def query_test():
|
|||
print(f"Got {len(titles_and_urls)} titles and URLs")
|
||||
tiny_index = TinyIndex(Document, TEST_INDEX_PATH, TEST_NUM_PAGES, TEST_PAGE_SIZE)
|
||||
|
||||
app = create_app.create(tiny_index)
|
||||
app = create_app.create()
|
||||
client = TestClient(app)
|
||||
|
||||
start = datetime.now()
|
||||
|
|
|
@ -8,6 +8,7 @@ from mwmbl.tinysearchengine import create_app
|
|||
from mwmbl.tinysearchengine.completer import Completer
|
||||
from mwmbl.tinysearchengine.indexer import TinyIndex, NUM_PAGES, PAGE_SIZE, Document
|
||||
from mwmbl.tinysearchengine.config import parse_config_file
|
||||
from mwmbl.tinysearchengine.rank import Ranker
|
||||
|
||||
logging.basicConfig()
|
||||
|
||||
|
@ -35,8 +36,10 @@ def main():
|
|||
terms = pd.read_csv(config.terms_path)
|
||||
completer = Completer(terms)
|
||||
|
||||
ranker = Ranker(tiny_index, completer)
|
||||
|
||||
# Initialize FastApi instance
|
||||
app = create_app.create(tiny_index, completer)
|
||||
app = create_app.create(ranker)
|
||||
|
||||
# Initialize uvicorn server using global app instance and server config params
|
||||
uvicorn.run(app, **config.server_config.dict())
|
||||
|
|
|
@ -10,6 +10,7 @@ from starlette.middleware.cors import CORSMiddleware
|
|||
from mwmbl.tinysearchengine.completer import Completer
|
||||
from mwmbl.tinysearchengine.hn_top_domains_filtered import DOMAINS
|
||||
from mwmbl.tinysearchengine.indexer import TinyIndex, Document
|
||||
from mwmbl.tinysearchengine.rank import Ranker
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -17,7 +18,7 @@ logger = getLogger(__name__)
|
|||
SCORE_THRESHOLD = 0.25
|
||||
|
||||
|
||||
def create(tiny_index: TinyIndex, completer: Completer):
|
||||
def create(ranker: Ranker):
|
||||
app = FastAPI()
|
||||
|
||||
# Allow CORS requests from any site
|
||||
|
@ -29,96 +30,10 @@ def create(tiny_index: TinyIndex, completer: Completer):
|
|||
|
||||
@app.get("/search")
|
||||
def search(s: str):
|
||||
results, terms = get_results(s)
|
||||
|
||||
is_complete = s.endswith(' ')
|
||||
pattern = get_query_regex(terms, is_complete)
|
||||
formatted_results = []
|
||||
for result in results:
|
||||
formatted_result = {}
|
||||
for content_type, content in [('title', result.title), ('extract', result.extract)]:
|
||||
matches = re.finditer(pattern, content, re.IGNORECASE)
|
||||
all_spans = [0] + sum((list(m.span()) for m in matches), []) + [len(content)]
|
||||
content_result = []
|
||||
for i in range(len(all_spans) - 1):
|
||||
is_bold = i % 2 == 1
|
||||
start = all_spans[i]
|
||||
end = all_spans[i + 1]
|
||||
content_result.append({'value': content[start:end], 'is_bold': is_bold})
|
||||
formatted_result[content_type] = content_result
|
||||
formatted_result['url'] = result.url
|
||||
formatted_results.append(formatted_result)
|
||||
|
||||
logger.info("Return results: %r", formatted_results)
|
||||
return formatted_results
|
||||
|
||||
def get_query_regex(terms, is_complete):
|
||||
if not terms:
|
||||
return ''
|
||||
|
||||
if is_complete:
|
||||
term_patterns = [rf'\b{term}\b' for term in terms]
|
||||
else:
|
||||
term_patterns = [rf'\b{term}\b' for term in terms[:-1]] + [rf'\b{terms[-1]}']
|
||||
pattern = '|'.join(term_patterns)
|
||||
return pattern
|
||||
|
||||
def score_result(terms, result: Document, is_complete: bool):
|
||||
domain = urlparse(result.url).netloc
|
||||
domain_score = DOMAINS.get(domain, 0.0)
|
||||
|
||||
result_string = f"{result.title.strip()} {result.extract.strip()}"
|
||||
query_regex = get_query_regex(terms, is_complete)
|
||||
matches = list(re.finditer(query_regex, result_string, flags=re.IGNORECASE))
|
||||
match_strings = {x.group(0).lower() for x in matches}
|
||||
match_length = sum(len(x) for x in match_strings)
|
||||
|
||||
last_match_char = 1
|
||||
seen_matches = set()
|
||||
for match in matches:
|
||||
value = match.group(0).lower()
|
||||
if value not in seen_matches:
|
||||
last_match_char = match.span()[1]
|
||||
seen_matches.add(value)
|
||||
|
||||
total_possible_match_length = sum(len(x) for x in terms)
|
||||
score = 0.1*domain_score + 0.9*(match_length + 1./last_match_char) / (total_possible_match_length + 1)
|
||||
return score
|
||||
|
||||
def order_results(terms: list[str], results: list[Document], is_complete: bool):
|
||||
results_and_scores = [(score_result(terms, result, is_complete), result) for result in results]
|
||||
ordered_results = sorted(results_and_scores, key=itemgetter(0), reverse=True)
|
||||
filtered_results = [result for score, result in ordered_results if score > SCORE_THRESHOLD]
|
||||
return filtered_results
|
||||
return ranker.search(s)
|
||||
|
||||
@app.get("/complete")
|
||||
def complete(q: str):
|
||||
ordered_results, terms = get_results(q)
|
||||
results = [item.title.replace("\n", "") + ' — ' +
|
||||
item.url.replace("\n", "") for item in ordered_results]
|
||||
if len(results) == 0:
|
||||
return []
|
||||
return [q, results]
|
||||
return ranker.complete(q)
|
||||
|
||||
def get_results(q):
|
||||
terms = [x.lower() for x in q.replace('.', ' ').split()]
|
||||
is_complete = q.endswith(' ')
|
||||
if len(terms) > 0 and not is_complete:
|
||||
retrieval_terms = terms[:-1] + completer.complete(terms[-1])
|
||||
else:
|
||||
retrieval_terms = terms
|
||||
|
||||
pages = []
|
||||
seen_items = set()
|
||||
for term in retrieval_terms:
|
||||
items = tiny_index.retrieve(term)
|
||||
if items is not None:
|
||||
for item in items:
|
||||
if term in item.title.lower() or term in item.extract.lower():
|
||||
if item.title not in seen_items:
|
||||
pages.append(item)
|
||||
seen_items.add(item.title)
|
||||
|
||||
ordered_results = order_results(terms, pages, is_complete)
|
||||
return ordered_results, terms
|
||||
return app
|
||||
|
|
119
mwmbl/tinysearchengine/rank.py
Normal file
119
mwmbl/tinysearchengine/rank.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
import re
|
||||
from logging import getLogger
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import FastAPI
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from mwmbl.tinysearchengine.completer import Completer
|
||||
from mwmbl.tinysearchengine.hn_top_domains_filtered import DOMAINS
|
||||
from mwmbl.tinysearchengine.indexer import TinyIndex, Document
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
SCORE_THRESHOLD = 0.25
|
||||
|
||||
|
||||
def _get_query_regex(terms, is_complete):
|
||||
if not terms:
|
||||
return ''
|
||||
|
||||
if is_complete:
|
||||
term_patterns = [rf'\b{term}\b' for term in terms]
|
||||
else:
|
||||
term_patterns = [rf'\b{term}\b' for term in terms[:-1]] + [rf'\b{terms[-1]}']
|
||||
pattern = '|'.join(term_patterns)
|
||||
return pattern
|
||||
|
||||
|
||||
def _score_result(terms, result: Document, is_complete: bool):
|
||||
domain = urlparse(result.url).netloc
|
||||
domain_score = DOMAINS.get(domain, 0.0)
|
||||
|
||||
result_string = f"{result.title.strip()} {result.extract.strip()}"
|
||||
query_regex = _get_query_regex(terms, is_complete)
|
||||
matches = list(re.finditer(query_regex, result_string, flags=re.IGNORECASE))
|
||||
match_strings = {x.group(0).lower() for x in matches}
|
||||
match_length = sum(len(x) for x in match_strings)
|
||||
|
||||
last_match_char = 1
|
||||
seen_matches = set()
|
||||
for match in matches:
|
||||
value = match.group(0).lower()
|
||||
if value not in seen_matches:
|
||||
last_match_char = match.span()[1]
|
||||
seen_matches.add(value)
|
||||
|
||||
total_possible_match_length = sum(len(x) for x in terms)
|
||||
score = 0.1*domain_score + 0.9*(match_length + 1./last_match_char) / (total_possible_match_length + 1)
|
||||
return score
|
||||
|
||||
|
||||
def _order_results(terms: list[str], results: list[Document], is_complete: bool):
|
||||
results_and_scores = [(_score_result(terms, result, is_complete), result) for result in results]
|
||||
ordered_results = sorted(results_and_scores, key=itemgetter(0), reverse=True)
|
||||
filtered_results = [result for score, result in ordered_results if score > SCORE_THRESHOLD]
|
||||
return filtered_results
|
||||
|
||||
|
||||
class Ranker:
|
||||
def __init__(self, tiny_index: TinyIndex, completer: Completer):
|
||||
self.tiny_index = tiny_index
|
||||
self.completer = completer
|
||||
|
||||
def search(self, s: str):
|
||||
results, terms = self._get_results(s)
|
||||
|
||||
is_complete = s.endswith(' ')
|
||||
pattern = _get_query_regex(terms, is_complete)
|
||||
formatted_results = []
|
||||
for result in results:
|
||||
formatted_result = {}
|
||||
for content_type, content in [('title', result.title), ('extract', result.extract)]:
|
||||
matches = re.finditer(pattern, content, re.IGNORECASE)
|
||||
all_spans = [0] + sum((list(m.span()) for m in matches), []) + [len(content)]
|
||||
content_result = []
|
||||
for i in range(len(all_spans) - 1):
|
||||
is_bold = i % 2 == 1
|
||||
start = all_spans[i]
|
||||
end = all_spans[i + 1]
|
||||
content_result.append({'value': content[start:end], 'is_bold': is_bold})
|
||||
formatted_result[content_type] = content_result
|
||||
formatted_result['url'] = result.url
|
||||
formatted_results.append(formatted_result)
|
||||
|
||||
logger.info("Return results: %r", formatted_results)
|
||||
return formatted_results
|
||||
|
||||
def complete(self, q: str):
|
||||
ordered_results, terms = self._get_results(q)
|
||||
results = [item.title.replace("\n", "") + ' — ' +
|
||||
item.url.replace("\n", "") for item in ordered_results]
|
||||
if len(results) == 0:
|
||||
return []
|
||||
return [q, results]
|
||||
|
||||
def _get_results(self, q):
|
||||
terms = [x.lower() for x in q.replace('.', ' ').split()]
|
||||
is_complete = q.endswith(' ')
|
||||
if len(terms) > 0 and not is_complete:
|
||||
retrieval_terms = terms[:-1] + self.completer.complete(terms[-1])
|
||||
else:
|
||||
retrieval_terms = terms
|
||||
|
||||
pages = []
|
||||
seen_items = set()
|
||||
for term in retrieval_terms:
|
||||
items = self.tiny_index.retrieve(term)
|
||||
if items is not None:
|
||||
for item in items:
|
||||
if term in item.title.lower() or term in item.extract.lower():
|
||||
if item.title not in seen_items:
|
||||
pages.append(item)
|
||||
seen_items.add(item.title)
|
||||
|
||||
ordered_results = _order_results(terms, pages, is_complete)
|
||||
return ordered_results, terms
|
Loading…
Reference in a new issue