main.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import asyncio
  2. import os
  3. from concurrent.futures import ThreadPoolExecutor
  4. from typing import Any
  5. import orjson
  6. import uvicorn
  7. from fastapi import FastAPI, Form, HTTPException, UploadFile
  8. from fastapi.responses import ORJSONResponse
  9. from starlette.formparsers import MultiPartParser
  10. from app.models.base import InferenceModel
  11. from .config import settings
  12. from .models.cache import ModelCache
  13. from .schemas import (
  14. MessageResponse,
  15. ModelType,
  16. TextResponse,
  17. )
  18. MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger
  19. app = FastAPI()
  20. def init_state() -> None:
  21. app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
  22. # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
  23. app.state.thread_pool = ThreadPoolExecutor(settings.request_threads)
  24. @app.on_event("startup")
  25. async def startup_event() -> None:
  26. init_state()
  27. @app.get("/", response_model=MessageResponse)
  28. async def root() -> dict[str, str]:
  29. return {"message": "Immich ML"}
  30. @app.get("/ping", response_model=TextResponse)
  31. def ping() -> str:
  32. return "pong"
  33. @app.post("/predict")
  34. async def predict(
  35. model_name: str = Form(alias="modelName"),
  36. model_type: ModelType = Form(alias="modelType"),
  37. options: str = Form(default="{}"),
  38. text: str | None = Form(default=None),
  39. image: UploadFile | None = None,
  40. ) -> Any:
  41. if image is not None:
  42. inputs: str | bytes = await image.read()
  43. elif text is not None:
  44. inputs = text
  45. else:
  46. raise HTTPException(400, "Either image or text must be provided")
  47. model: InferenceModel = await app.state.model_cache.get(model_name, model_type, **orjson.loads(options))
  48. outputs = await run(model, inputs)
  49. return ORJSONResponse(outputs)
  50. async def run(model: InferenceModel, inputs: Any) -> Any:
  51. return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
  52. if __name__ == "__main__":
  53. is_dev = os.getenv("NODE_ENV") == "development"
  54. uvicorn.run(
  55. "app.main:app",
  56. host=settings.host,
  57. port=settings.port,
  58. reload=is_dev,
  59. workers=settings.workers,
  60. )