diff --git a/machine-learning/app/config.py b/machine-learning/app/config.py index 8870b8c0e..9b89114b0 100644 --- a/machine-learning/app/config.py +++ b/machine-learning/app/config.py @@ -13,7 +13,7 @@ from .schemas import ModelType class Settings(BaseSettings): cache_folder: str = "/cache" - model_ttl: int = 0 + model_ttl: int = 300 host: str = "0.0.0.0" port: int = 3003 workers: int = 1 @@ -21,6 +21,7 @@ class Settings(BaseSettings): request_threads: int = os.cpu_count() or 4 model_inter_op_threads: int = 1 model_intra_op_threads: int = 2 + shutdown_poll_s: int = 10 class Config: env_prefix = "MACHINE_LEARNING_" diff --git a/machine-learning/app/main.py b/machine-learning/app/main.py index 375c14a9e..48de46ee8 100644 --- a/machine-learning/app/main.py +++ b/machine-learning/app/main.py @@ -1,6 +1,10 @@ import asyncio +import gc +import os +import sys import threading from concurrent.futures import ThreadPoolExecutor +import time from typing import Any from zipfile import BadZipFile @@ -35,6 +39,9 @@ def init_state() -> None: # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None app.state.locks = {model_type: threading.Lock() for model_type in ModelType} + app.state.last_called = None + if settings.model_ttl > 0 and settings.shutdown_poll_s > 0: + asyncio.ensure_future(idle_shutdown_task()) log.info(f"Initialized request thread pool with {settings.request_threads} threads.") @@ -81,7 +88,7 @@ async def predict( async def run(model: InferenceModel, inputs: Any) -> Any: if app.state.thread_pool is None: return model.predict(inputs) - + app.state.last_called = time.time() return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs) @@ -113,3 +120,31 @@ async def load(model: InferenceModel) -> InferenceModel: else: await loop.run_in_executor(app.state.thread_pool, _load) return model + + +def predict(model: InferenceModel, inputs: Any) -> Any: + return model.predict(inputs) + + +async def idle_shutdown_task() -> None: + while True: + log.debug("Checking for inactivity...") + if app.state.last_called is not None and time.time() - app.state.last_called > settings.model_ttl: + log.debug("Shutting down due to inactivity.") + loop = asyncio.get_running_loop() + for task in asyncio.all_tasks(loop): + if task is not asyncio.current_task(): + try: + task.cancel() + except asyncio.CancelledError: + pass + sys.stderr.close() + sys.stdout.close() + sys.stdout = sys.stderr = open(os.devnull, "w") + try: + await app.state.model_cache.cache.clear() + gc.collect() + loop.stop() + except asyncio.CancelledError: + pass + await asyncio.sleep(settings.shutdown_poll_s)