Ver código fonte

feat(ml): ARM NN acceleration

Fynn Petersen-Frey 1 ano atrás
pai
commit
5f6ad9e239

+ 3 - 2
machine-learning/export/env.yaml

@@ -20,6 +20,7 @@ dependencies:
   - torchvision
   - torchvision
   - transformers==4.*
   - transformers==4.*
   - pip:
   - pip:
-    - multilingual-clip
-    - onnx-simplifier
+      - multilingual-clip
+      - onnx-simplifier
+      - tensorflow
 category: main
 category: main

+ 70 - 0
machine-learning/export/models/tfclip.py

@@ -0,0 +1,70 @@
+import tempfile
+from pathlib import Path
+
+import tensorflow as tf
+from transformers import TFCLIPModel
+
+from .util import ModelType, get_model_path
+
+
+class _CLIPWrapper(tf.Module):
+    def __init__(self, model_name: str):
+        super(_CLIPWrapper)
+        self.model = TFCLIPModel.from_pretrained(model_name)
+
+    @tf.function()
+    def encode_image(self, input):
+        return self.model.get_image_features(input)
+
+    @tf.function()
+    def encode_text(self, input):
+        return self.model.get_text_features(input)
+
+
+# exported model signatures use batch size 2 because of the following reasons:
+# 1. ARM-NN cannot use dynamic batch sizes
+# 2. batch size 1 creates a larger TF-Lite model that uses a lot (50%) more RAM
+# 3. batch size 2 is ~50% faster on GPU than 1 while 4 (or larger) are not faster
+# 4. batch size >2 wastes more computation if only a single image is processed
+BATCH_SIZE = 2
+
+SIGNATURE_TEXT = "encode_text"
+SIGNATURE_IMAGE = "encode_image"
+
+
+def to_tflite(
+    model_name,
+    output_path_image: Path | str | None,
+    output_path_text: Path | str | None,
+    context_length: int = 77,
+):
+    with tempfile.TemporaryDirectory() as tmpdir:
+        _export_temporary_tf_model(model_name, tmpdir, context_length)
+        if output_path_image is not None:
+            image_path = get_model_path(output_path_image, ModelType.TFLITE)
+            _export_tflite_model(tmpdir, SIGNATURE_IMAGE, image_path.as_posix())
+        if output_path_text is not None:
+            text_path = get_model_path(output_path_text, ModelType.TFLITE)
+            _export_tflite_model(tmpdir, SIGNATURE_TEXT, text_path.as_posix())
+
+
+def _export_temporary_tf_model(model_name, tmp_path: str, context_length: int):
+    wrapper = _CLIPWrapper(model_name)
+    conf = wrapper.model.config.vision_config
+    spec_visual = tf.TensorSpec(
+        shape=(BATCH_SIZE, conf.num_channels, conf.image_size, conf.image_size), dtype=tf.float32
+    )
+    encode_image = wrapper.encode_image.get_concrete_function(spec_visual)
+    spec_text = tf.TensorSpec(shape=(BATCH_SIZE, context_length), dtype=tf.int32)
+    encode_text = wrapper.encode_text.get_concrete_function(spec_text)
+    signatures = {"encode_text": encode_text, "encode_image": encode_image}
+    tf.saved_model.save(wrapper, tmp_path, signatures)
+
+
+def _export_tflite_model(tmp_path: str, signature: str, output_path: str):
+    converter = tf.lite.TFLiteConverter.from_saved_model(tmp_path, signature_keys=[signature])
+    converter.optimizations = [tf.lite.Optimize.DEFAULT]
+    converter.target_spec.supported_types = [tf.float32]
+    tflite_model = converter.convert()
+    with open(output_path, "wb") as f:
+        f.write(tflite_model)

+ 8 - 2
machine-learning/export/models/util.py

@@ -1,12 +1,18 @@
 import json
 import json
+from enum import Enum
 from pathlib import Path
 from pathlib import Path
 from typing import Any
 from typing import Any
 
 
 
 
-def get_model_path(output_dir: Path | str) -> Path:
+class ModelType(Enum):
+    ONNX = "onnx"
+    TFLITE = "tflite"
+
+
+def get_model_path(output_dir: Path | str, model_type: ModelType = ModelType.ONNX) -> Path:
     output_dir = Path(output_dir)
     output_dir = Path(output_dir)
     output_dir.mkdir(parents=True, exist_ok=True)
     output_dir.mkdir(parents=True, exist_ok=True)
-    return output_dir / "model.onnx"
+    return output_dir / f"model.{model_type.value}"
 
 
 
 
 def save_config(config: Any, output_path: Path | str) -> None:
 def save_config(config: Any, output_path: Path | str) -> None:

+ 5 - 2
machine-learning/export/run.py

@@ -4,7 +4,7 @@ from pathlib import Path
 from tempfile import TemporaryDirectory
 from tempfile import TemporaryDirectory
 
 
 from huggingface_hub import create_repo, login, upload_folder
 from huggingface_hub import create_repo, login, upload_folder
-from models import mclip, openclip
+from models import mclip, openclip, tfclip
 from rich.progress import Progress
 from rich.progress import Progress
 
 
 models = [
 models = [
@@ -37,9 +37,10 @@ models = [
     "M-CLIP/XLM-Roberta-Large-Vit-B-32",
     "M-CLIP/XLM-Roberta-Large-Vit-B-32",
     "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus",
     "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus",
     "M-CLIP/XLM-Roberta-Large-Vit-L-14",
     "M-CLIP/XLM-Roberta-Large-Vit-L-14",
+    "openai/clip-vit-base-patch32",
 ]
 ]
 
 
-login(token=os.environ["HF_AUTH_TOKEN"])
+# login(token=os.environ["HF_AUTH_TOKEN"])
 
 
 with Progress() as progress:
 with Progress() as progress:
     task1 = progress.add_task("[green]Exporting models...", total=len(models))
     task1 = progress.add_task("[green]Exporting models...", total=len(models))
@@ -65,6 +66,8 @@ with Progress() as progress:
                 textual_dir = tmpdir / model_name / "textual"
                 textual_dir = tmpdir / model_name / "textual"
                 if model.startswith("M-CLIP"):
                 if model.startswith("M-CLIP"):
                     mclip.to_onnx(model, visual_dir, textual_dir)
                     mclip.to_onnx(model, visual_dir, textual_dir)
+                elif "/" in model:
+                    tfclip.to_tflite(model, visual_dir.as_posix(), textual_dir.as_posix())
                 else:
                 else:
                     name, _, pretrained = model_name.partition("__")
                     name, _, pretrained = model_name.partition("__")
                     openclip.to_onnx(openclip.OpenCLIPModelConfig(name, pretrained), visual_dir, textual_dir)
                     openclip.to_onnx(openclip.OpenCLIPModelConfig(name, pretrained), visual_dir, textual_dir)