ソースを参照

Store framework name along with the model name

vishnukvmd 1 年間 前
コミット
01c82eaf82

+ 5 - 0
lib/services/semantic_search/frameworks/ggml.dart

@@ -10,6 +10,11 @@ class GGML extends MLFramework {
 
   final _computer = Computer.shared();
   final _logger = Logger("GGML");
+  
+  @override
+  String getFrameworkName() {
+    return "ggml";
+  }
 
   @override
   String getImageModelRemotePath() {

+ 3 - 0
lib/services/semantic_search/frameworks/ml_framework.dart

@@ -9,6 +9,9 @@ import "package:photos/core/network/network.dart";
 abstract class MLFramework {
   final _logger = Logger("MLFramework");
 
+  /// Returns the name of the framework
+  String getFrameworkName();
+
   /// Returns the path of the Image Model hosted remotely
   String getImageModelRemotePath();
 

+ 5 - 0
lib/services/semantic_search/frameworks/onnx/onnx.dart

@@ -14,6 +14,11 @@ class ONNX extends MLFramework {
   final _clipText = OnnxTextEncoder();
   int _textEncoderAddress = 0;
 
+  @override
+  String getFrameworkName() {
+    return "onnx";
+  }
+
   @override
   String getImageModelRemotePath() {
     return "";

+ 2 - 2
lib/services/semantic_search/semantic_search_service.dart

@@ -25,7 +25,7 @@ class SemanticSearchService {
       SemanticSearchService._privateConstructor();
   static final Computer _computer = Computer.shared();
 
-  static const kModelName = "onnx-clip";
+  static const kModelName = "clip";
   static const kEmbeddingLength = 512;
   static const kScoreThreshold = 0.23;
 
@@ -214,7 +214,7 @@ class SemanticSearchService {
       }
       final embedding = Embedding(
         fileID: file.uploadedFileID!,
-        model: kModelName,
+        model: _mlFramework.getFrameworkName() + "-" + kModelName,
         embedding: result,
       );
       await EmbeddingStore.instance.storeEmbedding(