|
@@ -7,7 +7,9 @@ import gzip
|
|
import json
|
|
import json
|
|
import os
|
|
import os
|
|
from multiprocessing.pool import ThreadPool
|
|
from multiprocessing.pool import ThreadPool
|
|
|
|
+from pathlib import Path
|
|
from tempfile import NamedTemporaryFile
|
|
from tempfile import NamedTemporaryFile
|
|
|
|
+from urllib.parse import urlparse
|
|
|
|
|
|
from pydantic import ValidationError
|
|
from pydantic import ValidationError
|
|
|
|
|
|
@@ -24,10 +26,6 @@ class BatchCache:
|
|
os.makedirs(repo_path, exist_ok=True)
|
|
os.makedirs(repo_path, exist_ok=True)
|
|
self.path = repo_path
|
|
self.path = repo_path
|
|
|
|
|
|
- def store(self, batch: HashedBatch):
|
|
|
|
- with NamedTemporaryFile(mode='w', dir=self.path, prefix='batch_', suffix='.json', delete=False) as output_file:
|
|
|
|
- output_file.write(batch.json())
|
|
|
|
-
|
|
|
|
def get(self, num_batches) -> dict[str, HashedBatch]:
|
|
def get(self, num_batches) -> dict[str, HashedBatch]:
|
|
batches = {}
|
|
batches = {}
|
|
for path in os.listdir(self.path):
|
|
for path in os.listdir(self.path):
|
|
@@ -65,5 +63,17 @@ class BatchCache:
|
|
print("Failed to validate batch", data)
|
|
print("Failed to validate batch", data)
|
|
return 0
|
|
return 0
|
|
if len(batch.items) > 0:
|
|
if len(batch.items) > 0:
|
|
- self.store(batch)
|
|
|
|
|
|
+ self.store(batch, url)
|
|
return len(batch.items)
|
|
return len(batch.items)
|
|
|
|
+
|
|
|
|
+ def store(self, batch, url):
|
|
|
|
+ path = self.get_path_from_url(url)
|
|
|
|
+ print("Path", path)
|
|
|
|
+ os.makedirs(path.parent, exist_ok=True)
|
|
|
|
+ with open(path, 'wb') as output_file:
|
|
|
|
+ data = gzip.compress(batch.json().encode('utf8'))
|
|
|
|
+ output_file.write(data)
|
|
|
|
+
|
|
|
|
+ def get_path_from_url(self, url) -> Path:
|
|
|
|
+ url_path = urlparse(url).path
|
|
|
|
+ return Path(self.path) / url_path.lstrip('/')
|