|
@@ -8,7 +8,7 @@ from collections import Counter
|
|
|
from dataclasses import dataclass, fields, asdict, astuple
|
|
|
from itertools import islice
|
|
|
from mmap import mmap, PROT_READ
|
|
|
-from typing import List, Iterator, TypeVar, Generic, Iterable
|
|
|
+from typing import List, Iterator, TypeVar, Generic, Iterable, Callable
|
|
|
from urllib.parse import unquote
|
|
|
|
|
|
import justext
|
|
@@ -57,15 +57,18 @@ class TokenizedDocument(Document):
|
|
|
tokens: List[str]
|
|
|
|
|
|
|
|
|
-class TinyIndexBase:
|
|
|
- def __init__(self, item_type: type, num_pages: int, page_size: int):
|
|
|
- self.item_type = item_type
|
|
|
+T = TypeVar('T')
|
|
|
+
|
|
|
+
|
|
|
+class TinyIndexBase(Generic[T]):
|
|
|
+ def __init__(self, item_factory: Callable[..., T], num_pages: int, page_size: int):
|
|
|
+ self.item_factory = item_factory
|
|
|
self.num_pages = num_pages
|
|
|
self.page_size = page_size
|
|
|
self.decompressor = ZstdDecompressor()
|
|
|
self.mmap = None
|
|
|
|
|
|
- def retrieve(self, key: str):
|
|
|
+ def retrieve(self, key: str) -> List[T]:
|
|
|
index = self._get_key_page_index(key)
|
|
|
page = self.get_page(index)
|
|
|
if page is None:
|
|
@@ -88,24 +91,24 @@ class TinyIndexBase:
|
|
|
return None
|
|
|
return json.loads(decompressed_data.decode('utf8'))
|
|
|
|
|
|
- def convert_items(self, items):
|
|
|
- converted = [self.item_type(*item) for item in items]
|
|
|
+ def convert_items(self, items) -> List[T]:
|
|
|
+ converted = [self.item_factory(*item) for item in items]
|
|
|
# print("Converted", items, converted)
|
|
|
return converted
|
|
|
|
|
|
|
|
|
-class TinyIndex(TinyIndexBase):
|
|
|
- def __init__(self, index_path, num_pages, page_size):
|
|
|
- super().__init__(Document, num_pages, page_size)
|
|
|
+class TinyIndex(TinyIndexBase[T]):
|
|
|
+ def __init__(self, item_factory: Callable[..., T], index_path, num_pages, page_size):
|
|
|
+ super().__init__(item_factory, num_pages, page_size)
|
|
|
# print("REtrieve path", index_path)
|
|
|
self.index_path = index_path
|
|
|
self.index_file = open(self.index_path, 'rb')
|
|
|
self.mmap = mmap(self.index_file.fileno(), 0, prot=PROT_READ)
|
|
|
|
|
|
|
|
|
-class TinyIndexer(TinyIndexBase):
|
|
|
- def __init__(self, item_type: type, index_path: str, num_pages: int, page_size: int):
|
|
|
- super().__init__(item_type, num_pages, page_size)
|
|
|
+class TinyIndexer(TinyIndexBase[T]):
|
|
|
+ def __init__(self, item_factory: Callable[..., T], index_path: str, num_pages: int, page_size: int):
|
|
|
+ super().__init__(item_factory, num_pages, page_size)
|
|
|
self.index_path = index_path
|
|
|
self.compressor = ZstdCompressor()
|
|
|
self.decompressor = ZstdDecompressor()
|
|
@@ -127,10 +130,10 @@ class TinyIndexer(TinyIndexBase):
|
|
|
# for token in document.tokens:
|
|
|
# self._index_document(document, token)
|
|
|
|
|
|
- def index(self, key: str, value):
|
|
|
+ def index(self, key: str, value: T):
|
|
|
# print("Index", value)
|
|
|
- assert type(value) == self.item_type, f"Can only index the specified type" \
|
|
|
- f" ({self.item_type.__name__})"
|
|
|
+ assert type(value) == self.item_factory, f"Can only index the specified type" \
|
|
|
+ f" ({self.item_factory.__name__})"
|
|
|
page_index = self._get_key_page_index(key)
|
|
|
current_page = self.get_page(page_index)
|
|
|
if current_page is None:
|