Selaa lähdekoodia

fix(ml): `minScore` not being set correctly (#3916)

* fixed `minScore` not being set correctly

* apply to init

* don't send `enabled`

* fix eslint warning

* better error message
Mert 1 vuosi sitten
vanhempi
commit
df26e12db6

+ 2 - 2
machine-learning/app/models/facial_recognition.py

@@ -23,7 +23,7 @@ class FaceRecognizer(InferenceModel):
         cache_dir: Path | str | None = None,
         **model_kwargs: Any,
     ) -> None:
-        self.min_score = min_score
+        self.min_score = model_kwargs.pop("minScore", min_score)
         super().__init__(model_name, cache_dir, **model_kwargs)
 
     def _download(self, **model_kwargs: Any) -> None:
@@ -105,4 +105,4 @@ class FaceRecognizer(InferenceModel):
         return self.cache_dir.is_dir() and any(self.cache_dir.glob("*.onnx"))
 
     def configure(self, **model_kwargs: Any) -> None:
-        self.det_model.det_thresh = model_kwargs.get("min_score", self.det_model.det_thresh)
+        self.det_model.det_thresh = model_kwargs.pop("minScore", self.det_model.det_thresh)

+ 2 - 2
machine-learning/app/models/image_classification.py

@@ -22,7 +22,7 @@ class ImageClassifier(InferenceModel):
         cache_dir: Path | str | None = None,
         **model_kwargs: Any,
     ) -> None:
-        self.min_score = min_score
+        self.min_score = model_kwargs.pop("minScore", min_score)
         super().__init__(model_name, cache_dir, **model_kwargs)
 
     def _download(self, **model_kwargs: Any) -> None:
@@ -65,4 +65,4 @@ class ImageClassifier(InferenceModel):
         return tags
 
     def configure(self, **model_kwargs: Any) -> None:
-        self.min_score = model_kwargs.get("min_score", self.min_score)
+        self.min_score = model_kwargs.pop("minScore", self.min_score)

+ 4 - 1
server/src/infra/repositories/machine-learning.repository.ts

@@ -43,7 +43,10 @@ export class MachineLearningRepository implements IMachineLearningRepository {
 
   async getFormData(input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<FormData> {
     const formData = new FormData();
-    const { modelName, modelType, ...options } = config;
+    const { enabled, modelName, modelType, ...options } = config;
+    if (!enabled) {
+      throw new Error(`${modelType} is not enabled`);
+    }
 
     formData.append('modelName', modelName);
     if (modelType) {