From 01c82eaf82a9a29c41fa08deefad3868b1a74295 Mon Sep 17 00:00:00 2001 From: vishnukvmd Date: Wed, 13 Dec 2023 14:42:02 +0530 Subject: [PATCH] Store framework name along with the model name --- lib/services/semantic_search/frameworks/ggml.dart | 5 +++++ lib/services/semantic_search/frameworks/ml_framework.dart | 3 +++ lib/services/semantic_search/frameworks/onnx/onnx.dart | 5 +++++ lib/services/semantic_search/semantic_search_service.dart | 4 ++-- 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/lib/services/semantic_search/frameworks/ggml.dart b/lib/services/semantic_search/frameworks/ggml.dart index e4903091c..eaf7b1871 100644 --- a/lib/services/semantic_search/frameworks/ggml.dart +++ b/lib/services/semantic_search/frameworks/ggml.dart @@ -10,6 +10,11 @@ class GGML extends MLFramework { final _computer = Computer.shared(); final _logger = Logger("GGML"); + + @override + String getFrameworkName() { + return "ggml"; + } @override String getImageModelRemotePath() { diff --git a/lib/services/semantic_search/frameworks/ml_framework.dart b/lib/services/semantic_search/frameworks/ml_framework.dart index bb722189e..73699e153 100644 --- a/lib/services/semantic_search/frameworks/ml_framework.dart +++ b/lib/services/semantic_search/frameworks/ml_framework.dart @@ -9,6 +9,9 @@ import "package:photos/core/network/network.dart"; abstract class MLFramework { final _logger = Logger("MLFramework"); + /// Returns the name of the framework + String getFrameworkName(); + /// Returns the path of the Image Model hosted remotely String getImageModelRemotePath(); diff --git a/lib/services/semantic_search/frameworks/onnx/onnx.dart b/lib/services/semantic_search/frameworks/onnx/onnx.dart index 26db4b4cf..9d8efc0d1 100644 --- a/lib/services/semantic_search/frameworks/onnx/onnx.dart +++ b/lib/services/semantic_search/frameworks/onnx/onnx.dart @@ -14,6 +14,11 @@ class ONNX extends MLFramework { final _clipText = OnnxTextEncoder(); int _textEncoderAddress = 0; + @override + String getFrameworkName() { + return "onnx"; + } + @override String getImageModelRemotePath() { return ""; diff --git a/lib/services/semantic_search/semantic_search_service.dart b/lib/services/semantic_search/semantic_search_service.dart index 9b8392ca1..9283943c9 100644 --- a/lib/services/semantic_search/semantic_search_service.dart +++ b/lib/services/semantic_search/semantic_search_service.dart @@ -25,7 +25,7 @@ class SemanticSearchService { SemanticSearchService._privateConstructor(); static final Computer _computer = Computer.shared(); - static const kModelName = "onnx-clip"; + static const kModelName = "clip"; static const kEmbeddingLength = 512; static const kScoreThreshold = 0.23; @@ -214,7 +214,7 @@ class SemanticSearchService { } final embedding = Embedding( fileID: file.uploadedFileID!, - model: kModelName, + model: _mlFramework.getFrameworkName() + "-" + kModelName, embedding: result, ); await EmbeddingStore.instance.storeEmbedding(