dynamic batching

This commit is contained in:
mertalev 2023-07-11 19:52:48 -04:00
parent 1280c6aeea
commit 4b15607fd7
No known key found for this signature in database
GPG key ID: 9181CD92C0A1C5E3
3 changed files with 103 additions and 6 deletions

View file

@ -31,11 +31,11 @@ services:
build:
context: ../machine-learning
dockerfile: Dockerfile
command: python main.py
entrypoint: ../entrypoint.sh
ports:
- 3003:3003
volumes:
- ../machine-learning/app:/usr/src/app
- ../machine-learning:/usr/src
- model-cache:/cache
env_file:
- .env

View file

@ -31,7 +31,7 @@ app = FastAPI()
def init_state() -> None:
app.state.last_called = None
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
app.state.thread_pool = ThreadPoolExecutor()
app.state.thread_pool = ThreadPoolExecutor(max_workers=os.cpu_count() or 4)
@app.on_event("startup")

View file

@ -1,3 +1,6 @@
import asyncio
import time
from functools import wraps
from typing import Any
from aiocache.backends.memory import SimpleMemoryCache
@ -47,9 +50,9 @@ class ModelCache:
"""
key = self.cache.build_key(model_name, model_type.value)
async with OptimisticLock(self.cache, key) as lock:
model = await self.cache.get(key)
if model is None:
model = await self.cache.get(key)
if model is None:
async with OptimisticLock(self.cache, key) as lock:
model = InferenceModel.from_model_type(model_type, model_name, **model_kwargs)
await lock.cas(model, ttl=self.ttl)
return model
@ -95,3 +98,97 @@ class RevalidationPlugin(BasePlugin):
key = client.build_key(key, namespace)
if val is not None and key in client._handlers:
await client.expire(key, client.ttl)
def batched(max_size: int = 16, timeout_s: float = 0.01):
"""
Batches async calls into lists until the list reaches length `max_size` or `timeout_s` seconds pass, whichever comes first.
Calls should pass an element as their only argument.
Callables should take a list as their only argument and return a list of the same length.
Inspired by Ray's @serve.batch decorator.
"""
def decorator_factory(func):
queue = asyncio.Queue(maxsize=max_size)
lock = asyncio.Lock()
output = None
element_id = 0
processing = {}
processed = {}
batch = []
@wraps(func)
async def decorator(element):
nonlocal element_id
nonlocal batch
nonlocal output
cur_idx = element_id
processing[cur_idx] = element
element_id += 1
await queue.put(cur_idx)
while cur_idx not in processed:
start = time.monotonic()
async with lock:
batch_ids = []
while len(batch) < max_size and time.monotonic() - start < timeout_s:
try:
cur = queue.get_nowait()
batch_ids.append(cur)
batch.append(processing.pop(cur))
except asyncio.QueueEmpty:
await asyncio.sleep(0)
output = await func(batch)
batch = []
for i, id in enumerate(batch_ids):
processed[id] = output[i]
return processed.pop(cur_idx)
return decorator
return decorator_factory
def batched_method(max_size: int = 16, timeout_s: float = 0.001):
"""
Batches async calls into lists until the list reaches length `max_size` or `timeout_s` seconds pass, whichever comes first.
Calls should pass an element as their only argument.
Callables should take a list as their only argument and return a list of the same length.
Inspired by Ray's @serve.batch decorator.
"""
def decorator_factory(func):
queue = asyncio.Queue(maxsize=max_size)
lock = asyncio.Lock()
output = None
element_id = 0
processing = {}
processed = {}
batch = []
@wraps(func)
async def decorator(self, element):
nonlocal element_id
nonlocal batch
nonlocal output
cur_idx = element_id
processing[cur_idx] = element
element_id += 1
await queue.put(cur_idx)
while cur_idx not in processed:
start = time.monotonic()
async with lock:
batch_ids = []
while len(batch) < max_size and time.monotonic() - start < timeout_s:
try:
cur = queue.get_nowait()
batch_ids.append(cur)
batch.append(processing.pop(cur))
except asyncio.QueueEmpty:
await asyncio.sleep(0)
output = await func(self, batch)
batch = []
for i, id in enumerate(batch_ids):
processed[id] = output[i]
return processed.pop(cur_idx)
return decorator
return decorator_factory