clip.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import os
  2. import zipfile
  3. from io import BytesIO
  4. from typing import Any, Literal
  5. import onnxruntime as ort
  6. import torch
  7. from clip_server.model.clip import BICUBIC, _convert_image_to_rgb
  8. from clip_server.model.clip_onnx import _MODELS, _S3_BUCKET_V2, CLIPOnnxModel, download_model
  9. from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE
  10. from clip_server.model.tokenization import Tokenizer
  11. from PIL import Image
  12. from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
  13. from ..schemas import ModelType
  14. from .base import InferenceModel
  15. _ST_TO_JINA_MODEL_NAME = {
  16. "clip-ViT-B-16": "ViT-B-16::openai",
  17. "clip-ViT-B-32": "ViT-B-32::openai",
  18. "clip-ViT-B-32-multilingual-v1": "M-CLIP/XLM-Roberta-Large-Vit-B-32",
  19. "clip-ViT-L-14": "ViT-L-14::openai",
  20. }
  21. class CLIPEncoder(InferenceModel):
  22. _model_type = ModelType.CLIP
  23. def __init__(
  24. self,
  25. model_name: str,
  26. cache_dir: str | None = None,
  27. mode: Literal["text", "vision"] | None = None,
  28. **model_kwargs: Any,
  29. ) -> None:
  30. if mode is not None and mode not in ("text", "vision"):
  31. raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'")
  32. if "vit-b" not in model_name.lower():
  33. raise ValueError(f"Only ViT-B models are currently supported; got '{model_name}'")
  34. self.mode = mode
  35. jina_model_name = self._get_jina_model_name(model_name)
  36. super().__init__(jina_model_name, cache_dir, **model_kwargs)
  37. def _download(self, **model_kwargs: Any) -> None:
  38. models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name]
  39. text_onnx_path = self.cache_dir / "textual.onnx"
  40. vision_onnx_path = self.cache_dir / "visual.onnx"
  41. if not text_onnx_path.is_file():
  42. self._download_model(*models[0])
  43. if not vision_onnx_path.is_file():
  44. self._download_model(*models[1])
  45. def _load(self, **model_kwargs: Any) -> None:
  46. if self.mode == "text" or self.mode is None:
  47. self.text_model = ort.InferenceSession(
  48. self.cache_dir / "textual.onnx",
  49. sess_options=self.sess_options,
  50. providers=self.providers,
  51. provider_options=self.provider_options,
  52. )
  53. self.text_outputs = [output.name for output in self.text_model.get_outputs()]
  54. self.tokenizer = Tokenizer(self.model_name)
  55. if self.mode == "vision" or self.mode is None:
  56. self.vision_model = ort.InferenceSession(
  57. self.cache_dir / "visual.onnx",
  58. sess_options=self.sess_options,
  59. providers=self.providers,
  60. provider_options=self.provider_options,
  61. )
  62. self.vision_outputs = [output.name for output in self.vision_model.get_outputs()]
  63. image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)]
  64. self.transform = _transform_pil_image(image_size)
  65. def _predict(self, image_or_text: Image.Image | str) -> list[float]:
  66. if isinstance(image_or_text, bytes):
  67. image_or_text = Image.open(BytesIO(image_or_text))
  68. match image_or_text:
  69. case Image.Image():
  70. if self.mode == "text":
  71. raise TypeError("Cannot encode image as text-only model")
  72. pixel_values = self.transform(image_or_text)
  73. assert isinstance(pixel_values, torch.Tensor)
  74. pixel_values = torch.unsqueeze(pixel_values, 0).numpy()
  75. outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values})
  76. case str():
  77. if self.mode == "vision":
  78. raise TypeError("Cannot encode text as vision-only model")
  79. text_inputs: dict[str, torch.Tensor] = self.tokenizer(image_or_text)
  80. inputs = {
  81. "input_ids": text_inputs["input_ids"].int().numpy(),
  82. "attention_mask": text_inputs["attention_mask"].int().numpy(),
  83. }
  84. outputs = self.text_model.run(self.text_outputs, inputs)
  85. case _:
  86. raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
  87. return outputs[0][0].tolist()
  88. def _get_jina_model_name(self, model_name: str) -> str:
  89. if model_name in _MODELS:
  90. return model_name
  91. elif model_name in _ST_TO_JINA_MODEL_NAME:
  92. print(
  93. (f"Warning: Sentence-Transformer model names such as '{model_name}' are no longer supported."),
  94. (f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."),
  95. )
  96. return _ST_TO_JINA_MODEL_NAME[model_name]
  97. else:
  98. raise ValueError(f"Unknown model name {model_name}.")
  99. def _download_model(self, model_name: str, model_md5: str) -> bool:
  100. # downloading logic is adapted from clip-server's CLIPOnnxModel class
  101. download_model(
  102. url=_S3_BUCKET_V2 + model_name,
  103. target_folder=self.cache_dir.as_posix(),
  104. md5sum=model_md5,
  105. with_resume=True,
  106. )
  107. file = self.cache_dir / model_name.split("/")[1]
  108. if file.suffix == ".zip":
  109. with zipfile.ZipFile(file, "r") as zip_ref:
  110. zip_ref.extractall(self.cache_dir)
  111. os.remove(file)
  112. return True
  113. # same as `_transform_blob` without `_blob2image`
  114. def _transform_pil_image(n_px: int) -> Compose:
  115. return Compose(
  116. [
  117. Resize(n_px, interpolation=BICUBIC),
  118. CenterCrop(n_px),
  119. _convert_image_to_rgb,
  120. ToTensor(),
  121. Normalize(
  122. (0.48145466, 0.4578275, 0.40821073),
  123. (0.26862954, 0.26130258, 0.27577711),
  124. ),
  125. ]
  126. )