image_classification.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. from pathlib import Path
  2. from typing import Any
  3. from huggingface_hub import snapshot_download
  4. from PIL.Image import Image
  5. from transformers.pipelines import pipeline
  6. from ..config import settings
  7. from ..schemas import ModelType
  8. from .base import InferenceModel
  9. class ImageClassifier(InferenceModel):
  10. _model_type = ModelType.IMAGE_CLASSIFICATION
  11. def __init__(
  12. self,
  13. model_name: str,
  14. min_score: float = settings.min_tag_score,
  15. cache_dir: Path | str | None = None,
  16. **model_kwargs: Any,
  17. ) -> None:
  18. self.min_score = min_score
  19. super().__init__(model_name, cache_dir, **model_kwargs)
  20. def _download(self, **model_kwargs: Any) -> None:
  21. snapshot_download(
  22. cache_dir=self.cache_dir, repo_id=self.model_name, allow_patterns=["*.bin", "*.json", "*.txt"]
  23. )
  24. def _load(self, **model_kwargs: Any) -> None:
  25. self.model = pipeline(
  26. self.model_type.value,
  27. self.model_name,
  28. model_kwargs={"cache_dir": self.cache_dir, **model_kwargs},
  29. )
  30. def _predict(self, image: Image) -> list[str]:
  31. predictions: list[dict[str, Any]] = self.model(image) # type: ignore
  32. tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
  33. return tags