main.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import os
  2. import io
  3. from typing import Any
  4. from cache import ModelCache
  5. from schemas import (
  6. EmbeddingResponse,
  7. FaceResponse,
  8. TagResponse,
  9. MessageResponse,
  10. TextModelRequest,
  11. TextResponse,
  12. )
  13. import uvicorn
  14. from PIL import Image
  15. from fastapi import FastAPI, HTTPException, Depends, Body
  16. from models import get_model, run_classification, run_facial_recognition
  17. from config import settings
  18. _model_cache = None
  19. app = FastAPI()
  20. @app.on_event("startup")
  21. async def startup_event() -> None:
  22. global _model_cache
  23. _model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
  24. models = [
  25. (settings.classification_model, "image-classification"),
  26. (settings.clip_image_model, "clip"),
  27. (settings.clip_text_model, "clip"),
  28. (settings.facial_recognition_model, "facial-recognition"),
  29. ]
  30. # Get all models
  31. for model_name, model_type in models:
  32. if settings.eager_startup:
  33. await _model_cache.get_cached_model(model_name, model_type)
  34. else:
  35. get_model(model_name, model_type)
  36. def dep_model_cache():
  37. if _model_cache is None:
  38. raise HTTPException(status_code=500, detail="Unable to load model.")
  39. def dep_input_image(image: bytes = Body(...)) -> Image:
  40. return Image.open(io.BytesIO(image))
  41. @app.get("/", response_model=MessageResponse)
  42. async def root() -> dict[str, str]:
  43. return {"message": "Immich ML"}
  44. @app.get("/ping", response_model=TextResponse)
  45. def ping() -> str:
  46. return "pong"
  47. @app.post(
  48. "/image-classifier/tag-image",
  49. response_model=TagResponse,
  50. status_code=200,
  51. dependencies=[Depends(dep_model_cache)],
  52. )
  53. async def image_classification(
  54. image: Image = Depends(dep_input_image)
  55. ) -> list[str]:
  56. try:
  57. model = await _model_cache.get_cached_model(
  58. settings.classification_model, "image-classification"
  59. )
  60. labels = run_classification(model, image, settings.min_tag_score)
  61. except Exception as ex:
  62. raise HTTPException(status_code=500, detail=str(ex))
  63. else:
  64. return labels
  65. @app.post(
  66. "/sentence-transformer/encode-image",
  67. response_model=EmbeddingResponse,
  68. status_code=200,
  69. dependencies=[Depends(dep_model_cache)],
  70. )
  71. async def clip_encode_image(
  72. image: Image = Depends(dep_input_image)
  73. ) -> list[float]:
  74. model = await _model_cache.get_cached_model(settings.clip_image_model, "clip")
  75. embedding = model.encode(image).tolist()
  76. return embedding
  77. @app.post(
  78. "/sentence-transformer/encode-text",
  79. response_model=EmbeddingResponse,
  80. status_code=200,
  81. dependencies=[Depends(dep_model_cache)],
  82. )
  83. async def clip_encode_text(
  84. payload: TextModelRequest
  85. ) -> list[float]:
  86. model = await _model_cache.get_cached_model(settings.clip_text_model, "clip")
  87. embedding = model.encode(payload.text).tolist()
  88. return embedding
  89. @app.post(
  90. "/facial-recognition/detect-faces",
  91. response_model=FaceResponse,
  92. status_code=200,
  93. dependencies=[Depends(dep_model_cache)],
  94. )
  95. async def facial_recognition(
  96. image: bytes = Body(...),
  97. ) -> list[dict[str, Any]]:
  98. model = await _model_cache.get_cached_model(
  99. settings.facial_recognition_model, "facial-recognition"
  100. )
  101. faces = run_facial_recognition(model, image)
  102. return faces
  103. if __name__ == "__main__":
  104. is_dev = os.getenv("NODE_ENV") == "development"
  105. uvicorn.run(
  106. "main:app",
  107. host=settings.host,
  108. port=settings.port,
  109. reload=is_dev,
  110. workers=settings.workers,
  111. )