clip.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import os
  2. import zipfile
  3. from io import BytesIO
  4. from typing import Any, Literal
  5. import onnxruntime as ort
  6. import torch
  7. from clip_server.model.clip import BICUBIC, _convert_image_to_rgb
  8. from clip_server.model.clip_onnx import _MODELS, _S3_BUCKET_V2, CLIPOnnxModel, download_model
  9. from clip_server.model.pretrained_models import _VISUAL_MODEL_IMAGE_SIZE
  10. from clip_server.model.tokenization import Tokenizer
  11. from PIL import Image
  12. from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
  13. from ..config import log
  14. from ..schemas import ModelType
  15. from .base import InferenceModel
  16. class CLIPEncoder(InferenceModel):
  17. _model_type = ModelType.CLIP
  18. def __init__(
  19. self,
  20. model_name: str,
  21. cache_dir: str | None = None,
  22. mode: Literal["text", "vision"] | None = None,
  23. **model_kwargs: Any,
  24. ) -> None:
  25. if mode is not None and mode not in ("text", "vision"):
  26. raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'")
  27. if model_name not in _MODELS:
  28. raise ValueError(f"Unknown model name {model_name}.")
  29. self.mode = mode
  30. super().__init__(model_name, cache_dir, **model_kwargs)
  31. def _download(self) -> None:
  32. models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name]
  33. text_onnx_path = self.cache_dir / "textual.onnx"
  34. vision_onnx_path = self.cache_dir / "visual.onnx"
  35. if not text_onnx_path.is_file():
  36. self._download_model(*models[0])
  37. if not vision_onnx_path.is_file():
  38. self._download_model(*models[1])
  39. def _load(self) -> None:
  40. if self.mode == "text" or self.mode is None:
  41. log.debug(f"Loading clip text model '{self.model_name}'")
  42. self.text_model = ort.InferenceSession(
  43. self.cache_dir / "textual.onnx",
  44. sess_options=self.sess_options,
  45. providers=self.providers,
  46. provider_options=self.provider_options,
  47. )
  48. self.text_outputs = [output.name for output in self.text_model.get_outputs()]
  49. self.tokenizer = Tokenizer(self.model_name)
  50. if self.mode == "vision" or self.mode is None:
  51. log.debug(f"Loading clip vision model '{self.model_name}'")
  52. self.vision_model = ort.InferenceSession(
  53. self.cache_dir / "visual.onnx",
  54. sess_options=self.sess_options,
  55. providers=self.providers,
  56. provider_options=self.provider_options,
  57. )
  58. self.vision_outputs = [output.name for output in self.vision_model.get_outputs()]
  59. image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)]
  60. self.transform = _transform_pil_image(image_size)
  61. def _predict(self, image_or_text: Image.Image | str) -> list[float]:
  62. if isinstance(image_or_text, bytes):
  63. image_or_text = Image.open(BytesIO(image_or_text))
  64. match image_or_text:
  65. case Image.Image():
  66. if self.mode == "text":
  67. raise TypeError("Cannot encode image as text-only model")
  68. pixel_values = self.transform(image_or_text)
  69. assert isinstance(pixel_values, torch.Tensor)
  70. pixel_values = torch.unsqueeze(pixel_values, 0).numpy()
  71. outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values})
  72. case str():
  73. if self.mode == "vision":
  74. raise TypeError("Cannot encode text as vision-only model")
  75. text_inputs: dict[str, torch.Tensor] = self.tokenizer(image_or_text)
  76. inputs = {
  77. "input_ids": text_inputs["input_ids"].int().numpy(),
  78. "attention_mask": text_inputs["attention_mask"].int().numpy(),
  79. }
  80. outputs = self.text_model.run(self.text_outputs, inputs)
  81. case _:
  82. raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
  83. return outputs[0][0].tolist()
  84. def _download_model(self, model_name: str, model_md5: str) -> bool:
  85. # downloading logic is adapted from clip-server's CLIPOnnxModel class
  86. download_model(
  87. url=_S3_BUCKET_V2 + model_name,
  88. target_folder=self.cache_dir.as_posix(),
  89. md5sum=model_md5,
  90. with_resume=True,
  91. )
  92. file = self.cache_dir / model_name.split("/")[1]
  93. if file.suffix == ".zip":
  94. with zipfile.ZipFile(file, "r") as zip_ref:
  95. zip_ref.extractall(self.cache_dir)
  96. os.remove(file)
  97. return True
  98. @property
  99. def cached(self) -> bool:
  100. return (self.cache_dir / "textual.onnx").is_file() and (self.cache_dir / "visual.onnx").is_file()
  101. # same as `_transform_blob` without `_blob2image`
  102. def _transform_pil_image(n_px: int) -> Compose:
  103. return Compose(
  104. [
  105. Resize(n_px, interpolation=BICUBIC),
  106. CenterCrop(n_px),
  107. _convert_image_to_rgb,
  108. ToTensor(),
  109. Normalize(
  110. (0.48145466, 0.4578275, 0.40821073),
  111. (0.26862954, 0.26130258, 0.27577711),
  112. ),
  113. ]
  114. )