base.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. from __future__ import annotations
  2. import pickle
  3. from abc import ABC, abstractmethod
  4. from pathlib import Path
  5. from shutil import rmtree
  6. from typing import Any
  7. import onnxruntime as ort
  8. from ..config import get_cache_dir, log, settings
  9. from ..schemas import ModelType
  10. class InferenceModel(ABC):
  11. _model_type: ModelType
  12. def __init__(
  13. self,
  14. model_name: str,
  15. cache_dir: Path | str | None = None,
  16. inter_op_num_threads: int = settings.model_inter_op_threads,
  17. intra_op_num_threads: int = settings.model_intra_op_threads,
  18. **model_kwargs: Any,
  19. ) -> None:
  20. self.model_name = model_name
  21. self.loaded = False
  22. self._cache_dir = Path(cache_dir) if cache_dir is not None else None
  23. self.providers = model_kwargs.pop("providers", ["CPUExecutionProvider"])
  24. # don't pre-allocate more memory than needed
  25. self.provider_options = model_kwargs.pop(
  26. "provider_options", [{"arena_extend_strategy": "kSameAsRequested"}] * len(self.providers)
  27. )
  28. log.debug(
  29. (
  30. f"Setting '{self.model_name}' execution providers to {self.providers}"
  31. "in descending order of preference"
  32. ),
  33. )
  34. log.debug(f"Setting execution provider options to {self.provider_options}")
  35. self.sess_options = PicklableSessionOptions()
  36. # avoid thread contention between models
  37. if inter_op_num_threads > 1:
  38. self.sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
  39. log.debug(f"Setting execution_mode to {self.sess_options.execution_mode.name}")
  40. log.debug(f"Setting inter_op_num_threads to {inter_op_num_threads}")
  41. log.debug(f"Setting intra_op_num_threads to {intra_op_num_threads}")
  42. self.sess_options.inter_op_num_threads = inter_op_num_threads
  43. self.sess_options.intra_op_num_threads = intra_op_num_threads
  44. self.sess_options.enable_cpu_mem_arena = False
  45. def download(self) -> None:
  46. if not self.cached:
  47. log.info(
  48. (f"Downloading {self.model_type.replace('-', ' ')} model '{self.model_name}'." "This may take a while.")
  49. )
  50. self._download()
  51. def load(self) -> None:
  52. if self.loaded:
  53. return
  54. self.download()
  55. log.info(f"Loading {self.model_type.replace('-', ' ')} model '{self.model_name}'")
  56. self._load()
  57. self.loaded = True
  58. def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
  59. self.load()
  60. if model_kwargs:
  61. self.configure(**model_kwargs)
  62. return self._predict(inputs)
  63. @abstractmethod
  64. def _predict(self, inputs: Any) -> Any:
  65. ...
  66. def configure(self, **model_kwargs: Any) -> None:
  67. pass
  68. @abstractmethod
  69. def _download(self) -> None:
  70. ...
  71. @abstractmethod
  72. def _load(self) -> None:
  73. ...
  74. @property
  75. def model_type(self) -> ModelType:
  76. return self._model_type
  77. @property
  78. def cache_dir(self) -> Path:
  79. return self._cache_dir if self._cache_dir is not None else get_cache_dir(self.model_name, self.model_type)
  80. @cache_dir.setter
  81. def cache_dir(self, cache_dir: Path) -> None:
  82. self._cache_dir = cache_dir
  83. @property
  84. def cached(self) -> bool:
  85. return self.cache_dir.exists() and any(self.cache_dir.iterdir())
  86. @classmethod
  87. def from_model_type(cls, model_type: ModelType, model_name: str, **model_kwargs: Any) -> InferenceModel:
  88. subclasses = {subclass._model_type: subclass for subclass in cls.__subclasses__()}
  89. if model_type not in subclasses:
  90. raise ValueError(f"Unsupported model type: {model_type}")
  91. return subclasses[model_type](model_name, **model_kwargs)
  92. def clear_cache(self) -> None:
  93. if not self.cache_dir.exists():
  94. log.warn(
  95. f"Attempted to clear cache for model '{self.model_name}' but cache directory does not exist.",
  96. )
  97. return
  98. if not rmtree.avoids_symlink_attacks:
  99. raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform.")
  100. if self.cache_dir.is_dir():
  101. log.info(f"Cleared cache directory for model '{self.model_name}'.")
  102. rmtree(self.cache_dir)
  103. else:
  104. log.warn(
  105. (
  106. f"Encountered file instead of directory at cache path "
  107. f"for '{self.model_name}'. Removing file and replacing with a directory."
  108. ),
  109. )
  110. self.cache_dir.unlink()
  111. self.cache_dir.mkdir(parents=True, exist_ok=True)
  112. # HF deep copies configs, so we need to make session options picklable
  113. class PicklableSessionOptions(ort.SessionOptions):
  114. def __getstate__(self) -> bytes:
  115. return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
  116. def __setstate__(self, state: Any) -> None:
  117. self.__init__() # type: ignore
  118. for attr, val in pickle.loads(state):
  119. setattr(self, attr, val)