|
@@ -2,56 +2,59 @@ import "dart:io";
|
|
import "dart:math";
|
|
import "dart:math";
|
|
import "dart:typed_data";
|
|
import "dart:typed_data";
|
|
|
|
|
|
|
|
+import "package:flutter/services.dart";
|
|
import "package:logging/logging.dart";
|
|
import "package:logging/logging.dart";
|
|
import "package:onnxruntime/onnxruntime.dart";
|
|
import "package:onnxruntime/onnxruntime.dart";
|
|
import "package:photos/services/semantic_search/frameworks/onnx/onnx_text_tokenizer.dart";
|
|
import "package:photos/services/semantic_search/frameworks/onnx/onnx_text_tokenizer.dart";
|
|
|
|
|
|
class OnnxTextEncoder {
|
|
class OnnxTextEncoder {
|
|
- static const vocabFilePath = "assets/clip/bpe_simple_vocab_16e6.txt";
|
|
|
|
- final _logger = Logger("CLIPTextEncoder");
|
|
|
|
- OrtSessionOptions? _sessionOptions;
|
|
|
|
- OrtSession? _session;
|
|
|
|
|
|
+ static const vocabFilePath = "assets/models/clip/bpe_simple_vocab_16e6.txt";
|
|
|
|
+ final _logger = Logger("OnnxTextEncoder");
|
|
|
|
+ final OnnxTextTokenizer _tokenizer = OnnxTextTokenizer();
|
|
|
|
|
|
OnnxTextEncoder() {
|
|
OnnxTextEncoder() {
|
|
OrtEnv.instance.init();
|
|
OrtEnv.instance.init();
|
|
OrtEnv.instance.availableProviders().forEach((element) {
|
|
OrtEnv.instance.availableProviders().forEach((element) {
|
|
- print('onnx provider=$element');
|
|
|
|
|
|
+ _logger.info('onnx provider=$element');
|
|
});
|
|
});
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ Future<void> init() async {
|
|
|
|
+ final vocab = await rootBundle.loadString(vocabFilePath);
|
|
|
|
+ await _tokenizer.init(vocab);
|
|
|
|
+ }
|
|
|
|
+
|
|
release() {
|
|
release() {
|
|
- _sessionOptions?.release();
|
|
|
|
- _sessionOptions = null;
|
|
|
|
- _session?.release();
|
|
|
|
- _session = null;
|
|
|
|
OrtEnv.instance.release();
|
|
OrtEnv.instance.release();
|
|
}
|
|
}
|
|
|
|
|
|
- Future<void> loadModel(Map args) async {
|
|
|
|
- _sessionOptions = OrtSessionOptions()
|
|
|
|
|
|
+ Future<int> loadModel(Map args) async {
|
|
|
|
+ final sessionOptions = OrtSessionOptions()
|
|
..setInterOpNumThreads(1)
|
|
..setInterOpNumThreads(1)
|
|
..setIntraOpNumThreads(1)
|
|
..setIntraOpNumThreads(1)
|
|
..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
|
|
..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
|
|
-
|
|
|
|
try {
|
|
try {
|
|
|
|
+ _logger.info("Loading text model");
|
|
final bytes = File(args["textModelPath"]).readAsBytesSync();
|
|
final bytes = File(args["textModelPath"]).readAsBytesSync();
|
|
- _session = OrtSession.fromBuffer(bytes, _sessionOptions!);
|
|
|
|
|
|
+ final session = OrtSession.fromBuffer(bytes, sessionOptions);
|
|
_logger.info('text model loaded');
|
|
_logger.info('text model loaded');
|
|
|
|
+ return session.address;
|
|
} catch (e, s) {
|
|
} catch (e, s) {
|
|
_logger.severe('text model not loaded', e, s);
|
|
_logger.severe('text model not loaded', e, s);
|
|
}
|
|
}
|
|
|
|
+ return -1;
|
|
}
|
|
}
|
|
|
|
|
|
Future<List<double>> infer(Map args) async {
|
|
Future<List<double>> infer(Map args) async {
|
|
final text = args["text"];
|
|
final text = args["text"];
|
|
|
|
+ final address = args["address"] as int;
|
|
final runOptions = OrtRunOptions();
|
|
final runOptions = OrtRunOptions();
|
|
- final tokenizer = OnnxTextTokenizer(vocabFilePath);
|
|
|
|
- await tokenizer.init();
|
|
|
|
- final data = List.filled(1, Int32List.fromList(tokenizer.tokenize(text)));
|
|
|
|
|
|
+ final data = List.filled(1, Int32List.fromList(_tokenizer.tokenize(text)));
|
|
final inputOrt = OrtValueTensor.createTensorWithDataList(data, [1, 77]);
|
|
final inputOrt = OrtValueTensor.createTensorWithDataList(data, [1, 77]);
|
|
final inputs = {'input': inputOrt};
|
|
final inputs = {'input': inputOrt};
|
|
- final outputs = _session?.run(runOptions, inputs);
|
|
|
|
- final embedding = (outputs?[0]?.value as List<List<double>>)[0];
|
|
|
|
|
|
+ final session = OrtSession.fromAddress(address);
|
|
|
|
+ final outputs = session.run(runOptions, inputs);
|
|
|
|
+ final embedding = (outputs[0]?.value as List<List<double>>)[0];
|
|
double textNormalization = 0;
|
|
double textNormalization = 0;
|
|
for (int i = 0; i < 512; i++) {
|
|
for (int i = 0; i < 512; i++) {
|
|
textNormalization += embedding[i] * embedding[i];
|
|
textNormalization += embedding[i] * embedding[i];
|
|
@@ -63,7 +66,6 @@ class OnnxTextEncoder {
|
|
|
|
|
|
inputOrt.release();
|
|
inputOrt.release();
|
|
runOptions.release();
|
|
runOptions.release();
|
|
- _session?.release();
|
|
|
|
return (embedding);
|
|
return (embedding);
|
|
}
|
|
}
|
|
}
|
|
}
|