main.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import os
  2. from typing import Any
  3. from cache import ModelCache
  4. from schemas import (
  5. EmbeddingResponse,
  6. FaceResponse,
  7. TagResponse,
  8. MessageResponse,
  9. TextModelRequest,
  10. TextResponse,
  11. VisionModelRequest,
  12. )
  13. import uvicorn
  14. from PIL import Image
  15. from fastapi import FastAPI, HTTPException
  16. from models import get_model, run_classification, run_facial_recognition
  17. classification_model = os.getenv(
  18. "MACHINE_LEARNING_CLASSIFICATION_MODEL", "microsoft/resnet-50"
  19. )
  20. clip_image_model = os.getenv("MACHINE_LEARNING_CLIP_IMAGE_MODEL", "clip-ViT-B-32")
  21. clip_text_model = os.getenv("MACHINE_LEARNING_CLIP_TEXT_MODEL", "clip-ViT-B-32")
  22. facial_recognition_model = os.getenv(
  23. "MACHINE_LEARNING_FACIAL_RECOGNITION_MODEL", "buffalo_l"
  24. )
  25. min_tag_score = float(os.getenv("MACHINE_LEARNING_MIN_TAG_SCORE", 0.9))
  26. eager_startup = (
  27. os.getenv("MACHINE_LEARNING_EAGER_STARTUP", "true") == "true"
  28. ) # loads all models at startup
  29. model_ttl = int(os.getenv("MACHINE_LEARNING_MODEL_TTL", 300))
  30. _model_cache = None
  31. app = FastAPI()
  32. @app.on_event("startup")
  33. async def startup_event() -> None:
  34. global _model_cache
  35. _model_cache = ModelCache(ttl=model_ttl, revalidate=True)
  36. models = [
  37. (classification_model, "image-classification"),
  38. (clip_image_model, "clip"),
  39. (clip_text_model, "clip"),
  40. (facial_recognition_model, "facial-recognition"),
  41. ]
  42. # Get all models
  43. for model_name, model_type in models:
  44. if eager_startup:
  45. await _model_cache.get_cached_model(model_name, model_type)
  46. else:
  47. get_model(model_name, model_type)
  48. @app.get("/", response_model=MessageResponse)
  49. async def root() -> dict[str, str]:
  50. return {"message": "Immich ML"}
  51. @app.get("/ping", response_model=TextResponse)
  52. def ping() -> str:
  53. return "pong"
  54. @app.post("/image-classifier/tag-image", response_model=TagResponse, status_code=200)
  55. async def image_classification(payload: VisionModelRequest) -> list[str]:
  56. if _model_cache is None:
  57. raise HTTPException(status_code=500, detail="Unable to load model.")
  58. model = await _model_cache.get_cached_model(
  59. classification_model, "image-classification"
  60. )
  61. labels = run_classification(model, payload.image_path, min_tag_score)
  62. return labels
  63. @app.post(
  64. "/sentence-transformer/encode-image",
  65. response_model=EmbeddingResponse,
  66. status_code=200,
  67. )
  68. async def clip_encode_image(payload: VisionModelRequest) -> list[float]:
  69. if _model_cache is None:
  70. raise HTTPException(status_code=500, detail="Unable to load model.")
  71. model = await _model_cache.get_cached_model(clip_image_model, "clip")
  72. image = Image.open(payload.image_path)
  73. embedding = model.encode(image).tolist()
  74. return embedding
  75. @app.post(
  76. "/sentence-transformer/encode-text",
  77. response_model=EmbeddingResponse,
  78. status_code=200,
  79. )
  80. async def clip_encode_text(payload: TextModelRequest) -> list[float]:
  81. if _model_cache is None:
  82. raise HTTPException(status_code=500, detail="Unable to load model.")
  83. model = await _model_cache.get_cached_model(clip_text_model, "clip")
  84. embedding = model.encode(payload.text).tolist()
  85. return embedding
  86. @app.post(
  87. "/facial-recognition/detect-faces", response_model=FaceResponse, status_code=200
  88. )
  89. async def facial_recognition(payload: VisionModelRequest) -> list[dict[str, Any]]:
  90. if _model_cache is None:
  91. raise HTTPException(status_code=500, detail="Unable to load model.")
  92. model = await _model_cache.get_cached_model(
  93. facial_recognition_model, "facial-recognition"
  94. )
  95. faces = run_facial_recognition(model, payload.image_path)
  96. return faces
  97. if __name__ == "__main__":
  98. host = os.getenv("MACHINE_LEARNING_HOST", "0.0.0.0")
  99. port = int(os.getenv("MACHINE_LEARNING_PORT", 3003))
  100. is_dev = os.getenv("NODE_ENV") == "development"
  101. uvicorn.run("main:app", host=host, port=port, reload=is_dev, workers=1)