clip.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. import json
  2. from abc import abstractmethod
  3. from functools import cached_property
  4. from io import BytesIO
  5. from pathlib import Path
  6. from typing import Any, Literal
  7. import numpy as np
  8. import onnxruntime as ort
  9. from huggingface_hub import snapshot_download
  10. from PIL import Image
  11. from transformers import AutoTokenizer
  12. from app.config import log
  13. from app.models.transforms import crop, get_pil_resampling, normalize, resize, to_numpy
  14. from app.schemas import ModelType, ndarray_f32, ndarray_i32, ndarray_i64
  15. from .base import InferenceModel
  16. class BaseCLIPEncoder(InferenceModel):
  17. _model_type = ModelType.CLIP
  18. def __init__(
  19. self,
  20. model_name: str,
  21. cache_dir: str | None = None,
  22. mode: Literal["text", "vision"] | None = None,
  23. **model_kwargs: Any,
  24. ) -> None:
  25. self.mode = mode
  26. super().__init__(model_name, cache_dir, **model_kwargs)
  27. def _load(self) -> None:
  28. if self.mode == "text" or self.mode is None:
  29. log.debug(f"Loading clip text model '{self.model_name}'")
  30. self.text_model = ort.InferenceSession(
  31. self.textual_path.as_posix(),
  32. sess_options=self.sess_options,
  33. providers=self.providers,
  34. provider_options=self.provider_options,
  35. )
  36. if self.mode == "vision" or self.mode is None:
  37. log.debug(f"Loading clip vision model '{self.model_name}'")
  38. self.vision_model = ort.InferenceSession(
  39. self.visual_path.as_posix(),
  40. sess_options=self.sess_options,
  41. providers=self.providers,
  42. provider_options=self.provider_options,
  43. )
  44. def _predict(self, image_or_text: Image.Image | str) -> list[float]:
  45. if isinstance(image_or_text, bytes):
  46. image_or_text = Image.open(BytesIO(image_or_text))
  47. match image_or_text:
  48. case Image.Image():
  49. if self.mode == "text":
  50. raise TypeError("Cannot encode image as text-only model")
  51. outputs = self.vision_model.run(None, self.transform(image_or_text))
  52. case str():
  53. if self.mode == "vision":
  54. raise TypeError("Cannot encode text as vision-only model")
  55. outputs = self.text_model.run(None, self.tokenize(image_or_text))
  56. case _:
  57. raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
  58. return outputs[0][0].tolist()
  59. @abstractmethod
  60. def tokenize(self, text: str) -> dict[str, ndarray_i32]:
  61. pass
  62. @abstractmethod
  63. def transform(self, image: Image.Image) -> dict[str, ndarray_f32]:
  64. pass
  65. @property
  66. def textual_dir(self) -> Path:
  67. return self.cache_dir / "textual"
  68. @property
  69. def visual_dir(self) -> Path:
  70. return self.cache_dir / "visual"
  71. @property
  72. def model_cfg_path(self) -> Path:
  73. return self.cache_dir / "config.json"
  74. @property
  75. def textual_path(self) -> Path:
  76. return self.textual_dir / "model.onnx"
  77. @property
  78. def visual_path(self) -> Path:
  79. return self.visual_dir / "model.onnx"
  80. @property
  81. def preprocess_cfg_path(self) -> Path:
  82. return self.visual_dir / "preprocess_cfg.json"
  83. @property
  84. def cached(self) -> bool:
  85. return self.textual_path.is_file() and self.visual_path.is_file()
  86. class OpenCLIPEncoder(BaseCLIPEncoder):
  87. def __init__(
  88. self,
  89. model_name: str,
  90. cache_dir: str | None = None,
  91. mode: Literal["text", "vision"] | None = None,
  92. **model_kwargs: Any,
  93. ) -> None:
  94. super().__init__(_clean_model_name(model_name), cache_dir, mode, **model_kwargs)
  95. def _download(self) -> None:
  96. snapshot_download(
  97. f"immich-app/{self.model_name}",
  98. cache_dir=self.cache_dir,
  99. local_dir=self.cache_dir,
  100. local_dir_use_symlinks=False,
  101. )
  102. def _load(self) -> None:
  103. super()._load()
  104. self.tokenizer = AutoTokenizer.from_pretrained(self.textual_dir)
  105. self.sequence_length = self.model_cfg["text_cfg"]["context_length"]
  106. self.size = (
  107. self.preprocess_cfg["size"][0] if type(self.preprocess_cfg["size"]) == list else self.preprocess_cfg["size"]
  108. )
  109. self.resampling = get_pil_resampling(self.preprocess_cfg["interpolation"])
  110. self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
  111. self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
  112. def tokenize(self, text: str) -> dict[str, ndarray_i32]:
  113. input_ids: ndarray_i64 = self.tokenizer(
  114. text,
  115. max_length=self.sequence_length,
  116. return_tensors="np",
  117. return_attention_mask=False,
  118. padding="max_length",
  119. truncation=True,
  120. ).input_ids
  121. return {"text": input_ids.astype(np.int32)}
  122. def transform(self, image: Image.Image) -> dict[str, ndarray_f32]:
  123. image = resize(image, self.size)
  124. image = crop(image, self.size)
  125. image_np = to_numpy(image)
  126. image_np = normalize(image_np, self.mean, self.std)
  127. return {"image": np.expand_dims(image_np.transpose(2, 0, 1), 0)}
  128. @cached_property
  129. def model_cfg(self) -> dict[str, Any]:
  130. return json.load(self.model_cfg_path.open())
  131. @cached_property
  132. def preprocess_cfg(self) -> dict[str, Any]:
  133. return json.load(self.preprocess_cfg_path.open())
  134. class MCLIPEncoder(OpenCLIPEncoder):
  135. def tokenize(self, text: str) -> dict[str, ndarray_i32]:
  136. tokens: dict[str, ndarray_i64] = self.tokenizer(text, return_tensors="np")
  137. return {k: v.astype(np.int32) for k, v in tokens.items()}
  138. _OPENCLIP_MODELS = {
  139. "RN50__openai",
  140. "RN50__yfcc15m",
  141. "RN50__cc12m",
  142. "RN101__openai",
  143. "RN101__yfcc15m",
  144. "RN50x4__openai",
  145. "RN50x16__openai",
  146. "RN50x64__openai",
  147. "ViT-B-32__openai",
  148. "ViT-B-32__laion2b_e16",
  149. "ViT-B-32__laion400m_e31",
  150. "ViT-B-32__laion400m_e32",
  151. "ViT-B-32__laion2b-s34b-b79k",
  152. "ViT-B-16__openai",
  153. "ViT-B-16__laion400m_e31",
  154. "ViT-B-16__laion400m_e32",
  155. "ViT-B-16-plus-240__laion400m_e31",
  156. "ViT-B-16-plus-240__laion400m_e32",
  157. "ViT-L-14__openai",
  158. "ViT-L-14__laion400m_e31",
  159. "ViT-L-14__laion400m_e32",
  160. "ViT-L-14__laion2b-s32b-b82k",
  161. "ViT-L-14-336__openai",
  162. "ViT-H-14__laion2b-s32b-b79k",
  163. "ViT-g-14__laion2b-s12b-b42k",
  164. }
  165. _MCLIP_MODELS = {
  166. "LABSE-Vit-L-14",
  167. "XLM-Roberta-Large-Vit-B-32",
  168. "XLM-Roberta-Large-Vit-B-16Plus",
  169. "XLM-Roberta-Large-Vit-L-14",
  170. }
  171. def _clean_model_name(model_name: str) -> str:
  172. return model_name.split("/")[-1].replace("::", "__")
  173. def is_openclip(model_name: str) -> bool:
  174. return _clean_model_name(model_name) in _OPENCLIP_MODELS
  175. def is_mclip(model_name: str) -> bool:
  176. return _clean_model_name(model_name) in _MCLIP_MODELS