Implement retrieval
This commit is contained in:
parent
acc2a9194e
commit
fdb5cbbf3c
3 changed files with 28 additions and 28 deletions
29
app.py
29
app.py
|
@ -7,9 +7,11 @@ from fastapi import FastAPI
|
|||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.responses import FileResponse, RedirectResponse
|
||||
|
||||
from index import TinyIndex, PAGE_SIZE, NUM_PAGES
|
||||
from paths import INDEX_PATH
|
||||
|
||||
app = FastAPI()
|
||||
tiny_index = TinyIndex(INDEX_PATH, NUM_PAGES, PAGE_SIZE)
|
||||
|
||||
|
||||
@app.get("/search")
|
||||
|
@ -43,28 +45,21 @@ def complete_term(term):
|
|||
def complete(q: str):
|
||||
terms = [x.lower() for x in q.split()]
|
||||
|
||||
completed = complete_term(terms[-1])
|
||||
terms = terms[:-1] + [completed]
|
||||
# completed = complete_term(terms[-1])
|
||||
# terms = terms[:-1] + [completed]
|
||||
|
||||
con = sqlite3.connect(INDEX_PATH)
|
||||
in_part = ','.join('?'*len(terms))
|
||||
query = f"""
|
||||
SELECT title, url
|
||||
FROM terms INNER JOIN pages
|
||||
ON terms.page_id = pages.id
|
||||
WHERE term IN ({in_part})
|
||||
GROUP BY title, url
|
||||
ORDER BY count(*) DESC, length(title)
|
||||
LIMIT 5
|
||||
"""
|
||||
|
||||
data = con.execute(query, terms).fetchall()
|
||||
pages = []
|
||||
for term in terms:
|
||||
page = tiny_index.retrieve(term)
|
||||
if page is not None:
|
||||
pages += page
|
||||
|
||||
results = [title.replace("\n", "") + ' — ' +
|
||||
url.replace("\n", "") for title, url in data]
|
||||
url.replace("\n", "") for title, url in pages]
|
||||
if len(results) == 0:
|
||||
# print("No results")
|
||||
return []
|
||||
# print("Results", results_list)
|
||||
# print("Results", results)
|
||||
return [q, results]
|
||||
|
||||
|
||||
|
|
19
index.py
19
index.py
|
@ -8,7 +8,7 @@ import sqlite3
|
|||
from dataclasses import dataclass
|
||||
from glob import glob
|
||||
from itertools import chain, count, islice
|
||||
from mmap import mmap
|
||||
from mmap import mmap, PROT_READ
|
||||
from typing import List, Iterator
|
||||
from urllib.parse import unquote
|
||||
|
||||
|
@ -68,6 +68,14 @@ class TinyIndexBase:
|
|||
self.decompressor = ZstdDecompressor()
|
||||
self.mmap = None
|
||||
|
||||
def retrieve(self, token):
|
||||
index = self._get_token_page_index(token)
|
||||
return self._get_page(index)
|
||||
|
||||
def _get_token_page_index(self, token):
|
||||
token_hash = mmh3.hash(token, signed=False)
|
||||
return token_hash % self.num_pages
|
||||
|
||||
def _get_page(self, i):
|
||||
"""
|
||||
Get the page at index i, decompress and deserialise it using JSON
|
||||
|
@ -84,8 +92,8 @@ class TinyIndex(TinyIndexBase):
|
|||
def __init__(self, index_path, num_pages, page_size):
|
||||
super().__init__(num_pages, page_size)
|
||||
self.index_path = index_path
|
||||
self.index_file = None
|
||||
self.mmap = None
|
||||
self.index_file = open(self.index_path, 'rb')
|
||||
self.mmap = mmap(self.index_file.fileno(), 0, prot=PROT_READ)
|
||||
|
||||
|
||||
class TinyIndexer(TinyIndexBase):
|
||||
|
@ -95,6 +103,7 @@ class TinyIndexer(TinyIndexBase):
|
|||
self.compressor = ZstdCompressor()
|
||||
self.decompressor = ZstdDecompressor()
|
||||
self.index_file = None
|
||||
self.mmap = None
|
||||
|
||||
def __enter__(self):
|
||||
self.create_if_not_exists()
|
||||
|
@ -122,10 +131,6 @@ class TinyIndexer(TinyIndexBase):
|
|||
except ValueError:
|
||||
pass
|
||||
|
||||
def _get_token_page_index(self, token):
|
||||
token_hash = mmh3.hash(token, signed=False)
|
||||
return token_hash % self.num_pages
|
||||
|
||||
def _write_page(self, data, i):
|
||||
"""
|
||||
Serialise the data using JSON, compress it and store it at index i.
|
||||
|
|
|
@ -15,7 +15,7 @@ from paths import TEST_INDEX_PATH
|
|||
from wiki import get_wiki_titles_and_urls
|
||||
|
||||
|
||||
NUM_DOCUMENTS = 10000
|
||||
NUM_DOCUMENTS = 500
|
||||
|
||||
|
||||
def query_test():
|
||||
|
@ -29,9 +29,9 @@ def query_test():
|
|||
result = client.get('/complete', params={'q': title})
|
||||
assert result.status_code == 200
|
||||
data = result.content.decode('utf8')
|
||||
# data = json.dumps(complete(title))
|
||||
# print("Data", data, url, sep='\n')
|
||||
|
||||
if url in data:
|
||||
if title in data:
|
||||
hits += 1
|
||||
|
||||
end = datetime.now()
|
||||
|
@ -61,7 +61,7 @@ def performance_test():
|
|||
print("Index size", index_size)
|
||||
# print("Num tokens", indexer.get_num_tokens())
|
||||
|
||||
# query_test()
|
||||
query_test()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Add table
Reference in a new issue