main.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import os
  2. from io import BytesIO
  3. from typing import Any
  4. import cv2
  5. import numpy as np
  6. import uvicorn
  7. from fastapi import Body, Depends, FastAPI
  8. from PIL import Image
  9. from .config import settings
  10. from .models.base import InferenceModel
  11. from .models.cache import ModelCache
  12. from .schemas import (
  13. EmbeddingResponse,
  14. FaceResponse,
  15. MessageResponse,
  16. ModelType,
  17. TagResponse,
  18. TextModelRequest,
  19. TextResponse,
  20. )
  21. app = FastAPI()
  22. @app.on_event("startup")
  23. async def startup_event() -> None:
  24. app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
  25. same_clip = settings.clip_image_model == settings.clip_text_model
  26. app.state.clip_vision_type = ModelType.CLIP if same_clip else ModelType.CLIP_VISION
  27. app.state.clip_text_type = ModelType.CLIP if same_clip else ModelType.CLIP_TEXT
  28. models = [
  29. (settings.classification_model, ModelType.IMAGE_CLASSIFICATION),
  30. (settings.clip_image_model, app.state.clip_vision_type),
  31. (settings.clip_text_model, app.state.clip_text_type),
  32. (settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION),
  33. ]
  34. # Get all models
  35. for model_name, model_type in models:
  36. if settings.eager_startup:
  37. await app.state.model_cache.get(model_name, model_type)
  38. else:
  39. InferenceModel.from_model_type(model_type, model_name)
  40. def dep_pil_image(byte_image: bytes = Body(...)) -> Image.Image:
  41. return Image.open(BytesIO(byte_image))
  42. def dep_cv_image(byte_image: bytes = Body(...)) -> cv2.Mat:
  43. byte_image_np = np.frombuffer(byte_image, np.uint8)
  44. return cv2.imdecode(byte_image_np, cv2.IMREAD_COLOR)
  45. @app.get("/", response_model=MessageResponse)
  46. async def root() -> dict[str, str]:
  47. return {"message": "Immich ML"}
  48. @app.get("/ping", response_model=TextResponse)
  49. def ping() -> str:
  50. return "pong"
  51. @app.post(
  52. "/image-classifier/tag-image",
  53. response_model=TagResponse,
  54. status_code=200,
  55. )
  56. async def image_classification(
  57. image: Image.Image = Depends(dep_pil_image),
  58. ) -> list[str]:
  59. model = await app.state.model_cache.get(
  60. settings.classification_model, ModelType.IMAGE_CLASSIFICATION
  61. )
  62. labels = model.predict(image)
  63. return labels
  64. @app.post(
  65. "/sentence-transformer/encode-image",
  66. response_model=EmbeddingResponse,
  67. status_code=200,
  68. )
  69. async def clip_encode_image(
  70. image: Image.Image = Depends(dep_pil_image),
  71. ) -> list[float]:
  72. model = await app.state.model_cache.get(
  73. settings.clip_image_model, app.state.clip_vision_type
  74. )
  75. embedding = model.predict(image)
  76. return embedding
  77. @app.post(
  78. "/sentence-transformer/encode-text",
  79. response_model=EmbeddingResponse,
  80. status_code=200,
  81. )
  82. async def clip_encode_text(payload: TextModelRequest) -> list[float]:
  83. model = await app.state.model_cache.get(
  84. settings.clip_text_model, app.state.clip_text_type
  85. )
  86. embedding = model.predict(payload.text)
  87. return embedding
  88. @app.post(
  89. "/facial-recognition/detect-faces",
  90. response_model=FaceResponse,
  91. status_code=200,
  92. )
  93. async def facial_recognition(
  94. image: cv2.Mat = Depends(dep_cv_image),
  95. ) -> list[dict[str, Any]]:
  96. model = await app.state.model_cache.get(
  97. settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION
  98. )
  99. faces = model.predict(image)
  100. return faces
  101. if __name__ == "__main__":
  102. is_dev = os.getenv("NODE_ENV") == "development"
  103. uvicorn.run(
  104. "app.main:app",
  105. host=settings.host,
  106. port=settings.port,
  107. reload=is_dev,
  108. workers=settings.workers,
  109. )