diff --git a/lib/services/semantic_search/frameworks/onnx/onnx.dart b/lib/services/semantic_search/frameworks/onnx/onnx.dart index fa93659ea..26db4b4cf 100644 --- a/lib/services/semantic_search/frameworks/onnx/onnx.dart +++ b/lib/services/semantic_search/frameworks/onnx/onnx.dart @@ -12,6 +12,7 @@ class ONNX extends MLFramework { final _logger = Logger("ONNX"); final _clipImage = OnnxImageEncoder(); final _clipText = OnnxTextEncoder(); + int _textEncoderAddress = 0; @override String getImageModelRemotePath() { @@ -41,7 +42,8 @@ class ONNX extends MLFramework { @override Future loadTextModel(String path) async { final startTime = DateTime.now(); - await _computer.compute( + await _clipText.init(); + _textEncoderAddress = await _computer.compute( _clipText.loadModel, param: { "textModelPath": path, @@ -83,6 +85,7 @@ class ONNX extends MLFramework { _clipText.infer, param: { "text": text, + "address": _textEncoderAddress, }, taskName: "createTextEmbedding", ) as List; diff --git a/lib/services/semantic_search/frameworks/onnx/onnx_text_encoder.dart b/lib/services/semantic_search/frameworks/onnx/onnx_text_encoder.dart index 5e2487c39..eb741945c 100644 --- a/lib/services/semantic_search/frameworks/onnx/onnx_text_encoder.dart +++ b/lib/services/semantic_search/frameworks/onnx/onnx_text_encoder.dart @@ -2,56 +2,59 @@ import "dart:io"; import "dart:math"; import "dart:typed_data"; +import "package:flutter/services.dart"; import "package:logging/logging.dart"; import "package:onnxruntime/onnxruntime.dart"; import "package:photos/services/semantic_search/frameworks/onnx/onnx_text_tokenizer.dart"; class OnnxTextEncoder { - static const vocabFilePath = "assets/clip/bpe_simple_vocab_16e6.txt"; - final _logger = Logger("CLIPTextEncoder"); - OrtSessionOptions? _sessionOptions; - OrtSession? _session; + static const vocabFilePath = "assets/models/clip/bpe_simple_vocab_16e6.txt"; + final _logger = Logger("OnnxTextEncoder"); + final OnnxTextTokenizer _tokenizer = OnnxTextTokenizer(); OnnxTextEncoder() { OrtEnv.instance.init(); OrtEnv.instance.availableProviders().forEach((element) { - print('onnx provider=$element'); + _logger.info('onnx provider=$element'); }); } + Future init() async { + final vocab = await rootBundle.loadString(vocabFilePath); + await _tokenizer.init(vocab); + } + release() { - _sessionOptions?.release(); - _sessionOptions = null; - _session?.release(); - _session = null; OrtEnv.instance.release(); } - Future loadModel(Map args) async { - _sessionOptions = OrtSessionOptions() + Future loadModel(Map args) async { + final sessionOptions = OrtSessionOptions() ..setInterOpNumThreads(1) ..setIntraOpNumThreads(1) ..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll); - try { + _logger.info("Loading text model"); final bytes = File(args["textModelPath"]).readAsBytesSync(); - _session = OrtSession.fromBuffer(bytes, _sessionOptions!); + final session = OrtSession.fromBuffer(bytes, sessionOptions); _logger.info('text model loaded'); + return session.address; } catch (e, s) { _logger.severe('text model not loaded', e, s); } + return -1; } Future> infer(Map args) async { final text = args["text"]; + final address = args["address"] as int; final runOptions = OrtRunOptions(); - final tokenizer = OnnxTextTokenizer(vocabFilePath); - await tokenizer.init(); - final data = List.filled(1, Int32List.fromList(tokenizer.tokenize(text))); + final data = List.filled(1, Int32List.fromList(_tokenizer.tokenize(text))); final inputOrt = OrtValueTensor.createTensorWithDataList(data, [1, 77]); final inputs = {'input': inputOrt}; - final outputs = _session?.run(runOptions, inputs); - final embedding = (outputs?[0]?.value as List>)[0]; + final session = OrtSession.fromAddress(address); + final outputs = session.run(runOptions, inputs); + final embedding = (outputs[0]?.value as List>)[0]; double textNormalization = 0; for (int i = 0; i < 512; i++) { textNormalization += embedding[i] * embedding[i]; @@ -63,7 +66,6 @@ class OnnxTextEncoder { inputOrt.release(); runOptions.release(); - _session?.release(); return (embedding); } }