|
@@ -1,6 +1,5 @@
|
|
|
from __future__ import annotations
|
|
|
|
|
|
-import os
|
|
|
import pickle
|
|
|
from abc import ABC, abstractmethod
|
|
|
from pathlib import Path
|
|
@@ -11,7 +10,7 @@ from zipfile import BadZipFile
|
|
|
import onnxruntime as ort
|
|
|
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf # type: ignore
|
|
|
|
|
|
-from ..config import get_cache_dir, settings
|
|
|
+from ..config import get_cache_dir, log, settings
|
|
|
from ..schemas import ModelType
|
|
|
|
|
|
|
|
@@ -37,22 +36,41 @@ class InferenceModel(ABC):
|
|
|
self.provider_options = model_kwargs.pop(
|
|
|
"provider_options", [{"arena_extend_strategy": "kSameAsRequested"}] * len(self.providers)
|
|
|
)
|
|
|
+ log.debug(
|
|
|
+ (
|
|
|
+ f"Setting '{self.model_name}' execution providers to {self.providers}"
|
|
|
+ "in descending order of preference"
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ log.debug(f"Setting execution provider options to {self.provider_options}")
|
|
|
self.sess_options = PicklableSessionOptions()
|
|
|
# avoid thread contention between models
|
|
|
if inter_op_num_threads > 1:
|
|
|
self.sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
|
|
|
+
|
|
|
+ log.debug(f"Setting execution_mode to {self.sess_options.execution_mode.name}")
|
|
|
+ log.debug(f"Setting inter_op_num_threads to {inter_op_num_threads}")
|
|
|
+ log.debug(f"Setting intra_op_num_threads to {intra_op_num_threads}")
|
|
|
self.sess_options.inter_op_num_threads = inter_op_num_threads
|
|
|
self.sess_options.intra_op_num_threads = intra_op_num_threads
|
|
|
|
|
|
try:
|
|
|
loader(**model_kwargs)
|
|
|
except (OSError, InvalidProtobuf, BadZipFile):
|
|
|
+ log.warn(
|
|
|
+ (
|
|
|
+ f"Failed to load {self.model_type.replace('_', ' ')} model '{self.model_name}'."
|
|
|
+ "Clearing cache and retrying."
|
|
|
+ )
|
|
|
+ )
|
|
|
self.clear_cache()
|
|
|
loader(**model_kwargs)
|
|
|
|
|
|
def download(self, **model_kwargs: Any) -> None:
|
|
|
if not self.cached:
|
|
|
- print(f"Downloading {self.model_type.value.replace('_', ' ')} model. This may take a while...")
|
|
|
+ log.info(
|
|
|
+ (f"Downloading {self.model_type.replace('_', ' ')} model '{self.model_name}'." "This may take a while.")
|
|
|
+ )
|
|
|
self._download(**model_kwargs)
|
|
|
|
|
|
def load(self, **model_kwargs: Any) -> None:
|
|
@@ -62,7 +80,7 @@ class InferenceModel(ABC):
|
|
|
|
|
|
def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
|
|
|
if not self._loaded:
|
|
|
- print(f"Loading {self.model_type.value.replace('_', ' ')} model...")
|
|
|
+ log.info(f"Loading {self.model_type.replace('_', ' ')} model '{self.model_name}'")
|
|
|
self.load()
|
|
|
if model_kwargs:
|
|
|
self.configure(**model_kwargs)
|
|
@@ -109,13 +127,23 @@ class InferenceModel(ABC):
|
|
|
|
|
|
def clear_cache(self) -> None:
|
|
|
if not self.cache_dir.exists():
|
|
|
+ log.warn(
|
|
|
+ f"Attempted to clear cache for model '{self.model_name}' but cache directory does not exist.",
|
|
|
+ )
|
|
|
return
|
|
|
if not rmtree.avoids_symlink_attacks:
|
|
|
raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform.")
|
|
|
|
|
|
if self.cache_dir.is_dir():
|
|
|
+ log.info(f"Cleared cache directory for model '{self.model_name}'.")
|
|
|
rmtree(self.cache_dir)
|
|
|
else:
|
|
|
+ log.warn(
|
|
|
+ (
|
|
|
+ f"Encountered file instead of directory at cache path "
|
|
|
+ f"for '{self.model_name}'. Removing file and replacing with a directory."
|
|
|
+ ),
|
|
|
+ )
|
|
|
self.cache_dir.unlink()
|
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|