Merge pull request #98 from mwmbl/rishabh-fix-trim-data
Fix trimming page size logic while adding to a page
This commit is contained in:
commit
e9dfd40ecb
2 changed files with 134 additions and 13 deletions
|
@ -65,19 +65,39 @@ class TinyIndexMetadata:
|
|||
values = json.loads(data[constant_length:].decode('utf8'))
|
||||
return TinyIndexMetadata(**values)
|
||||
|
||||
# Find the optimal amount of data that fits onto a page
|
||||
# We do this by leveraging binary search to quickly find the index where:
|
||||
# - index+1 cannot fit onto a page
|
||||
# - <=index can fit on a page
|
||||
def _binary_search_fitting_size(compressor: ZstdCompressor, page_size: int, items:list[T], lo:int, hi:int):
|
||||
# Base case: our binary search has gone too far
|
||||
if lo > hi:
|
||||
return -1, None
|
||||
# Check the midpoint to see if it will fit onto a page
|
||||
mid = (lo+hi)//2
|
||||
compressed_data = compressor.compress(json.dumps(items[:mid]).encode('utf8'))
|
||||
size = len(compressed_data)
|
||||
if size > page_size:
|
||||
# We cannot fit this much data into a page
|
||||
# Reduce the hi boundary, and try again
|
||||
return _binary_search_fitting_size(compressor, page_size, items, lo, mid-1)
|
||||
else:
|
||||
# We can fit this data into a page, but maybe we can fit more data
|
||||
# Try to see if we have a better match
|
||||
potential_target, potential_data = _binary_search_fitting_size(compressor, page_size, items, mid+1, hi)
|
||||
if potential_target != -1:
|
||||
# We found a larger index that can still fit onto a page, so use that
|
||||
return potential_target, potential_data
|
||||
else:
|
||||
# No better match, use our index
|
||||
return mid, compressed_data
|
||||
|
||||
def _trim_items_to_page(compressor: ZstdCompressor, page_size: int, items:list[T]):
|
||||
# Find max number of items that fit on a page
|
||||
return _binary_search_fitting_size(compressor, page_size, items, 0, len(items))
|
||||
|
||||
def _get_page_data(compressor: ZstdCompressor, page_size: int, items: list[T]):
|
||||
bytes_io = BytesIO()
|
||||
stream_writer = compressor.stream_writer(bytes_io, write_size=128)
|
||||
|
||||
num_fitting = 0
|
||||
for i, item in enumerate(items):
|
||||
serialised_data = json.dumps(item) + '\n'
|
||||
stream_writer.write(serialised_data.encode('utf8'))
|
||||
stream_writer.flush()
|
||||
if len(bytes_io.getvalue()) > page_size:
|
||||
break
|
||||
num_fitting = i + 1
|
||||
num_fitting, serialised_data = _trim_items_to_page(compressor, page_size, items)
|
||||
|
||||
compressed_data = compressor.compress(json.dumps(items[:num_fitting]).encode('utf8'))
|
||||
assert len(compressed_data) < page_size, "The data shouldn't get bigger"
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from mwmbl.tinysearchengine.indexer import Document, TinyIndex
|
||||
|
||||
from mwmbl.tinysearchengine.indexer import Document, TinyIndex, _binary_search_fitting_size, astuple, _trim_items_to_page, _get_page_data, _pad_to_page_size
|
||||
from zstandard import ZstdDecompressor, ZstdCompressor, ZstdError
|
||||
import json
|
||||
|
||||
def test_create_index():
|
||||
num_pages = 10
|
||||
|
@ -14,3 +15,103 @@ def test_create_index():
|
|||
for i in range(num_pages):
|
||||
page = indexer.get_page(i)
|
||||
assert page == []
|
||||
|
||||
def test_binary_search_fitting_size_all_fit():
|
||||
items = [1,2,3,4,5,6,7,8,9]
|
||||
compressor = ZstdCompressor()
|
||||
page_size = 4096
|
||||
count_fit, data = _binary_search_fitting_size(compressor,page_size,items,0,len(items))
|
||||
|
||||
# We should fit everything
|
||||
assert count_fit == len(items)
|
||||
|
||||
def test_binary_search_fitting_size_subset_fit():
|
||||
items = [1,2,3,4,5,6,7,8,9]
|
||||
compressor = ZstdCompressor()
|
||||
page_size = 15
|
||||
count_fit, data = _binary_search_fitting_size(compressor,page_size,items,0,len(items))
|
||||
|
||||
# We should not fit everything
|
||||
assert count_fit < len(items)
|
||||
|
||||
def test_binary_search_fitting_size_none_fit():
|
||||
items = [1,2,3,4,5,6,7,8,9]
|
||||
compressor = ZstdCompressor()
|
||||
page_size = 5
|
||||
count_fit, data = _binary_search_fitting_size(compressor,page_size,items,0,len(items))
|
||||
|
||||
# We should not fit anything
|
||||
assert count_fit == -1
|
||||
assert data is None
|
||||
|
||||
def test_get_page_data_single_doc():
|
||||
document1 = Document(title='title1',url='url1',extract='extract1',score=1.0)
|
||||
documents = [document1]
|
||||
items = [astuple(value) for value in documents]
|
||||
|
||||
compressor = ZstdCompressor()
|
||||
page_size = 4096
|
||||
|
||||
# Trim data
|
||||
num_fitting,trimmed_data = _trim_items_to_page(compressor,4096,items)
|
||||
|
||||
# We should be able to fit the 1 item into a page
|
||||
assert num_fitting == 1
|
||||
|
||||
# Compare the trimmed data to the actual data we're persisting
|
||||
# We need to pad the trimmmed data, then it should be equal to the data we persist
|
||||
padded_trimmed_data = _pad_to_page_size(trimmed_data, page_size)
|
||||
serialized_data = _get_page_data(compressor,page_size,items)
|
||||
assert serialized_data == padded_trimmed_data
|
||||
|
||||
|
||||
def test_get_page_data_many_docs_all_fit():
|
||||
# Build giant documents item
|
||||
documents = []
|
||||
documents_len = 500
|
||||
page_size = 4096
|
||||
for x in range(documents_len):
|
||||
txt = 'text{}'.format(x)
|
||||
document = Document(title=txt,url=txt,extract=txt,score=x)
|
||||
documents.append(document)
|
||||
items = [astuple(value) for value in documents]
|
||||
|
||||
# Trim the items
|
||||
compressor = ZstdCompressor()
|
||||
num_fitting,trimmed_data = _trim_items_to_page(compressor,page_size,items)
|
||||
|
||||
# We should be able to fit all items
|
||||
assert num_fitting == documents_len
|
||||
|
||||
# Compare the trimmed data to the actual data we're persisting
|
||||
# We need to pad the trimmed data, then it should be equal to the data we persist
|
||||
serialized_data = _get_page_data(compressor,page_size,items)
|
||||
padded_trimmed_data = _pad_to_page_size(trimmed_data, page_size)
|
||||
|
||||
assert serialized_data == padded_trimmed_data
|
||||
|
||||
def test_get_page_data_many_docs_subset_fit():
|
||||
# Build giant documents item
|
||||
documents = []
|
||||
documents_len = 5000
|
||||
page_size = 4096
|
||||
for x in range(documents_len):
|
||||
txt = 'text{}'.format(x)
|
||||
document = Document(title=txt,url=txt,extract=txt,score=x)
|
||||
documents.append(document)
|
||||
items = [astuple(value) for value in documents]
|
||||
|
||||
# Trim the items
|
||||
compressor = ZstdCompressor()
|
||||
num_fitting,trimmed_data = _trim_items_to_page(compressor,page_size,items)
|
||||
|
||||
# We should be able to fit a subset of the items onto the page
|
||||
assert num_fitting > 1
|
||||
assert num_fitting < documents_len
|
||||
|
||||
# Compare the trimmed data to the actual data we're persisting
|
||||
# We need to pad the trimmed data, then it should be equal to the data we persist
|
||||
serialized_data = _get_page_data(compressor,page_size,items)
|
||||
padded_trimmed_data = _pad_to_page_size(trimmed_data, page_size)
|
||||
|
||||
assert serialized_data == padded_trimmed_data
|
Loading…
Reference in a new issue