main.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import asyncio
  2. import gc
  3. import os
  4. import sys
  5. import threading
  6. import time
  7. from concurrent.futures import ThreadPoolExecutor
  8. from typing import Any
  9. from zipfile import BadZipFile
  10. import orjson
  11. from fastapi import FastAPI, Form, HTTPException, UploadFile
  12. from fastapi.responses import ORJSONResponse
  13. from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
  14. from starlette.formparsers import MultiPartParser
  15. from app.models.base import InferenceModel
  16. from .config import log, settings
  17. from .models.cache import ModelCache
  18. from .schemas import (
  19. MessageResponse,
  20. ModelType,
  21. TextResponse,
  22. )
  23. MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger
  24. app = FastAPI()
  25. def init_state() -> None:
  26. app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
  27. log.info(
  28. (
  29. "Created in-memory cache with unloading "
  30. f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
  31. )
  32. )
  33. # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
  34. app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
  35. app.state.lock = threading.Lock()
  36. app.state.last_called = None
  37. if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
  38. asyncio.ensure_future(idle_shutdown_task())
  39. log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
  40. @app.on_event("startup")
  41. async def startup_event() -> None:
  42. init_state()
  43. @app.get("/", response_model=MessageResponse)
  44. async def root() -> dict[str, str]:
  45. return {"message": "Immich ML"}
  46. @app.get("/ping", response_model=TextResponse)
  47. def ping() -> str:
  48. return "pong"
  49. @app.post("/predict")
  50. async def predict(
  51. model_name: str = Form(alias="modelName"),
  52. model_type: ModelType = Form(alias="modelType"),
  53. options: str = Form(default="{}"),
  54. text: str | None = Form(default=None),
  55. image: UploadFile | None = None,
  56. ) -> Any:
  57. if image is not None:
  58. inputs: str | bytes = await image.read()
  59. elif text is not None:
  60. inputs = text
  61. else:
  62. raise HTTPException(400, "Either image or text must be provided")
  63. try:
  64. kwargs = orjson.loads(options)
  65. except orjson.JSONDecodeError:
  66. raise HTTPException(400, f"Invalid options JSON: {options}")
  67. model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs))
  68. model.configure(**kwargs)
  69. outputs = await run(model, inputs)
  70. return ORJSONResponse(outputs)
  71. async def run(model: InferenceModel, inputs: Any) -> Any:
  72. app.state.last_called = time.time()
  73. if app.state.thread_pool is None:
  74. return model.predict(inputs)
  75. return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
  76. async def load(model: InferenceModel) -> InferenceModel:
  77. if model.loaded:
  78. return model
  79. def _load() -> None:
  80. with app.state.lock:
  81. model.load()
  82. loop = asyncio.get_running_loop()
  83. try:
  84. if app.state.thread_pool is None:
  85. model.load()
  86. else:
  87. await loop.run_in_executor(app.state.thread_pool, _load)
  88. return model
  89. except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
  90. log.warn(
  91. (
  92. f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'."
  93. "Clearing cache and retrying."
  94. )
  95. )
  96. model.clear_cache()
  97. if app.state.thread_pool is None:
  98. model.load()
  99. else:
  100. await loop.run_in_executor(app.state.thread_pool, _load)
  101. return model
  102. async def idle_shutdown_task() -> None:
  103. while True:
  104. log.debug("Checking for inactivity...")
  105. if app.state.last_called is not None and time.time() - app.state.last_called > settings.model_ttl:
  106. log.info("Shutting down due to inactivity.")
  107. loop = asyncio.get_running_loop()
  108. for task in asyncio.all_tasks(loop):
  109. if task is not asyncio.current_task():
  110. try:
  111. task.cancel()
  112. except asyncio.CancelledError:
  113. pass
  114. sys.stderr.close()
  115. sys.stdout.close()
  116. sys.stdout = sys.stderr = open(os.devnull, "w")
  117. try:
  118. await app.state.model_cache.cache.clear()
  119. gc.collect()
  120. loop.stop()
  121. except asyncio.CancelledError:
  122. pass
  123. await asyncio.sleep(settings.model_ttl_poll_s)