tfclip.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import tempfile
  2. from pathlib import Path
  3. import tensorflow as tf
  4. from transformers import TFCLIPModel
  5. from .util import ModelType, get_model_path
  6. class _CLIPWrapper(tf.Module):
  7. def __init__(self, model_name: str):
  8. super(_CLIPWrapper)
  9. self.model = TFCLIPModel.from_pretrained(model_name)
  10. @tf.function()
  11. def encode_image(self, input):
  12. return self.model.get_image_features(input)
  13. @tf.function()
  14. def encode_text(self, input):
  15. return self.model.get_text_features(input)
  16. # exported model signatures use batch size 2 because of the following reasons:
  17. # 1. ARM-NN cannot use dynamic batch sizes
  18. # 2. batch size 1 creates a larger TF-Lite model that uses a lot (50%) more RAM
  19. # 3. batch size 2 is ~50% faster on GPU than 1 while 4 (or larger) are not faster
  20. # 4. batch size >2 wastes more computation if only a single image is processed
  21. BATCH_SIZE = 2
  22. SIGNATURE_TEXT = "encode_text"
  23. SIGNATURE_IMAGE = "encode_image"
  24. def to_tflite(
  25. model_name,
  26. output_path_image: Path | str | None,
  27. output_path_text: Path | str | None,
  28. context_length: int = 77,
  29. ):
  30. with tempfile.TemporaryDirectory() as tmpdir:
  31. _export_temporary_tf_model(model_name, tmpdir, context_length)
  32. if output_path_image is not None:
  33. image_path = get_model_path(output_path_image, ModelType.TFLITE)
  34. _export_tflite_model(tmpdir, SIGNATURE_IMAGE, image_path.as_posix())
  35. if output_path_text is not None:
  36. text_path = get_model_path(output_path_text, ModelType.TFLITE)
  37. _export_tflite_model(tmpdir, SIGNATURE_TEXT, text_path.as_posix())
  38. def _export_temporary_tf_model(model_name, tmp_path: str, context_length: int):
  39. wrapper = _CLIPWrapper(model_name)
  40. conf = wrapper.model.config.vision_config
  41. spec_visual = tf.TensorSpec(
  42. shape=(BATCH_SIZE, conf.num_channels, conf.image_size, conf.image_size), dtype=tf.float32
  43. )
  44. encode_image = wrapper.encode_image.get_concrete_function(spec_visual)
  45. spec_text = tf.TensorSpec(shape=(BATCH_SIZE, context_length), dtype=tf.int32)
  46. encode_text = wrapper.encode_text.get_concrete_function(spec_text)
  47. signatures = {"encode_text": encode_text, "encode_image": encode_image}
  48. tf.saved_model.save(wrapper, tmp_path, signatures)
  49. def _export_tflite_model(tmp_path: str, signature: str, output_path: str):
  50. converter = tf.lite.TFLiteConverter.from_saved_model(tmp_path, signature_keys=[signature])
  51. converter.optimizations = [tf.lite.Optimize.DEFAULT]
  52. converter.target_spec.supported_types = [tf.float32]
  53. tflite_model = converter.convert()
  54. with open(output_path, "wb") as f:
  55. f.write(tflite_model)