restart process on inactivity

This commit is contained in:
mertalev 2023-11-12 19:12:58 -05:00
parent 069a32dcdb
commit c6659c7fa3
No known key found for this signature in database
GPG key ID: 9181CD92C0A1C5E3
2 changed files with 38 additions and 2 deletions

View file

@ -13,7 +13,7 @@ from .schemas import ModelType
class Settings(BaseSettings): class Settings(BaseSettings):
cache_folder: str = "/cache" cache_folder: str = "/cache"
model_ttl: int = 0 model_ttl: int = 300
host: str = "0.0.0.0" host: str = "0.0.0.0"
port: int = 3003 port: int = 3003
workers: int = 1 workers: int = 1
@ -21,6 +21,7 @@ class Settings(BaseSettings):
request_threads: int = os.cpu_count() or 4 request_threads: int = os.cpu_count() or 4
model_inter_op_threads: int = 1 model_inter_op_threads: int = 1
model_intra_op_threads: int = 2 model_intra_op_threads: int = 2
shutdown_poll_s: int = 10
class Config: class Config:
env_prefix = "MACHINE_LEARNING_" env_prefix = "MACHINE_LEARNING_"

View file

@ -1,6 +1,10 @@
import asyncio import asyncio
import gc
import os
import sys
import threading import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import time
from typing import Any from typing import Any
from zipfile import BadZipFile from zipfile import BadZipFile
@ -35,6 +39,9 @@ def init_state() -> None:
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
app.state.locks = {model_type: threading.Lock() for model_type in ModelType} app.state.locks = {model_type: threading.Lock() for model_type in ModelType}
app.state.last_called = None
if settings.model_ttl > 0 and settings.shutdown_poll_s > 0:
asyncio.ensure_future(idle_shutdown_task())
log.info(f"Initialized request thread pool with {settings.request_threads} threads.") log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
@ -81,7 +88,7 @@ async def predict(
async def run(model: InferenceModel, inputs: Any) -> Any: async def run(model: InferenceModel, inputs: Any) -> Any:
if app.state.thread_pool is None: if app.state.thread_pool is None:
return model.predict(inputs) return model.predict(inputs)
app.state.last_called = time.time()
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs) return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
@ -113,3 +120,31 @@ async def load(model: InferenceModel) -> InferenceModel:
else: else:
await loop.run_in_executor(app.state.thread_pool, _load) await loop.run_in_executor(app.state.thread_pool, _load)
return model return model
def predict(model: InferenceModel, inputs: Any) -> Any:
return model.predict(inputs)
async def idle_shutdown_task() -> None:
while True:
log.debug("Checking for inactivity...")
if app.state.last_called is not None and time.time() - app.state.last_called > settings.model_ttl:
log.debug("Shutting down due to inactivity.")
loop = asyncio.get_running_loop()
for task in asyncio.all_tasks(loop):
if task is not asyncio.current_task():
try:
task.cancel()
except asyncio.CancelledError:
pass
sys.stderr.close()
sys.stdout.close()
sys.stdout = sys.stderr = open(os.devnull, "w")
try:
await app.state.model_cache.cache.clear()
gc.collect()
loop.stop()
except asyncio.CancelledError:
pass
await asyncio.sleep(settings.shutdown_poll_s)