|
@@ -1,6 +1,7 @@
|
|
|
import tempfile
|
|
|
import warnings
|
|
|
from dataclasses import dataclass, field
|
|
|
+from math import e
|
|
|
from pathlib import Path
|
|
|
|
|
|
import open_clip
|
|
@@ -69,10 +70,12 @@ def export_image_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig,
|
|
|
output_path = Path(output_path)
|
|
|
|
|
|
def encode_image(image: torch.Tensor) -> torch.Tensor:
|
|
|
- return model.encode_image(image, normalize=True)
|
|
|
+ output = model.encode_image(image, normalize=True)
|
|
|
+ assert isinstance(output, torch.Tensor)
|
|
|
+ return output
|
|
|
|
|
|
args = (torch.randn(1, 3, model_cfg.image_size, model_cfg.image_size),)
|
|
|
- traced = torch.jit.trace(encode_image, args)
|
|
|
+ traced = torch.jit.trace(encode_image, args) # type: ignore[no-untyped-call]
|
|
|
|
|
|
with warnings.catch_warnings():
|
|
|
warnings.simplefilter("ignore", UserWarning)
|
|
@@ -91,10 +94,12 @@ def export_text_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, o
|
|
|
output_path = Path(output_path)
|
|
|
|
|
|
def encode_text(text: torch.Tensor) -> torch.Tensor:
|
|
|
- return model.encode_text(text, normalize=True)
|
|
|
+ output = model.encode_text(text, normalize=True)
|
|
|
+ assert isinstance(output, torch.Tensor)
|
|
|
+ return output
|
|
|
|
|
|
args = (torch.ones(1, model_cfg.sequence_length, dtype=torch.int32),)
|
|
|
- traced = torch.jit.trace(encode_text, args)
|
|
|
+ traced = torch.jit.trace(encode_text, args) # type: ignore[no-untyped-call]
|
|
|
|
|
|
with warnings.catch_warnings():
|
|
|
warnings.simplefilter("ignore", UserWarning)
|