clip.py 995 B

12345678910111213141516171819202122232425262728293031323334353637
  1. from pathlib import Path
  2. from PIL.Image import Image
  3. from sentence_transformers import SentenceTransformer
  4. from ..schemas import ModelType
  5. from .base import InferenceModel
  6. class CLIPSTEncoder(InferenceModel):
  7. _model_type = ModelType.CLIP
  8. def __init__(
  9. self,
  10. model_name: str,
  11. cache_dir: Path | None = None,
  12. **model_kwargs,
  13. ):
  14. super().__init__(model_name, cache_dir)
  15. self.model = SentenceTransformer(
  16. self.model_name,
  17. cache_folder=self.cache_dir.as_posix(),
  18. **model_kwargs,
  19. )
  20. def predict(self, image_or_text: Image | str) -> list[float]:
  21. return self.model.encode(image_or_text).tolist()
  22. # stubs to allow different behavior between the two in the future
  23. # and handle loading different image and text clip models
  24. class CLIPSTVisionEncoder(CLIPSTEncoder):
  25. _model_type = ModelType.CLIP_VISION
  26. class CLIPSTTextEncoder(CLIPSTEncoder):
  27. _model_type = ModelType.CLIP_TEXT