image_classification.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from io import BytesIO
  2. from pathlib import Path
  3. from typing import Any
  4. from huggingface_hub import snapshot_download
  5. from optimum.onnxruntime import ORTModelForImageClassification
  6. from optimum.pipelines import pipeline
  7. from PIL import Image
  8. from transformers import AutoImageProcessor
  9. from ..config import log
  10. from ..schemas import ModelType
  11. from .base import InferenceModel
  12. class ImageClassifier(InferenceModel):
  13. _model_type = ModelType.IMAGE_CLASSIFICATION
  14. def __init__(
  15. self,
  16. model_name: str,
  17. min_score: float = 0.9,
  18. cache_dir: Path | str | None = None,
  19. **model_kwargs: Any,
  20. ) -> None:
  21. self.min_score = model_kwargs.pop("minScore", min_score)
  22. super().__init__(model_name, cache_dir, **model_kwargs)
  23. def _download(self) -> None:
  24. snapshot_download(
  25. cache_dir=self.cache_dir,
  26. repo_id=self.model_name,
  27. allow_patterns=["*.bin", "*.json", "*.txt"],
  28. local_dir=self.cache_dir,
  29. local_dir_use_symlinks=True,
  30. )
  31. def _load(self) -> None:
  32. processor = AutoImageProcessor.from_pretrained(self.cache_dir, cache_dir=self.cache_dir)
  33. model_path = self.cache_dir / "model.onnx"
  34. model_kwargs = {
  35. "cache_dir": self.cache_dir,
  36. "provider": self.providers[0],
  37. "provider_options": self.provider_options[0],
  38. "session_options": self.sess_options,
  39. }
  40. if model_path.exists():
  41. model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs)
  42. self.model = pipeline(self.model_type.value, model, feature_extractor=processor)
  43. else:
  44. log.info(
  45. (
  46. f"ONNX model not found in cache directory for '{self.model_name}'."
  47. "Exporting optimized model for future use."
  48. ),
  49. )
  50. self.sess_options.optimized_model_filepath = model_path.as_posix()
  51. self.model = pipeline(
  52. self.model_type.value,
  53. self.model_name,
  54. model_kwargs=model_kwargs,
  55. feature_extractor=processor,
  56. )
  57. def _predict(self, image: Image.Image | bytes) -> list[str]:
  58. if isinstance(image, bytes):
  59. image = Image.open(BytesIO(image))
  60. predictions: list[dict[str, Any]] = self.model(image) # type: ignore
  61. tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
  62. return tags
  63. def configure(self, **model_kwargs: Any) -> None:
  64. self.min_score = model_kwargs.pop("minScore", self.min_score)