main.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import asyncio
  2. import threading
  3. from concurrent.futures import ThreadPoolExecutor
  4. from typing import Any
  5. from zipfile import BadZipFile
  6. import orjson
  7. from fastapi import FastAPI, Form, HTTPException, UploadFile
  8. from fastapi.responses import ORJSONResponse
  9. from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
  10. from starlette.formparsers import MultiPartParser
  11. from app.models.base import InferenceModel
  12. from .config import log, settings
  13. from .models.cache import ModelCache
  14. from .schemas import (
  15. MessageResponse,
  16. ModelType,
  17. TextResponse,
  18. )
  19. MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger
  20. app = FastAPI()
  21. def init_state() -> None:
  22. app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
  23. log.info(
  24. (
  25. "Created in-memory cache with unloading "
  26. f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
  27. )
  28. )
  29. # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
  30. app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
  31. app.state.locks = {model_type: threading.Lock() for model_type in ModelType}
  32. log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
  33. @app.on_event("startup")
  34. async def startup_event() -> None:
  35. init_state()
  36. @app.get("/", response_model=MessageResponse)
  37. async def root() -> dict[str, str]:
  38. return {"message": "Immich ML"}
  39. @app.get("/ping", response_model=TextResponse)
  40. def ping() -> str:
  41. return "pong"
  42. @app.post("/predict")
  43. async def predict(
  44. model_name: str = Form(alias="modelName"),
  45. model_type: ModelType = Form(alias="modelType"),
  46. options: str = Form(default="{}"),
  47. text: str | None = Form(default=None),
  48. image: UploadFile | None = None,
  49. ) -> Any:
  50. if image is not None:
  51. inputs: str | bytes = await image.read()
  52. elif text is not None:
  53. inputs = text
  54. else:
  55. raise HTTPException(400, "Either image or text must be provided")
  56. try:
  57. kwargs = orjson.loads(options)
  58. except orjson.JSONDecodeError:
  59. raise HTTPException(400, f"Invalid options JSON: {options}")
  60. model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs))
  61. model.configure(**kwargs)
  62. outputs = await run(model, inputs)
  63. return ORJSONResponse(outputs)
  64. async def run(model: InferenceModel, inputs: Any) -> Any:
  65. if app.state.thread_pool is None:
  66. return model.predict(inputs)
  67. return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
  68. async def load(model: InferenceModel) -> InferenceModel:
  69. if model.loaded:
  70. return model
  71. def _load() -> None:
  72. with app.state.locks[model.model_type]:
  73. model.load()
  74. loop = asyncio.get_running_loop()
  75. try:
  76. if app.state.thread_pool is None:
  77. model.load()
  78. else:
  79. await loop.run_in_executor(app.state.thread_pool, _load)
  80. return model
  81. except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
  82. log.warn(
  83. (
  84. f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'."
  85. "Clearing cache and retrying."
  86. )
  87. )
  88. model.clear_cache()
  89. if app.state.thread_pool is None:
  90. model.load()
  91. else:
  92. await loop.run_in_executor(app.state.thread_pool, _load)
  93. return model