Преглед на файлове

Merge pull request #46 from mwmbl/refactor-for-evaluation

Refactor to enable easier evaluation
Daoud Clarke преди 3 години
родител
ревизия
82c46b50bc
променени са 4 файла, в които са добавени 128 реда и са изтрити 91 реда
  1. 1 1
      analyse/performance.py
  2. 4 1
      mwmbl/tinysearchengine/app.py
  3. 4 89
      mwmbl/tinysearchengine/create_app.py
  4. 119 0
      mwmbl/tinysearchengine/rank.py

+ 1 - 1
analyse/performance.py

@@ -36,7 +36,7 @@ def query_test():
     print(f"Got {len(titles_and_urls)} titles and URLs")
     print(f"Got {len(titles_and_urls)} titles and URLs")
     tiny_index = TinyIndex(Document, TEST_INDEX_PATH, TEST_NUM_PAGES, TEST_PAGE_SIZE)
     tiny_index = TinyIndex(Document, TEST_INDEX_PATH, TEST_NUM_PAGES, TEST_PAGE_SIZE)
 
 
-    app = create_app.create(tiny_index)
+    app = create_app.create()
     client = TestClient(app)
     client = TestClient(app)
 
 
     start = datetime.now()
     start = datetime.now()

+ 4 - 1
mwmbl/tinysearchengine/app.py

@@ -8,6 +8,7 @@ from mwmbl.tinysearchengine import create_app
 from mwmbl.tinysearchengine.completer import Completer
 from mwmbl.tinysearchengine.completer import Completer
 from mwmbl.tinysearchengine.indexer import TinyIndex, NUM_PAGES, PAGE_SIZE, Document
 from mwmbl.tinysearchengine.indexer import TinyIndex, NUM_PAGES, PAGE_SIZE, Document
 from mwmbl.tinysearchengine.config import parse_config_file
 from mwmbl.tinysearchengine.config import parse_config_file
+from mwmbl.tinysearchengine.rank import Ranker
 
 
 logging.basicConfig()
 logging.basicConfig()
 
 
@@ -35,8 +36,10 @@ def main():
     terms = pd.read_csv(config.terms_path)
     terms = pd.read_csv(config.terms_path)
     completer = Completer(terms)
     completer = Completer(terms)
 
 
+    ranker = Ranker(tiny_index, completer)
+
     # Initialize FastApi instance
     # Initialize FastApi instance
-    app = create_app.create(tiny_index, completer)
+    app = create_app.create(ranker)
 
 
     # Initialize uvicorn server using global app instance and server config params
     # Initialize uvicorn server using global app instance and server config params
     uvicorn.run(app, **config.server_config.dict())
     uvicorn.run(app, **config.server_config.dict())

+ 4 - 89
mwmbl/tinysearchengine/create_app.py

@@ -10,6 +10,7 @@ from starlette.middleware.cors import CORSMiddleware
 from mwmbl.tinysearchengine.completer import Completer
 from mwmbl.tinysearchengine.completer import Completer
 from mwmbl.tinysearchengine.hn_top_domains_filtered import DOMAINS
 from mwmbl.tinysearchengine.hn_top_domains_filtered import DOMAINS
 from mwmbl.tinysearchengine.indexer import TinyIndex, Document
 from mwmbl.tinysearchengine.indexer import TinyIndex, Document
+from mwmbl.tinysearchengine.rank import Ranker
 
 
 logger = getLogger(__name__)
 logger = getLogger(__name__)
 
 
@@ -17,7 +18,7 @@ logger = getLogger(__name__)
 SCORE_THRESHOLD = 0.25
 SCORE_THRESHOLD = 0.25
 
 
 
 
-def create(tiny_index: TinyIndex, completer: Completer):
+def create(ranker: Ranker):
     app = FastAPI()
     app = FastAPI()
     
     
     # Allow CORS requests from any site
     # Allow CORS requests from any site
@@ -29,96 +30,10 @@ def create(tiny_index: TinyIndex, completer: Completer):
 
 
     @app.get("/search")
     @app.get("/search")
     def search(s: str):
     def search(s: str):
-        results, terms = get_results(s)
-
-        is_complete = s.endswith(' ')
-        pattern = get_query_regex(terms, is_complete)
-        formatted_results = []
-        for result in results:
-            formatted_result = {}
-            for content_type, content in [('title', result.title), ('extract', result.extract)]:
-                matches = re.finditer(pattern, content, re.IGNORECASE)
-                all_spans = [0] + sum((list(m.span()) for m in matches), []) + [len(content)]
-                content_result = []
-                for i in range(len(all_spans) - 1):
-                    is_bold = i % 2 == 1
-                    start = all_spans[i]
-                    end = all_spans[i + 1]
-                    content_result.append({'value': content[start:end], 'is_bold': is_bold})
-                formatted_result[content_type] = content_result
-            formatted_result['url'] = result.url
-            formatted_results.append(formatted_result)
-
-        logger.info("Return results: %r", formatted_results)
-        return formatted_results
-
-    def get_query_regex(terms, is_complete):
-        if not terms:
-            return ''
-
-        if is_complete:
-            term_patterns = [rf'\b{term}\b' for term in terms]
-        else:
-            term_patterns = [rf'\b{term}\b' for term in terms[:-1]] + [rf'\b{terms[-1]}']
-        pattern = '|'.join(term_patterns)
-        return pattern
-
-    def score_result(terms, result: Document, is_complete: bool):
-        domain = urlparse(result.url).netloc
-        domain_score = DOMAINS.get(domain, 0.0)
-
-        result_string = f"{result.title.strip()} {result.extract.strip()}"
-        query_regex = get_query_regex(terms, is_complete)
-        matches = list(re.finditer(query_regex, result_string, flags=re.IGNORECASE))
-        match_strings = {x.group(0).lower() for x in matches}
-        match_length = sum(len(x) for x in match_strings)
-
-        last_match_char = 1
-        seen_matches = set()
-        for match in matches:
-            value = match.group(0).lower()
-            if value not in seen_matches:
-                last_match_char = match.span()[1]
-                seen_matches.add(value)
-
-        total_possible_match_length = sum(len(x) for x in terms)
-        score = 0.1*domain_score + 0.9*(match_length + 1./last_match_char) / (total_possible_match_length + 1)
-        return score
-
-    def order_results(terms: list[str], results: list[Document], is_complete: bool):
-        results_and_scores = [(score_result(terms, result, is_complete), result) for result in results]
-        ordered_results = sorted(results_and_scores, key=itemgetter(0), reverse=True)
-        filtered_results = [result for score, result in ordered_results if score > SCORE_THRESHOLD]
-        return filtered_results
+        return ranker.search(s)
 
 
     @app.get("/complete")
     @app.get("/complete")
     def complete(q: str):
     def complete(q: str):
-        ordered_results, terms = get_results(q)
-        results = [item.title.replace("\n", "") + ' — ' +
-                   item.url.replace("\n", "") for item in ordered_results]
-        if len(results) == 0:
-            return []
-        return [q, results]
-
-    def get_results(q):
-        terms = [x.lower() for x in q.replace('.', ' ').split()]
-        is_complete = q.endswith(' ')
-        if len(terms) > 0 and not is_complete:
-            retrieval_terms = terms[:-1] + completer.complete(terms[-1])
-        else:
-            retrieval_terms = terms
-
-        pages = []
-        seen_items = set()
-        for term in retrieval_terms:
-            items = tiny_index.retrieve(term)
-            if items is not None:
-                for item in items:
-                    if term in item.title.lower() or term in item.extract.lower():
-                        if item.title not in seen_items:
-                            pages.append(item)
-                            seen_items.add(item.title)
+        return ranker.complete(q)
 
 
-        ordered_results = order_results(terms, pages, is_complete)
-        return ordered_results, terms
     return app
     return app

+ 119 - 0
mwmbl/tinysearchengine/rank.py

@@ -0,0 +1,119 @@
+import re
+from logging import getLogger
+from operator import itemgetter
+from pathlib import Path
+from urllib.parse import urlparse
+
+from fastapi import FastAPI
+from starlette.middleware.cors import CORSMiddleware
+
+from mwmbl.tinysearchengine.completer import Completer
+from mwmbl.tinysearchengine.hn_top_domains_filtered import DOMAINS
+from mwmbl.tinysearchengine.indexer import TinyIndex, Document
+
+logger = getLogger(__name__)
+
+
+SCORE_THRESHOLD = 0.25
+
+
+def _get_query_regex(terms, is_complete):
+    if not terms:
+        return ''
+
+    if is_complete:
+        term_patterns = [rf'\b{term}\b' for term in terms]
+    else:
+        term_patterns = [rf'\b{term}\b' for term in terms[:-1]] + [rf'\b{terms[-1]}']
+    pattern = '|'.join(term_patterns)
+    return pattern
+
+
+def _score_result(terms, result: Document, is_complete: bool):
+    domain = urlparse(result.url).netloc
+    domain_score = DOMAINS.get(domain, 0.0)
+
+    result_string = f"{result.title.strip()} {result.extract.strip()}"
+    query_regex = _get_query_regex(terms, is_complete)
+    matches = list(re.finditer(query_regex, result_string, flags=re.IGNORECASE))
+    match_strings = {x.group(0).lower() for x in matches}
+    match_length = sum(len(x) for x in match_strings)
+
+    last_match_char = 1
+    seen_matches = set()
+    for match in matches:
+        value = match.group(0).lower()
+        if value not in seen_matches:
+            last_match_char = match.span()[1]
+            seen_matches.add(value)
+
+    total_possible_match_length = sum(len(x) for x in terms)
+    score = 0.1*domain_score + 0.9*(match_length + 1./last_match_char) / (total_possible_match_length + 1)
+    return score
+
+
+def _order_results(terms: list[str], results: list[Document], is_complete: bool):
+    results_and_scores = [(_score_result(terms, result, is_complete), result) for result in results]
+    ordered_results = sorted(results_and_scores, key=itemgetter(0), reverse=True)
+    filtered_results = [result for score, result in ordered_results if score > SCORE_THRESHOLD]
+    return filtered_results
+
+
+class Ranker:
+    def __init__(self, tiny_index: TinyIndex, completer: Completer):
+        self.tiny_index = tiny_index
+        self.completer = completer
+
+    def search(self, s: str):
+        results, terms = self._get_results(s)
+
+        is_complete = s.endswith(' ')
+        pattern = _get_query_regex(terms, is_complete)
+        formatted_results = []
+        for result in results:
+            formatted_result = {}
+            for content_type, content in [('title', result.title), ('extract', result.extract)]:
+                matches = re.finditer(pattern, content, re.IGNORECASE)
+                all_spans = [0] + sum((list(m.span()) for m in matches), []) + [len(content)]
+                content_result = []
+                for i in range(len(all_spans) - 1):
+                    is_bold = i % 2 == 1
+                    start = all_spans[i]
+                    end = all_spans[i + 1]
+                    content_result.append({'value': content[start:end], 'is_bold': is_bold})
+                formatted_result[content_type] = content_result
+            formatted_result['url'] = result.url
+            formatted_results.append(formatted_result)
+
+        logger.info("Return results: %r", formatted_results)
+        return formatted_results
+
+    def complete(self, q: str):
+        ordered_results, terms = self._get_results(q)
+        results = [item.title.replace("\n", "") + ' — ' +
+                   item.url.replace("\n", "") for item in ordered_results]
+        if len(results) == 0:
+            return []
+        return [q, results]
+
+    def _get_results(self, q):
+        terms = [x.lower() for x in q.replace('.', ' ').split()]
+        is_complete = q.endswith(' ')
+        if len(terms) > 0 and not is_complete:
+            retrieval_terms = terms[:-1] + self.completer.complete(terms[-1])
+        else:
+            retrieval_terms = terms
+
+        pages = []
+        seen_items = set()
+        for term in retrieval_terms:
+            items = self.tiny_index.retrieve(term)
+            if items is not None:
+                for item in items:
+                    if term in item.title.lower() or term in item.extract.lower():
+                        if item.title not in seen_items:
+                            pages.append(item)
+                            seen_items.add(item.title)
+
+        ordered_results = _order_results(terms, pages, is_complete)
+        return ordered_results, terms