浏览代码

Use the cached address to retrieve the encoder

vishnukvmd 1 年之前
父节点
当前提交
1c0ff2f10c

+ 4 - 1
lib/services/semantic_search/frameworks/onnx/onnx.dart

@@ -12,6 +12,7 @@ class ONNX extends MLFramework {
   final _logger = Logger("ONNX");
   final _clipImage = OnnxImageEncoder();
   final _clipText = OnnxTextEncoder();
+  int _textEncoderAddress = 0;
 
   @override
   String getImageModelRemotePath() {
@@ -41,7 +42,8 @@ class ONNX extends MLFramework {
   @override
   Future<void> loadTextModel(String path) async {
     final startTime = DateTime.now();
-    await _computer.compute(
+    await _clipText.init();
+    _textEncoderAddress = await _computer.compute(
       _clipText.loadModel,
       param: {
         "textModelPath": path,
@@ -83,6 +85,7 @@ class ONNX extends MLFramework {
         _clipText.infer,
         param: {
           "text": text,
+          "address": _textEncoderAddress,
         },
         taskName: "createTextEmbedding",
       ) as List<double>;

+ 21 - 19
lib/services/semantic_search/frameworks/onnx/onnx_text_encoder.dart

@@ -2,56 +2,59 @@ import "dart:io";
 import "dart:math";
 import "dart:typed_data";
 
+import "package:flutter/services.dart";
 import "package:logging/logging.dart";
 import "package:onnxruntime/onnxruntime.dart";
 import "package:photos/services/semantic_search/frameworks/onnx/onnx_text_tokenizer.dart";
 
 class OnnxTextEncoder {
-  static const vocabFilePath = "assets/clip/bpe_simple_vocab_16e6.txt";
-  final _logger = Logger("CLIPTextEncoder");
-  OrtSessionOptions? _sessionOptions;
-  OrtSession? _session;
+  static const vocabFilePath = "assets/models/clip/bpe_simple_vocab_16e6.txt";
+  final _logger = Logger("OnnxTextEncoder");
+  final OnnxTextTokenizer _tokenizer = OnnxTextTokenizer();
 
   OnnxTextEncoder() {
     OrtEnv.instance.init();
     OrtEnv.instance.availableProviders().forEach((element) {
-      print('onnx provider=$element');
+      _logger.info('onnx provider=$element');
     });
   }
 
+  Future<void> init() async {
+    final vocab = await rootBundle.loadString(vocabFilePath);
+    await _tokenizer.init(vocab);
+  }
+
   release() {
-    _sessionOptions?.release();
-    _sessionOptions = null;
-    _session?.release();
-    _session = null;
     OrtEnv.instance.release();
   }
 
-  Future<void> loadModel(Map args) async {
-    _sessionOptions = OrtSessionOptions()
+  Future<int> loadModel(Map args) async {
+    final sessionOptions = OrtSessionOptions()
       ..setInterOpNumThreads(1)
       ..setIntraOpNumThreads(1)
       ..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
-
     try {
+      _logger.info("Loading text model");
       final bytes = File(args["textModelPath"]).readAsBytesSync();
-      _session = OrtSession.fromBuffer(bytes, _sessionOptions!);
+      final session = OrtSession.fromBuffer(bytes, sessionOptions);
       _logger.info('text model loaded');
+      return session.address;
     } catch (e, s) {
       _logger.severe('text model not loaded', e, s);
     }
+    return -1;
   }
 
   Future<List<double>> infer(Map args) async {
     final text = args["text"];
+    final address = args["address"] as int;
     final runOptions = OrtRunOptions();
-    final tokenizer = OnnxTextTokenizer(vocabFilePath);
-    await tokenizer.init();
-    final data = List.filled(1, Int32List.fromList(tokenizer.tokenize(text)));
+    final data = List.filled(1, Int32List.fromList(_tokenizer.tokenize(text)));
     final inputOrt = OrtValueTensor.createTensorWithDataList(data, [1, 77]);
     final inputs = {'input': inputOrt};
-    final outputs = _session?.run(runOptions, inputs);
-    final embedding = (outputs?[0]?.value as List<List<double>>)[0];
+    final session = OrtSession.fromAddress(address);
+    final outputs = session.run(runOptions, inputs);
+    final embedding = (outputs[0]?.value as List<List<double>>)[0];
     double textNormalization = 0;
     for (int i = 0; i < 512; i++) {
       textNormalization += embedding[i] * embedding[i];
@@ -63,7 +66,6 @@ class OnnxTextEncoder {
 
     inputOrt.release();
     runOptions.release();
-    _session?.release();
     return (embedding);
   }
 }