disable by default

This commit is contained in:
mertalev 2023-10-21 23:27:30 -04:00
parent 57506aa1fe
commit bdf8c9f1a9
No known key found for this signature in database
GPG key ID: 9181CD92C0A1C5E3
2 changed files with 8 additions and 3 deletions

View file

@ -21,7 +21,7 @@ class Settings(BaseSettings):
request_threads: int = os.cpu_count() or 4
model_inter_op_threads: int = 1
model_intra_op_threads: int = 2
max_batch_size: int = 8
max_batch_size: int = 1
batch_timeout_s: float = 0.005
class Config:

View file

@ -76,8 +76,13 @@ async def predict(
model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs))
model.configure(**kwargs)
batcher: Batcher = app.state.model_batcher.get(model_name, model_type, **kwargs)
outputs = await batcher.batch_process(element, run, model)
if settings.max_batch_size > 1:
batcher: Batcher = app.state.model_batcher.get(model_name, model_type, **kwargs)
outputs = await batcher.batch_process(element, run, model)
else:
outputs = await run(model, [element])
return ORJSONResponse(outputs)