import os import zipfile from io import BytesIO from typing import Any, Literal import onnxruntime as ort import torch from clip_server.model.clip import BICUBIC, _convert_image_to_rgb from clip_server.model.clip_onnx import _MODELS, _S3_BUCKET_V2, CLIPOnnxModel, download_model from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE from clip_server.model.tokenization import Tokenizer from PIL import Image from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor from ..config import log from ..schemas import ModelType from .base import InferenceModel class CLIPEncoder(InferenceModel): _model_type = ModelType.CLIP def __init__( self, model_name: str, cache_dir: str | None = None, mode: Literal["text", "vision"] | None = None, **model_kwargs: Any, ) -> None: if mode is not None and mode not in ("text", "vision"): raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'") if model_name not in _MODELS: raise ValueError(f"Unknown model name {model_name}.") self.mode = mode super().__init__(model_name, cache_dir, **model_kwargs) def _download(self) -> None: models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name] text_onnx_path = self.cache_dir / "textual.onnx" vision_onnx_path = self.cache_dir / "visual.onnx" if not text_onnx_path.is_file(): self._download_model(*models[0]) if not vision_onnx_path.is_file(): self._download_model(*models[1]) def _load(self) -> None: if self.mode == "text" or self.mode is None: log.debug(f"Loading clip text model '{self.model_name}'") self.text_model = ort.InferenceSession( self.cache_dir / "textual.onnx", sess_options=self.sess_options, providers=self.providers, provider_options=self.provider_options, ) self.text_outputs = [output.name for output in self.text_model.get_outputs()] self.tokenizer = Tokenizer(self.model_name) if self.mode == "vision" or self.mode is None: log.debug(f"Loading clip vision model '{self.model_name}'") self.vision_model = ort.InferenceSession( self.cache_dir / "visual.onnx", sess_options=self.sess_options, providers=self.providers, provider_options=self.provider_options, ) self.vision_outputs = [output.name for output in self.vision_model.get_outputs()] image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)] self.transform = _transform_pil_image(image_size) def _predict(self, image_or_text: Image.Image | str) -> list[float]: if isinstance(image_or_text, bytes): image_or_text = Image.open(BytesIO(image_or_text)) match image_or_text: case Image.Image(): if self.mode == "text": raise TypeError("Cannot encode image as text-only model") pixel_values = self.transform(image_or_text) assert isinstance(pixel_values, torch.Tensor) pixel_values = torch.unsqueeze(pixel_values, 0).numpy() outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values}) case str(): if self.mode == "vision": raise TypeError("Cannot encode text as vision-only model") text_inputs: dict[str, torch.Tensor] = self.tokenizer(image_or_text) inputs = { "input_ids": text_inputs["input_ids"].int().numpy(), "attention_mask": text_inputs["attention_mask"].int().numpy(), } outputs = self.text_model.run(self.text_outputs, inputs) case _: raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}") return outputs[0][0].tolist() def _download_model(self, model_name: str, model_md5: str) -> bool: # downloading logic is adapted from clip-server's CLIPOnnxModel class download_model( url=_S3_BUCKET_V2 + model_name, target_folder=self.cache_dir.as_posix(), md5sum=model_md5, with_resume=True, ) file = self.cache_dir / model_name.split("/")[1] if file.suffix == ".zip": with zipfile.ZipFile(file, "r") as zip_ref: zip_ref.extractall(self.cache_dir) os.remove(file) return True @property def cached(self) -> bool: return (self.cache_dir / "textual.onnx").is_file() and (self.cache_dir / "visual.onnx").is_file() # same as `_transform_blob` without `_blob2image` def _transform_pil_image(n_px: int) -> Compose: return Compose( [ Resize(n_px, interpolation=BICUBIC), CenterCrop(n_px), _convert_image_to_rgb, ToTensor(), Normalize( (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711), ), ] )