image_classification.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536
  1. from pathlib import Path
  2. from typing import Any
  3. from PIL.Image import Image
  4. from transformers.pipelines import pipeline
  5. from ..config import settings
  6. from ..schemas import ModelType
  7. from .base import InferenceModel
  8. class ImageClassifier(InferenceModel):
  9. _model_type = ModelType.IMAGE_CLASSIFICATION
  10. def __init__(
  11. self,
  12. model_name: str,
  13. min_score: float = settings.min_tag_score,
  14. cache_dir: Path | str | None = None,
  15. **model_kwargs: Any,
  16. ) -> None:
  17. self.min_score = min_score
  18. super().__init__(model_name, cache_dir, **model_kwargs)
  19. def load(self, **model_kwargs: Any) -> None:
  20. self.model = pipeline(
  21. self.model_type.value,
  22. self.model_name,
  23. model_kwargs={"cache_dir": self.cache_dir, **model_kwargs},
  24. )
  25. def predict(self, image: Image) -> list[str]:
  26. predictions: list[dict[str, Any]] = self.model(image) # type: ignore
  27. tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
  28. return tags