Forráskód Böngészése

feat(ml)!: switch image classification and CLIP models to ONNX (#3809)

Mert 1 éve
szülő
commit
165b91b068

+ 1 - 0
.github/workflows/test.yml

@@ -171,6 +171,7 @@ jobs:
       - name: Install dependencies
         run: |
           poetry install --with dev
+          poetry run pip install --no-deps -r requirements.txt
       - name: Lint with ruff
         run: |
           poetry run ruff check --format=github app

+ 2 - 1
machine-learning/Dockerfile

@@ -10,8 +10,9 @@ RUN poetry config installer.max-workers 10 && \
 RUN python -m venv /opt/venv
 ENV VIRTUAL_ENV="/opt/venv" PATH="/opt/venv/bin:${PATH}"
 
-COPY poetry.lock pyproject.toml ./
+COPY poetry.lock pyproject.toml requirements.txt ./
 RUN poetry install --sync --no-interaction --no-ansi --no-root --only main
+RUN pip install --no-deps -r requirements.txt
 
 FROM python:3.11.4-slim-bullseye@sha256:91d194f58f50594cda71dcd2e8fdefd90e7ecc57d07823813b67c8521e565dcd
 

+ 10 - 3
machine-learning/app/config.py

@@ -1,3 +1,4 @@
+import os
 from pathlib import Path
 
 from pydantic import BaseSettings
@@ -8,8 +9,8 @@ from .schemas import ModelType
 class Settings(BaseSettings):
     cache_folder: str = "/cache"
     classification_model: str = "microsoft/resnet-50"
-    clip_image_model: str = "clip-ViT-B-32"
-    clip_text_model: str = "clip-ViT-B-32"
+    clip_image_model: str = "ViT-B-32::openai"
+    clip_text_model: str = "ViT-B-32::openai"
     facial_recognition_model: str = "buffalo_l"
     min_tag_score: float = 0.9
     eager_startup: bool = False
@@ -19,14 +20,20 @@ class Settings(BaseSettings):
     workers: int = 1
     min_face_score: float = 0.7
     test_full: bool = False
+    request_threads: int = os.cpu_count() or 4
+    model_inter_op_threads: int = 1
+    model_intra_op_threads: int = 2
 
     class Config:
         env_prefix = "MACHINE_LEARNING_"
         case_sensitive = False
 
 
+_clean_name = str.maketrans(":\\/", "___", ".")
+
+
 def get_cache_dir(model_name: str, model_type: ModelType) -> Path:
-    return Path(settings.cache_folder, model_type.value, model_name)
+    return Path(settings.cache_folder) / model_type.value / model_name.translate(_clean_name)
 
 
 settings = Settings()

+ 29 - 14
machine-learning/app/main.py

@@ -1,4 +1,6 @@
+import asyncio
 import os
+from concurrent.futures import ThreadPoolExecutor
 from io import BytesIO
 from typing import Any
 
@@ -8,6 +10,8 @@ import uvicorn
 from fastapi import Body, Depends, FastAPI
 from PIL import Image
 
+from app.models.base import InferenceModel
+
 from .config import settings
 from .models.cache import ModelCache
 from .schemas import (
@@ -25,19 +29,21 @@ app = FastAPI()
 
 def init_state() -> None:
     app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
+    # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
+    app.state.thread_pool = ThreadPoolExecutor(settings.request_threads)
 
 
 async def load_models() -> None:
-    models = [
-        (settings.classification_model, ModelType.IMAGE_CLASSIFICATION),
-        (settings.clip_image_model, ModelType.CLIP),
-        (settings.clip_text_model, ModelType.CLIP),
-        (settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION),
+    models: list[tuple[str, ModelType, dict[str, Any]]] = [
+        (settings.classification_model, ModelType.IMAGE_CLASSIFICATION, {}),
+        (settings.clip_image_model, ModelType.CLIP, {"mode": "vision"}),
+        (settings.clip_text_model, ModelType.CLIP, {"mode": "text"}),
+        (settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION, {}),
     ]
 
     # Get all models
-    for model_name, model_type in models:
-        await app.state.model_cache.get(model_name, model_type, eager=settings.eager_startup)
+    for model_name, model_type, model_kwargs in models:
+        await app.state.model_cache.get(model_name, model_type, eager=settings.eager_startup, **model_kwargs)
 
 
 @app.on_event("startup")
@@ -46,11 +52,16 @@ async def startup_event() -> None:
     await load_models()
 
 
+@app.on_event("shutdown")
+async def shutdown_event() -> None:
+    app.state.thread_pool.shutdown()
+
+
 def dep_pil_image(byte_image: bytes = Body(...)) -> Image.Image:
     return Image.open(BytesIO(byte_image))
 
 
-def dep_cv_image(byte_image: bytes = Body(...)) -> cv2.Mat:
+def dep_cv_image(byte_image: bytes = Body(...)) -> np.ndarray[int, np.dtype[Any]]:
     byte_image_np = np.frombuffer(byte_image, np.uint8)
     return cv2.imdecode(byte_image_np, cv2.IMREAD_COLOR)
 
@@ -74,7 +85,7 @@ async def image_classification(
     image: Image.Image = Depends(dep_pil_image),
 ) -> list[str]:
     model = await app.state.model_cache.get(settings.classification_model, ModelType.IMAGE_CLASSIFICATION)
-    labels = model.predict(image)
+    labels = await predict(model, image)
     return labels
 
 
@@ -86,8 +97,8 @@ async def image_classification(
 async def clip_encode_image(
     image: Image.Image = Depends(dep_pil_image),
 ) -> list[float]:
-    model = await app.state.model_cache.get(settings.clip_image_model, ModelType.CLIP)
-    embedding = model.predict(image)
+    model = await app.state.model_cache.get(settings.clip_image_model, ModelType.CLIP, mode="vision")
+    embedding = await predict(model, image)
     return embedding
 
 
@@ -97,8 +108,8 @@ async def clip_encode_image(
     status_code=200,
 )
 async def clip_encode_text(payload: TextModelRequest) -> list[float]:
-    model = await app.state.model_cache.get(settings.clip_text_model, ModelType.CLIP)
-    embedding = model.predict(payload.text)
+    model = await app.state.model_cache.get(settings.clip_text_model, ModelType.CLIP, mode="text")
+    embedding = await predict(model, payload.text)
     return embedding
 
 
@@ -111,10 +122,14 @@ async def facial_recognition(
     image: cv2.Mat = Depends(dep_cv_image),
 ) -> list[dict[str, Any]]:
     model = await app.state.model_cache.get(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION)
-    faces = model.predict(image)
+    faces = await predict(model, image)
     return faces
 
 
+async def predict(model: InferenceModel, inputs: Any) -> Any:
+    return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
+
+
 if __name__ == "__main__":
     is_dev = os.getenv("NODE_ENV") == "development"
     uvicorn.run(

+ 1 - 1
machine-learning/app/models/__init__.py

@@ -1,3 +1,3 @@
-from .clip import CLIPSTEncoder
+from .clip import CLIPEncoder
 from .facial_recognition import FaceRecognizer
 from .image_classification import ImageClassifier

+ 37 - 2
machine-learning/app/models/base.py

@@ -1,14 +1,17 @@
 from __future__ import annotations
 
+import os
+import pickle
 from abc import ABC, abstractmethod
 from pathlib import Path
 from shutil import rmtree
 from typing import Any
 from zipfile import BadZipFile
 
+import onnxruntime as ort
 from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf  # type: ignore
 
-from ..config import get_cache_dir
+from ..config import get_cache_dir, settings
 from ..schemas import ModelType
 
 
@@ -16,12 +19,31 @@ class InferenceModel(ABC):
     _model_type: ModelType
 
     def __init__(
-        self, model_name: str, cache_dir: Path | str | None = None, eager: bool = True, **model_kwargs: Any
+        self,
+        model_name: str,
+        cache_dir: Path | str | None = None,
+        eager: bool = True,
+        inter_op_num_threads: int = settings.model_inter_op_threads,
+        intra_op_num_threads: int = settings.model_intra_op_threads,
+        **model_kwargs: Any,
     ) -> None:
         self.model_name = model_name
         self._loaded = False
         self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type)
         loader = self.load if eager else self.download
+
+        self.providers = model_kwargs.pop("providers", ["CPUExecutionProvider"])
+        #  don't pre-allocate more memory than needed
+        self.provider_options = model_kwargs.pop(
+            "provider_options", [{"arena_extend_strategy": "kSameAsRequested"}] * len(self.providers)
+        )
+        self.sess_options = PicklableSessionOptions()
+        # avoid thread contention between models
+        if inter_op_num_threads > 1:
+            self.sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
+        self.sess_options.inter_op_num_threads = inter_op_num_threads
+        self.sess_options.intra_op_num_threads = intra_op_num_threads
+
         try:
             loader(**model_kwargs)
         except (OSError, InvalidProtobuf, BadZipFile):
@@ -30,6 +52,7 @@ class InferenceModel(ABC):
 
     def download(self, **model_kwargs: Any) -> None:
         if not self.cached:
+            print(f"Downloading {self.model_type.value.replace('_', ' ')} model. This may take a while...")
             self._download(**model_kwargs)
 
     def load(self, **model_kwargs: Any) -> None:
@@ -39,6 +62,7 @@ class InferenceModel(ABC):
 
     def predict(self, inputs: Any) -> Any:
         if not self._loaded:
+            print(f"Loading {self.model_type.value.replace('_', ' ')} model...")
             self.load()
         return self._predict(inputs)
 
@@ -89,3 +113,14 @@ class InferenceModel(ABC):
         else:
             self.cache_dir.unlink()
         self.cache_dir.mkdir(parents=True, exist_ok=True)
+
+
+# HF deep copies configs, so we need to make session options picklable
+class PicklableSessionOptions(ort.SessionOptions):
+    def __getstate__(self) -> bytes:
+        return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
+
+    def __setstate__(self, state: Any) -> None:
+        self.__init__()  # type: ignore
+        for attr, val in pickle.loads(state):
+            setattr(self, attr, val)

+ 1 - 1
machine-learning/app/models/cache.py

@@ -46,7 +46,7 @@ class ModelCache:
             model: The requested model.
         """
 
-        key = self.cache.build_key(model_name, model_type.value)
+        key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
         async with OptimisticLock(self.cache, key) as lock:
             model = await self.cache.get(key)
             if model is None:

+ 127 - 17
machine-learning/app/models/clip.py

@@ -1,31 +1,141 @@
-from typing import Any
+import os
+import zipfile
+from typing import Any, Literal
 
+import onnxruntime as ort
+import torch
+from clip_server.model.clip import BICUBIC, _convert_image_to_rgb
+from clip_server.model.clip_onnx import _MODELS, _S3_BUCKET_V2, CLIPOnnxModel, download_model
+from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE
+from clip_server.model.tokenization import Tokenizer
 from PIL.Image import Image
-from sentence_transformers import SentenceTransformer
-from sentence_transformers.util import snapshot_download
+from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
 
 from ..schemas import ModelType
 from .base import InferenceModel
 
+_ST_TO_JINA_MODEL_NAME = {
+    "clip-ViT-B-16": "ViT-B-16::openai",
+    "clip-ViT-B-32": "ViT-B-32::openai",
+    "clip-ViT-B-32-multilingual-v1": "M-CLIP/XLM-Roberta-Large-Vit-B-32",
+    "clip-ViT-L-14": "ViT-L-14::openai",
+}
 
-class CLIPSTEncoder(InferenceModel):
+
+class CLIPEncoder(InferenceModel):
     _model_type = ModelType.CLIP
 
+    def __init__(
+        self,
+        model_name: str,
+        cache_dir: str | None = None,
+        mode: Literal["text", "vision"] | None = None,
+        **model_kwargs: Any,
+    ) -> None:
+        if mode is not None and mode not in ("text", "vision"):
+            raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'")
+        if "vit-b" not in model_name.lower():
+            raise ValueError(f"Only ViT-B models are currently supported; got '{model_name}'")
+        self.mode = mode
+        jina_model_name = self._get_jina_model_name(model_name)
+        super().__init__(jina_model_name, cache_dir, **model_kwargs)
+
     def _download(self, **model_kwargs: Any) -> None:
-        repo_id = self.model_name if "/" in self.model_name else f"sentence-transformers/{self.model_name}"
-        snapshot_download(
-            cache_dir=self.cache_dir,
-            repo_id=repo_id,
-            library_name="sentence-transformers",
-            ignore_files=["flax_model.msgpack", "rust_model.ot", "tf_model.h5"],
-        )
+        models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name]
+        text_onnx_path = self.cache_dir / "textual.onnx"
+        vision_onnx_path = self.cache_dir / "visual.onnx"
+
+        if not text_onnx_path.is_file():
+            self._download_model(*models[0])
+
+        if not vision_onnx_path.is_file():
+            self._download_model(*models[1])
 
     def _load(self, **model_kwargs: Any) -> None:
-        self.model = SentenceTransformer(
-            self.model_name,
-            cache_folder=self.cache_dir.as_posix(),
-            **model_kwargs,
-        )
+        if self.mode == "text" or self.mode is None:
+            self.text_model = ort.InferenceSession(
+                self.cache_dir / "textual.onnx",
+                sess_options=self.sess_options,
+                providers=self.providers,
+                provider_options=self.provider_options,
+            )
+            self.text_outputs = [output.name for output in self.text_model.get_outputs()]
+            self.tokenizer = Tokenizer(self.model_name)
+
+        if self.mode == "vision" or self.mode is None:
+            self.vision_model = ort.InferenceSession(
+                self.cache_dir / "visual.onnx",
+                sess_options=self.sess_options,
+                providers=self.providers,
+                provider_options=self.provider_options,
+            )
+            self.vision_outputs = [output.name for output in self.vision_model.get_outputs()]
+
+            image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)]
+            self.transform = _transform_pil_image(image_size)
 
     def _predict(self, image_or_text: Image | str) -> list[float]:
-        return self.model.encode(image_or_text).tolist()
+        match image_or_text:
+            case Image():
+                if self.mode == "text":
+                    raise TypeError("Cannot encode image as text-only model")
+                pixel_values = self.transform(image_or_text)
+                assert isinstance(pixel_values, torch.Tensor)
+                pixel_values = torch.unsqueeze(pixel_values, 0).numpy()
+                outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values})
+            case str():
+                if self.mode == "vision":
+                    raise TypeError("Cannot encode text as vision-only model")
+                text_inputs: dict[str, torch.Tensor] = self.tokenizer(image_or_text)
+                inputs = {
+                    "input_ids": text_inputs["input_ids"].int().numpy(),
+                    "attention_mask": text_inputs["attention_mask"].int().numpy(),
+                }
+                outputs = self.text_model.run(self.text_outputs, inputs)
+            case _:
+                raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
+
+        return outputs[0][0].tolist()
+
+    def _get_jina_model_name(self, model_name: str) -> str:
+        if model_name in _MODELS:
+            return model_name
+        elif model_name in _ST_TO_JINA_MODEL_NAME:
+            print(
+                (f"Warning: Sentence-Transformer model names such as '{model_name}' are no longer supported."),
+                (f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."),
+            )
+            return _ST_TO_JINA_MODEL_NAME[model_name]
+        else:
+            raise ValueError(f"Unknown model name {model_name}.")
+
+    def _download_model(self, model_name: str, model_md5: str) -> bool:
+        # downloading logic is adapted from clip-server's CLIPOnnxModel class
+        download_model(
+            url=_S3_BUCKET_V2 + model_name,
+            target_folder=self.cache_dir.as_posix(),
+            md5sum=model_md5,
+            with_resume=True,
+        )
+        file = self.cache_dir / model_name.split("/")[1]
+        if file.suffix == ".zip":
+            with zipfile.ZipFile(file, "r") as zip_ref:
+                zip_ref.extractall(self.cache_dir)
+            os.remove(file)
+        return True
+
+
+# same as `_transform_blob` without `_blob2image`
+def _transform_pil_image(n_px: int) -> Compose:
+    return Compose(
+        [
+            Resize(n_px, interpolation=BICUBIC),
+            CenterCrop(n_px),
+            _convert_image_to_rgb,
+            ToTensor(),
+            Normalize(
+                (0.48145466, 0.4578275, 0.40821073),
+                (0.26862954, 0.26130258, 0.27577711),
+            ),
+        ]
+    )

+ 21 - 4
machine-learning/app/models/facial_recognition.py

@@ -4,6 +4,7 @@ from typing import Any
 
 import cv2
 import numpy as np
+import onnxruntime as ort
 from insightface.model_zoo import ArcFaceONNX, RetinaFace
 from insightface.utils.face_align import norm_crop
 from insightface.utils.storage import BASE_REPO_URL, download_file
@@ -42,15 +43,31 @@ class FaceRecognizer(InferenceModel):
             rec_file = next(self.cache_dir.glob("w600k_*.onnx"))
         except StopIteration:
             raise FileNotFoundError("Facial recognition models not found in cache directory")
-        self.det_model = RetinaFace(det_file.as_posix())
-        self.rec_model = ArcFaceONNX(rec_file.as_posix())
+
+        self.det_model = RetinaFace(
+            session=ort.InferenceSession(
+                det_file.as_posix(),
+                sess_options=self.sess_options,
+                providers=self.providers,
+                provider_options=self.provider_options,
+            ),
+        )
+        self.rec_model = ArcFaceONNX(
+            rec_file.as_posix(),
+            session=ort.InferenceSession(
+                rec_file.as_posix(),
+                sess_options=self.sess_options,
+                providers=self.providers,
+                provider_options=self.provider_options,
+            ),
+        )
 
         self.det_model.prepare(
-            ctx_id=-1,
+            ctx_id=0,
             det_thresh=self.min_score,
             input_size=(640, 640),
         )
-        self.rec_model.prepare(ctx_id=-1)
+        self.rec_model.prepare(ctx_id=0)
 
     def _predict(self, image: cv2.Mat) -> list[dict[str, Any]]:
         bboxes, kpss = self.det_model.detect(image)

+ 28 - 7
machine-learning/app/models/image_classification.py

@@ -2,8 +2,10 @@ from pathlib import Path
 from typing import Any
 
 from huggingface_hub import snapshot_download
+from optimum.onnxruntime import ORTModelForImageClassification
+from optimum.pipelines import pipeline
 from PIL.Image import Image
-from transformers.pipelines import pipeline
+from transformers import AutoImageProcessor
 
 from ..config import settings
 from ..schemas import ModelType
@@ -25,15 +27,34 @@ class ImageClassifier(InferenceModel):
 
     def _download(self, **model_kwargs: Any) -> None:
         snapshot_download(
-            cache_dir=self.cache_dir, repo_id=self.model_name, allow_patterns=["*.bin", "*.json", "*.txt"]
+            cache_dir=self.cache_dir,
+            repo_id=self.model_name,
+            allow_patterns=["*.bin", "*.json", "*.txt"],
+            local_dir=self.cache_dir,
+            local_dir_use_symlinks=True,
         )
 
     def _load(self, **model_kwargs: Any) -> None:
-        self.model = pipeline(
-            self.model_type.value,
-            self.model_name,
-            model_kwargs={"cache_dir": self.cache_dir, **model_kwargs},
-        )
+        processor = AutoImageProcessor.from_pretrained(self.cache_dir)
+        model_kwargs |= {
+            "cache_dir": self.cache_dir,
+            "provider": self.providers[0],
+            "provider_options": self.provider_options[0],
+            "session_options": self.sess_options,
+        }
+        model_path = self.cache_dir / "model.onnx"
+
+        if model_path.exists():
+            model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs)
+            self.model = pipeline(self.model_type.value, model, feature_extractor=processor)
+        else:
+            self.sess_options.optimized_model_filepath = model_path.as_posix()
+            self.model = pipeline(
+                self.model_type.value,
+                self.model_name,
+                model_kwargs=model_kwargs,
+                feature_extractor=processor,
+            )
 
     def _predict(self, image: Image) -> list[str]:
         predictions: list[dict[str, Any]] = self.model(image)  # type: ignore

+ 34 - 19
machine-learning/app/test_main.py

@@ -1,17 +1,20 @@
+import pickle
 from io import BytesIO
 from typing import TypeAlias
 from unittest import mock
 
 import cv2
 import numpy as np
+import onnxruntime as ort
 import pytest
 from fastapi.testclient import TestClient
 from PIL import Image
 from pytest_mock import MockerFixture
 
 from .config import settings
+from .models.base import PicklableSessionOptions
 from .models.cache import ModelCache
-from .models.clip import CLIPSTEncoder
+from .models.clip import CLIPEncoder
 from .models.facial_recognition import FaceRecognizer
 from .models.image_classification import ImageClassifier
 from .schemas import ModelType
@@ -72,45 +75,47 @@ class TestCLIP:
     embedding = np.random.rand(512).astype(np.float32)
 
     def test_eager_init(self, mocker: MockerFixture) -> None:
-        mocker.patch.object(CLIPSTEncoder, "download")
-        mock_load = mocker.patch.object(CLIPSTEncoder, "load")
-        clip_model = CLIPSTEncoder("test_model_name", cache_dir="test_cache", eager=True, test_arg="test_arg")
+        mocker.patch.object(CLIPEncoder, "download")
+        mock_load = mocker.patch.object(CLIPEncoder, "load")
+        clip_model = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", eager=True, test_arg="test_arg")
 
-        assert clip_model.model_name == "test_model_name"
+        assert clip_model.model_name == "ViT-B-32::openai"
         mock_load.assert_called_once_with(test_arg="test_arg")
 
     def test_lazy_init(self, mocker: MockerFixture) -> None:
-        mock_download = mocker.patch.object(CLIPSTEncoder, "download")
-        mock_load = mocker.patch.object(CLIPSTEncoder, "load")
-        clip_model = CLIPSTEncoder("test_model_name", cache_dir="test_cache", eager=False, test_arg="test_arg")
+        mock_download = mocker.patch.object(CLIPEncoder, "download")
+        mock_load = mocker.patch.object(CLIPEncoder, "load")
+        clip_model = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", eager=False, test_arg="test_arg")
 
-        assert clip_model.model_name == "test_model_name"
+        assert clip_model.model_name == "ViT-B-32::openai"
         mock_download.assert_called_once_with(test_arg="test_arg")
         mock_load.assert_not_called()
 
     def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
-        mocker.patch.object(CLIPSTEncoder, "load")
-        clip_encoder = CLIPSTEncoder("test_model_name", cache_dir="test_cache")
-        clip_encoder.model = mock.Mock()
-        clip_encoder.model.encode.return_value = self.embedding
+        mocker.patch.object(CLIPEncoder, "download")
+        mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
+        mocked.return_value.run.return_value = [[self.embedding]]
+        clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="vision")
+        assert clip_encoder.mode == "vision"
         embedding = clip_encoder.predict(pil_image)
 
         assert isinstance(embedding, list)
         assert len(embedding) == 512
         assert all([isinstance(num, float) for num in embedding])
-        clip_encoder.model.encode.assert_called_once()
+        clip_encoder.vision_model.run.assert_called_once()
 
     def test_basic_text(self, mocker: MockerFixture) -> None:
-        mocker.patch.object(CLIPSTEncoder, "load")
-        clip_encoder = CLIPSTEncoder("test_model_name", cache_dir="test_cache")
-        clip_encoder.model = mock.Mock()
-        clip_encoder.model.encode.return_value = self.embedding
+        mocker.patch.object(CLIPEncoder, "download")
+        mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
+        mocked.return_value.run.return_value = [[self.embedding]]
+        clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="text")
+        assert clip_encoder.mode == "text"
         embedding = clip_encoder.predict("test search query")
 
         assert isinstance(embedding, list)
         assert len(embedding) == 512
         assert all([isinstance(num, float) for num in embedding])
-        clip_encoder.model.encode.assert_called_once()
+        clip_encoder.text_model.run.assert_called_once()
 
 
 class TestFaceRecognition:
@@ -254,3 +259,13 @@ class TestEndpoints:
             headers=headers,
         )
         assert response.status_code == 200
+
+
+def test_sess_options() -> None:
+    sess_options = PicklableSessionOptions()
+    sess_options.intra_op_num_threads = 1
+    sess_options.inter_op_num_threads = 1
+    pickled = pickle.dumps(sess_options)
+    unpickled = pickle.loads(pickled)
+    assert unpickled.intra_op_num_threads == 1
+    assert unpickled.inter_op_num_threads == 1

A különbségek nem kerülnek megjelenítésre, a fájl túl nagy
+ 710 - 168
machine-learning/poetry.lock


+ 19 - 4
machine-learning/pyproject.toml

@@ -13,7 +13,6 @@ torch = [
     {markers = "platform_machine == 'amd64' or platform_machine == 'x86_64'", version = "=2.0.1", source = "pytorch-cpu"}
 ]
 transformers = "^4.29.2"
-sentence-transformers = "^2.2.2"
 onnxruntime = "^1.15.0"
 insightface = "^0.7.3"
 opencv-python-headless = "^4.7.0.72"
@@ -22,6 +21,15 @@ fastapi = "^0.95.2"
 uvicorn = {extras = ["standard"], version = "^0.22.0"}
 pydantic = "^1.10.8"
 aiocache = "^0.12.1"
+optimum = "^1.9.1"
+torchvision = [
+    {markers = "platform_machine == 'arm64' or platform_machine == 'aarch64'", version = "=0.15.2", source = "pypi"},
+    {markers = "platform_machine == 'amd64' or platform_machine == 'x86_64'", version = "=0.15.2", source = "pytorch-cpu"}
+]
+rich = "^13.4.2"
+ftfy = "^6.1.1"
+setuptools = "^68.0.0"
+open-clip-torch = "^2.20.0"
 
 [tool.poetry.group.dev.dependencies]
 mypy = "^1.3.0"
@@ -62,13 +70,20 @@ warn_untyped_fields = true
 [[tool.mypy.overrides]]
 module = [
     "huggingface_hub",
-    "transformers.pipelines",
+    "transformers",
     "cv2",
     "insightface.model_zoo",
     "insightface.utils.face_align",
     "insightface.utils.storage",
-    "sentence_transformers",
-    "sentence_transformers.util",
+    "onnxruntime",
+    "optimum",
+    "optimum.pipelines",
+    "optimum.onnxruntime",
+    "clip_server.model.clip",
+    "clip_server.model.clip_onnx",
+    "clip_server.model.pretrained_models",
+    "clip_server.model.tokenization",
+    "torchvision.transforms",
     "aiocache.backends.memory",
     "aiocache.lock",
     "aiocache.plugins"

+ 2 - 0
machine-learning/requirements.txt

@@ -0,0 +1,2 @@
+# requirements to be installed with `--no-deps` flag
+clip-server==0.8.*

Nem az összes módosított fájl került megjelenítésre, mert túl sok fájl változott