feat(ml): export clip models to ONNX and host models on Hugging Face (#4700)

* export clip models

* export to hf

refactored export code

* export mclip, general refactoring

cleanup

* updated conda deps

* do transforms with pillow and numpy, add tokenization config to export, general refactoring

* moved conda dockerfile, re-added poetry

* minor fixes

* updated link

* updated tests

* removed `requirements.txt` from workflow

* fixed mimalloc path

* removed torchvision

* cleaner np typing

* review suggestions

* update default model name

* update test
This commit is contained in:
Mert 2023-10-31 06:02:04 -04:00 committed by GitHub
parent 3212a47720
commit 87a0ba3db3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
29 changed files with 6192 additions and 2043 deletions

View file

@ -166,7 +166,6 @@ jobs:
- name: Install dependencies
run: |
poetry install --with dev
poetry run pip install --no-deps -r requirements.txt
- name: Lint with ruff
run: |
poetry run ruff check --format=github app

View file

@ -10,9 +10,8 @@ RUN poetry config installer.max-workers 10 && \
RUN python -m venv /opt/venv
ENV VIRTUAL_ENV="/opt/venv" PATH="/opt/venv/bin:${PATH}"
COPY poetry.lock pyproject.toml requirements.txt ./
COPY poetry.lock pyproject.toml ./
RUN poetry install --sync --no-interaction --no-ansi --no-root --only main
RUN pip install --no-deps -r requirements.txt
FROM python:3.11-slim-bookworm

View file

@ -1,5 +1,6 @@
import json
from typing import Any, Iterator, TypeAlias
from pathlib import Path
from typing import Any, Iterator
from unittest import mock
import numpy as np
@ -8,8 +9,7 @@ from fastapi.testclient import TestClient
from PIL import Image
from .main import app, init_state
ndarray: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
from .schemas import ndarray_f32
@pytest.fixture
@ -18,13 +18,13 @@ def pil_image() -> Image.Image:
@pytest.fixture
def cv_image(pil_image: Image.Image) -> ndarray:
def cv_image(pil_image: Image.Image) -> ndarray_f32:
return np.asarray(pil_image)[:, :, ::-1] # PIL uses RGB while cv2 uses BGR
@pytest.fixture
def mock_get_model() -> Iterator[mock.Mock]:
with mock.patch("app.models.cache.InferenceModel.from_model_type", autospec=True) as mocked:
with mock.patch("app.models.cache.from_model_type", autospec=True) as mocked:
yield mocked
@ -37,3 +37,25 @@ def deployed_app() -> TestClient:
@pytest.fixture(scope="session")
def responses() -> dict[str, Any]:
return json.load(open("responses.json", "r"))
@pytest.fixture(scope="session")
def clip_model_cfg() -> dict[str, Any]:
return {
"embed_dim": 512,
"vision_cfg": {"image_size": 224, "layers": 12, "width": 768, "patch_size": 32},
"text_cfg": {"context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12},
}
@pytest.fixture(scope="session")
def clip_preprocess_cfg() -> dict[str, Any]:
return {
"size": [224, 224],
"mode": "RGB",
"mean": [0.48145466, 0.4578275, 0.40821073],
"std": [0.26862954, 0.26130258, 0.27577711],
"interpolation": "bicubic",
"resize_mode": "shortest",
"fill_color": 0,
}

View file

@ -1,3 +1,25 @@
from .clip import CLIPEncoder
from typing import Any
from app.schemas import ModelType
from .base import InferenceModel
from .clip import MCLIPEncoder, OpenCLIPEncoder, is_mclip, is_openclip
from .facial_recognition import FaceRecognizer
from .image_classification import ImageClassifier
def from_model_type(model_type: ModelType, model_name: str, **model_kwargs: Any) -> InferenceModel:
match model_type:
case ModelType.CLIP:
if is_openclip(model_name):
return OpenCLIPEncoder(model_name, **model_kwargs)
elif is_mclip(model_name):
return MCLIPEncoder(model_name, **model_kwargs)
else:
raise ValueError(f"Unknown CLIP model {model_name}")
case ModelType.FACIAL_RECOGNITION:
return FaceRecognizer(model_name, **model_kwargs)
case ModelType.IMAGE_CLASSIFICATION:
return ImageClassifier(model_name, **model_kwargs)
case _:
raise ValueError(f"Unknown model type {model_type}")

View file

@ -25,7 +25,7 @@ class InferenceModel(ABC):
) -> None:
self.model_name = model_name
self.loaded = False
self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type)
self._cache_dir = Path(cache_dir) if cache_dir is not None else None
self.providers = model_kwargs.pop("providers", ["CPUExecutionProvider"])
# don't pre-allocate more memory than needed
self.provider_options = model_kwargs.pop(
@ -92,7 +92,7 @@ class InferenceModel(ABC):
@property
def cache_dir(self) -> Path:
return self._cache_dir
return self._cache_dir if self._cache_dir is not None else get_cache_dir(self.model_name, self.model_type)
@cache_dir.setter
def cache_dir(self, cache_dir: Path) -> None:

View file

@ -4,6 +4,8 @@ from aiocache.backends.memory import SimpleMemoryCache
from aiocache.lock import OptimisticLock
from aiocache.plugins import BasePlugin, TimingPlugin
from app.models import from_model_type
from ..schemas import ModelType
from .base import InferenceModel
@ -50,7 +52,7 @@ class ModelCache:
async with OptimisticLock(self.cache, key) as lock:
model = await self.cache.get(key)
if model is None:
model = InferenceModel.from_model_type(model_type, model_name, **model_kwargs)
model = from_model_type(model_type, model_name, **model_kwargs)
await lock.cas(model, ttl=self.ttl)
return model

View file

@ -1,23 +1,24 @@
import os
import zipfile
import json
from abc import abstractmethod
from functools import cached_property
from io import BytesIO
from pathlib import Path
from typing import Any, Literal
import numpy as np
import onnxruntime as ort
import torch
from clip_server.model.clip import BICUBIC, _convert_image_to_rgb
from clip_server.model.clip_onnx import _MODELS, _S3_BUCKET_V2, CLIPOnnxModel, download_model
from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE
from clip_server.model.tokenization import Tokenizer
from huggingface_hub import snapshot_download
from PIL import Image
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
from transformers import AutoTokenizer
from app.config import log
from app.models.transforms import crop, get_pil_resampling, normalize, resize, to_numpy
from app.schemas import ModelType, ndarray_f32, ndarray_i32, ndarray_i64
from ..config import log
from ..schemas import ModelType
from .base import InferenceModel
class CLIPEncoder(InferenceModel):
class BaseCLIPEncoder(InferenceModel):
_model_type = ModelType.CLIP
def __init__(
@ -27,48 +28,29 @@ class CLIPEncoder(InferenceModel):
mode: Literal["text", "vision"] | None = None,
**model_kwargs: Any,
) -> None:
if mode is not None and mode not in ("text", "vision"):
raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'")
if model_name not in _MODELS:
raise ValueError(f"Unknown model name {model_name}.")
self.mode = mode
super().__init__(model_name, cache_dir, **model_kwargs)
def _download(self) -> None:
models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name]
text_onnx_path = self.cache_dir / "textual.onnx"
vision_onnx_path = self.cache_dir / "visual.onnx"
if not text_onnx_path.is_file():
self._download_model(*models[0])
if not vision_onnx_path.is_file():
self._download_model(*models[1])
def _load(self) -> None:
if self.mode == "text" or self.mode is None:
log.debug(f"Loading clip text model '{self.model_name}'")
self.text_model = ort.InferenceSession(
self.cache_dir / "textual.onnx",
self.textual_path.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
)
self.text_outputs = [output.name for output in self.text_model.get_outputs()]
self.tokenizer = Tokenizer(self.model_name)
if self.mode == "vision" or self.mode is None:
log.debug(f"Loading clip vision model '{self.model_name}'")
self.vision_model = ort.InferenceSession(
self.cache_dir / "visual.onnx",
self.visual_path.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
)
self.vision_outputs = [output.name for output in self.vision_model.get_outputs()]
image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)]
self.transform = _transform_pil_image(image_size)
def _predict(self, image_or_text: Image.Image | str) -> list[float]:
if isinstance(image_or_text, bytes):
@ -78,55 +60,163 @@ class CLIPEncoder(InferenceModel):
case Image.Image():
if self.mode == "text":
raise TypeError("Cannot encode image as text-only model")
pixel_values = self.transform(image_or_text)
assert isinstance(pixel_values, torch.Tensor)
pixel_values = torch.unsqueeze(pixel_values, 0).numpy()
outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values})
outputs = self.vision_model.run(None, self.transform(image_or_text))
case str():
if self.mode == "vision":
raise TypeError("Cannot encode text as vision-only model")
text_inputs: dict[str, torch.Tensor] = self.tokenizer(image_or_text)
inputs = {
"input_ids": text_inputs["input_ids"].int().numpy(),
"attention_mask": text_inputs["attention_mask"].int().numpy(),
}
outputs = self.text_model.run(self.text_outputs, inputs)
outputs = self.text_model.run(None, self.tokenize(image_or_text))
case _:
raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
return outputs[0][0].tolist()
def _download_model(self, model_name: str, model_md5: str) -> bool:
# downloading logic is adapted from clip-server's CLIPOnnxModel class
download_model(
url=_S3_BUCKET_V2 + model_name,
target_folder=self.cache_dir.as_posix(),
md5sum=model_md5,
with_resume=True,
)
file = self.cache_dir / model_name.split("/")[1]
if file.suffix == ".zip":
with zipfile.ZipFile(file, "r") as zip_ref:
zip_ref.extractall(self.cache_dir)
os.remove(file)
return True
@abstractmethod
def tokenize(self, text: str) -> dict[str, ndarray_i32]:
pass
@abstractmethod
def transform(self, image: Image.Image) -> dict[str, ndarray_f32]:
pass
@property
def textual_dir(self) -> Path:
return self.cache_dir / "textual"
@property
def visual_dir(self) -> Path:
return self.cache_dir / "visual"
@property
def model_cfg_path(self) -> Path:
return self.cache_dir / "config.json"
@property
def textual_path(self) -> Path:
return self.textual_dir / "model.onnx"
@property
def visual_path(self) -> Path:
return self.visual_dir / "model.onnx"
@property
def preprocess_cfg_path(self) -> Path:
return self.visual_dir / "preprocess_cfg.json"
@property
def cached(self) -> bool:
return (self.cache_dir / "textual.onnx").is_file() and (self.cache_dir / "visual.onnx").is_file()
return self.textual_path.is_file() and self.visual_path.is_file()
# same as `_transform_blob` without `_blob2image`
def _transform_pil_image(n_px: int) -> Compose:
return Compose(
[
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)
class OpenCLIPEncoder(BaseCLIPEncoder):
def __init__(
self,
model_name: str,
cache_dir: str | None = None,
mode: Literal["text", "vision"] | None = None,
**model_kwargs: Any,
) -> None:
super().__init__(_clean_model_name(model_name), cache_dir, mode, **model_kwargs)
def _download(self) -> None:
snapshot_download(
f"immich-app/{self.model_name}",
cache_dir=self.cache_dir,
local_dir=self.cache_dir,
local_dir_use_symlinks=False,
)
def _load(self) -> None:
super()._load()
self.tokenizer = AutoTokenizer.from_pretrained(self.textual_dir)
self.sequence_length = self.model_cfg["text_cfg"]["context_length"]
self.size = (
self.preprocess_cfg["size"][0] if type(self.preprocess_cfg["size"]) == list else self.preprocess_cfg["size"]
)
self.resampling = get_pil_resampling(self.preprocess_cfg["interpolation"])
self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
def tokenize(self, text: str) -> dict[str, ndarray_i32]:
input_ids: ndarray_i64 = self.tokenizer(
text,
max_length=self.sequence_length,
return_tensors="np",
return_attention_mask=False,
padding="max_length",
truncation=True,
).input_ids
return {"text": input_ids.astype(np.int32)}
def transform(self, image: Image.Image) -> dict[str, ndarray_f32]:
image = resize(image, self.size)
image = crop(image, self.size)
image_np = to_numpy(image)
image_np = normalize(image_np, self.mean, self.std)
return {"image": np.expand_dims(image_np.transpose(2, 0, 1), 0)}
@cached_property
def model_cfg(self) -> dict[str, Any]:
return json.load(self.model_cfg_path.open())
@cached_property
def preprocess_cfg(self) -> dict[str, Any]:
return json.load(self.preprocess_cfg_path.open())
class MCLIPEncoder(OpenCLIPEncoder):
def tokenize(self, text: str) -> dict[str, ndarray_i32]:
tokens: dict[str, ndarray_i64] = self.tokenizer(text, return_tensors="np")
return {k: v.astype(np.int32) for k, v in tokens.items()}
_OPENCLIP_MODELS = {
"RN50__openai",
"RN50__yfcc15m",
"RN50__cc12m",
"RN101__openai",
"RN101__yfcc15m",
"RN50x4__openai",
"RN50x16__openai",
"RN50x64__openai",
"ViT-B-32__openai",
"ViT-B-32__laion2b_e16",
"ViT-B-32__laion400m_e31",
"ViT-B-32__laion400m_e32",
"ViT-B-32__laion2b-s34b-b79k",
"ViT-B-16__openai",
"ViT-B-16__laion400m_e31",
"ViT-B-16__laion400m_e32",
"ViT-B-16-plus-240__laion400m_e31",
"ViT-B-16-plus-240__laion400m_e32",
"ViT-L-14__openai",
"ViT-L-14__laion400m_e31",
"ViT-L-14__laion400m_e32",
"ViT-L-14__laion2b-s32b-b82k",
"ViT-L-14-336__openai",
"ViT-H-14__laion2b-s32b-b79k",
"ViT-g-14__laion2b-s12b-b42k",
}
_MCLIP_MODELS = {
"LABSE-Vit-L-14",
"XLM-Roberta-Large-Vit-B-32",
"XLM-Roberta-Large-Vit-B-16Plus",
"XLM-Roberta-Large-Vit-L-14",
}
def _clean_model_name(model_name: str) -> str:
return model_name.split("/")[-1].replace("::", "__")
def is_openclip(model_name: str) -> bool:
return _clean_model_name(model_name) in _OPENCLIP_MODELS
def is_mclip(model_name: str) -> bool:
return _clean_model_name(model_name) in _MCLIP_MODELS

View file

@ -9,7 +9,8 @@ from insightface.model_zoo import ArcFaceONNX, RetinaFace
from insightface.utils.face_align import norm_crop
from insightface.utils.storage import BASE_REPO_URL, download_file
from ..schemas import ModelType
from app.schemas import ModelType, ndarray_f32
from .base import InferenceModel
@ -68,7 +69,7 @@ class FaceRecognizer(InferenceModel):
)
self.rec_model.prepare(ctx_id=0)
def _predict(self, image: np.ndarray[int, np.dtype[Any]] | bytes) -> list[dict[str, Any]]:
def _predict(self, image: ndarray_f32 | bytes) -> list[dict[str, Any]]:
if isinstance(image, bytes):
image = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR)
bboxes, kpss = self.det_model.detect(image)

View file

@ -0,0 +1,35 @@
import numpy as np
from PIL import Image
from app.schemas import ndarray_f32
_PIL_RESAMPLING_METHODS = {resampling.name.lower(): resampling for resampling in Image.Resampling}
def resize(img: Image.Image, size: int) -> Image.Image:
if img.width < img.height:
return img.resize((size, int((img.height / img.width) * size)), resample=Image.BICUBIC)
else:
return img.resize((int((img.width / img.height) * size), size), resample=Image.BICUBIC)
# https://stackoverflow.com/a/60883103
def crop(img: Image.Image, size: int) -> Image.Image:
left = int((img.size[0] / 2) - (size / 2))
upper = int((img.size[1] / 2) - (size / 2))
right = left + size
lower = upper + size
return img.crop((left, upper, right, lower))
def to_numpy(img: Image.Image) -> ndarray_f32:
return np.asarray(img.convert("RGB")).astype(np.float32) / 255.0
def normalize(img: ndarray_f32, mean: float | ndarray_f32, std: float | ndarray_f32) -> ndarray_f32:
return (img - mean) / std
def get_pil_resampling(resample: str) -> Image.Resampling:
return _PIL_RESAMPLING_METHODS[resample.lower()]

View file

@ -1,5 +1,7 @@
from enum import StrEnum
from typing import TypeAlias
import numpy as np
from pydantic import BaseModel
@ -31,3 +33,8 @@ class ModelType(StrEnum):
IMAGE_CLASSIFICATION = "image-classification"
CLIP = "clip"
FACIAL_RECOGNITION = "facial-recognition"
ndarray_f32: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
ndarray_i64: TypeAlias = np.ndarray[int, np.dtype[np.int64]]
ndarray_i32: TypeAlias = np.ndarray[int, np.dtype[np.int32]]

View file

@ -1,7 +1,8 @@
import json
import pickle
from io import BytesIO
from typing import Any, TypeAlias
from pathlib import Path
from typing import Any, Callable
from unittest import mock
import cv2
@ -14,13 +15,11 @@ from pytest_mock import MockerFixture
from .config import settings
from .models.base import PicklableSessionOptions
from .models.cache import ModelCache
from .models.clip import CLIPEncoder
from .models.clip import OpenCLIPEncoder
from .models.facial_recognition import FaceRecognizer
from .models.image_classification import ImageClassifier
from .schemas import ModelType
ndarray: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
class TestImageClassifier:
classifier_preds = [
@ -56,30 +55,50 @@ class TestImageClassifier:
class TestCLIP:
embedding = np.random.rand(512).astype(np.float32)
cache_dir = Path("test_cache")
def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
mocker.patch.object(CLIPEncoder, "download")
def test_basic_image(
self,
pil_image: Image.Image,
mocker: MockerFixture,
clip_model_cfg: dict[str, Any],
clip_preprocess_cfg: Callable[[Path], dict[str, Any]],
) -> None:
mocker.patch.object(OpenCLIPEncoder, "download")
mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg)
mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg)
mocker.patch("app.models.clip.AutoTokenizer.from_pretrained", autospec=True)
mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
mocked.return_value.run.return_value = [[self.embedding]]
clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="vision")
assert clip_encoder.mode == "vision"
clip_encoder = OpenCLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="vision")
embedding = clip_encoder.predict(pil_image)
assert clip_encoder.mode == "vision"
assert isinstance(embedding, list)
assert len(embedding) == 512
assert len(embedding) == clip_model_cfg["embed_dim"]
assert all([isinstance(num, float) for num in embedding])
clip_encoder.vision_model.run.assert_called_once()
def test_basic_text(self, mocker: MockerFixture) -> None:
mocker.patch.object(CLIPEncoder, "download")
def test_basic_text(
self,
mocker: MockerFixture,
clip_model_cfg: dict[str, Any],
clip_preprocess_cfg: Callable[[Path], dict[str, Any]],
) -> None:
mocker.patch.object(OpenCLIPEncoder, "download")
mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg)
mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg)
mocker.patch("app.models.clip.AutoTokenizer.from_pretrained", autospec=True)
mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
mocked.return_value.run.return_value = [[self.embedding]]
clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="text")
assert clip_encoder.mode == "text"
clip_encoder = OpenCLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="text")
embedding = clip_encoder.predict("test search query")
assert clip_encoder.mode == "text"
assert isinstance(embedding, list)
assert len(embedding) == 512
assert len(embedding) == clip_model_cfg["embed_dim"]
assert all([isinstance(num, float) for num in embedding])
clip_encoder.text_model.run.assert_called_once()

View file

@ -0,0 +1,21 @@
FROM mambaorg/micromamba:bookworm-slim as builder
ENV NODE_ENV=production \
TRANSFORMERS_CACHE=/cache \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PATH="/opt/venv/bin:$PATH" \
PYTHONPATH=/usr/src
COPY --chown=$MAMBA_USER:$MAMBA_USER conda-lock.yml /tmp/conda-lock.yml
RUN micromamba install -y -n base -f /tmp/conda-lock.yml && \
micromamba remove -y -n base cxx-compiler && \
micromamba clean --all --yes
WORKDIR /usr/src/app
COPY --chown=$MAMBA_USER:$MAMBA_USER start.sh .
COPY --chown=$MAMBA_USER:$MAMBA_USER app .
ENTRYPOINT ["/usr/local/bin/_entrypoint.sh"]
CMD ["./start.sh"]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,15 @@
name: base
channels:
- conda-forge
platforms:
- linux-64
- linux-aarch64
dependencies:
- black
- conda-lock
- mypy
- pytest
- pytest-cov
- pytest-mock
- ruff
category: dev

View file

@ -0,0 +1,25 @@
name: base
channels:
- conda-forge
- nvidia
- pytorch-nightly
platforms:
- linux-64
dependencies:
- cxx-compiler
- onnx==1.*
- onnxruntime==1.*
- open-clip-torch==2.*
- orjson==3.*
- pip
- python==3.11.*
- pytorch
- rich==13.*
- safetensors==0.*
- setuptools==68.*
- torchvision
- transformers==4.*
- pip:
- multilingual-clip
- onnx-simplifier
category: main

View file

@ -0,0 +1,67 @@
import tempfile
import warnings
from pathlib import Path
import torch
from multilingual_clip.pt_multilingual_clip import MultilingualCLIP
from transformers import AutoTokenizer
from .openclip import OpenCLIPModelConfig
from .openclip import to_onnx as openclip_to_onnx
from .optimize import optimize
from .util import get_model_path
_MCLIP_TO_OPENCLIP = {
"M-CLIP/XLM-Roberta-Large-Vit-B-32": OpenCLIPModelConfig("ViT-B-32", "openai"),
"M-CLIP/XLM-Roberta-Large-Vit-B-16Plus": OpenCLIPModelConfig("ViT-B-16-plus-240", "laion400m_e32"),
"M-CLIP/LABSE-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
"M-CLIP/XLM-Roberta-Large-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
}
def to_onnx(
model_name: str,
output_dir_visual: Path | str,
output_dir_textual: Path | str,
) -> None:
textual_path = get_model_path(output_dir_textual)
with tempfile.TemporaryDirectory() as tmpdir:
model = MultilingualCLIP.from_pretrained(model_name, cache_dir=tmpdir)
AutoTokenizer.from_pretrained(model_name).save_pretrained(output_dir_textual)
for param in model.parameters():
param.requires_grad_(False)
export_text_encoder(model, textual_path)
openclip_to_onnx(_MCLIP_TO_OPENCLIP[model_name], output_dir_visual)
optimize(textual_path)
def export_text_encoder(model: MultilingualCLIP, output_path: Path | str) -> None:
output_path = Path(output_path)
def forward(self: MultilingualCLIP, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
embs = self.transformer(input_ids, attention_mask)[0]
embs = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None]
embs = self.LinearTransformation(embs)
return torch.nn.functional.normalize(embs, dim=-1)
# unfortunately need to monkeypatch for tracing to work here
# otherwise it hits the 2GiB protobuf serialization limit
MultilingualCLIP.forward = forward
args = (torch.ones(1, 77, dtype=torch.int32), torch.ones(1, 77, dtype=torch.int32))
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
torch.onnx.export(
model,
args,
output_path.as_posix(),
input_names=["input_ids", "attention_mask"],
output_names=["text_embedding"],
opset_version=17,
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
},
)

View file

@ -0,0 +1,109 @@
import tempfile
import warnings
from dataclasses import dataclass, field
from pathlib import Path
import open_clip
import torch
from transformers import AutoTokenizer
from .optimize import optimize
from .util import get_model_path, save_config
@dataclass
class OpenCLIPModelConfig:
name: str
pretrained: str
image_size: int = field(init=False)
sequence_length: int = field(init=False)
def __post_init__(self) -> None:
open_clip_cfg = open_clip.get_model_config(self.name)
if open_clip_cfg is None:
raise ValueError(f"Unknown model {self.name}")
self.image_size = open_clip_cfg["vision_cfg"]["image_size"]
self.sequence_length = open_clip_cfg["text_cfg"]["context_length"]
def to_onnx(
model_cfg: OpenCLIPModelConfig,
output_dir_visual: Path | str | None = None,
output_dir_textual: Path | str | None = None,
) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
model = open_clip.create_model(
model_cfg.name,
pretrained=model_cfg.pretrained,
jit=False,
cache_dir=tmpdir,
require_pretrained=True,
)
text_vision_cfg = open_clip.get_model_config(model_cfg.name)
for param in model.parameters():
param.requires_grad_(False)
if output_dir_visual is not None:
output_dir_visual = Path(output_dir_visual)
visual_path = get_model_path(output_dir_visual)
save_config(open_clip.get_model_preprocess_cfg(model), output_dir_visual / "preprocess_cfg.json")
save_config(text_vision_cfg, output_dir_visual.parent / "config.json")
export_image_encoder(model, model_cfg, visual_path)
optimize(visual_path)
if output_dir_textual is not None:
output_dir_textual = Path(output_dir_textual)
textual_path = get_model_path(output_dir_textual)
tokenizer_name = text_vision_cfg["text_cfg"].get("hf_tokenizer_name", "openai/clip-vit-base-patch32")
AutoTokenizer.from_pretrained(tokenizer_name).save_pretrained(output_dir_textual)
export_text_encoder(model, model_cfg, textual_path)
optimize(textual_path)
def export_image_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, output_path: Path | str) -> None:
output_path = Path(output_path)
def encode_image(image: torch.Tensor) -> torch.Tensor:
return model.encode_image(image, normalize=True)
args = (torch.randn(1, 3, model_cfg.image_size, model_cfg.image_size),)
traced = torch.jit.trace(encode_image, args)
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
torch.onnx.export(
traced,
args,
output_path.as_posix(),
input_names=["image"],
output_names=["image_embedding"],
opset_version=17,
dynamic_axes={"image": {0: "batch_size"}},
)
def export_text_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, output_path: Path | str) -> None:
output_path = Path(output_path)
def encode_text(text: torch.Tensor) -> torch.Tensor:
return model.encode_text(text, normalize=True)
args = (torch.ones(1, model_cfg.sequence_length, dtype=torch.int32),)
traced = torch.jit.trace(encode_text, args)
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
torch.onnx.export(
traced,
args,
output_path.as_posix(),
input_names=["text"],
output_names=["text_embedding"],
opset_version=17,
dynamic_axes={"text": {0: "batch_size"}},
)

View file

@ -0,0 +1,38 @@
from pathlib import Path
import onnx
import onnxruntime as ort
import onnxsim
def optimize_onnxsim(model_path: Path | str, output_path: Path | str) -> None:
model_path = Path(model_path)
output_path = Path(output_path)
model = onnx.load(model_path.as_posix())
model, check = onnxsim.simplify(model, skip_shape_inference=True)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model, output_path.as_posix())
def optimize_ort(
model_path: Path | str,
output_path: Path | str,
level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
) -> None:
model_path = Path(model_path)
output_path = Path(output_path)
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = level
sess_options.optimized_model_filepath = output_path.as_posix()
ort.InferenceSession(model_path.as_posix(), providers=["CPUExecutionProvider"], sess_options=sess_options)
def optimize(model_path: Path | str) -> None:
model_path = Path(model_path)
optimize_ort(model_path, model_path)
# onnxsim serializes large models as a blob, which uses much more memory when loading the model at runtime
if not any(file.name.startswith("Constant") for file in model_path.parent.iterdir()):
optimize_onnxsim(model_path, model_path)

View file

@ -0,0 +1,15 @@
import json
from pathlib import Path
from typing import Any
def get_model_path(output_dir: Path | str) -> Path:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
return output_dir / "model.onnx"
def save_config(config: Any, output_path: Path | str) -> None:
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
json.dump(config, output_path.open("w"))

View file

@ -0,0 +1,76 @@
import gc
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from huggingface_hub import create_repo, login, upload_folder
from models import mclip, openclip
from rich.progress import Progress
models = [
"RN50::openai",
"RN50::yfcc15m",
"RN50::cc12m",
"RN101::openai",
"RN101::yfcc15m",
"RN50x4::openai",
"RN50x16::openai",
"RN50x64::openai",
"ViT-B-32::openai",
"ViT-B-32::laion2b_e16",
"ViT-B-32::laion400m_e31",
"ViT-B-32::laion400m_e32",
"ViT-B-32::laion2b-s34b-b79k",
"ViT-B-16::openai",
"ViT-B-16::laion400m_e31",
"ViT-B-16::laion400m_e32",
"ViT-B-16-plus-240::laion400m_e31",
"ViT-B-16-plus-240::laion400m_e32",
"ViT-L-14::openai",
"ViT-L-14::laion400m_e31",
"ViT-L-14::laion400m_e32",
"ViT-L-14::laion2b-s32b-b82k",
"ViT-L-14-336::openai",
"ViT-H-14::laion2b-s32b-b79k",
"ViT-g-14::laion2b-s12b-b42k",
"M-CLIP/LABSE-Vit-L-14",
"M-CLIP/XLM-Roberta-Large-Vit-B-32",
"M-CLIP/XLM-Roberta-Large-Vit-B-16Plus",
"M-CLIP/XLM-Roberta-Large-Vit-L-14",
]
login(token=os.environ["HF_AUTH_TOKEN"])
with Progress() as progress:
task1 = progress.add_task("[green]Exporting models...", total=len(models))
task2 = progress.add_task("[yellow]Uploading models...", total=len(models))
with TemporaryDirectory() as tmp:
tmpdir = Path(tmp)
for model in models:
model_name = model.split("/")[-1].replace("::", "__")
config_path = tmpdir / model_name / "config.json"
def upload() -> None:
progress.update(task2, description=f"[yellow]Uploading {model_name}")
repo_id = f"immich-app/{model_name}"
create_repo(repo_id, exist_ok=True)
upload_folder(repo_id=repo_id, folder_path=tmpdir / model_name)
progress.update(task2, advance=1)
def export() -> None:
progress.update(task1, description=f"[green]Exporting {model_name}")
visual_dir = tmpdir / model_name / "visual"
textual_dir = tmpdir / model_name / "textual"
if model.startswith("M-CLIP"):
mclip.to_onnx(model, visual_dir, textual_dir)
else:
name, _, pretrained = model_name.partition("__")
openclip.to_onnx(openclip.OpenCLIPModelConfig(name, pretrained), visual_dir, textual_dir)
progress.update(task1, advance=1)
gc.collect()
export()
upload()

View file

@ -1,11 +1,12 @@
from io import BytesIO
import json
from argparse import ArgumentParser
from io import BytesIO
from typing import Any
from locust import HttpUser, events, task
from locust.env import Environment
from PIL import Image
from argparse import ArgumentParser
byte_image = BytesIO()
@ -14,11 +15,21 @@ def _(parser: ArgumentParser) -> None:
parser.add_argument("--tag-model", type=str, default="microsoft/resnet-50")
parser.add_argument("--clip-model", type=str, default="ViT-B-32::openai")
parser.add_argument("--face-model", type=str, default="buffalo_l")
parser.add_argument("--tag-min-score", type=int, default=0.0,
help="Returns all tags at or above this score. The default returns all tags.")
parser.add_argument("--face-min-score", type=int, default=0.034,
help=("Returns all faces at or above this score. The default returns 1 face per request; "
"setting this to 0 blows up the number of faces to the thousands."))
parser.add_argument(
"--tag-min-score",
type=int,
default=0.0,
help="Returns all tags at or above this score. The default returns all tags.",
)
parser.add_argument(
"--face-min-score",
type=int,
default=0.034,
help=(
"Returns all faces at or above this score. The default returns 1 face per request; "
"setting this to 0 blows up the number of faces to the thousands."
),
)
parser.add_argument("--image-size", type=int, default=1000)
@ -62,7 +73,7 @@ class CLIPTextFormDataLoadTest(InferenceLoadTest):
("modelName", self.environment.parsed_options.clip_model),
("modelType", "clip"),
("options", json.dumps({"mode": "text"})),
("text", "test search query")
("text", "test search query"),
]
self.client.post("/predict", data=data)
@ -88,5 +99,5 @@ class RecognitionFormDataLoadTest(InferenceLoadTest):
("options", json.dumps({"minScore": self.environment.parsed_options.face_min_score})),
]
files = {"image": self.data}
self.client.post("/predict", data=data, files=files)

File diff suppressed because it is too large Load diff

View file

@ -9,8 +9,8 @@ packages = [{include = "app"}]
[tool.poetry.dependencies]
python = "^3.11"
torch = [
{markers = "platform_machine == 'arm64' or platform_machine == 'aarch64'", version = "=2.0.1", source = "pypi"},
{markers = "platform_machine == 'amd64' or platform_machine == 'x86_64'", version = "=2.0.1", source = "pytorch-cpu"}
{markers = "platform_machine == 'arm64' or platform_machine == 'aarch64'", version = "=2.1.0", source = "pypi"},
{markers = "platform_machine == 'amd64' or platform_machine == 'x86_64'", version = "=2.1.0", source = "pytorch-cpu"}
]
transformers = "^4.29.2"
onnxruntime = "^1.15.0"
@ -22,14 +22,9 @@ uvicorn = {extras = ["standard"], version = "^0.22.0"}
pydantic = "^1.10.8"
aiocache = "^0.12.1"
optimum = "^1.9.1"
torchvision = [
{markers = "platform_machine == 'arm64' or platform_machine == 'aarch64'", version = "=0.15.2", source = "pypi"},
{markers = "platform_machine == 'amd64' or platform_machine == 'x86_64'", version = "=0.15.2", source = "pytorch-cpu"}
]
rich = "^13.4.2"
ftfy = "^6.1.1"
setuptools = "^68.0.0"
open-clip-torch = "^2.20.0"
python-multipart = "^0.0.6"
orjson = "^3.9.5"
safetensors = "0.3.2"
@ -63,6 +58,7 @@ warn_redundant_casts = true
disallow_any_generics = true
check_untyped_defs = true
disallow_untyped_defs = true
ignore_missing_imports = true
[tool.pydantic-mypy]
init_forbid_extra = true
@ -70,30 +66,6 @@ init_typed = true
warn_required_dynamic_aliases = true
warn_untyped_fields = true
[[tool.mypy.overrides]]
module = [
"huggingface_hub",
"transformers",
"gunicorn",
"cv2",
"insightface.model_zoo",
"insightface.utils.face_align",
"insightface.utils.storage",
"onnxruntime",
"optimum",
"optimum.pipelines",
"optimum.onnxruntime",
"clip_server.model.clip",
"clip_server.model.clip_onnx",
"clip_server.model.pretrained_models",
"clip_server.model.tokenization",
"torchvision.transforms",
"aiocache.backends.memory",
"aiocache.lock",
"aiocache.plugins"
]
ignore_missing_imports = true
[tool.ruff]
line-length = 120
target-version = "py311"

View file

@ -1,2 +0,0 @@
# requirements to be installed with `--no-deps` flag
clip-server==0.8.*

View file

@ -193,7 +193,7 @@ describe(SmartInfoService.name, () => {
expect(machineMock.encodeImage).toHaveBeenCalledWith(
'http://immich-machine-learning:3003',
{ imagePath: 'path/to/resize.ext' },
{ enabled: true, modelName: 'ViT-B-32::openai' },
{ enabled: true, modelName: 'ViT-B-32__openai' },
);
expect(smartMock.upsert).toHaveBeenCalledWith({
assetId: 'asset-1',

View file

@ -67,7 +67,7 @@ export const defaults = Object.freeze<SystemConfig>({
},
clip: {
enabled: true,
modelName: 'ViT-B-32::openai',
modelName: 'ViT-B-32__openai',
},
facialRecognition: {
enabled: true,

View file

@ -68,7 +68,7 @@ const updatedConfig = Object.freeze<SystemConfig>({
},
clip: {
enabled: true,
modelName: 'ViT-B-32::openai',
modelName: 'ViT-B-32__openai',
},
facialRecognition: {
enabled: true,

View file

@ -140,9 +140,8 @@
isEdited={machineLearningConfig.clip.modelName !== savedConfig.clip.modelName}
>
<p slot="desc" class="immich-form-label pb-2 text-sm">
The name of a CLIP model listed <a
href="https://clip-as-service.jina.ai/user-guides/benchmark/#size-and-efficiency"><u>here</u></a
>. Note that you must re-run the 'Encode CLIP' job for all images upon changing a model.
The name of a CLIP model listed <a href="https://huggingface.co/immich-app"><u>here</u></a>. Note that
you must re-run the 'Encode CLIP' job for all images upon changing a model.
</p>
</SettingInputField>
</div>