clip.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import os
  2. import zipfile
  3. from io import BytesIO
  4. from typing import Any, Literal
  5. import onnxruntime as ort
  6. import torch
  7. from clip_server.model.clip import BICUBIC, _convert_image_to_rgb
  8. from clip_server.model.clip_onnx import _MODELS, _S3_BUCKET_V2, CLIPOnnxModel, download_model
  9. from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE
  10. from clip_server.model.tokenization import Tokenizer
  11. from PIL import Image
  12. from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
  13. from ..config import log
  14. from ..schemas import ModelType
  15. from .base import InferenceModel
  16. _ST_TO_JINA_MODEL_NAME = {
  17. "clip-ViT-B-16": "ViT-B-16::openai",
  18. "clip-ViT-B-32": "ViT-B-32::openai",
  19. "clip-ViT-B-32-multilingual-v1": "M-CLIP/XLM-Roberta-Large-Vit-B-32",
  20. "clip-ViT-L-14": "ViT-L-14::openai",
  21. }
  22. class CLIPEncoder(InferenceModel):
  23. _model_type = ModelType.CLIP
  24. def __init__(
  25. self,
  26. model_name: str,
  27. cache_dir: str | None = None,
  28. mode: Literal["text", "vision"] | None = None,
  29. **model_kwargs: Any,
  30. ) -> None:
  31. if mode is not None and mode not in ("text", "vision"):
  32. raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'")
  33. if "vit-b" not in model_name.lower():
  34. raise ValueError(f"Only ViT-B models are currently supported; got '{model_name}'")
  35. self.mode = mode
  36. jina_model_name = self._get_jina_model_name(model_name)
  37. super().__init__(jina_model_name, cache_dir, **model_kwargs)
  38. def _download(self, **model_kwargs: Any) -> None:
  39. models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name]
  40. text_onnx_path = self.cache_dir / "textual.onnx"
  41. vision_onnx_path = self.cache_dir / "visual.onnx"
  42. if not text_onnx_path.is_file():
  43. self._download_model(*models[0])
  44. if not vision_onnx_path.is_file():
  45. self._download_model(*models[1])
  46. def _load(self, **model_kwargs: Any) -> None:
  47. if self.mode == "text" or self.mode is None:
  48. self.text_model = ort.InferenceSession(
  49. self.cache_dir / "textual.onnx",
  50. sess_options=self.sess_options,
  51. providers=self.providers,
  52. provider_options=self.provider_options,
  53. )
  54. self.text_outputs = [output.name for output in self.text_model.get_outputs()]
  55. self.tokenizer = Tokenizer(self.model_name)
  56. if self.mode == "vision" or self.mode is None:
  57. self.vision_model = ort.InferenceSession(
  58. self.cache_dir / "visual.onnx",
  59. sess_options=self.sess_options,
  60. providers=self.providers,
  61. provider_options=self.provider_options,
  62. )
  63. self.vision_outputs = [output.name for output in self.vision_model.get_outputs()]
  64. image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)]
  65. self.transform = _transform_pil_image(image_size)
  66. def _predict(self, image_or_text: Image.Image | str) -> list[float]:
  67. if isinstance(image_or_text, bytes):
  68. image_or_text = Image.open(BytesIO(image_or_text))
  69. match image_or_text:
  70. case Image.Image():
  71. if self.mode == "text":
  72. raise TypeError("Cannot encode image as text-only model")
  73. pixel_values = self.transform(image_or_text)
  74. assert isinstance(pixel_values, torch.Tensor)
  75. pixel_values = torch.unsqueeze(pixel_values, 0).numpy()
  76. outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values})
  77. case str():
  78. if self.mode == "vision":
  79. raise TypeError("Cannot encode text as vision-only model")
  80. text_inputs: dict[str, torch.Tensor] = self.tokenizer(image_or_text)
  81. inputs = {
  82. "input_ids": text_inputs["input_ids"].int().numpy(),
  83. "attention_mask": text_inputs["attention_mask"].int().numpy(),
  84. }
  85. outputs = self.text_model.run(self.text_outputs, inputs)
  86. case _:
  87. raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
  88. return outputs[0][0].tolist()
  89. def _get_jina_model_name(self, model_name: str) -> str:
  90. if model_name in _MODELS:
  91. return model_name
  92. elif model_name in _ST_TO_JINA_MODEL_NAME:
  93. log.warn(
  94. (
  95. f"Sentence-Transformer models like '{model_name}' are not supported."
  96. f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."
  97. ),
  98. )
  99. return _ST_TO_JINA_MODEL_NAME[model_name]
  100. else:
  101. raise ValueError(f"Unknown model name {model_name}.")
  102. def _download_model(self, model_name: str, model_md5: str) -> bool:
  103. # downloading logic is adapted from clip-server's CLIPOnnxModel class
  104. download_model(
  105. url=_S3_BUCKET_V2 + model_name,
  106. target_folder=self.cache_dir.as_posix(),
  107. md5sum=model_md5,
  108. with_resume=True,
  109. )
  110. file = self.cache_dir / model_name.split("/")[1]
  111. if file.suffix == ".zip":
  112. with zipfile.ZipFile(file, "r") as zip_ref:
  113. zip_ref.extractall(self.cache_dir)
  114. os.remove(file)
  115. return True
  116. # same as `_transform_blob` without `_blob2image`
  117. def _transform_pil_image(n_px: int) -> Compose:
  118. return Compose(
  119. [
  120. Resize(n_px, interpolation=BICUBIC),
  121. CenterCrop(n_px),
  122. _convert_image_to_rgb,
  123. ToTensor(),
  124. Normalize(
  125. (0.48145466, 0.4578275, 0.40821073),
  126. (0.26862954, 0.26130258, 0.27577711),
  127. ),
  128. ]
  129. )