base.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from __future__ import annotations
  2. import os
  3. import pickle
  4. from abc import ABC, abstractmethod
  5. from pathlib import Path
  6. from shutil import rmtree
  7. from typing import Any
  8. from zipfile import BadZipFile
  9. import onnxruntime as ort
  10. from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf # type: ignore
  11. from ..config import get_cache_dir, settings
  12. from ..schemas import ModelType
  13. class InferenceModel(ABC):
  14. _model_type: ModelType
  15. def __init__(
  16. self,
  17. model_name: str,
  18. cache_dir: Path | str | None = None,
  19. eager: bool = True,
  20. inter_op_num_threads: int = settings.model_inter_op_threads,
  21. intra_op_num_threads: int = settings.model_intra_op_threads,
  22. **model_kwargs: Any,
  23. ) -> None:
  24. self.model_name = model_name
  25. self._loaded = False
  26. self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type)
  27. loader = self.load if eager else self.download
  28. self.providers = model_kwargs.pop("providers", ["CPUExecutionProvider"])
  29. # don't pre-allocate more memory than needed
  30. self.provider_options = model_kwargs.pop(
  31. "provider_options", [{"arena_extend_strategy": "kSameAsRequested"}] * len(self.providers)
  32. )
  33. self.sess_options = PicklableSessionOptions()
  34. # avoid thread contention between models
  35. if inter_op_num_threads > 1:
  36. self.sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
  37. self.sess_options.inter_op_num_threads = inter_op_num_threads
  38. self.sess_options.intra_op_num_threads = intra_op_num_threads
  39. try:
  40. loader(**model_kwargs)
  41. except (OSError, InvalidProtobuf, BadZipFile):
  42. self.clear_cache()
  43. loader(**model_kwargs)
  44. def download(self, **model_kwargs: Any) -> None:
  45. if not self.cached:
  46. print(f"Downloading {self.model_type.value.replace('_', ' ')} model. This may take a while...")
  47. self._download(**model_kwargs)
  48. def load(self, **model_kwargs: Any) -> None:
  49. self.download(**model_kwargs)
  50. self._load(**model_kwargs)
  51. self._loaded = True
  52. def predict(self, inputs: Any) -> Any:
  53. if not self._loaded:
  54. print(f"Loading {self.model_type.value.replace('_', ' ')} model...")
  55. self.load()
  56. return self._predict(inputs)
  57. @abstractmethod
  58. def _predict(self, inputs: Any) -> Any:
  59. ...
  60. @abstractmethod
  61. def _download(self, **model_kwargs: Any) -> None:
  62. ...
  63. @abstractmethod
  64. def _load(self, **model_kwargs: Any) -> None:
  65. ...
  66. @property
  67. def model_type(self) -> ModelType:
  68. return self._model_type
  69. @property
  70. def cache_dir(self) -> Path:
  71. return self._cache_dir
  72. @cache_dir.setter
  73. def cache_dir(self, cache_dir: Path) -> None:
  74. self._cache_dir = cache_dir
  75. @property
  76. def cached(self) -> bool:
  77. return self.cache_dir.exists() and any(self.cache_dir.iterdir())
  78. @classmethod
  79. def from_model_type(cls, model_type: ModelType, model_name: str, **model_kwargs: Any) -> InferenceModel:
  80. subclasses = {subclass._model_type: subclass for subclass in cls.__subclasses__()}
  81. if model_type not in subclasses:
  82. raise ValueError(f"Unsupported model type: {model_type}")
  83. return subclasses[model_type](model_name, **model_kwargs)
  84. def clear_cache(self) -> None:
  85. if not self.cache_dir.exists():
  86. return
  87. if not rmtree.avoids_symlink_attacks:
  88. raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform.")
  89. if self.cache_dir.is_dir():
  90. rmtree(self.cache_dir)
  91. else:
  92. self.cache_dir.unlink()
  93. self.cache_dir.mkdir(parents=True, exist_ok=True)
  94. # HF deep copies configs, so we need to make session options picklable
  95. class PicklableSessionOptions(ort.SessionOptions):
  96. def __getstate__(self) -> bytes:
  97. return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
  98. def __setstate__(self, state: Any) -> None:
  99. self.__init__() # type: ignore
  100. for attr, val in pickle.loads(state):
  101. setattr(self, attr, val)