123456789101112131415161718192021222324252627282930313233343536373839404142 |
- from pathlib import Path
- from typing import Any
- from huggingface_hub import snapshot_download
- from PIL.Image import Image
- from transformers.pipelines import pipeline
- from ..config import settings
- from ..schemas import ModelType
- from .base import InferenceModel
- class ImageClassifier(InferenceModel):
- _model_type = ModelType.IMAGE_CLASSIFICATION
- def __init__(
- self,
- model_name: str,
- min_score: float = settings.min_tag_score,
- cache_dir: Path | str | None = None,
- **model_kwargs: Any,
- ) -> None:
- self.min_score = min_score
- super().__init__(model_name, cache_dir, **model_kwargs)
- def _download(self, **model_kwargs: Any) -> None:
- snapshot_download(
- cache_dir=self.cache_dir, repo_id=self.model_name, allow_patterns=["*.bin", "*.json", "*.txt"]
- )
- def _load(self, **model_kwargs: Any) -> None:
- self.model = pipeline(
- self.model_type.value,
- self.model_name,
- model_kwargs={"cache_dir": self.cache_dir, **model_kwargs},
- )
- def _predict(self, image: Image) -> list[str]:
- predictions: list[dict[str, Any]] = self.model(image) # type: ignore
- tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
- return tags
|