diff --git a/lib/services/semantic_search/frameworks/onnx/onnx.dart b/lib/services/semantic_search/frameworks/onnx/onnx.dart index 41f8bb3a5..061de2c3a 100644 --- a/lib/services/semantic_search/frameworks/onnx/onnx.dart +++ b/lib/services/semantic_search/frameworks/onnx/onnx.dart @@ -35,6 +35,7 @@ class ONNX extends MLFramework { @override Future loadImageModel(String path) async { final startTime = DateTime.now(); + await _clipImage.init(); _imageEncoderAddress = await _computer.compute( _clipImage.loadModel, param: { @@ -50,7 +51,9 @@ class ONNX extends MLFramework { @override Future loadTextModel(String path) async { final startTime = DateTime.now(); - await _clipText.init(); + await _computer.compute(_clipText.init); + // Doing this from main isolate since `rootBundle` cannot be accessed outside it + await _clipText.initTokenizer(); _textEncoderAddress = await _computer.compute( _clipText.loadModel, param: { diff --git a/lib/services/semantic_search/frameworks/onnx/onnx_image_encoder.dart b/lib/services/semantic_search/frameworks/onnx/onnx_image_encoder.dart index 842fce7ba..671f02aa1 100644 --- a/lib/services/semantic_search/frameworks/onnx/onnx_image_encoder.dart +++ b/lib/services/semantic_search/frameworks/onnx/onnx_image_encoder.dart @@ -9,7 +9,7 @@ import "package:onnxruntime/onnxruntime.dart"; class OnnxImageEncoder { final _logger = Logger("OnnxImageEncoder"); - OnnxImageEncoder() { + Future init() async { OrtEnv.instance.init(); } @@ -96,14 +96,14 @@ class OnnxImageEncoder { final int ny = rgb.height; final int inputSize = 3 * nx * ny; final inputImage = List.filled(inputSize, 0.toDouble()); - + const int nx2 = 224; const int ny2 = 224; const int totalSize = 3 * nx2 * ny2; // Load image into List inputImage for (int y = 0; y < ny; y++) { - for (int x = 0; x < nx; x ++) { + for (int x = 0; x < nx; x++) { final int i = 3 * (y * nx + x); inputImage[i] = rgb.getPixel(x, y).r.toDouble(); inputImage[i + 1] = rgb.getPixel(x, y).g.toDouble(); @@ -121,7 +121,7 @@ class OnnxImageEncoder { final std = [0.26862954, 0.26130258, 0.27577711]; for (int y = 0; y < ny3; y++) { - for (int x = 0; x < nx3; x ++) { + for (int x = 0; x < nx3; x++) { for (int c = 0; c < 3; c++) { //linear interpolation final double sx = (x + 0.5) * scale - 0.5; @@ -152,18 +152,17 @@ class OnnxImageEncoder { final double v = v0 * (1 - dy) + v1 * dy; final int v2 = min(max(v.round(), 0), 255); - + final int i = 3 * (y * nx3 + x) + c; result[i] = ((v2 / 255) - mean[c]) / std[c]; - } } } final floatList = Float32List.fromList(result); - - final inputOrt = OrtValueTensor.createTensorWithDataList(floatList, [1, 3, 224, 224]); + final inputOrt = + OrtValueTensor.createTensorWithDataList(floatList, [1, 3, 224, 224]); final inputs = {'input': inputOrt}; final session = OrtSession.fromAddress(args["address"]); final outputs = session.run(runOptions, inputs); diff --git a/lib/services/semantic_search/frameworks/onnx/onnx_text_encoder.dart b/lib/services/semantic_search/frameworks/onnx/onnx_text_encoder.dart index bd8a0559b..6a8bb3b74 100644 --- a/lib/services/semantic_search/frameworks/onnx/onnx_text_encoder.dart +++ b/lib/services/semantic_search/frameworks/onnx/onnx_text_encoder.dart @@ -8,19 +8,17 @@ import "package:onnxruntime/onnxruntime.dart"; import "package:photos/services/semantic_search/frameworks/onnx/onnx_text_tokenizer.dart"; class OnnxTextEncoder { - static const vocabFilePath = "assets/models/clip/bpe_simple_vocab_16e6.txt"; + static const kVocabFilePath = "assets/models/clip/bpe_simple_vocab_16e6.txt"; final _logger = Logger("OnnxTextEncoder"); final OnnxTextTokenizer _tokenizer = OnnxTextTokenizer(); - OnnxTextEncoder() { + Future init() async { OrtEnv.instance.init(); - OrtEnv.instance.availableProviders().forEach((element) { - _logger.info('onnx provider=$element'); - }); } - Future init() async { - final vocab = await rootBundle.loadString(vocabFilePath); + // Do not run in an isolate since rootBundle can only be accessed in the main isolate + Future initTokenizer() async { + final vocab = await rootBundle.loadString(kVocabFilePath); await _tokenizer.init(vocab); }