image_classification.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from pathlib import Path
  2. from typing import Any
  3. from huggingface_hub import snapshot_download
  4. from optimum.onnxruntime import ORTModelForImageClassification
  5. from optimum.pipelines import pipeline
  6. from PIL.Image import Image
  7. from transformers import AutoImageProcessor
  8. from ..config import settings
  9. from ..schemas import ModelType
  10. from .base import InferenceModel
  11. class ImageClassifier(InferenceModel):
  12. _model_type = ModelType.IMAGE_CLASSIFICATION
  13. def __init__(
  14. self,
  15. model_name: str,
  16. min_score: float = settings.min_tag_score,
  17. cache_dir: Path | str | None = None,
  18. **model_kwargs: Any,
  19. ) -> None:
  20. self.min_score = min_score
  21. super().__init__(model_name, cache_dir, **model_kwargs)
  22. def _download(self, **model_kwargs: Any) -> None:
  23. snapshot_download(
  24. cache_dir=self.cache_dir,
  25. repo_id=self.model_name,
  26. allow_patterns=["*.bin", "*.json", "*.txt"],
  27. local_dir=self.cache_dir,
  28. local_dir_use_symlinks=True,
  29. )
  30. def _load(self, **model_kwargs: Any) -> None:
  31. processor = AutoImageProcessor.from_pretrained(self.cache_dir)
  32. model_kwargs |= {
  33. "cache_dir": self.cache_dir,
  34. "provider": self.providers[0],
  35. "provider_options": self.provider_options[0],
  36. "session_options": self.sess_options,
  37. }
  38. model_path = self.cache_dir / "model.onnx"
  39. if model_path.exists():
  40. model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs)
  41. self.model = pipeline(self.model_type.value, model, feature_extractor=processor)
  42. else:
  43. self.sess_options.optimized_model_filepath = model_path.as_posix()
  44. self.model = pipeline(
  45. self.model_type.value,
  46. self.model_name,
  47. model_kwargs=model_kwargs,
  48. feature_extractor=processor,
  49. )
  50. def _predict(self, image: Image) -> list[str]:
  51. predictions: list[dict[str, Any]] = self.model(image) # type: ignore
  52. tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
  53. return tags