clip.py 604 B

12345678910111213141516171819202122
  1. from pathlib import Path
  2. from typing import Any
  3. from PIL.Image import Image
  4. from sentence_transformers import SentenceTransformer
  5. from ..schemas import ModelType
  6. from .base import InferenceModel
  7. class CLIPSTEncoder(InferenceModel):
  8. _model_type = ModelType.CLIP
  9. def load(self, **model_kwargs: Any) -> None:
  10. self.model = SentenceTransformer(
  11. self.model_name,
  12. cache_folder=self.cache_dir.as_posix(),
  13. **model_kwargs,
  14. )
  15. def predict(self, image_or_text: Image | str) -> list[float]:
  16. return self.model.encode(image_or_text).tolist()