mclip.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import tempfile
  2. import warnings
  3. from pathlib import Path
  4. import torch
  5. from multilingual_clip.pt_multilingual_clip import MultilingualCLIP
  6. from transformers import AutoTokenizer
  7. from .openclip import OpenCLIPModelConfig
  8. from .openclip import to_onnx as openclip_to_onnx
  9. from .optimize import optimize
  10. from .util import get_model_path
  11. _MCLIP_TO_OPENCLIP = {
  12. "M-CLIP/XLM-Roberta-Large-Vit-B-32": OpenCLIPModelConfig("ViT-B-32", "openai"),
  13. "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus": OpenCLIPModelConfig("ViT-B-16-plus-240", "laion400m_e32"),
  14. "M-CLIP/LABSE-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
  15. "M-CLIP/XLM-Roberta-Large-Vit-L-14": OpenCLIPModelConfig("ViT-L-14", "openai"),
  16. }
  17. def to_onnx(
  18. model_name: str,
  19. output_dir_visual: Path | str,
  20. output_dir_textual: Path | str,
  21. ) -> None:
  22. textual_path = get_model_path(output_dir_textual)
  23. with tempfile.TemporaryDirectory() as tmpdir:
  24. model = MultilingualCLIP.from_pretrained(model_name, cache_dir=tmpdir)
  25. AutoTokenizer.from_pretrained(model_name).save_pretrained(output_dir_textual)
  26. for param in model.parameters():
  27. param.requires_grad_(False)
  28. export_text_encoder(model, textual_path)
  29. openclip_to_onnx(_MCLIP_TO_OPENCLIP[model_name], output_dir_visual)
  30. optimize(textual_path)
  31. def export_text_encoder(model: MultilingualCLIP, output_path: Path | str) -> None:
  32. output_path = Path(output_path)
  33. def forward(self: MultilingualCLIP, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
  34. embs = self.transformer(input_ids, attention_mask)[0]
  35. embs = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None]
  36. embs = self.LinearTransformation(embs)
  37. return torch.nn.functional.normalize(embs, dim=-1)
  38. # unfortunately need to monkeypatch for tracing to work here
  39. # otherwise it hits the 2GiB protobuf serialization limit
  40. MultilingualCLIP.forward = forward
  41. args = (torch.ones(1, 77, dtype=torch.int32), torch.ones(1, 77, dtype=torch.int32))
  42. with warnings.catch_warnings():
  43. warnings.simplefilter("ignore", UserWarning)
  44. torch.onnx.export(
  45. model,
  46. args,
  47. output_path.as_posix(),
  48. input_names=["input_ids", "attention_mask"],
  49. output_names=["text_embedding"],
  50. opset_version=17,
  51. dynamic_axes={
  52. "input_ids": {0: "batch_size", 1: "sequence_length"},
  53. "attention_mask": {0: "batch_size", 1: "sequence_length"},
  54. },
  55. )