main.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import asyncio
  2. import os
  3. from concurrent.futures import ThreadPoolExecutor
  4. from io import BytesIO
  5. from typing import Any
  6. import cv2
  7. import numpy as np
  8. import uvicorn
  9. from fastapi import Body, Depends, FastAPI
  10. from PIL import Image
  11. from app.models.base import InferenceModel
  12. from .config import settings
  13. from .models.cache import ModelCache
  14. from .schemas import (
  15. EmbeddingResponse,
  16. FaceResponse,
  17. MessageResponse,
  18. ModelType,
  19. TagResponse,
  20. TextModelRequest,
  21. TextResponse,
  22. )
  23. app = FastAPI()
  24. def init_state() -> None:
  25. app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
  26. # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
  27. app.state.thread_pool = ThreadPoolExecutor(settings.request_threads)
  28. async def load_models() -> None:
  29. models: list[tuple[str, ModelType, dict[str, Any]]] = [
  30. (settings.classification_model, ModelType.IMAGE_CLASSIFICATION, {}),
  31. (settings.clip_image_model, ModelType.CLIP, {"mode": "vision"}),
  32. (settings.clip_text_model, ModelType.CLIP, {"mode": "text"}),
  33. (settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION, {}),
  34. ]
  35. # Get all models
  36. for model_name, model_type, model_kwargs in models:
  37. await app.state.model_cache.get(model_name, model_type, eager=settings.eager_startup, **model_kwargs)
  38. @app.on_event("startup")
  39. async def startup_event() -> None:
  40. init_state()
  41. await load_models()
  42. @app.on_event("shutdown")
  43. async def shutdown_event() -> None:
  44. app.state.thread_pool.shutdown()
  45. def dep_pil_image(byte_image: bytes = Body(...)) -> Image.Image:
  46. return Image.open(BytesIO(byte_image))
  47. def dep_cv_image(byte_image: bytes = Body(...)) -> np.ndarray[int, np.dtype[Any]]:
  48. byte_image_np = np.frombuffer(byte_image, np.uint8)
  49. return cv2.imdecode(byte_image_np, cv2.IMREAD_COLOR)
  50. @app.get("/", response_model=MessageResponse)
  51. async def root() -> dict[str, str]:
  52. return {"message": "Immich ML"}
  53. @app.get("/ping", response_model=TextResponse)
  54. def ping() -> str:
  55. return "pong"
  56. @app.post(
  57. "/image-classifier/tag-image",
  58. response_model=TagResponse,
  59. status_code=200,
  60. )
  61. async def image_classification(
  62. image: Image.Image = Depends(dep_pil_image),
  63. ) -> list[str]:
  64. model = await app.state.model_cache.get(settings.classification_model, ModelType.IMAGE_CLASSIFICATION)
  65. labels = await predict(model, image)
  66. return labels
  67. @app.post(
  68. "/sentence-transformer/encode-image",
  69. response_model=EmbeddingResponse,
  70. status_code=200,
  71. )
  72. async def clip_encode_image(
  73. image: Image.Image = Depends(dep_pil_image),
  74. ) -> list[float]:
  75. model = await app.state.model_cache.get(settings.clip_image_model, ModelType.CLIP, mode="vision")
  76. embedding = await predict(model, image)
  77. return embedding
  78. @app.post(
  79. "/sentence-transformer/encode-text",
  80. response_model=EmbeddingResponse,
  81. status_code=200,
  82. )
  83. async def clip_encode_text(payload: TextModelRequest) -> list[float]:
  84. model = await app.state.model_cache.get(settings.clip_text_model, ModelType.CLIP, mode="text")
  85. embedding = await predict(model, payload.text)
  86. return embedding
  87. @app.post(
  88. "/facial-recognition/detect-faces",
  89. response_model=FaceResponse,
  90. status_code=200,
  91. )
  92. async def facial_recognition(
  93. image: cv2.Mat = Depends(dep_cv_image),
  94. ) -> list[dict[str, Any]]:
  95. model = await app.state.model_cache.get(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION)
  96. faces = await predict(model, image)
  97. return faces
  98. async def predict(model: InferenceModel, inputs: Any) -> Any:
  99. return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
  100. if __name__ == "__main__":
  101. is_dev = os.getenv("NODE_ENV") == "development"
  102. uvicorn.run(
  103. "app.main:app",
  104. host=settings.host,
  105. port=settings.port,
  106. reload=is_dev,
  107. workers=settings.workers,
  108. )