base.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from __future__ import annotations
  2. from abc import ABC, abstractmethod
  3. from pathlib import Path
  4. from shutil import rmtree
  5. from typing import Any
  6. from zipfile import BadZipFile
  7. from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf # type: ignore
  8. from ..config import get_cache_dir
  9. from ..schemas import ModelType
  10. class InferenceModel(ABC):
  11. _model_type: ModelType
  12. def __init__(
  13. self, model_name: str, cache_dir: Path | str | None = None, eager: bool = True, **model_kwargs: Any
  14. ) -> None:
  15. self.model_name = model_name
  16. self._loaded = False
  17. self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type)
  18. loader = self.load if eager else self.download
  19. try:
  20. loader(**model_kwargs)
  21. except (OSError, InvalidProtobuf, BadZipFile):
  22. self.clear_cache()
  23. loader(**model_kwargs)
  24. def download(self, **model_kwargs: Any) -> None:
  25. if not self.cached:
  26. self._download(**model_kwargs)
  27. def load(self, **model_kwargs: Any) -> None:
  28. self.download(**model_kwargs)
  29. self._load(**model_kwargs)
  30. self._loaded = True
  31. def predict(self, inputs: Any) -> Any:
  32. if not self._loaded:
  33. self.load()
  34. return self._predict(inputs)
  35. @abstractmethod
  36. def _predict(self, inputs: Any) -> Any:
  37. ...
  38. @abstractmethod
  39. def _download(self, **model_kwargs: Any) -> None:
  40. ...
  41. @abstractmethod
  42. def _load(self, **model_kwargs: Any) -> None:
  43. ...
  44. @property
  45. def model_type(self) -> ModelType:
  46. return self._model_type
  47. @property
  48. def cache_dir(self) -> Path:
  49. return self._cache_dir
  50. @cache_dir.setter
  51. def cache_dir(self, cache_dir: Path) -> None:
  52. self._cache_dir = cache_dir
  53. @property
  54. def cached(self) -> bool:
  55. return self.cache_dir.exists() and any(self.cache_dir.iterdir())
  56. @classmethod
  57. def from_model_type(cls, model_type: ModelType, model_name: str, **model_kwargs: Any) -> InferenceModel:
  58. subclasses = {subclass._model_type: subclass for subclass in cls.__subclasses__()}
  59. if model_type not in subclasses:
  60. raise ValueError(f"Unsupported model type: {model_type}")
  61. return subclasses[model_type](model_name, **model_kwargs)
  62. def clear_cache(self) -> None:
  63. if not self.cache_dir.exists():
  64. return
  65. if not rmtree.avoids_symlink_attacks:
  66. raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform.")
  67. if self.cache_dir.is_dir():
  68. rmtree(self.cache_dir)
  69. else:
  70. self.cache_dir.unlink()
  71. self.cache_dir.mkdir(parents=True, exist_ok=True)