main.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import os
  2. from io import BytesIO
  3. from typing import Any
  4. import cv2
  5. import numpy as np
  6. import uvicorn
  7. from fastapi import Body, Depends, FastAPI
  8. from PIL import Image
  9. from .config import settings
  10. from .models.base import InferenceModel
  11. from .models.cache import ModelCache
  12. from .schemas import (
  13. EmbeddingResponse,
  14. FaceResponse,
  15. MessageResponse,
  16. ModelType,
  17. TagResponse,
  18. TextModelRequest,
  19. TextResponse,
  20. )
  21. app = FastAPI()
  22. def init_state() -> None:
  23. app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
  24. async def load_models() -> None:
  25. models = [
  26. (settings.classification_model, ModelType.IMAGE_CLASSIFICATION),
  27. (settings.clip_image_model, ModelType.CLIP),
  28. (settings.clip_text_model, ModelType.CLIP),
  29. (settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION),
  30. ]
  31. # Get all models
  32. for model_name, model_type in models:
  33. if settings.eager_startup:
  34. await app.state.model_cache.get(model_name, model_type)
  35. else:
  36. InferenceModel.from_model_type(model_type, model_name)
  37. @app.on_event("startup")
  38. async def startup_event() -> None:
  39. init_state()
  40. await load_models()
  41. def dep_pil_image(byte_image: bytes = Body(...)) -> Image.Image:
  42. return Image.open(BytesIO(byte_image))
  43. def dep_cv_image(byte_image: bytes = Body(...)) -> cv2.Mat:
  44. byte_image_np = np.frombuffer(byte_image, np.uint8)
  45. return cv2.imdecode(byte_image_np, cv2.IMREAD_COLOR)
  46. @app.get("/", response_model=MessageResponse)
  47. async def root() -> dict[str, str]:
  48. return {"message": "Immich ML"}
  49. @app.get("/ping", response_model=TextResponse)
  50. def ping() -> str:
  51. return "pong"
  52. @app.post(
  53. "/image-classifier/tag-image",
  54. response_model=TagResponse,
  55. status_code=200,
  56. )
  57. async def image_classification(
  58. image: Image.Image = Depends(dep_pil_image),
  59. ) -> list[str]:
  60. model = await app.state.model_cache.get(settings.classification_model, ModelType.IMAGE_CLASSIFICATION)
  61. labels = model.predict(image)
  62. return labels
  63. @app.post(
  64. "/sentence-transformer/encode-image",
  65. response_model=EmbeddingResponse,
  66. status_code=200,
  67. )
  68. async def clip_encode_image(
  69. image: Image.Image = Depends(dep_pil_image),
  70. ) -> list[float]:
  71. model = await app.state.model_cache.get(settings.clip_image_model, ModelType.CLIP)
  72. embedding = model.predict(image)
  73. return embedding
  74. @app.post(
  75. "/sentence-transformer/encode-text",
  76. response_model=EmbeddingResponse,
  77. status_code=200,
  78. )
  79. async def clip_encode_text(payload: TextModelRequest) -> list[float]:
  80. model = await app.state.model_cache.get(settings.clip_text_model, ModelType.CLIP)
  81. embedding = model.predict(payload.text)
  82. return embedding
  83. @app.post(
  84. "/facial-recognition/detect-faces",
  85. response_model=FaceResponse,
  86. status_code=200,
  87. )
  88. async def facial_recognition(
  89. image: cv2.Mat = Depends(dep_cv_image),
  90. ) -> list[dict[str, Any]]:
  91. model = await app.state.model_cache.get(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION)
  92. faces = model.predict(image)
  93. return faces
  94. if __name__ == "__main__":
  95. is_dev = os.getenv("NODE_ENV") == "development"
  96. uvicorn.run(
  97. "app.main:app",
  98. host=settings.host,
  99. port=settings.port,
  100. reload=is_dev,
  101. workers=settings.workers,
  102. )