image_classification.py 1.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. from pathlib import Path
  2. from PIL.Image import Image
  3. from transformers.pipelines import pipeline
  4. from ..config import settings
  5. from ..schemas import ModelType
  6. from .base import InferenceModel
  7. class ImageClassifier(InferenceModel):
  8. _model_type = ModelType.IMAGE_CLASSIFICATION
  9. def __init__(
  10. self,
  11. model_name: str,
  12. min_score: float = settings.min_tag_score,
  13. cache_dir: Path | None = None,
  14. **model_kwargs,
  15. ):
  16. super().__init__(model_name, cache_dir)
  17. self.min_score = min_score
  18. self.model = pipeline(
  19. self.model_type.value,
  20. self.model_name,
  21. model_kwargs={"cache_dir": self.cache_dir, **model_kwargs},
  22. )
  23. def predict(self, image: Image) -> list[str]:
  24. predictions = self.model(image)
  25. tags = list(
  26. {
  27. tag
  28. for pred in predictions
  29. for tag in pred["label"].split(", ")
  30. if pred["score"] >= self.min_score
  31. }
  32. )
  33. return tags