123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222 |
- import json
- from abc import abstractmethod
- from functools import cached_property
- from io import BytesIO
- from pathlib import Path
- from typing import Any, Literal
- import numpy as np
- import onnxruntime as ort
- from huggingface_hub import snapshot_download
- from PIL import Image
- from transformers import AutoTokenizer
- from app.config import log
- from app.models.transforms import crop, get_pil_resampling, normalize, resize, to_numpy
- from app.schemas import ModelType, ndarray_f32, ndarray_i32, ndarray_i64
- from .base import InferenceModel
- class BaseCLIPEncoder(InferenceModel):
- _model_type = ModelType.CLIP
- def __init__(
- self,
- model_name: str,
- cache_dir: str | None = None,
- mode: Literal["text", "vision"] | None = None,
- **model_kwargs: Any,
- ) -> None:
- self.mode = mode
- super().__init__(model_name, cache_dir, **model_kwargs)
- def _load(self) -> None:
- if self.mode == "text" or self.mode is None:
- log.debug(f"Loading clip text model '{self.model_name}'")
- self.text_model = ort.InferenceSession(
- self.textual_path.as_posix(),
- sess_options=self.sess_options,
- providers=self.providers,
- provider_options=self.provider_options,
- )
- if self.mode == "vision" or self.mode is None:
- log.debug(f"Loading clip vision model '{self.model_name}'")
- self.vision_model = ort.InferenceSession(
- self.visual_path.as_posix(),
- sess_options=self.sess_options,
- providers=self.providers,
- provider_options=self.provider_options,
- )
- def _predict(self, image_or_text: Image.Image | str) -> list[float]:
- if isinstance(image_or_text, bytes):
- image_or_text = Image.open(BytesIO(image_or_text))
- match image_or_text:
- case Image.Image():
- if self.mode == "text":
- raise TypeError("Cannot encode image as text-only model")
- outputs = self.vision_model.run(None, self.transform(image_or_text))
- case str():
- if self.mode == "vision":
- raise TypeError("Cannot encode text as vision-only model")
- outputs = self.text_model.run(None, self.tokenize(image_or_text))
- case _:
- raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
- return outputs[0][0].tolist()
- @abstractmethod
- def tokenize(self, text: str) -> dict[str, ndarray_i32]:
- pass
- @abstractmethod
- def transform(self, image: Image.Image) -> dict[str, ndarray_f32]:
- pass
- @property
- def textual_dir(self) -> Path:
- return self.cache_dir / "textual"
- @property
- def visual_dir(self) -> Path:
- return self.cache_dir / "visual"
- @property
- def model_cfg_path(self) -> Path:
- return self.cache_dir / "config.json"
- @property
- def textual_path(self) -> Path:
- return self.textual_dir / "model.onnx"
- @property
- def visual_path(self) -> Path:
- return self.visual_dir / "model.onnx"
- @property
- def preprocess_cfg_path(self) -> Path:
- return self.visual_dir / "preprocess_cfg.json"
- @property
- def cached(self) -> bool:
- return self.textual_path.is_file() and self.visual_path.is_file()
- class OpenCLIPEncoder(BaseCLIPEncoder):
- def __init__(
- self,
- model_name: str,
- cache_dir: str | None = None,
- mode: Literal["text", "vision"] | None = None,
- **model_kwargs: Any,
- ) -> None:
- super().__init__(_clean_model_name(model_name), cache_dir, mode, **model_kwargs)
- def _download(self) -> None:
- snapshot_download(
- f"immich-app/{self.model_name}",
- cache_dir=self.cache_dir,
- local_dir=self.cache_dir,
- local_dir_use_symlinks=False,
- )
- def _load(self) -> None:
- super()._load()
- self.tokenizer = AutoTokenizer.from_pretrained(self.textual_dir)
- self.sequence_length = self.model_cfg["text_cfg"]["context_length"]
- self.size = (
- self.preprocess_cfg["size"][0] if type(self.preprocess_cfg["size"]) == list else self.preprocess_cfg["size"]
- )
- self.resampling = get_pil_resampling(self.preprocess_cfg["interpolation"])
- self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
- self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
- def tokenize(self, text: str) -> dict[str, ndarray_i32]:
- input_ids: ndarray_i64 = self.tokenizer(
- text,
- max_length=self.sequence_length,
- return_tensors="np",
- return_attention_mask=False,
- padding="max_length",
- truncation=True,
- ).input_ids
- return {"text": input_ids.astype(np.int32)}
- def transform(self, image: Image.Image) -> dict[str, ndarray_f32]:
- image = resize(image, self.size)
- image = crop(image, self.size)
- image_np = to_numpy(image)
- image_np = normalize(image_np, self.mean, self.std)
- return {"image": np.expand_dims(image_np.transpose(2, 0, 1), 0)}
- @cached_property
- def model_cfg(self) -> dict[str, Any]:
- return json.load(self.model_cfg_path.open())
- @cached_property
- def preprocess_cfg(self) -> dict[str, Any]:
- return json.load(self.preprocess_cfg_path.open())
- class MCLIPEncoder(OpenCLIPEncoder):
- def tokenize(self, text: str) -> dict[str, ndarray_i32]:
- tokens: dict[str, ndarray_i64] = self.tokenizer(text, return_tensors="np")
- return {k: v.astype(np.int32) for k, v in tokens.items()}
- _OPENCLIP_MODELS = {
- "RN50__openai",
- "RN50__yfcc15m",
- "RN50__cc12m",
- "RN101__openai",
- "RN101__yfcc15m",
- "RN50x4__openai",
- "RN50x16__openai",
- "RN50x64__openai",
- "ViT-B-32__openai",
- "ViT-B-32__laion2b_e16",
- "ViT-B-32__laion400m_e31",
- "ViT-B-32__laion400m_e32",
- "ViT-B-32__laion2b-s34b-b79k",
- "ViT-B-16__openai",
- "ViT-B-16__laion400m_e31",
- "ViT-B-16__laion400m_e32",
- "ViT-B-16-plus-240__laion400m_e31",
- "ViT-B-16-plus-240__laion400m_e32",
- "ViT-L-14__openai",
- "ViT-L-14__laion400m_e31",
- "ViT-L-14__laion400m_e32",
- "ViT-L-14__laion2b-s32b-b82k",
- "ViT-L-14-336__openai",
- "ViT-H-14__laion2b-s32b-b79k",
- "ViT-g-14__laion2b-s12b-b42k",
- }
- _MCLIP_MODELS = {
- "LABSE-Vit-L-14",
- "XLM-Roberta-Large-Vit-B-32",
- "XLM-Roberta-Large-Vit-B-16Plus",
- "XLM-Roberta-Large-Vit-L-14",
- }
- def _clean_model_name(model_name: str) -> str:
- return model_name.split("/")[-1].replace("::", "__")
- def is_openclip(model_name: str) -> bool:
- return _clean_model_name(model_name) in _OPENCLIP_MODELS
- def is_mclip(model_name: str) -> bool:
- return _clean_model_name(model_name) in _MCLIP_MODELS
|