From a2f5674bbbb7d404aff4c91c8ac2c1842282f4f4 Mon Sep 17 00:00:00 2001 From: Mert <101130780+mertalev@users.noreply.github.com> Date: Sat, 24 Jun 2023 23:18:09 -0400 Subject: [PATCH] refactor(ml): modularization and styling (#2835) * basic refactor and styling * removed batching * module entrypoint * removed unused imports * model superclass, model cache now in app state * fixed cache dir and enforced abstract method --------- Co-authored-by: Alex Tran --- machine-learning/Dockerfile | 4 +- machine-learning/app/__init__.py | 0 machine-learning/app/config.py | 11 +- machine-learning/app/main.py | 102 +++++++-------- machine-learning/app/models.py | 119 ------------------ machine-learning/app/models/__init__.py | 3 + machine-learning/app/models/base.py | 52 ++++++++ machine-learning/app/{ => models}/cache.py | 26 ++-- machine-learning/app/models/clip.py | 37 ++++++ .../app/models/facial_recognition.py | 59 +++++++++ .../app/models/image_classification.py | 40 ++++++ machine-learning/app/schemas.py | 10 ++ 12 files changed, 281 insertions(+), 182 deletions(-) create mode 100644 machine-learning/app/__init__.py delete mode 100644 machine-learning/app/models.py create mode 100644 machine-learning/app/models/__init__.py create mode 100644 machine-learning/app/models/base.py rename machine-learning/app/{ => models}/cache.py (80%) create mode 100644 machine-learning/app/models/clip.py create mode 100644 machine-learning/app/models/facial_recognition.py create mode 100644 machine-learning/app/models/image_classification.py diff --git a/machine-learning/Dockerfile b/machine-learning/Dockerfile index 25d18b59a..e5d5f22c6 100644 --- a/machine-learning/Dockerfile +++ b/machine-learning/Dockerfile @@ -21,8 +21,8 @@ ENV NODE_ENV=production \ PYTHONDONTWRITEBYTECODE=1 \ PYTHONUNBUFFERED=1 \ PATH="/opt/venv/bin:$PATH" \ - PYTHONPATH=`pwd` + PYTHONPATH=/usr/src COPY --from=builder /opt/venv /opt/venv COPY app . -ENTRYPOINT ["python", "main.py"] +ENTRYPOINT ["python", "-m", "app.main"] diff --git a/machine-learning/app/__init__.py b/machine-learning/app/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/machine-learning/app/config.py b/machine-learning/app/config.py index f6ce64e75..3845cf1f3 100644 --- a/machine-learning/app/config.py +++ b/machine-learning/app/config.py @@ -1,5 +1,10 @@ +from pathlib import Path + from pydantic import BaseSettings +from .schemas import ModelType + + class Settings(BaseSettings): cache_folder: str = "/cache" classification_model: str = "microsoft/resnet-50" @@ -15,8 +20,12 @@ class Settings(BaseSettings): min_face_score: float = 0.7 class Config(BaseSettings.Config): - env_prefix = 'MACHINE_LEARNING_' + env_prefix = "MACHINE_LEARNING_" case_sensitive = False +def get_cache_dir(model_name: str, model_type: ModelType) -> Path: + return Path(settings.cache_folder, model_type.value, model_name) + + settings = Settings() diff --git a/machine-learning/app/main.py b/machine-learning/app/main.py index cb65caaa4..49436977b 100644 --- a/machine-learning/app/main.py +++ b/machine-learning/app/main.py @@ -1,52 +1,58 @@ import os -import io +from io import BytesIO from typing import Any -from cache import ModelCache -from schemas import ( +import cv2 +import numpy as np +import uvicorn +from fastapi import Body, Depends, FastAPI +from PIL import Image + +from .config import settings +from .models.base import InferenceModel +from .models.cache import ModelCache +from .schemas import ( EmbeddingResponse, FaceResponse, - TagResponse, MessageResponse, + ModelType, + TagResponse, TextModelRequest, TextResponse, ) -import uvicorn -from PIL import Image -from fastapi import FastAPI, HTTPException, Depends, Body -from models import get_model, run_classification, run_facial_recognition -from config import settings - -_model_cache = None app = FastAPI() @app.on_event("startup") async def startup_event() -> None: - global _model_cache - _model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True) + app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True) + same_clip = settings.clip_image_model == settings.clip_text_model + app.state.clip_vision_type = ModelType.CLIP if same_clip else ModelType.CLIP_VISION + app.state.clip_text_type = ModelType.CLIP if same_clip else ModelType.CLIP_TEXT models = [ - (settings.classification_model, "image-classification"), - (settings.clip_image_model, "clip"), - (settings.clip_text_model, "clip"), - (settings.facial_recognition_model, "facial-recognition"), + (settings.classification_model, ModelType.IMAGE_CLASSIFICATION), + (settings.clip_image_model, app.state.clip_vision_type), + (settings.clip_text_model, app.state.clip_text_type), + (settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION), ] # Get all models for model_name, model_type in models: if settings.eager_startup: - await _model_cache.get_cached_model(model_name, model_type) + await app.state.model_cache.get(model_name, model_type) else: - get_model(model_name, model_type) + InferenceModel.from_model_type(model_type, model_name) -def dep_model_cache(): - if _model_cache is None: - raise HTTPException(status_code=500, detail="Unable to load model.") +def dep_pil_image(byte_image: bytes = Body(...)) -> Image.Image: + return Image.open(BytesIO(byte_image)) + + +def dep_cv_image(byte_image: bytes = Body(...)) -> cv2.Mat: + byte_image_np = np.frombuffer(byte_image, np.uint8) + return cv2.imdecode(byte_image_np, cv2.IMREAD_COLOR) -def dep_input_image(image: bytes = Body(...)) -> Image: - return Image.open(io.BytesIO(image)) @app.get("/", response_model=MessageResponse) async def root() -> dict[str, str]: @@ -62,33 +68,29 @@ def ping() -> str: "/image-classifier/tag-image", response_model=TagResponse, status_code=200, - dependencies=[Depends(dep_model_cache)], ) async def image_classification( - image: Image = Depends(dep_input_image) + image: Image.Image = Depends(dep_pil_image), ) -> list[str]: - try: - model = await _model_cache.get_cached_model( - settings.classification_model, "image-classification" - ) - labels = run_classification(model, image, settings.min_tag_score) - except Exception as ex: - raise HTTPException(status_code=500, detail=str(ex)) - else: - return labels + model = await app.state.model_cache.get( + settings.classification_model, ModelType.IMAGE_CLASSIFICATION + ) + labels = model.predict(image) + return labels @app.post( "/sentence-transformer/encode-image", response_model=EmbeddingResponse, status_code=200, - dependencies=[Depends(dep_model_cache)], ) async def clip_encode_image( - image: Image = Depends(dep_input_image) + image: Image.Image = Depends(dep_pil_image), ) -> list[float]: - model = await _model_cache.get_cached_model(settings.clip_image_model, "clip") - embedding = model.encode(image).tolist() + model = await app.state.model_cache.get( + settings.clip_image_model, app.state.clip_vision_type + ) + embedding = model.predict(image) return embedding @@ -96,13 +98,12 @@ async def clip_encode_image( "/sentence-transformer/encode-text", response_model=EmbeddingResponse, status_code=200, - dependencies=[Depends(dep_model_cache)], ) -async def clip_encode_text( - payload: TextModelRequest -) -> list[float]: - model = await _model_cache.get_cached_model(settings.clip_text_model, "clip") - embedding = model.encode(payload.text).tolist() +async def clip_encode_text(payload: TextModelRequest) -> list[float]: + model = await app.state.model_cache.get( + settings.clip_text_model, app.state.clip_text_type + ) + embedding = model.predict(payload.text) return embedding @@ -110,22 +111,21 @@ async def clip_encode_text( "/facial-recognition/detect-faces", response_model=FaceResponse, status_code=200, - dependencies=[Depends(dep_model_cache)], ) async def facial_recognition( - image: bytes = Body(...), + image: cv2.Mat = Depends(dep_cv_image), ) -> list[dict[str, Any]]: - model = await _model_cache.get_cached_model( - settings.facial_recognition_model, "facial-recognition" + model = await app.state.model_cache.get( + settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION ) - faces = run_facial_recognition(model, image) + faces = model.predict(image) return faces if __name__ == "__main__": is_dev = os.getenv("NODE_ENV") == "development" uvicorn.run( - "main:app", + "app.main:app", host=settings.host, port=settings.port, reload=is_dev, diff --git a/machine-learning/app/models.py b/machine-learning/app/models.py deleted file mode 100644 index 04bd3b70b..000000000 --- a/machine-learning/app/models.py +++ /dev/null @@ -1,119 +0,0 @@ -import torch -from insightface.app import FaceAnalysis -from pathlib import Path - -from transformers import pipeline, Pipeline -from sentence_transformers import SentenceTransformer -from typing import Any, BinaryIO -import cv2 as cv -import numpy as np -from PIL import Image -from config import settings - -device = "cuda" if torch.cuda.is_available() else "cpu" - - -def get_model(model_name: str, model_type: str, **model_kwargs): - """ - Instantiates the specified model. - - Args: - model_name: Name of model in the model hub used for the task. - model_type: Model type or task, which determines which model zoo is used. - `facial-recognition` uses Insightface, while all other models use the HF Model Hub. - - Options: - `image-classification`, `clip`,`facial-recognition`, `tokenizer`, `processor` - - Returns: - model: The requested model. - """ - - cache_dir = _get_cache_dir(model_name, model_type) - match model_type: - case "facial-recognition": - model = _load_facial_recognition( - model_name, cache_dir=cache_dir, **model_kwargs - ) - case "clip": - model = SentenceTransformer( - model_name, cache_folder=cache_dir, **model_kwargs - ) - case _: - model = pipeline( - model_type, - model_name, - model_kwargs={"cache_dir": cache_dir, **model_kwargs}, - ) - - return model - - -def run_classification( - model: Pipeline, image: Image, min_score: float | None = None -): - predictions: list[dict[str, Any]] = model(image) # type: ignore - result = { - tag - for pred in predictions - for tag in pred["label"].split(", ") - if min_score is None or pred["score"] >= min_score - } - - return list(result) - - -def run_facial_recognition( - model: FaceAnalysis, image: bytes -) -> list[dict[str, Any]]: - file_bytes = np.frombuffer(image, dtype=np.uint8) - img = cv.imdecode(file_bytes, cv.IMREAD_COLOR) - height, width, _ = img.shape - results = [] - faces = model.get(img) - - for face in faces: - x1, y1, x2, y2 = face.bbox - - results.append( - { - "imageWidth": width, - "imageHeight": height, - "boundingBox": { - "x1": round(x1), - "y1": round(y1), - "x2": round(x2), - "y2": round(y2), - }, - "score": face.det_score.item(), - "embedding": face.normed_embedding.tolist(), - } - ) - return results - - -def _load_facial_recognition( - model_name: str, - min_face_score: float | None = None, - cache_dir: Path | str | None = None, - **model_kwargs, -): - if cache_dir is None: - cache_dir = _get_cache_dir(model_name, "facial-recognition") - if isinstance(cache_dir, Path): - cache_dir = cache_dir.as_posix() - if min_face_score is None: - min_face_score = settings.min_face_score - - model = FaceAnalysis( - name=model_name, - root=cache_dir, - allowed_modules=["detection", "recognition"], - **model_kwargs, - ) - model.prepare(ctx_id=0, det_thresh=min_face_score, det_size=(640, 640)) - return model - - -def _get_cache_dir(model_name: str, model_type: str) -> Path: - return Path(settings.cache_folder, device, model_type, model_name) diff --git a/machine-learning/app/models/__init__.py b/machine-learning/app/models/__init__.py new file mode 100644 index 000000000..b64613505 --- /dev/null +++ b/machine-learning/app/models/__init__.py @@ -0,0 +1,3 @@ +from .clip import CLIPSTTextEncoder, CLIPSTVisionEncoder +from .facial_recognition import FaceRecognizer +from .image_classification import ImageClassifier diff --git a/machine-learning/app/models/base.py b/machine-learning/app/models/base.py new file mode 100644 index 000000000..122f3627e --- /dev/null +++ b/machine-learning/app/models/base.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from abc import abstractmethod, ABC +from pathlib import Path +from typing import Any + +from ..config import get_cache_dir +from ..schemas import ModelType + + +class InferenceModel(ABC): + _model_type: ModelType + + def __init__( + self, + model_name: str, + cache_dir: Path | None = None, + ): + self.model_name = model_name + self._cache_dir = ( + cache_dir + if cache_dir is not None + else get_cache_dir(model_name, self.model_type) + ) + + @abstractmethod + def predict(self, inputs: Any) -> Any: + ... + + @property + def model_type(self) -> ModelType: + return self._model_type + + @property + def cache_dir(self) -> Path: + return self._cache_dir + + @cache_dir.setter + def cache_dir(self, cache_dir: Path): + self._cache_dir = cache_dir + + @classmethod + def from_model_type( + cls, model_type: ModelType, model_name, **model_kwargs + ) -> InferenceModel: + subclasses = { + subclass._model_type: subclass for subclass in cls.__subclasses__() + } + if model_type not in subclasses: + raise ValueError(f"Unsupported model type: {model_type}") + + return subclasses[model_type](model_name, **model_kwargs) diff --git a/machine-learning/app/cache.py b/machine-learning/app/models/cache.py similarity index 80% rename from machine-learning/app/cache.py rename to machine-learning/app/models/cache.py index 8ff5a435d..bcf959fe7 100644 --- a/machine-learning/app/cache.py +++ b/machine-learning/app/models/cache.py @@ -1,8 +1,11 @@ -from aiocache.plugins import TimingPlugin, BasePlugin +import asyncio + from aiocache.backends.memory import SimpleMemoryCache from aiocache.lock import OptimisticLock -from typing import Any -from models import get_model +from aiocache.plugins import BasePlugin, TimingPlugin + +from ..schemas import ModelType +from .base import InferenceModel class ModelCache: @@ -10,7 +13,7 @@ class ModelCache: def __init__( self, - ttl: int | None = None, + ttl: float | None = None, revalidate: bool = False, timeout: int | None = None, profiling: bool = False, @@ -35,9 +38,9 @@ class ModelCache: ttl=ttl, timeout=timeout, plugins=plugins, namespace=None ) - async def get_cached_model( - self, model_name: str, model_type: str, **model_kwargs - ) -> Any: + async def get( + self, model_name: str, model_type: ModelType, **model_kwargs + ) -> InferenceModel: """ Args: model_name: Name of model in the model hub used for the task. @@ -47,11 +50,16 @@ class ModelCache: model: The requested model. """ - key = self.cache.build_key(model_name, model_type) + key = self.cache.build_key(model_name, model_type.value) model = await self.cache.get(key) if model is None: async with OptimisticLock(self.cache, key) as lock: - model = get_model(model_name, model_type, **model_kwargs) + model = await asyncio.get_running_loop().run_in_executor( + None, + lambda: InferenceModel.from_model_type( + model_type, model_name, **model_kwargs + ), + ) await lock.cas(model, ttl=self.ttl) return model diff --git a/machine-learning/app/models/clip.py b/machine-learning/app/models/clip.py new file mode 100644 index 000000000..51731f790 --- /dev/null +++ b/machine-learning/app/models/clip.py @@ -0,0 +1,37 @@ +from pathlib import Path + +from PIL.Image import Image +from sentence_transformers import SentenceTransformer + +from ..schemas import ModelType +from .base import InferenceModel + + +class CLIPSTEncoder(InferenceModel): + _model_type = ModelType.CLIP + + def __init__( + self, + model_name: str, + cache_dir: Path | None = None, + **model_kwargs, + ): + super().__init__(model_name, cache_dir) + self.model = SentenceTransformer( + self.model_name, + cache_folder=self.cache_dir.as_posix(), + **model_kwargs, + ) + + def predict(self, image_or_text: Image | str) -> list[float]: + return self.model.encode(image_or_text).tolist() + + +# stubs to allow different behavior between the two in the future +# and handle loading different image and text clip models +class CLIPSTVisionEncoder(CLIPSTEncoder): + _model_type = ModelType.CLIP_VISION + + +class CLIPSTTextEncoder(CLIPSTEncoder): + _model_type = ModelType.CLIP_TEXT diff --git a/machine-learning/app/models/facial_recognition.py b/machine-learning/app/models/facial_recognition.py new file mode 100644 index 000000000..ff993c172 --- /dev/null +++ b/machine-learning/app/models/facial_recognition.py @@ -0,0 +1,59 @@ +from pathlib import Path +from typing import Any + +import cv2 +from insightface.app import FaceAnalysis + +from ..config import settings +from ..schemas import ModelType +from .base import InferenceModel + + +class FaceRecognizer(InferenceModel): + _model_type = ModelType.FACIAL_RECOGNITION + + def __init__( + self, + model_name: str, + min_score: float = settings.min_face_score, + cache_dir: Path | None = None, + **model_kwargs, + ): + super().__init__(model_name, cache_dir) + self.min_score = min_score + model = FaceAnalysis( + name=self.model_name, + root=self.cache_dir.as_posix(), + allowed_modules=["detection", "recognition"], + **model_kwargs, + ) + model.prepare( + ctx_id=0, + det_thresh=self.min_score, + det_size=(640, 640), + ) + self.model = model + + def predict(self, image: cv2.Mat) -> list[dict[str, Any]]: + height, width, _ = image.shape + results = [] + faces = self.model.get(image) + + for face in faces: + x1, y1, x2, y2 = face.bbox + + results.append( + { + "imageWidth": width, + "imageHeight": height, + "boundingBox": { + "x1": round(x1), + "y1": round(y1), + "x2": round(x2), + "y2": round(y2), + }, + "score": face.det_score.item(), + "embedding": face.normed_embedding.tolist(), + } + ) + return results diff --git a/machine-learning/app/models/image_classification.py b/machine-learning/app/models/image_classification.py new file mode 100644 index 000000000..adb55181d --- /dev/null +++ b/machine-learning/app/models/image_classification.py @@ -0,0 +1,40 @@ +from pathlib import Path + +from PIL.Image import Image +from transformers.pipelines import pipeline + +from ..config import settings +from ..schemas import ModelType +from .base import InferenceModel + + +class ImageClassifier(InferenceModel): + _model_type = ModelType.IMAGE_CLASSIFICATION + + def __init__( + self, + model_name: str, + min_score: float = settings.min_tag_score, + cache_dir: Path | None = None, + **model_kwargs, + ): + super().__init__(model_name, cache_dir) + self.min_score = min_score + + self.model = pipeline( + self.model_type.value, + self.model_name, + model_kwargs={"cache_dir": self.cache_dir, **model_kwargs}, + ) + + def predict(self, image: Image) -> list[str]: + predictions = self.model(image) + tags = list( + { + tag + for pred in predictions + for tag in pred["label"].split(", ") + if pred["score"] >= self.min_score + } + ) + return tags diff --git a/machine-learning/app/schemas.py b/machine-learning/app/schemas.py index ed58e4ea3..db6b7b50b 100644 --- a/machine-learning/app/schemas.py +++ b/machine-learning/app/schemas.py @@ -1,3 +1,5 @@ +from enum import Enum + from pydantic import BaseModel @@ -54,3 +56,11 @@ class Face(BaseModel): class FaceResponse(BaseModel): __root__: list[Face] + + +class ModelType(Enum): + IMAGE_CLASSIFICATION = "image-classification" + CLIP = "clip" + CLIP_VISION = "clip-vision" + CLIP_TEXT = "clip-text" + FACIAL_RECOGNITION = "facial-recognition"