dynamic batching
This commit is contained in:
parent
1280c6aeea
commit
4b15607fd7
3 changed files with 103 additions and 6 deletions
|
@ -31,11 +31,11 @@ services:
|
|||
build:
|
||||
context: ../machine-learning
|
||||
dockerfile: Dockerfile
|
||||
command: python main.py
|
||||
entrypoint: ../entrypoint.sh
|
||||
ports:
|
||||
- 3003:3003
|
||||
volumes:
|
||||
- ../machine-learning/app:/usr/src/app
|
||||
- ../machine-learning:/usr/src
|
||||
- model-cache:/cache
|
||||
env_file:
|
||||
- .env
|
||||
|
|
|
@ -31,7 +31,7 @@ app = FastAPI()
|
|||
def init_state() -> None:
|
||||
app.state.last_called = None
|
||||
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
|
||||
app.state.thread_pool = ThreadPoolExecutor()
|
||||
app.state.thread_pool = ThreadPoolExecutor(max_workers=os.cpu_count() or 4)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
import asyncio
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from aiocache.backends.memory import SimpleMemoryCache
|
||||
|
@ -47,9 +50,9 @@ class ModelCache:
|
|||
"""
|
||||
|
||||
key = self.cache.build_key(model_name, model_type.value)
|
||||
async with OptimisticLock(self.cache, key) as lock:
|
||||
model = await self.cache.get(key)
|
||||
if model is None:
|
||||
model = await self.cache.get(key)
|
||||
if model is None:
|
||||
async with OptimisticLock(self.cache, key) as lock:
|
||||
model = InferenceModel.from_model_type(model_type, model_name, **model_kwargs)
|
||||
await lock.cas(model, ttl=self.ttl)
|
||||
return model
|
||||
|
@ -95,3 +98,97 @@ class RevalidationPlugin(BasePlugin):
|
|||
key = client.build_key(key, namespace)
|
||||
if val is not None and key in client._handlers:
|
||||
await client.expire(key, client.ttl)
|
||||
|
||||
|
||||
def batched(max_size: int = 16, timeout_s: float = 0.01):
|
||||
"""
|
||||
Batches async calls into lists until the list reaches length `max_size` or `timeout_s` seconds pass, whichever comes first.
|
||||
Calls should pass an element as their only argument.
|
||||
Callables should take a list as their only argument and return a list of the same length.
|
||||
Inspired by Ray's @serve.batch decorator.
|
||||
"""
|
||||
|
||||
def decorator_factory(func):
|
||||
queue = asyncio.Queue(maxsize=max_size)
|
||||
lock = asyncio.Lock()
|
||||
output = None
|
||||
element_id = 0
|
||||
processing = {}
|
||||
processed = {}
|
||||
batch = []
|
||||
|
||||
@wraps(func)
|
||||
async def decorator(element):
|
||||
nonlocal element_id
|
||||
nonlocal batch
|
||||
nonlocal output
|
||||
|
||||
cur_idx = element_id
|
||||
processing[cur_idx] = element
|
||||
element_id += 1
|
||||
await queue.put(cur_idx)
|
||||
while cur_idx not in processed:
|
||||
start = time.monotonic()
|
||||
async with lock:
|
||||
batch_ids = []
|
||||
while len(batch) < max_size and time.monotonic() - start < timeout_s:
|
||||
try:
|
||||
cur = queue.get_nowait()
|
||||
batch_ids.append(cur)
|
||||
batch.append(processing.pop(cur))
|
||||
except asyncio.QueueEmpty:
|
||||
await asyncio.sleep(0)
|
||||
output = await func(batch)
|
||||
batch = []
|
||||
for i, id in enumerate(batch_ids):
|
||||
processed[id] = output[i]
|
||||
return processed.pop(cur_idx)
|
||||
return decorator
|
||||
return decorator_factory
|
||||
|
||||
|
||||
def batched_method(max_size: int = 16, timeout_s: float = 0.001):
|
||||
"""
|
||||
Batches async calls into lists until the list reaches length `max_size` or `timeout_s` seconds pass, whichever comes first.
|
||||
Calls should pass an element as their only argument.
|
||||
Callables should take a list as their only argument and return a list of the same length.
|
||||
Inspired by Ray's @serve.batch decorator.
|
||||
"""
|
||||
|
||||
def decorator_factory(func):
|
||||
queue = asyncio.Queue(maxsize=max_size)
|
||||
lock = asyncio.Lock()
|
||||
output = None
|
||||
element_id = 0
|
||||
processing = {}
|
||||
processed = {}
|
||||
batch = []
|
||||
|
||||
@wraps(func)
|
||||
async def decorator(self, element):
|
||||
nonlocal element_id
|
||||
nonlocal batch
|
||||
nonlocal output
|
||||
|
||||
cur_idx = element_id
|
||||
processing[cur_idx] = element
|
||||
element_id += 1
|
||||
await queue.put(cur_idx)
|
||||
while cur_idx not in processed:
|
||||
start = time.monotonic()
|
||||
async with lock:
|
||||
batch_ids = []
|
||||
while len(batch) < max_size and time.monotonic() - start < timeout_s:
|
||||
try:
|
||||
cur = queue.get_nowait()
|
||||
batch_ids.append(cur)
|
||||
batch.append(processing.pop(cur))
|
||||
except asyncio.QueueEmpty:
|
||||
await asyncio.sleep(0)
|
||||
output = await func(self, batch)
|
||||
batch = []
|
||||
for i, id in enumerate(batch_ids):
|
||||
processed[id] = output[i]
|
||||
return processed.pop(cur_idx)
|
||||
return decorator
|
||||
return decorator_factory
|
Loading…
Reference in a new issue