base.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. from __future__ import annotations
  2. from abc import abstractmethod, ABC
  3. from pathlib import Path
  4. from typing import Any
  5. from ..config import get_cache_dir
  6. from ..schemas import ModelType
  7. class InferenceModel(ABC):
  8. _model_type: ModelType
  9. def __init__(
  10. self,
  11. model_name: str,
  12. cache_dir: Path | None = None,
  13. ):
  14. self.model_name = model_name
  15. self._cache_dir = (
  16. cache_dir
  17. if cache_dir is not None
  18. else get_cache_dir(model_name, self.model_type)
  19. )
  20. @abstractmethod
  21. def predict(self, inputs: Any) -> Any:
  22. ...
  23. @property
  24. def model_type(self) -> ModelType:
  25. return self._model_type
  26. @property
  27. def cache_dir(self) -> Path:
  28. return self._cache_dir
  29. @cache_dir.setter
  30. def cache_dir(self, cache_dir: Path):
  31. self._cache_dir = cache_dir
  32. @classmethod
  33. def from_model_type(
  34. cls, model_type: ModelType, model_name, **model_kwargs
  35. ) -> InferenceModel:
  36. subclasses = {
  37. subclass._model_type: subclass for subclass in cls.__subclasses__()
  38. }
  39. if model_type not in subclasses:
  40. raise ValueError(f"Unsupported model type: {model_type}")
  41. return subclasses[model_type](model_name, **model_kwargs)