image_classification.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from io import BytesIO
  2. from pathlib import Path
  3. from typing import Any
  4. from huggingface_hub import snapshot_download
  5. from optimum.onnxruntime import ORTModelForImageClassification
  6. from optimum.pipelines import pipeline
  7. from PIL import Image
  8. from transformers import AutoImageProcessor
  9. from ..schemas import ModelType
  10. from .base import InferenceModel
  11. class ImageClassifier(InferenceModel):
  12. _model_type = ModelType.IMAGE_CLASSIFICATION
  13. def __init__(
  14. self,
  15. model_name: str,
  16. min_score: float = 0.9,
  17. cache_dir: Path | str | None = None,
  18. **model_kwargs: Any,
  19. ) -> None:
  20. self.min_score = min_score
  21. super().__init__(model_name, cache_dir, **model_kwargs)
  22. def _download(self, **model_kwargs: Any) -> None:
  23. snapshot_download(
  24. cache_dir=self.cache_dir,
  25. repo_id=self.model_name,
  26. allow_patterns=["*.bin", "*.json", "*.txt"],
  27. local_dir=self.cache_dir,
  28. local_dir_use_symlinks=True,
  29. )
  30. def _load(self, **model_kwargs: Any) -> None:
  31. processor = AutoImageProcessor.from_pretrained(self.cache_dir)
  32. model_kwargs |= {
  33. "cache_dir": self.cache_dir,
  34. "provider": self.providers[0],
  35. "provider_options": self.provider_options[0],
  36. "session_options": self.sess_options,
  37. }
  38. model_path = self.cache_dir / "model.onnx"
  39. if model_path.exists():
  40. model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs)
  41. self.model = pipeline(self.model_type.value, model, feature_extractor=processor)
  42. else:
  43. self.sess_options.optimized_model_filepath = model_path.as_posix()
  44. self.model = pipeline(
  45. self.model_type.value,
  46. self.model_name,
  47. model_kwargs=model_kwargs,
  48. feature_extractor=processor,
  49. )
  50. def _predict(self, image: Image.Image | bytes) -> list[str]:
  51. if isinstance(image, bytes):
  52. image = Image.open(BytesIO(image))
  53. predictions: list[dict[str, Any]] = self.model(image) # type: ignore
  54. tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
  55. return tags
  56. def configure(self, **model_kwargs: Any) -> None:
  57. self.min_score = model_kwargs.get("min_score", self.min_score)