Explorar o código

fix(ml): clear model cache on load error (#2951)

* clear model cache on load error

* updated caught exceptions
Mert %!s(int64=2) %!d(string=hai) anos
pai
achega
47982641b2

+ 26 - 5
machine-learning/app/models/base.py

@@ -2,8 +2,11 @@ from __future__ import annotations
 
 
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from pathlib import Path
 from pathlib import Path
+from shutil import rmtree
 from typing import Any
 from typing import Any
 
 
+from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf
+
 from ..config import get_cache_dir
 from ..config import get_cache_dir
 from ..schemas import ModelType
 from ..schemas import ModelType
 
 
@@ -12,10 +15,8 @@ class InferenceModel(ABC):
     _model_type: ModelType
     _model_type: ModelType
 
 
     def __init__(
     def __init__(
-        self,
-        model_name: str,
-        cache_dir: Path | None = None,
-    ):
+        self, model_name: str, cache_dir: Path | None = None, **model_kwargs
+    ) -> None:
         self.model_name = model_name
         self.model_name = model_name
         self._cache_dir = (
         self._cache_dir = (
             cache_dir
             cache_dir
@@ -23,6 +24,16 @@ class InferenceModel(ABC):
             else get_cache_dir(model_name, self.model_type)
             else get_cache_dir(model_name, self.model_type)
         )
         )
 
 
+        try:
+            self.load(**model_kwargs)
+        except (OSError, InvalidProtobuf):
+            self.clear_cache()
+            self.load(**model_kwargs)
+
+    @abstractmethod
+    def load(self, **model_kwargs: Any) -> None:
+        ...
+
     @abstractmethod
     @abstractmethod
     def predict(self, inputs: Any) -> Any:
     def predict(self, inputs: Any) -> Any:
         ...
         ...
@@ -36,7 +47,7 @@ class InferenceModel(ABC):
         return self._cache_dir
         return self._cache_dir
 
 
     @cache_dir.setter
     @cache_dir.setter
-    def cache_dir(self, cache_dir: Path):
+    def cache_dir(self, cache_dir: Path) -> None:
         self._cache_dir = cache_dir
         self._cache_dir = cache_dir
 
 
     @classmethod
     @classmethod
@@ -50,3 +61,13 @@ class InferenceModel(ABC):
             raise ValueError(f"Unsupported model type: {model_type}")
             raise ValueError(f"Unsupported model type: {model_type}")
 
 
         return subclasses[model_type](model_name, **model_kwargs)
         return subclasses[model_type](model_name, **model_kwargs)
+
+    def clear_cache(self) -> None:
+        if not self.cache_dir.exists():
+            return
+        elif not rmtree.avoids_symlink_attacks:
+            raise RuntimeError(
+                "Attempted to clear cache, but rmtree is not safe on this platform."
+            )
+
+        rmtree(self.cache_dir)

+ 2 - 7
machine-learning/app/models/clip.py

@@ -1,4 +1,5 @@
 from pathlib import Path
 from pathlib import Path
+from typing import Any
 
 
 from PIL.Image import Image
 from PIL.Image import Image
 from sentence_transformers import SentenceTransformer
 from sentence_transformers import SentenceTransformer
@@ -10,13 +11,7 @@ from .base import InferenceModel
 class CLIPSTEncoder(InferenceModel):
 class CLIPSTEncoder(InferenceModel):
     _model_type = ModelType.CLIP
     _model_type = ModelType.CLIP
 
 
-    def __init__(
-        self,
-        model_name: str,
-        cache_dir: Path | None = None,
-        **model_kwargs,
-    ):
-        super().__init__(model_name, cache_dir)
+    def load(self, **model_kwargs: Any) -> None:
         self.model = SentenceTransformer(
         self.model = SentenceTransformer(
             self.model_name,
             self.model_name,
             cache_folder=self.cache_dir.as_posix(),
             cache_folder=self.cache_dir.as_posix(),

+ 6 - 5
machine-learning/app/models/facial_recognition.py

@@ -18,21 +18,22 @@ class FaceRecognizer(InferenceModel):
         min_score: float = settings.min_face_score,
         min_score: float = settings.min_face_score,
         cache_dir: Path | None = None,
         cache_dir: Path | None = None,
         **model_kwargs,
         **model_kwargs,
-    ):
-        super().__init__(model_name, cache_dir)
+    ) -> None:
         self.min_score = min_score
         self.min_score = min_score
-        model = FaceAnalysis(
+        super().__init__(model_name, cache_dir, **model_kwargs)
+
+    def load(self, **model_kwargs: Any) -> None:
+        self.model = FaceAnalysis(
             name=self.model_name,
             name=self.model_name,
             root=self.cache_dir.as_posix(),
             root=self.cache_dir.as_posix(),
             allowed_modules=["detection", "recognition"],
             allowed_modules=["detection", "recognition"],
             **model_kwargs,
             **model_kwargs,
         )
         )
-        model.prepare(
+        self.model.prepare(
             ctx_id=0,
             ctx_id=0,
             det_thresh=self.min_score,
             det_thresh=self.min_score,
             det_size=(640, 640),
             det_size=(640, 640),
         )
         )
-        self.model = model
 
 
     def predict(self, image: cv2.Mat) -> list[dict[str, Any]]:
     def predict(self, image: cv2.Mat) -> list[dict[str, Any]]:
         height, width, _ = image.shape
         height, width, _ = image.shape

+ 4 - 2
machine-learning/app/models/image_classification.py

@@ -1,4 +1,5 @@
 from pathlib import Path
 from pathlib import Path
+from typing import Any
 
 
 from PIL.Image import Image
 from PIL.Image import Image
 from transformers.pipelines import pipeline
 from transformers.pipelines import pipeline
@@ -17,10 +18,11 @@ class ImageClassifier(InferenceModel):
         min_score: float = settings.min_tag_score,
         min_score: float = settings.min_tag_score,
         cache_dir: Path | None = None,
         cache_dir: Path | None = None,
         **model_kwargs,
         **model_kwargs,
-    ):
-        super().__init__(model_name, cache_dir)
+    ) -> None:
         self.min_score = min_score
         self.min_score = min_score
+        super().__init__(model_name, cache_dir, **model_kwargs)
 
 
+    def load(self, **model_kwargs: Any) -> None:
         self.model = pipeline(
         self.model = pipeline(
             self.model_type.value,
             self.model_type.value,
             self.model_name,
             self.model_name,