base.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from __future__ import annotations
  2. import os
  3. import pickle
  4. from abc import ABC, abstractmethod
  5. from pathlib import Path
  6. from shutil import rmtree
  7. from typing import Any
  8. from zipfile import BadZipFile
  9. import onnxruntime as ort
  10. from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf # type: ignore
  11. from ..config import get_cache_dir, settings
  12. from ..schemas import ModelType
  13. class InferenceModel(ABC):
  14. _model_type: ModelType
  15. def __init__(
  16. self,
  17. model_name: str,
  18. cache_dir: Path | str | None = None,
  19. eager: bool = True,
  20. inter_op_num_threads: int = settings.model_inter_op_threads,
  21. intra_op_num_threads: int = settings.model_intra_op_threads,
  22. **model_kwargs: Any,
  23. ) -> None:
  24. self.model_name = model_name
  25. self._loaded = False
  26. self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type)
  27. loader = self.load if eager else self.download
  28. self.providers = model_kwargs.pop("providers", ["CPUExecutionProvider"])
  29. # don't pre-allocate more memory than needed
  30. self.provider_options = model_kwargs.pop(
  31. "provider_options", [{"arena_extend_strategy": "kSameAsRequested"}] * len(self.providers)
  32. )
  33. self.sess_options = PicklableSessionOptions()
  34. # avoid thread contention between models
  35. if inter_op_num_threads > 1:
  36. self.sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
  37. self.sess_options.inter_op_num_threads = inter_op_num_threads
  38. self.sess_options.intra_op_num_threads = intra_op_num_threads
  39. try:
  40. loader(**model_kwargs)
  41. except (OSError, InvalidProtobuf, BadZipFile):
  42. self.clear_cache()
  43. loader(**model_kwargs)
  44. def download(self, **model_kwargs: Any) -> None:
  45. if not self.cached:
  46. print(f"Downloading {self.model_type.value.replace('_', ' ')} model. This may take a while...")
  47. self._download(**model_kwargs)
  48. def load(self, **model_kwargs: Any) -> None:
  49. self.download(**model_kwargs)
  50. self._load(**model_kwargs)
  51. self._loaded = True
  52. def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
  53. if not self._loaded:
  54. print(f"Loading {self.model_type.value.replace('_', ' ')} model...")
  55. self.load()
  56. if model_kwargs:
  57. self.configure(**model_kwargs)
  58. return self._predict(inputs)
  59. @abstractmethod
  60. def _predict(self, inputs: Any) -> Any:
  61. ...
  62. def configure(self, **model_kwargs: Any) -> None:
  63. pass
  64. @abstractmethod
  65. def _download(self, **model_kwargs: Any) -> None:
  66. ...
  67. @abstractmethod
  68. def _load(self, **model_kwargs: Any) -> None:
  69. ...
  70. @property
  71. def model_type(self) -> ModelType:
  72. return self._model_type
  73. @property
  74. def cache_dir(self) -> Path:
  75. return self._cache_dir
  76. @cache_dir.setter
  77. def cache_dir(self, cache_dir: Path) -> None:
  78. self._cache_dir = cache_dir
  79. @property
  80. def cached(self) -> bool:
  81. return self.cache_dir.exists() and any(self.cache_dir.iterdir())
  82. @classmethod
  83. def from_model_type(cls, model_type: ModelType, model_name: str, **model_kwargs: Any) -> InferenceModel:
  84. subclasses = {subclass._model_type: subclass for subclass in cls.__subclasses__()}
  85. if model_type not in subclasses:
  86. raise ValueError(f"Unsupported model type: {model_type}")
  87. return subclasses[model_type](model_name, **model_kwargs)
  88. def clear_cache(self) -> None:
  89. if not self.cache_dir.exists():
  90. return
  91. if not rmtree.avoids_symlink_attacks:
  92. raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform.")
  93. if self.cache_dir.is_dir():
  94. rmtree(self.cache_dir)
  95. else:
  96. self.cache_dir.unlink()
  97. self.cache_dir.mkdir(parents=True, exist_ok=True)
  98. # HF deep copies configs, so we need to make session options picklable
  99. class PicklableSessionOptions(ort.SessionOptions):
  100. def __getstate__(self) -> bytes:
  101. return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
  102. def __setstate__(self, state: Any) -> None:
  103. self.__init__() # type: ignore
  104. for attr, val in pickle.loads(state):
  105. setattr(self, attr, val)