base.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from __future__ import annotations
  2. import pickle
  3. from abc import ABC, abstractmethod
  4. from pathlib import Path
  5. from shutil import rmtree
  6. from typing import Any
  7. from zipfile import BadZipFile
  8. import onnxruntime as ort
  9. from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore
  10. from ..config import get_cache_dir, log, settings
  11. from ..schemas import ModelType
  12. class InferenceModel(ABC):
  13. _model_type: ModelType
  14. def __init__(
  15. self,
  16. model_name: str,
  17. cache_dir: Path | str | None = None,
  18. eager: bool = True,
  19. inter_op_num_threads: int = settings.model_inter_op_threads,
  20. intra_op_num_threads: int = settings.model_intra_op_threads,
  21. **model_kwargs: Any,
  22. ) -> None:
  23. self.model_name = model_name
  24. self._loaded = False
  25. self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type)
  26. loader = self.load if eager else self.download
  27. self.providers = model_kwargs.pop("providers", ["CPUExecutionProvider"])
  28. # don't pre-allocate more memory than needed
  29. self.provider_options = model_kwargs.pop(
  30. "provider_options", [{"arena_extend_strategy": "kSameAsRequested"}] * len(self.providers)
  31. )
  32. log.debug(
  33. (
  34. f"Setting '{self.model_name}' execution providers to {self.providers}"
  35. "in descending order of preference"
  36. ),
  37. )
  38. log.debug(f"Setting execution provider options to {self.provider_options}")
  39. self.sess_options = PicklableSessionOptions()
  40. # avoid thread contention between models
  41. if inter_op_num_threads > 1:
  42. self.sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
  43. log.debug(f"Setting execution_mode to {self.sess_options.execution_mode.name}")
  44. log.debug(f"Setting inter_op_num_threads to {inter_op_num_threads}")
  45. log.debug(f"Setting intra_op_num_threads to {intra_op_num_threads}")
  46. self.sess_options.inter_op_num_threads = inter_op_num_threads
  47. self.sess_options.intra_op_num_threads = intra_op_num_threads
  48. self.sess_options.enable_cpu_mem_arena = False
  49. try:
  50. loader(**model_kwargs)
  51. except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
  52. log.warn(
  53. (
  54. f"Failed to load {self.model_type.replace('_', ' ')} model '{self.model_name}'."
  55. "Clearing cache and retrying."
  56. )
  57. )
  58. self.clear_cache()
  59. loader(**model_kwargs)
  60. def download(self, **model_kwargs: Any) -> None:
  61. if not self.cached:
  62. log.info(
  63. (f"Downloading {self.model_type.replace('_', ' ')} model '{self.model_name}'." "This may take a while.")
  64. )
  65. self._download(**model_kwargs)
  66. def load(self, **model_kwargs: Any) -> None:
  67. self.download(**model_kwargs)
  68. self._load(**model_kwargs)
  69. self._loaded = True
  70. def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
  71. if not self._loaded:
  72. log.info(f"Loading {self.model_type.replace('_', ' ')} model '{self.model_name}'")
  73. self.load()
  74. if model_kwargs:
  75. self.configure(**model_kwargs)
  76. return self._predict(inputs)
  77. @abstractmethod
  78. def _predict(self, inputs: Any) -> Any:
  79. ...
  80. def configure(self, **model_kwargs: Any) -> None:
  81. pass
  82. @abstractmethod
  83. def _download(self, **model_kwargs: Any) -> None:
  84. ...
  85. @abstractmethod
  86. def _load(self, **model_kwargs: Any) -> None:
  87. ...
  88. @property
  89. def model_type(self) -> ModelType:
  90. return self._model_type
  91. @property
  92. def cache_dir(self) -> Path:
  93. return self._cache_dir
  94. @cache_dir.setter
  95. def cache_dir(self, cache_dir: Path) -> None:
  96. self._cache_dir = cache_dir
  97. @property
  98. def cached(self) -> bool:
  99. return self.cache_dir.exists() and any(self.cache_dir.iterdir())
  100. @classmethod
  101. def from_model_type(cls, model_type: ModelType, model_name: str, **model_kwargs: Any) -> InferenceModel:
  102. subclasses = {subclass._model_type: subclass for subclass in cls.__subclasses__()}
  103. if model_type not in subclasses:
  104. raise ValueError(f"Unsupported model type: {model_type}")
  105. return subclasses[model_type](model_name, **model_kwargs)
  106. def clear_cache(self) -> None:
  107. if not self.cache_dir.exists():
  108. log.warn(
  109. f"Attempted to clear cache for model '{self.model_name}' but cache directory does not exist.",
  110. )
  111. return
  112. if not rmtree.avoids_symlink_attacks:
  113. raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform.")
  114. if self.cache_dir.is_dir():
  115. log.info(f"Cleared cache directory for model '{self.model_name}'.")
  116. rmtree(self.cache_dir)
  117. else:
  118. log.warn(
  119. (
  120. f"Encountered file instead of directory at cache path "
  121. f"for '{self.model_name}'. Removing file and replacing with a directory."
  122. ),
  123. )
  124. self.cache_dir.unlink()
  125. self.cache_dir.mkdir(parents=True, exist_ok=True)
  126. # HF deep copies configs, so we need to make session options picklable
  127. class PicklableSessionOptions(ort.SessionOptions):
  128. def __getstate__(self) -> bytes:
  129. return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
  130. def __setstate__(self, state: Any) -> None:
  131. self.__init__() # type: ignore
  132. for attr, val in pickle.loads(state):
  133. setattr(self, attr, val)