Improve typing of indexer

This commit is contained in:
Daoud Clarke 2021-06-13 21:41:19 +01:00
parent 0578f41a73
commit 896f782379
3 changed files with 22 additions and 19 deletions

4
app.py
View file

@ -1,7 +1,7 @@
import create_app
from index import TinyIndex, PAGE_SIZE, NUM_PAGES
from index import TinyIndex, PAGE_SIZE, NUM_PAGES, Document
from paths import INDEX_PATH
tiny_index = TinyIndex(INDEX_PATH, NUM_PAGES, PAGE_SIZE)
tiny_index = TinyIndex(Document, INDEX_PATH, NUM_PAGES, PAGE_SIZE)
app = create_app.create(tiny_index)

View file

@ -8,7 +8,7 @@ from collections import Counter
from dataclasses import dataclass, fields, asdict, astuple
from itertools import islice
from mmap import mmap, PROT_READ
from typing import List, Iterator, TypeVar, Generic, Iterable
from typing import List, Iterator, TypeVar, Generic, Iterable, Callable
from urllib.parse import unquote
import justext
@ -57,15 +57,18 @@ class TokenizedDocument(Document):
tokens: List[str]
class TinyIndexBase:
def __init__(self, item_type: type, num_pages: int, page_size: int):
self.item_type = item_type
T = TypeVar('T')
class TinyIndexBase(Generic[T]):
def __init__(self, item_factory: Callable[..., T], num_pages: int, page_size: int):
self.item_factory = item_factory
self.num_pages = num_pages
self.page_size = page_size
self.decompressor = ZstdDecompressor()
self.mmap = None
def retrieve(self, key: str):
def retrieve(self, key: str) -> List[T]:
index = self._get_key_page_index(key)
page = self.get_page(index)
if page is None:
@ -88,24 +91,24 @@ class TinyIndexBase:
return None
return json.loads(decompressed_data.decode('utf8'))
def convert_items(self, items):
converted = [self.item_type(*item) for item in items]
def convert_items(self, items) -> List[T]:
converted = [self.item_factory(*item) for item in items]
# print("Converted", items, converted)
return converted
class TinyIndex(TinyIndexBase):
def __init__(self, index_path, num_pages, page_size):
super().__init__(Document, num_pages, page_size)
class TinyIndex(TinyIndexBase[T]):
def __init__(self, item_factory: Callable[..., T], index_path, num_pages, page_size):
super().__init__(item_factory, num_pages, page_size)
# print("REtrieve path", index_path)
self.index_path = index_path
self.index_file = open(self.index_path, 'rb')
self.mmap = mmap(self.index_file.fileno(), 0, prot=PROT_READ)
class TinyIndexer(TinyIndexBase):
def __init__(self, item_type: type, index_path: str, num_pages: int, page_size: int):
super().__init__(item_type, num_pages, page_size)
class TinyIndexer(TinyIndexBase[T]):
def __init__(self, item_factory: Callable[..., T], index_path: str, num_pages: int, page_size: int):
super().__init__(item_factory, num_pages, page_size)
self.index_path = index_path
self.compressor = ZstdCompressor()
self.decompressor = ZstdDecompressor()
@ -127,10 +130,10 @@ class TinyIndexer(TinyIndexBase):
# for token in document.tokens:
# self._index_document(document, token)
def index(self, key: str, value):
def index(self, key: str, value: T):
# print("Index", value)
assert type(value) == self.item_type, f"Can only index the specified type" \
f" ({self.item_type.__name__})"
assert type(value) == self.item_factory, f"Can only index the specified type" \
f" ({self.item_factory.__name__})"
page_index = self._get_key_page_index(key)
current_page = self.get_page(page_index)
if current_page is None:

View file

@ -33,7 +33,7 @@ def get_test_pages():
def query_test():
titles_and_urls = get_test_pages()
print(f"Got {len(titles_and_urls)} titles and URLs")
tiny_index = TinyIndex(TEST_INDEX_PATH, TEST_NUM_PAGES, TEST_PAGE_SIZE)
tiny_index = TinyIndex(Document, TEST_INDEX_PATH, TEST_NUM_PAGES, TEST_PAGE_SIZE)
app = create_app.create(tiny_index)
client = TestClient(app)