Browse Source

chore(ml): move to fastAPI (#2336)

Alex 2 years ago
parent
commit
7e965cb6d4
3 changed files with 75 additions and 72 deletions
  1. 9 8
      machine-learning/Dockerfile
  2. 0 29
      machine-learning/gunicorn.conf.py
  3. 66 35
      machine-learning/src/main.py

+ 9 - 8
machine-learning/Dockerfile

@@ -1,14 +1,15 @@
 FROM python:3.10 as builder
 FROM python:3.10 as builder
 
 
 ENV PYTHONDONTWRITEBYTECODE=1 \
 ENV PYTHONDONTWRITEBYTECODE=1 \
-    PYTHONUNBUFFERED=1 \
-    PIP_NO_CACHE_DIR=true
+  PYTHONUNBUFFERED=1 \
+  PIP_NO_CACHE_DIR=true
 
 
 RUN python -m venv /opt/venv
 RUN python -m venv /opt/venv
 RUN /opt/venv/bin/pip install --pre torch  -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
 RUN /opt/venv/bin/pip install --pre torch  -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
-RUN /opt/venv/bin/pip install transformers tqdm numpy scikit-learn scipy nltk sentencepiece flask Pillow gunicorn
+RUN /opt/venv/bin/pip install transformers tqdm numpy scikit-learn scipy nltk sentencepiece fastapi Pillow uvicorn[standard]
 RUN /opt/venv/bin/pip install --no-deps sentence-transformers
 RUN /opt/venv/bin/pip install --no-deps sentence-transformers
 
 
+
 FROM python:3.10-slim
 FROM python:3.10-slim
 
 
 ENV NODE_ENV=production
 ENV NODE_ENV=production
@@ -16,12 +17,12 @@ ENV NODE_ENV=production
 COPY --from=builder /opt/venv /opt/venv
 COPY --from=builder /opt/venv /opt/venv
 
 
 ENV TRANSFORMERS_CACHE=/cache \
 ENV TRANSFORMERS_CACHE=/cache \
-    PYTHONDONTWRITEBYTECODE=1 \
-    PYTHONUNBUFFERED=1 \
-    PATH="/opt/venv/bin:$PATH"
+  PYTHONDONTWRITEBYTECODE=1 \
+  PYTHONUNBUFFERED=1 \
+  PATH="/opt/venv/bin:$PATH"
 
 
 WORKDIR /usr/src/app
 WORKDIR /usr/src/app
 
 
 COPY . .
 COPY . .
-
-CMD ["gunicorn", "src.main:server"]
+ENV PYTHONPATH=`pwd`
+CMD ["python", "main.py"]

+ 0 - 29
machine-learning/gunicorn.conf.py

@@ -1,29 +0,0 @@
-"""
-Gunicorn configuration options.
-https://docs.gunicorn.org/en/stable/settings.html
-"""
-import os
-
-
-# Set the bind address based on the env
-port = os.getenv("MACHINE_LEARNING_PORT") or "3003"
-listen_ip = os.getenv("MACHINE_LEARNING_IP") or "0.0.0.0"
-bind = [f"{listen_ip}:{port}"]
-
-# Preload the Flask app / models etc. before starting the server
-preload_app = True
-
-# Logging settings - log to stdout and set log level
-accesslog = "-"
-loglevel = os.getenv("MACHINE_LEARNING_LOG_LEVEL") or "info"
-
-# Worker settings
-# ----------------------
-# It is important these are chosen carefully as per
-# https://pythonspeed.com/articles/gunicorn-in-docker/
-# Otherwise we get workers failing to respond to heartbeat checks,
-# especially as requests take a long time to complete.
-workers = 2
-threads = 4
-worker_tmp_dir = "/dev/shm"
-timeout = 60

+ 66 - 35
machine-learning/src/main.py

@@ -1,58 +1,77 @@
-import os
-from flask import Flask, request
 from transformers import pipeline
 from transformers import pipeline
 from sentence_transformers import SentenceTransformer, util
 from sentence_transformers import SentenceTransformer, util
 from PIL import Image
 from PIL import Image
+from fastapi import FastAPI
+import uvicorn
+import os
+from pydantic import BaseModel
+
+
+class MlRequestBody(BaseModel):
+    thumbnailPath: str
+
+
+class ClipRequestBody(BaseModel):
+    text: str
+
 
 
 is_dev = os.getenv('NODE_ENV') == 'development'
 is_dev = os.getenv('NODE_ENV') == 'development'
 server_port = os.getenv('MACHINE_LEARNING_PORT', 3003)
 server_port = os.getenv('MACHINE_LEARNING_PORT', 3003)
 server_host = os.getenv('MACHINE_LEARNING_HOST', '0.0.0.0')
 server_host = os.getenv('MACHINE_LEARNING_HOST', '0.0.0.0')
 
 
-classification_model = os.getenv('MACHINE_LEARNING_CLASSIFICATION_MODEL', 'microsoft/resnet-50')
+app = FastAPI()
+
+"""
+Model Initialization
+"""
+classification_model = os.getenv(
+    'MACHINE_LEARNING_CLASSIFICATION_MODEL', 'microsoft/resnet-50')
 object_model = os.getenv('MACHINE_LEARNING_OBJECT_MODEL', 'hustvl/yolos-tiny')
 object_model = os.getenv('MACHINE_LEARNING_OBJECT_MODEL', 'hustvl/yolos-tiny')
-clip_image_model = os.getenv('MACHINE_LEARNING_CLIP_IMAGE_MODEL', 'clip-ViT-B-32')
-clip_text_model = os.getenv('MACHINE_LEARNING_CLIP_TEXT_MODEL', 'clip-ViT-B-32')
+clip_image_model = os.getenv(
+    'MACHINE_LEARNING_CLIP_IMAGE_MODEL', 'clip-ViT-B-32')
+clip_text_model = os.getenv(
+    'MACHINE_LEARNING_CLIP_TEXT_MODEL', 'clip-ViT-B-32')
 
 
 _model_cache = {}
 _model_cache = {}
-def _get_model(model, task=None):
-  global _model_cache
-  key = '|'.join([model, str(task)])
-  if key not in _model_cache:
-    if task:
-      _model_cache[key] = pipeline(model=model, task=task)
-    else:
-      _model_cache[key] = SentenceTransformer(model)
-  return _model_cache[key]
-
-server = Flask(__name__)
-
-@server.route("/ping")
+
+
+@app.get("/")
+async def root():
+    return {"message": "Immich ML"}
+
+
+@app.get("/ping")
 def ping():
 def ping():
     return "pong"
     return "pong"
 
 
-@server.route("/object-detection/detect-object", methods=['POST'])
-def object_detection():
+
+@app.post("/object-detection/detect-object", status_code=200)
+def object_detection(payload: MlRequestBody):
     model = _get_model(object_model, 'object-detection')
     model = _get_model(object_model, 'object-detection')
-    assetPath = request.json['thumbnailPath']
-    return run_engine(model, assetPath), 200
+    assetPath = payload.thumbnailPath
+    return run_engine(model, assetPath)
 
 
-@server.route("/image-classifier/tag-image", methods=['POST'])
-def image_classification():
+
+@app.post("/image-classifier/tag-image", status_code=200)
+def image_classification(payload: MlRequestBody):
     model = _get_model(classification_model, 'image-classification')
     model = _get_model(classification_model, 'image-classification')
-    assetPath = request.json['thumbnailPath']
-    return run_engine(model, assetPath), 200
+    assetPath = payload.thumbnailPath
+    return run_engine(model, assetPath)
+
 
 
-@server.route("/sentence-transformer/encode-image", methods=['POST'])
-def clip_encode_image():
+@app.post("/sentence-transformer/encode-image", status_code=200)
+def clip_encode_image(payload: MlRequestBody):
     model = _get_model(clip_image_model)
     model = _get_model(clip_image_model)
-    assetPath = request.json['thumbnailPath']
-    return model.encode(Image.open(assetPath)).tolist(), 200
+    assetPath = payload.thumbnailPath
+    return model.encode(Image.open(assetPath)).tolist()
 
 
-@server.route("/sentence-transformer/encode-text", methods=['POST'])
-def clip_encode_text():
+
+@app.post("/sentence-transformer/encode-text", status_code=200)
+def clip_encode_text(payload: ClipRequestBody):
     model = _get_model(clip_text_model)
     model = _get_model(clip_text_model)
-    text = request.json['text']
-    return model.encode(text).tolist(), 200
+    text = payload.text
+    return model.encode(text).tolist()
+
 
 
 def run_engine(engine, path):
 def run_engine(engine, path):
     result = []
     result = []
@@ -69,5 +88,17 @@ def run_engine(engine, path):
     return result
     return result
 
 
 
 
+def _get_model(model, task=None):
+    global _model_cache
+    key = '|'.join([model, str(task)])
+    if key not in _model_cache:
+        if task:
+            _model_cache[key] = pipeline(model=model, task=task)
+        else:
+            _model_cache[key] = SentenceTransformer(model)
+    return _model_cache[key]
+
+
 if __name__ == "__main__":
 if __name__ == "__main__":
-    server.run(debug=is_dev, host=server_host, port=server_port)
+    uvicorn.run("main:app", host=server_host,
+                port=int(server_port), reload=is_dev, workers=1)