fix(ml): minScore not being set correctly (#3916)

* fixed `minScore` not being set correctly

* apply to init

* don't send `enabled`

* fix eslint warning

* better error message
This commit is contained in:
Mert 2023-08-30 04:16:00 -04:00 committed by GitHub
parent 343d89c032
commit df26e12db6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 5 deletions

View file

@ -23,7 +23,7 @@ class FaceRecognizer(InferenceModel):
cache_dir: Path | str | None = None,
**model_kwargs: Any,
) -> None:
self.min_score = min_score
self.min_score = model_kwargs.pop("minScore", min_score)
super().__init__(model_name, cache_dir, **model_kwargs)
def _download(self, **model_kwargs: Any) -> None:
@ -105,4 +105,4 @@ class FaceRecognizer(InferenceModel):
return self.cache_dir.is_dir() and any(self.cache_dir.glob("*.onnx"))
def configure(self, **model_kwargs: Any) -> None:
self.det_model.det_thresh = model_kwargs.get("min_score", self.det_model.det_thresh)
self.det_model.det_thresh = model_kwargs.pop("minScore", self.det_model.det_thresh)

View file

@ -22,7 +22,7 @@ class ImageClassifier(InferenceModel):
cache_dir: Path | str | None = None,
**model_kwargs: Any,
) -> None:
self.min_score = min_score
self.min_score = model_kwargs.pop("minScore", min_score)
super().__init__(model_name, cache_dir, **model_kwargs)
def _download(self, **model_kwargs: Any) -> None:
@ -65,4 +65,4 @@ class ImageClassifier(InferenceModel):
return tags
def configure(self, **model_kwargs: Any) -> None:
self.min_score = model_kwargs.get("min_score", self.min_score)
self.min_score = model_kwargs.pop("minScore", self.min_score)

View file

@ -43,7 +43,10 @@ export class MachineLearningRepository implements IMachineLearningRepository {
async getFormData(input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<FormData> {
const formData = new FormData();
const { modelName, modelType, ...options } = config;
const { enabled, modelName, modelType, ...options } = config;
if (!enabled) {
throw new Error(`${modelType} is not enabled`);
}
formData.append('modelName', modelName);
if (modelType) {