clip.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import json
  2. from abc import abstractmethod
  3. from functools import cached_property
  4. from io import BytesIO
  5. from pathlib import Path
  6. from typing import Any, Literal
  7. import numpy as np
  8. import onnxruntime as ort
  9. from PIL import Image
  10. from transformers import AutoTokenizer
  11. from app.config import clean_name, log
  12. from app.models.transforms import crop, get_pil_resampling, normalize, resize, to_numpy
  13. from app.schemas import ModelType, ndarray_f32, ndarray_i32, ndarray_i64
  14. from .base import InferenceModel
  15. class BaseCLIPEncoder(InferenceModel):
  16. _model_type = ModelType.CLIP
  17. def __init__(
  18. self,
  19. model_name: str,
  20. cache_dir: str | None = None,
  21. mode: Literal["text", "vision"] | None = None,
  22. **model_kwargs: Any,
  23. ) -> None:
  24. self.mode = mode
  25. super().__init__(model_name, cache_dir, **model_kwargs)
  26. def _load(self) -> None:
  27. if self.mode == "text" or self.mode is None:
  28. log.debug(f"Loading clip text model '{self.model_name}'")
  29. self.text_model = ort.InferenceSession(
  30. self.textual_path.as_posix(),
  31. sess_options=self.sess_options,
  32. providers=self.providers,
  33. provider_options=self.provider_options,
  34. )
  35. if self.mode == "vision" or self.mode is None:
  36. log.debug(f"Loading clip vision model '{self.model_name}'")
  37. self.vision_model = ort.InferenceSession(
  38. self.visual_path.as_posix(),
  39. sess_options=self.sess_options,
  40. providers=self.providers,
  41. provider_options=self.provider_options,
  42. )
  43. def _predict(self, image_or_text: Image.Image | str) -> ndarray_f32:
  44. if isinstance(image_or_text, bytes):
  45. image_or_text = Image.open(BytesIO(image_or_text))
  46. match image_or_text:
  47. case Image.Image():
  48. if self.mode == "text":
  49. raise TypeError("Cannot encode image as text-only model")
  50. outputs: ndarray_f32 = self.vision_model.run(None, self.transform(image_or_text))[0][0]
  51. case str():
  52. if self.mode == "vision":
  53. raise TypeError("Cannot encode text as vision-only model")
  54. outputs = self.text_model.run(None, self.tokenize(image_or_text))[0][0]
  55. case _:
  56. raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
  57. return outputs
  58. @abstractmethod
  59. def tokenize(self, text: str) -> dict[str, ndarray_i32]:
  60. pass
  61. @abstractmethod
  62. def transform(self, image: Image.Image) -> dict[str, ndarray_f32]:
  63. pass
  64. @property
  65. def textual_dir(self) -> Path:
  66. return self.cache_dir / "textual"
  67. @property
  68. def visual_dir(self) -> Path:
  69. return self.cache_dir / "visual"
  70. @property
  71. def model_cfg_path(self) -> Path:
  72. return self.cache_dir / "config.json"
  73. @property
  74. def textual_path(self) -> Path:
  75. return self.textual_dir / "model.onnx"
  76. @property
  77. def visual_path(self) -> Path:
  78. return self.visual_dir / "model.onnx"
  79. @property
  80. def preprocess_cfg_path(self) -> Path:
  81. return self.visual_dir / "preprocess_cfg.json"
  82. @property
  83. def cached(self) -> bool:
  84. return self.textual_path.is_file() and self.visual_path.is_file()
  85. class OpenCLIPEncoder(BaseCLIPEncoder):
  86. def __init__(
  87. self,
  88. model_name: str,
  89. cache_dir: str | None = None,
  90. mode: Literal["text", "vision"] | None = None,
  91. **model_kwargs: Any,
  92. ) -> None:
  93. super().__init__(clean_name(model_name), cache_dir, mode, **model_kwargs)
  94. def _load(self) -> None:
  95. super()._load()
  96. self.tokenizer = AutoTokenizer.from_pretrained(self.textual_dir)
  97. self.sequence_length = self.model_cfg["text_cfg"]["context_length"]
  98. self.size = (
  99. self.preprocess_cfg["size"][0] if type(self.preprocess_cfg["size"]) == list else self.preprocess_cfg["size"]
  100. )
  101. self.resampling = get_pil_resampling(self.preprocess_cfg["interpolation"])
  102. self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
  103. self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
  104. def tokenize(self, text: str) -> dict[str, ndarray_i32]:
  105. input_ids: ndarray_i64 = self.tokenizer(
  106. text,
  107. max_length=self.sequence_length,
  108. return_tensors="np",
  109. return_attention_mask=False,
  110. padding="max_length",
  111. truncation=True,
  112. ).input_ids
  113. return {"text": input_ids.astype(np.int32)}
  114. def transform(self, image: Image.Image) -> dict[str, ndarray_f32]:
  115. image = resize(image, self.size)
  116. image = crop(image, self.size)
  117. image_np = to_numpy(image)
  118. image_np = normalize(image_np, self.mean, self.std)
  119. return {"image": np.expand_dims(image_np.transpose(2, 0, 1), 0)}
  120. @cached_property
  121. def model_cfg(self) -> dict[str, Any]:
  122. model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
  123. return model_cfg
  124. @cached_property
  125. def preprocess_cfg(self) -> dict[str, Any]:
  126. preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open())
  127. return preprocess_cfg
  128. class MCLIPEncoder(OpenCLIPEncoder):
  129. def tokenize(self, text: str) -> dict[str, ndarray_i32]:
  130. tokens: dict[str, ndarray_i64] = self.tokenizer(text, return_tensors="np")
  131. return {k: v.astype(np.int32) for k, v in tokens.items()}