From d8ecefaea55b62d90d0c746b53f442f04abd112a Mon Sep 17 00:00:00 2001 From: Mert <101130780+mertalev@users.noreply.github.com> Date: Tue, 10 Oct 2023 13:26:30 -0400 Subject: [PATCH] chore(ml): removed vit-b check and st warning (#4422) --- machine-learning/app/models/clip.py | 28 +++------------------------- 1 file changed, 3 insertions(+), 25 deletions(-) diff --git a/machine-learning/app/models/clip.py b/machine-learning/app/models/clip.py index 7d7bc5220..c1ce1801b 100644 --- a/machine-learning/app/models/clip.py +++ b/machine-learning/app/models/clip.py @@ -16,13 +16,6 @@ from ..config import log from ..schemas import ModelType from .base import InferenceModel -_ST_TO_JINA_MODEL_NAME = { - "clip-ViT-B-16": "ViT-B-16::openai", - "clip-ViT-B-32": "ViT-B-32::openai", - "clip-ViT-B-32-multilingual-v1": "M-CLIP/XLM-Roberta-Large-Vit-B-32", - "clip-ViT-L-14": "ViT-L-14::openai", -} - class CLIPEncoder(InferenceModel): _model_type = ModelType.CLIP @@ -36,11 +29,10 @@ class CLIPEncoder(InferenceModel): ) -> None: if mode is not None and mode not in ("text", "vision"): raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'") - if "vit-b" not in model_name.lower(): - raise ValueError(f"Only ViT-B models are currently supported; got '{model_name}'") + if model_name not in _MODELS: + raise ValueError(f"Unknown model name {model_name}.") self.mode = mode - jina_model_name = self._get_jina_model_name(model_name) - super().__init__(jina_model_name, cache_dir, **model_kwargs) + super().__init__(model_name, cache_dir, **model_kwargs) def _download(self) -> None: models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name] @@ -104,20 +96,6 @@ class CLIPEncoder(InferenceModel): return outputs[0][0].tolist() - def _get_jina_model_name(self, model_name: str) -> str: - if model_name in _MODELS: - return model_name - elif model_name in _ST_TO_JINA_MODEL_NAME: - log.warn( - ( - f"Sentence-Transformer models like '{model_name}' are not supported." - f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'." - ), - ) - return _ST_TO_JINA_MODEL_NAME[model_name] - else: - raise ValueError(f"Unknown model name {model_name}.") - def _download_model(self, model_name: str, model_md5: str) -> bool: # downloading logic is adapted from clip-server's CLIPOnnxModel class download_model(