Browse Source

feat(ml)!: switch image classification and CLIP models to ONNX (#3809)

Mert 1 year ago
parent
commit
165b91b068

+ 1 - 0
.github/workflows/test.yml

@@ -171,6 +171,7 @@ jobs:
       - name: Install dependencies
       - name: Install dependencies
         run: |
         run: |
           poetry install --with dev
           poetry install --with dev
+          poetry run pip install --no-deps -r requirements.txt
       - name: Lint with ruff
       - name: Lint with ruff
         run: |
         run: |
           poetry run ruff check --format=github app
           poetry run ruff check --format=github app

+ 2 - 1
machine-learning/Dockerfile

@@ -10,8 +10,9 @@ RUN poetry config installer.max-workers 10 && \
 RUN python -m venv /opt/venv
 RUN python -m venv /opt/venv
 ENV VIRTUAL_ENV="/opt/venv" PATH="/opt/venv/bin:${PATH}"
 ENV VIRTUAL_ENV="/opt/venv" PATH="/opt/venv/bin:${PATH}"
 
 
-COPY poetry.lock pyproject.toml ./
+COPY poetry.lock pyproject.toml requirements.txt ./
 RUN poetry install --sync --no-interaction --no-ansi --no-root --only main
 RUN poetry install --sync --no-interaction --no-ansi --no-root --only main
+RUN pip install --no-deps -r requirements.txt
 
 
 FROM python:3.11.4-slim-bullseye@sha256:91d194f58f50594cda71dcd2e8fdefd90e7ecc57d07823813b67c8521e565dcd
 FROM python:3.11.4-slim-bullseye@sha256:91d194f58f50594cda71dcd2e8fdefd90e7ecc57d07823813b67c8521e565dcd
 
 

+ 10 - 3
machine-learning/app/config.py

@@ -1,3 +1,4 @@
+import os
 from pathlib import Path
 from pathlib import Path
 
 
 from pydantic import BaseSettings
 from pydantic import BaseSettings
@@ -8,8 +9,8 @@ from .schemas import ModelType
 class Settings(BaseSettings):
 class Settings(BaseSettings):
     cache_folder: str = "/cache"
     cache_folder: str = "/cache"
     classification_model: str = "microsoft/resnet-50"
     classification_model: str = "microsoft/resnet-50"
-    clip_image_model: str = "clip-ViT-B-32"
-    clip_text_model: str = "clip-ViT-B-32"
+    clip_image_model: str = "ViT-B-32::openai"
+    clip_text_model: str = "ViT-B-32::openai"
     facial_recognition_model: str = "buffalo_l"
     facial_recognition_model: str = "buffalo_l"
     min_tag_score: float = 0.9
     min_tag_score: float = 0.9
     eager_startup: bool = False
     eager_startup: bool = False
@@ -19,14 +20,20 @@ class Settings(BaseSettings):
     workers: int = 1
     workers: int = 1
     min_face_score: float = 0.7
     min_face_score: float = 0.7
     test_full: bool = False
     test_full: bool = False
+    request_threads: int = os.cpu_count() or 4
+    model_inter_op_threads: int = 1
+    model_intra_op_threads: int = 2
 
 
     class Config:
     class Config:
         env_prefix = "MACHINE_LEARNING_"
         env_prefix = "MACHINE_LEARNING_"
         case_sensitive = False
         case_sensitive = False
 
 
 
 
+_clean_name = str.maketrans(":\\/", "___", ".")
+
+
 def get_cache_dir(model_name: str, model_type: ModelType) -> Path:
 def get_cache_dir(model_name: str, model_type: ModelType) -> Path:
-    return Path(settings.cache_folder, model_type.value, model_name)
+    return Path(settings.cache_folder) / model_type.value / model_name.translate(_clean_name)
 
 
 
 
 settings = Settings()
 settings = Settings()

+ 29 - 14
machine-learning/app/main.py

@@ -1,4 +1,6 @@
+import asyncio
 import os
 import os
+from concurrent.futures import ThreadPoolExecutor
 from io import BytesIO
 from io import BytesIO
 from typing import Any
 from typing import Any
 
 
@@ -8,6 +10,8 @@ import uvicorn
 from fastapi import Body, Depends, FastAPI
 from fastapi import Body, Depends, FastAPI
 from PIL import Image
 from PIL import Image
 
 
+from app.models.base import InferenceModel
+
 from .config import settings
 from .config import settings
 from .models.cache import ModelCache
 from .models.cache import ModelCache
 from .schemas import (
 from .schemas import (
@@ -25,19 +29,21 @@ app = FastAPI()
 
 
 def init_state() -> None:
 def init_state() -> None:
     app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
     app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
+    # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
+    app.state.thread_pool = ThreadPoolExecutor(settings.request_threads)
 
 
 
 
 async def load_models() -> None:
 async def load_models() -> None:
-    models = [
-        (settings.classification_model, ModelType.IMAGE_CLASSIFICATION),
-        (settings.clip_image_model, ModelType.CLIP),
-        (settings.clip_text_model, ModelType.CLIP),
-        (settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION),
+    models: list[tuple[str, ModelType, dict[str, Any]]] = [
+        (settings.classification_model, ModelType.IMAGE_CLASSIFICATION, {}),
+        (settings.clip_image_model, ModelType.CLIP, {"mode": "vision"}),
+        (settings.clip_text_model, ModelType.CLIP, {"mode": "text"}),
+        (settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION, {}),
     ]
     ]
 
 
     # Get all models
     # Get all models
-    for model_name, model_type in models:
-        await app.state.model_cache.get(model_name, model_type, eager=settings.eager_startup)
+    for model_name, model_type, model_kwargs in models:
+        await app.state.model_cache.get(model_name, model_type, eager=settings.eager_startup, **model_kwargs)
 
 
 
 
 @app.on_event("startup")
 @app.on_event("startup")
@@ -46,11 +52,16 @@ async def startup_event() -> None:
     await load_models()
     await load_models()
 
 
 
 
+@app.on_event("shutdown")
+async def shutdown_event() -> None:
+    app.state.thread_pool.shutdown()
+
+
 def dep_pil_image(byte_image: bytes = Body(...)) -> Image.Image:
 def dep_pil_image(byte_image: bytes = Body(...)) -> Image.Image:
     return Image.open(BytesIO(byte_image))
     return Image.open(BytesIO(byte_image))
 
 
 
 
-def dep_cv_image(byte_image: bytes = Body(...)) -> cv2.Mat:
+def dep_cv_image(byte_image: bytes = Body(...)) -> np.ndarray[int, np.dtype[Any]]:
     byte_image_np = np.frombuffer(byte_image, np.uint8)
     byte_image_np = np.frombuffer(byte_image, np.uint8)
     return cv2.imdecode(byte_image_np, cv2.IMREAD_COLOR)
     return cv2.imdecode(byte_image_np, cv2.IMREAD_COLOR)
 
 
@@ -74,7 +85,7 @@ async def image_classification(
     image: Image.Image = Depends(dep_pil_image),
     image: Image.Image = Depends(dep_pil_image),
 ) -> list[str]:
 ) -> list[str]:
     model = await app.state.model_cache.get(settings.classification_model, ModelType.IMAGE_CLASSIFICATION)
     model = await app.state.model_cache.get(settings.classification_model, ModelType.IMAGE_CLASSIFICATION)
-    labels = model.predict(image)
+    labels = await predict(model, image)
     return labels
     return labels
 
 
 
 
@@ -86,8 +97,8 @@ async def image_classification(
 async def clip_encode_image(
 async def clip_encode_image(
     image: Image.Image = Depends(dep_pil_image),
     image: Image.Image = Depends(dep_pil_image),
 ) -> list[float]:
 ) -> list[float]:
-    model = await app.state.model_cache.get(settings.clip_image_model, ModelType.CLIP)
-    embedding = model.predict(image)
+    model = await app.state.model_cache.get(settings.clip_image_model, ModelType.CLIP, mode="vision")
+    embedding = await predict(model, image)
     return embedding
     return embedding
 
 
 
 
@@ -97,8 +108,8 @@ async def clip_encode_image(
     status_code=200,
     status_code=200,
 )
 )
 async def clip_encode_text(payload: TextModelRequest) -> list[float]:
 async def clip_encode_text(payload: TextModelRequest) -> list[float]:
-    model = await app.state.model_cache.get(settings.clip_text_model, ModelType.CLIP)
-    embedding = model.predict(payload.text)
+    model = await app.state.model_cache.get(settings.clip_text_model, ModelType.CLIP, mode="text")
+    embedding = await predict(model, payload.text)
     return embedding
     return embedding
 
 
 
 
@@ -111,10 +122,14 @@ async def facial_recognition(
     image: cv2.Mat = Depends(dep_cv_image),
     image: cv2.Mat = Depends(dep_cv_image),
 ) -> list[dict[str, Any]]:
 ) -> list[dict[str, Any]]:
     model = await app.state.model_cache.get(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION)
     model = await app.state.model_cache.get(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION)
-    faces = model.predict(image)
+    faces = await predict(model, image)
     return faces
     return faces
 
 
 
 
+async def predict(model: InferenceModel, inputs: Any) -> Any:
+    return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
+
+
 if __name__ == "__main__":
 if __name__ == "__main__":
     is_dev = os.getenv("NODE_ENV") == "development"
     is_dev = os.getenv("NODE_ENV") == "development"
     uvicorn.run(
     uvicorn.run(

+ 1 - 1
machine-learning/app/models/__init__.py

@@ -1,3 +1,3 @@
-from .clip import CLIPSTEncoder
+from .clip import CLIPEncoder
 from .facial_recognition import FaceRecognizer
 from .facial_recognition import FaceRecognizer
 from .image_classification import ImageClassifier
 from .image_classification import ImageClassifier

+ 37 - 2
machine-learning/app/models/base.py

@@ -1,14 +1,17 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
+import os
+import pickle
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from pathlib import Path
 from pathlib import Path
 from shutil import rmtree
 from shutil import rmtree
 from typing import Any
 from typing import Any
 from zipfile import BadZipFile
 from zipfile import BadZipFile
 
 
+import onnxruntime as ort
 from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf  # type: ignore
 from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf  # type: ignore
 
 
-from ..config import get_cache_dir
+from ..config import get_cache_dir, settings
 from ..schemas import ModelType
 from ..schemas import ModelType
 
 
 
 
@@ -16,12 +19,31 @@ class InferenceModel(ABC):
     _model_type: ModelType
     _model_type: ModelType
 
 
     def __init__(
     def __init__(
-        self, model_name: str, cache_dir: Path | str | None = None, eager: bool = True, **model_kwargs: Any
+        self,
+        model_name: str,
+        cache_dir: Path | str | None = None,
+        eager: bool = True,
+        inter_op_num_threads: int = settings.model_inter_op_threads,
+        intra_op_num_threads: int = settings.model_intra_op_threads,
+        **model_kwargs: Any,
     ) -> None:
     ) -> None:
         self.model_name = model_name
         self.model_name = model_name
         self._loaded = False
         self._loaded = False
         self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type)
         self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type)
         loader = self.load if eager else self.download
         loader = self.load if eager else self.download
+
+        self.providers = model_kwargs.pop("providers", ["CPUExecutionProvider"])
+        #  don't pre-allocate more memory than needed
+        self.provider_options = model_kwargs.pop(
+            "provider_options", [{"arena_extend_strategy": "kSameAsRequested"}] * len(self.providers)
+        )
+        self.sess_options = PicklableSessionOptions()
+        # avoid thread contention between models
+        if inter_op_num_threads > 1:
+            self.sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
+        self.sess_options.inter_op_num_threads = inter_op_num_threads
+        self.sess_options.intra_op_num_threads = intra_op_num_threads
+
         try:
         try:
             loader(**model_kwargs)
             loader(**model_kwargs)
         except (OSError, InvalidProtobuf, BadZipFile):
         except (OSError, InvalidProtobuf, BadZipFile):
@@ -30,6 +52,7 @@ class InferenceModel(ABC):
 
 
     def download(self, **model_kwargs: Any) -> None:
     def download(self, **model_kwargs: Any) -> None:
         if not self.cached:
         if not self.cached:
+            print(f"Downloading {self.model_type.value.replace('_', ' ')} model. This may take a while...")
             self._download(**model_kwargs)
             self._download(**model_kwargs)
 
 
     def load(self, **model_kwargs: Any) -> None:
     def load(self, **model_kwargs: Any) -> None:
@@ -39,6 +62,7 @@ class InferenceModel(ABC):
 
 
     def predict(self, inputs: Any) -> Any:
     def predict(self, inputs: Any) -> Any:
         if not self._loaded:
         if not self._loaded:
+            print(f"Loading {self.model_type.value.replace('_', ' ')} model...")
             self.load()
             self.load()
         return self._predict(inputs)
         return self._predict(inputs)
 
 
@@ -89,3 +113,14 @@ class InferenceModel(ABC):
         else:
         else:
             self.cache_dir.unlink()
             self.cache_dir.unlink()
         self.cache_dir.mkdir(parents=True, exist_ok=True)
         self.cache_dir.mkdir(parents=True, exist_ok=True)
+
+
+# HF deep copies configs, so we need to make session options picklable
+class PicklableSessionOptions(ort.SessionOptions):
+    def __getstate__(self) -> bytes:
+        return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
+
+    def __setstate__(self, state: Any) -> None:
+        self.__init__()  # type: ignore
+        for attr, val in pickle.loads(state):
+            setattr(self, attr, val)

+ 1 - 1
machine-learning/app/models/cache.py

@@ -46,7 +46,7 @@ class ModelCache:
             model: The requested model.
             model: The requested model.
         """
         """
 
 
-        key = self.cache.build_key(model_name, model_type.value)
+        key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
         async with OptimisticLock(self.cache, key) as lock:
         async with OptimisticLock(self.cache, key) as lock:
             model = await self.cache.get(key)
             model = await self.cache.get(key)
             if model is None:
             if model is None:

+ 127 - 17
machine-learning/app/models/clip.py

@@ -1,31 +1,141 @@
-from typing import Any
+import os
+import zipfile
+from typing import Any, Literal
 
 
+import onnxruntime as ort
+import torch
+from clip_server.model.clip import BICUBIC, _convert_image_to_rgb
+from clip_server.model.clip_onnx import _MODELS, _S3_BUCKET_V2, CLIPOnnxModel, download_model
+from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE
+from clip_server.model.tokenization import Tokenizer
 from PIL.Image import Image
 from PIL.Image import Image
-from sentence_transformers import SentenceTransformer
-from sentence_transformers.util import snapshot_download
+from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
 
 
 from ..schemas import ModelType
 from ..schemas import ModelType
 from .base import InferenceModel
 from .base import InferenceModel
 
 
+_ST_TO_JINA_MODEL_NAME = {
+    "clip-ViT-B-16": "ViT-B-16::openai",
+    "clip-ViT-B-32": "ViT-B-32::openai",
+    "clip-ViT-B-32-multilingual-v1": "M-CLIP/XLM-Roberta-Large-Vit-B-32",
+    "clip-ViT-L-14": "ViT-L-14::openai",
+}
 
 
-class CLIPSTEncoder(InferenceModel):
+
+class CLIPEncoder(InferenceModel):
     _model_type = ModelType.CLIP
     _model_type = ModelType.CLIP
 
 
+    def __init__(
+        self,
+        model_name: str,
+        cache_dir: str | None = None,
+        mode: Literal["text", "vision"] | None = None,
+        **model_kwargs: Any,
+    ) -> None:
+        if mode is not None and mode not in ("text", "vision"):
+            raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'")
+        if "vit-b" not in model_name.lower():
+            raise ValueError(f"Only ViT-B models are currently supported; got '{model_name}'")
+        self.mode = mode
+        jina_model_name = self._get_jina_model_name(model_name)
+        super().__init__(jina_model_name, cache_dir, **model_kwargs)
+
     def _download(self, **model_kwargs: Any) -> None:
     def _download(self, **model_kwargs: Any) -> None:
-        repo_id = self.model_name if "/" in self.model_name else f"sentence-transformers/{self.model_name}"
-        snapshot_download(
-            cache_dir=self.cache_dir,
-            repo_id=repo_id,
-            library_name="sentence-transformers",
-            ignore_files=["flax_model.msgpack", "rust_model.ot", "tf_model.h5"],
-        )
+        models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name]
+        text_onnx_path = self.cache_dir / "textual.onnx"
+        vision_onnx_path = self.cache_dir / "visual.onnx"
+
+        if not text_onnx_path.is_file():
+            self._download_model(*models[0])
+
+        if not vision_onnx_path.is_file():
+            self._download_model(*models[1])
 
 
     def _load(self, **model_kwargs: Any) -> None:
     def _load(self, **model_kwargs: Any) -> None:
-        self.model = SentenceTransformer(
-            self.model_name,
-            cache_folder=self.cache_dir.as_posix(),
-            **model_kwargs,
-        )
+        if self.mode == "text" or self.mode is None:
+            self.text_model = ort.InferenceSession(
+                self.cache_dir / "textual.onnx",
+                sess_options=self.sess_options,
+                providers=self.providers,
+                provider_options=self.provider_options,
+            )
+            self.text_outputs = [output.name for output in self.text_model.get_outputs()]
+            self.tokenizer = Tokenizer(self.model_name)
+
+        if self.mode == "vision" or self.mode is None:
+            self.vision_model = ort.InferenceSession(
+                self.cache_dir / "visual.onnx",
+                sess_options=self.sess_options,
+                providers=self.providers,
+                provider_options=self.provider_options,
+            )
+            self.vision_outputs = [output.name for output in self.vision_model.get_outputs()]
+
+            image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)]
+            self.transform = _transform_pil_image(image_size)
 
 
     def _predict(self, image_or_text: Image | str) -> list[float]:
     def _predict(self, image_or_text: Image | str) -> list[float]:
-        return self.model.encode(image_or_text).tolist()
+        match image_or_text:
+            case Image():
+                if self.mode == "text":
+                    raise TypeError("Cannot encode image as text-only model")
+                pixel_values = self.transform(image_or_text)
+                assert isinstance(pixel_values, torch.Tensor)
+                pixel_values = torch.unsqueeze(pixel_values, 0).numpy()
+                outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values})
+            case str():
+                if self.mode == "vision":
+                    raise TypeError("Cannot encode text as vision-only model")
+                text_inputs: dict[str, torch.Tensor] = self.tokenizer(image_or_text)
+                inputs = {
+                    "input_ids": text_inputs["input_ids"].int().numpy(),
+                    "attention_mask": text_inputs["attention_mask"].int().numpy(),
+                }
+                outputs = self.text_model.run(self.text_outputs, inputs)
+            case _:
+                raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
+
+        return outputs[0][0].tolist()
+
+    def _get_jina_model_name(self, model_name: str) -> str:
+        if model_name in _MODELS:
+            return model_name
+        elif model_name in _ST_TO_JINA_MODEL_NAME:
+            print(
+                (f"Warning: Sentence-Transformer model names such as '{model_name}' are no longer supported."),
+                (f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."),
+            )
+            return _ST_TO_JINA_MODEL_NAME[model_name]
+        else:
+            raise ValueError(f"Unknown model name {model_name}.")
+
+    def _download_model(self, model_name: str, model_md5: str) -> bool:
+        # downloading logic is adapted from clip-server's CLIPOnnxModel class
+        download_model(
+            url=_S3_BUCKET_V2 + model_name,
+            target_folder=self.cache_dir.as_posix(),
+            md5sum=model_md5,
+            with_resume=True,
+        )
+        file = self.cache_dir / model_name.split("/")[1]
+        if file.suffix == ".zip":
+            with zipfile.ZipFile(file, "r") as zip_ref:
+                zip_ref.extractall(self.cache_dir)
+            os.remove(file)
+        return True
+
+
+# same as `_transform_blob` without `_blob2image`
+def _transform_pil_image(n_px: int) -> Compose:
+    return Compose(
+        [
+            Resize(n_px, interpolation=BICUBIC),
+            CenterCrop(n_px),
+            _convert_image_to_rgb,
+            ToTensor(),
+            Normalize(
+                (0.48145466, 0.4578275, 0.40821073),
+                (0.26862954, 0.26130258, 0.27577711),
+            ),
+        ]
+    )

+ 21 - 4
machine-learning/app/models/facial_recognition.py

@@ -4,6 +4,7 @@ from typing import Any
 
 
 import cv2
 import cv2
 import numpy as np
 import numpy as np
+import onnxruntime as ort
 from insightface.model_zoo import ArcFaceONNX, RetinaFace
 from insightface.model_zoo import ArcFaceONNX, RetinaFace
 from insightface.utils.face_align import norm_crop
 from insightface.utils.face_align import norm_crop
 from insightface.utils.storage import BASE_REPO_URL, download_file
 from insightface.utils.storage import BASE_REPO_URL, download_file
@@ -42,15 +43,31 @@ class FaceRecognizer(InferenceModel):
             rec_file = next(self.cache_dir.glob("w600k_*.onnx"))
             rec_file = next(self.cache_dir.glob("w600k_*.onnx"))
         except StopIteration:
         except StopIteration:
             raise FileNotFoundError("Facial recognition models not found in cache directory")
             raise FileNotFoundError("Facial recognition models not found in cache directory")
-        self.det_model = RetinaFace(det_file.as_posix())
-        self.rec_model = ArcFaceONNX(rec_file.as_posix())
+
+        self.det_model = RetinaFace(
+            session=ort.InferenceSession(
+                det_file.as_posix(),
+                sess_options=self.sess_options,
+                providers=self.providers,
+                provider_options=self.provider_options,
+            ),
+        )
+        self.rec_model = ArcFaceONNX(
+            rec_file.as_posix(),
+            session=ort.InferenceSession(
+                rec_file.as_posix(),
+                sess_options=self.sess_options,
+                providers=self.providers,
+                provider_options=self.provider_options,
+            ),
+        )
 
 
         self.det_model.prepare(
         self.det_model.prepare(
-            ctx_id=-1,
+            ctx_id=0,
             det_thresh=self.min_score,
             det_thresh=self.min_score,
             input_size=(640, 640),
             input_size=(640, 640),
         )
         )
-        self.rec_model.prepare(ctx_id=-1)
+        self.rec_model.prepare(ctx_id=0)
 
 
     def _predict(self, image: cv2.Mat) -> list[dict[str, Any]]:
     def _predict(self, image: cv2.Mat) -> list[dict[str, Any]]:
         bboxes, kpss = self.det_model.detect(image)
         bboxes, kpss = self.det_model.detect(image)

+ 28 - 7
machine-learning/app/models/image_classification.py

@@ -2,8 +2,10 @@ from pathlib import Path
 from typing import Any
 from typing import Any
 
 
 from huggingface_hub import snapshot_download
 from huggingface_hub import snapshot_download
+from optimum.onnxruntime import ORTModelForImageClassification
+from optimum.pipelines import pipeline
 from PIL.Image import Image
 from PIL.Image import Image
-from transformers.pipelines import pipeline
+from transformers import AutoImageProcessor
 
 
 from ..config import settings
 from ..config import settings
 from ..schemas import ModelType
 from ..schemas import ModelType
@@ -25,15 +27,34 @@ class ImageClassifier(InferenceModel):
 
 
     def _download(self, **model_kwargs: Any) -> None:
     def _download(self, **model_kwargs: Any) -> None:
         snapshot_download(
         snapshot_download(
-            cache_dir=self.cache_dir, repo_id=self.model_name, allow_patterns=["*.bin", "*.json", "*.txt"]
+            cache_dir=self.cache_dir,
+            repo_id=self.model_name,
+            allow_patterns=["*.bin", "*.json", "*.txt"],
+            local_dir=self.cache_dir,
+            local_dir_use_symlinks=True,
         )
         )
 
 
     def _load(self, **model_kwargs: Any) -> None:
     def _load(self, **model_kwargs: Any) -> None:
-        self.model = pipeline(
-            self.model_type.value,
-            self.model_name,
-            model_kwargs={"cache_dir": self.cache_dir, **model_kwargs},
-        )
+        processor = AutoImageProcessor.from_pretrained(self.cache_dir)
+        model_kwargs |= {
+            "cache_dir": self.cache_dir,
+            "provider": self.providers[0],
+            "provider_options": self.provider_options[0],
+            "session_options": self.sess_options,
+        }
+        model_path = self.cache_dir / "model.onnx"
+
+        if model_path.exists():
+            model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs)
+            self.model = pipeline(self.model_type.value, model, feature_extractor=processor)
+        else:
+            self.sess_options.optimized_model_filepath = model_path.as_posix()
+            self.model = pipeline(
+                self.model_type.value,
+                self.model_name,
+                model_kwargs=model_kwargs,
+                feature_extractor=processor,
+            )
 
 
     def _predict(self, image: Image) -> list[str]:
     def _predict(self, image: Image) -> list[str]:
         predictions: list[dict[str, Any]] = self.model(image)  # type: ignore
         predictions: list[dict[str, Any]] = self.model(image)  # type: ignore

+ 34 - 19
machine-learning/app/test_main.py

@@ -1,17 +1,20 @@
+import pickle
 from io import BytesIO
 from io import BytesIO
 from typing import TypeAlias
 from typing import TypeAlias
 from unittest import mock
 from unittest import mock
 
 
 import cv2
 import cv2
 import numpy as np
 import numpy as np
+import onnxruntime as ort
 import pytest
 import pytest
 from fastapi.testclient import TestClient
 from fastapi.testclient import TestClient
 from PIL import Image
 from PIL import Image
 from pytest_mock import MockerFixture
 from pytest_mock import MockerFixture
 
 
 from .config import settings
 from .config import settings
+from .models.base import PicklableSessionOptions
 from .models.cache import ModelCache
 from .models.cache import ModelCache
-from .models.clip import CLIPSTEncoder
+from .models.clip import CLIPEncoder
 from .models.facial_recognition import FaceRecognizer
 from .models.facial_recognition import FaceRecognizer
 from .models.image_classification import ImageClassifier
 from .models.image_classification import ImageClassifier
 from .schemas import ModelType
 from .schemas import ModelType
@@ -72,45 +75,47 @@ class TestCLIP:
     embedding = np.random.rand(512).astype(np.float32)
     embedding = np.random.rand(512).astype(np.float32)
 
 
     def test_eager_init(self, mocker: MockerFixture) -> None:
     def test_eager_init(self, mocker: MockerFixture) -> None:
-        mocker.patch.object(CLIPSTEncoder, "download")
-        mock_load = mocker.patch.object(CLIPSTEncoder, "load")
-        clip_model = CLIPSTEncoder("test_model_name", cache_dir="test_cache", eager=True, test_arg="test_arg")
+        mocker.patch.object(CLIPEncoder, "download")
+        mock_load = mocker.patch.object(CLIPEncoder, "load")
+        clip_model = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", eager=True, test_arg="test_arg")
 
 
-        assert clip_model.model_name == "test_model_name"
+        assert clip_model.model_name == "ViT-B-32::openai"
         mock_load.assert_called_once_with(test_arg="test_arg")
         mock_load.assert_called_once_with(test_arg="test_arg")
 
 
     def test_lazy_init(self, mocker: MockerFixture) -> None:
     def test_lazy_init(self, mocker: MockerFixture) -> None:
-        mock_download = mocker.patch.object(CLIPSTEncoder, "download")
-        mock_load = mocker.patch.object(CLIPSTEncoder, "load")
-        clip_model = CLIPSTEncoder("test_model_name", cache_dir="test_cache", eager=False, test_arg="test_arg")
+        mock_download = mocker.patch.object(CLIPEncoder, "download")
+        mock_load = mocker.patch.object(CLIPEncoder, "load")
+        clip_model = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", eager=False, test_arg="test_arg")
 
 
-        assert clip_model.model_name == "test_model_name"
+        assert clip_model.model_name == "ViT-B-32::openai"
         mock_download.assert_called_once_with(test_arg="test_arg")
         mock_download.assert_called_once_with(test_arg="test_arg")
         mock_load.assert_not_called()
         mock_load.assert_not_called()
 
 
     def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
     def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
-        mocker.patch.object(CLIPSTEncoder, "load")
-        clip_encoder = CLIPSTEncoder("test_model_name", cache_dir="test_cache")
-        clip_encoder.model = mock.Mock()
-        clip_encoder.model.encode.return_value = self.embedding
+        mocker.patch.object(CLIPEncoder, "download")
+        mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
+        mocked.return_value.run.return_value = [[self.embedding]]
+        clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="vision")
+        assert clip_encoder.mode == "vision"
         embedding = clip_encoder.predict(pil_image)
         embedding = clip_encoder.predict(pil_image)
 
 
         assert isinstance(embedding, list)
         assert isinstance(embedding, list)
         assert len(embedding) == 512
         assert len(embedding) == 512
         assert all([isinstance(num, float) for num in embedding])
         assert all([isinstance(num, float) for num in embedding])
-        clip_encoder.model.encode.assert_called_once()
+        clip_encoder.vision_model.run.assert_called_once()
 
 
     def test_basic_text(self, mocker: MockerFixture) -> None:
     def test_basic_text(self, mocker: MockerFixture) -> None:
-        mocker.patch.object(CLIPSTEncoder, "load")
-        clip_encoder = CLIPSTEncoder("test_model_name", cache_dir="test_cache")
-        clip_encoder.model = mock.Mock()
-        clip_encoder.model.encode.return_value = self.embedding
+        mocker.patch.object(CLIPEncoder, "download")
+        mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
+        mocked.return_value.run.return_value = [[self.embedding]]
+        clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="text")
+        assert clip_encoder.mode == "text"
         embedding = clip_encoder.predict("test search query")
         embedding = clip_encoder.predict("test search query")
 
 
         assert isinstance(embedding, list)
         assert isinstance(embedding, list)
         assert len(embedding) == 512
         assert len(embedding) == 512
         assert all([isinstance(num, float) for num in embedding])
         assert all([isinstance(num, float) for num in embedding])
-        clip_encoder.model.encode.assert_called_once()
+        clip_encoder.text_model.run.assert_called_once()
 
 
 
 
 class TestFaceRecognition:
 class TestFaceRecognition:
@@ -254,3 +259,13 @@ class TestEndpoints:
             headers=headers,
             headers=headers,
         )
         )
         assert response.status_code == 200
         assert response.status_code == 200
+
+
+def test_sess_options() -> None:
+    sess_options = PicklableSessionOptions()
+    sess_options.intra_op_num_threads = 1
+    sess_options.inter_op_num_threads = 1
+    pickled = pickle.dumps(sess_options)
+    unpickled = pickle.loads(pickled)
+    assert unpickled.intra_op_num_threads == 1
+    assert unpickled.inter_op_num_threads == 1

File diff suppressed because it is too large
+ 710 - 168
machine-learning/poetry.lock


+ 19 - 4
machine-learning/pyproject.toml

@@ -13,7 +13,6 @@ torch = [
     {markers = "platform_machine == 'amd64' or platform_machine == 'x86_64'", version = "=2.0.1", source = "pytorch-cpu"}
     {markers = "platform_machine == 'amd64' or platform_machine == 'x86_64'", version = "=2.0.1", source = "pytorch-cpu"}
 ]
 ]
 transformers = "^4.29.2"
 transformers = "^4.29.2"
-sentence-transformers = "^2.2.2"
 onnxruntime = "^1.15.0"
 onnxruntime = "^1.15.0"
 insightface = "^0.7.3"
 insightface = "^0.7.3"
 opencv-python-headless = "^4.7.0.72"
 opencv-python-headless = "^4.7.0.72"
@@ -22,6 +21,15 @@ fastapi = "^0.95.2"
 uvicorn = {extras = ["standard"], version = "^0.22.0"}
 uvicorn = {extras = ["standard"], version = "^0.22.0"}
 pydantic = "^1.10.8"
 pydantic = "^1.10.8"
 aiocache = "^0.12.1"
 aiocache = "^0.12.1"
+optimum = "^1.9.1"
+torchvision = [
+    {markers = "platform_machine == 'arm64' or platform_machine == 'aarch64'", version = "=0.15.2", source = "pypi"},
+    {markers = "platform_machine == 'amd64' or platform_machine == 'x86_64'", version = "=0.15.2", source = "pytorch-cpu"}
+]
+rich = "^13.4.2"
+ftfy = "^6.1.1"
+setuptools = "^68.0.0"
+open-clip-torch = "^2.20.0"
 
 
 [tool.poetry.group.dev.dependencies]
 [tool.poetry.group.dev.dependencies]
 mypy = "^1.3.0"
 mypy = "^1.3.0"
@@ -62,13 +70,20 @@ warn_untyped_fields = true
 [[tool.mypy.overrides]]
 [[tool.mypy.overrides]]
 module = [
 module = [
     "huggingface_hub",
     "huggingface_hub",
-    "transformers.pipelines",
+    "transformers",
     "cv2",
     "cv2",
     "insightface.model_zoo",
     "insightface.model_zoo",
     "insightface.utils.face_align",
     "insightface.utils.face_align",
     "insightface.utils.storage",
     "insightface.utils.storage",
-    "sentence_transformers",
-    "sentence_transformers.util",
+    "onnxruntime",
+    "optimum",
+    "optimum.pipelines",
+    "optimum.onnxruntime",
+    "clip_server.model.clip",
+    "clip_server.model.clip_onnx",
+    "clip_server.model.pretrained_models",
+    "clip_server.model.tokenization",
+    "torchvision.transforms",
     "aiocache.backends.memory",
     "aiocache.backends.memory",
     "aiocache.lock",
     "aiocache.lock",
     "aiocache.plugins"
     "aiocache.plugins"

+ 2 - 0
machine-learning/requirements.txt

@@ -0,0 +1,2 @@
+# requirements to be installed with `--no-deps` flag
+clip-server==0.8.*

Some files were not shown because too many files changed in this diff