openclip.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import tempfile
  2. import warnings
  3. from dataclasses import dataclass, field
  4. from math import e
  5. from pathlib import Path
  6. import open_clip
  7. import torch
  8. from transformers import AutoTokenizer
  9. from .optimize import optimize
  10. from .util import get_model_path, save_config
  11. @dataclass
  12. class OpenCLIPModelConfig:
  13. name: str
  14. pretrained: str
  15. image_size: int = field(init=False)
  16. sequence_length: int = field(init=False)
  17. def __post_init__(self) -> None:
  18. open_clip_cfg = open_clip.get_model_config(self.name)
  19. if open_clip_cfg is None:
  20. raise ValueError(f"Unknown model {self.name}")
  21. self.image_size = open_clip_cfg["vision_cfg"]["image_size"]
  22. self.sequence_length = open_clip_cfg["text_cfg"]["context_length"]
  23. def to_onnx(
  24. model_cfg: OpenCLIPModelConfig,
  25. output_dir_visual: Path | str | None = None,
  26. output_dir_textual: Path | str | None = None,
  27. ) -> None:
  28. with tempfile.TemporaryDirectory() as tmpdir:
  29. model = open_clip.create_model(
  30. model_cfg.name,
  31. pretrained=model_cfg.pretrained,
  32. jit=False,
  33. cache_dir=tmpdir,
  34. require_pretrained=True,
  35. )
  36. text_vision_cfg = open_clip.get_model_config(model_cfg.name)
  37. for param in model.parameters():
  38. param.requires_grad_(False)
  39. if output_dir_visual is not None:
  40. output_dir_visual = Path(output_dir_visual)
  41. visual_path = get_model_path(output_dir_visual)
  42. save_config(open_clip.get_model_preprocess_cfg(model), output_dir_visual / "preprocess_cfg.json")
  43. save_config(text_vision_cfg, output_dir_visual.parent / "config.json")
  44. export_image_encoder(model, model_cfg, visual_path)
  45. optimize(visual_path)
  46. if output_dir_textual is not None:
  47. output_dir_textual = Path(output_dir_textual)
  48. textual_path = get_model_path(output_dir_textual)
  49. tokenizer_name = text_vision_cfg["text_cfg"].get("hf_tokenizer_name", "openai/clip-vit-base-patch32")
  50. AutoTokenizer.from_pretrained(tokenizer_name).save_pretrained(output_dir_textual)
  51. export_text_encoder(model, model_cfg, textual_path)
  52. optimize(textual_path)
  53. def export_image_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, output_path: Path | str) -> None:
  54. output_path = Path(output_path)
  55. def encode_image(image: torch.Tensor) -> torch.Tensor:
  56. output = model.encode_image(image, normalize=True)
  57. assert isinstance(output, torch.Tensor)
  58. return output
  59. args = (torch.randn(1, 3, model_cfg.image_size, model_cfg.image_size),)
  60. traced = torch.jit.trace(encode_image, args) # type: ignore[no-untyped-call]
  61. with warnings.catch_warnings():
  62. warnings.simplefilter("ignore", UserWarning)
  63. torch.onnx.export(
  64. traced,
  65. args,
  66. output_path.as_posix(),
  67. input_names=["image"],
  68. output_names=["image_embedding"],
  69. opset_version=17,
  70. dynamic_axes={"image": {0: "batch_size"}},
  71. )
  72. def export_text_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, output_path: Path | str) -> None:
  73. output_path = Path(output_path)
  74. def encode_text(text: torch.Tensor) -> torch.Tensor:
  75. output = model.encode_text(text, normalize=True)
  76. assert isinstance(output, torch.Tensor)
  77. return output
  78. args = (torch.ones(1, model_cfg.sequence_length, dtype=torch.int32),)
  79. traced = torch.jit.trace(encode_text, args) # type: ignore[no-untyped-call]
  80. with warnings.catch_warnings():
  81. warnings.simplefilter("ignore", UserWarning)
  82. torch.onnx.export(
  83. traced,
  84. args,
  85. output_path.as_posix(),
  86. input_names=["text"],
  87. output_names=["text_embedding"],
  88. opset_version=17,
  89. dynamic_axes={"text": {0: "batch_size"}},
  90. )