Use the cached address to retrieve the encoder

This commit is contained in:
vishnukvmd 2023-12-13 14:39:18 +05:30
parent b2f9dd2c8b
commit 1c0ff2f10c
2 changed files with 25 additions and 20 deletions

View file

@ -12,6 +12,7 @@ class ONNX extends MLFramework {
final _logger = Logger("ONNX"); final _logger = Logger("ONNX");
final _clipImage = OnnxImageEncoder(); final _clipImage = OnnxImageEncoder();
final _clipText = OnnxTextEncoder(); final _clipText = OnnxTextEncoder();
int _textEncoderAddress = 0;
@override @override
String getImageModelRemotePath() { String getImageModelRemotePath() {
@ -41,7 +42,8 @@ class ONNX extends MLFramework {
@override @override
Future<void> loadTextModel(String path) async { Future<void> loadTextModel(String path) async {
final startTime = DateTime.now(); final startTime = DateTime.now();
await _computer.compute( await _clipText.init();
_textEncoderAddress = await _computer.compute(
_clipText.loadModel, _clipText.loadModel,
param: { param: {
"textModelPath": path, "textModelPath": path,
@ -83,6 +85,7 @@ class ONNX extends MLFramework {
_clipText.infer, _clipText.infer,
param: { param: {
"text": text, "text": text,
"address": _textEncoderAddress,
}, },
taskName: "createTextEmbedding", taskName: "createTextEmbedding",
) as List<double>; ) as List<double>;

View file

@ -2,56 +2,59 @@ import "dart:io";
import "dart:math"; import "dart:math";
import "dart:typed_data"; import "dart:typed_data";
import "package:flutter/services.dart";
import "package:logging/logging.dart"; import "package:logging/logging.dart";
import "package:onnxruntime/onnxruntime.dart"; import "package:onnxruntime/onnxruntime.dart";
import "package:photos/services/semantic_search/frameworks/onnx/onnx_text_tokenizer.dart"; import "package:photos/services/semantic_search/frameworks/onnx/onnx_text_tokenizer.dart";
class OnnxTextEncoder { class OnnxTextEncoder {
static const vocabFilePath = "assets/clip/bpe_simple_vocab_16e6.txt"; static const vocabFilePath = "assets/models/clip/bpe_simple_vocab_16e6.txt";
final _logger = Logger("CLIPTextEncoder"); final _logger = Logger("OnnxTextEncoder");
OrtSessionOptions? _sessionOptions; final OnnxTextTokenizer _tokenizer = OnnxTextTokenizer();
OrtSession? _session;
OnnxTextEncoder() { OnnxTextEncoder() {
OrtEnv.instance.init(); OrtEnv.instance.init();
OrtEnv.instance.availableProviders().forEach((element) { OrtEnv.instance.availableProviders().forEach((element) {
print('onnx provider=$element'); _logger.info('onnx provider=$element');
}); });
} }
Future<void> init() async {
final vocab = await rootBundle.loadString(vocabFilePath);
await _tokenizer.init(vocab);
}
release() { release() {
_sessionOptions?.release();
_sessionOptions = null;
_session?.release();
_session = null;
OrtEnv.instance.release(); OrtEnv.instance.release();
} }
Future<void> loadModel(Map args) async { Future<int> loadModel(Map args) async {
_sessionOptions = OrtSessionOptions() final sessionOptions = OrtSessionOptions()
..setInterOpNumThreads(1) ..setInterOpNumThreads(1)
..setIntraOpNumThreads(1) ..setIntraOpNumThreads(1)
..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll); ..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
try { try {
_logger.info("Loading text model");
final bytes = File(args["textModelPath"]).readAsBytesSync(); final bytes = File(args["textModelPath"]).readAsBytesSync();
_session = OrtSession.fromBuffer(bytes, _sessionOptions!); final session = OrtSession.fromBuffer(bytes, sessionOptions);
_logger.info('text model loaded'); _logger.info('text model loaded');
return session.address;
} catch (e, s) { } catch (e, s) {
_logger.severe('text model not loaded', e, s); _logger.severe('text model not loaded', e, s);
} }
return -1;
} }
Future<List<double>> infer(Map args) async { Future<List<double>> infer(Map args) async {
final text = args["text"]; final text = args["text"];
final address = args["address"] as int;
final runOptions = OrtRunOptions(); final runOptions = OrtRunOptions();
final tokenizer = OnnxTextTokenizer(vocabFilePath); final data = List.filled(1, Int32List.fromList(_tokenizer.tokenize(text)));
await tokenizer.init();
final data = List.filled(1, Int32List.fromList(tokenizer.tokenize(text)));
final inputOrt = OrtValueTensor.createTensorWithDataList(data, [1, 77]); final inputOrt = OrtValueTensor.createTensorWithDataList(data, [1, 77]);
final inputs = {'input': inputOrt}; final inputs = {'input': inputOrt};
final outputs = _session?.run(runOptions, inputs); final session = OrtSession.fromAddress(address);
final embedding = (outputs?[0]?.value as List<List<double>>)[0]; final outputs = session.run(runOptions, inputs);
final embedding = (outputs[0]?.value as List<List<double>>)[0];
double textNormalization = 0; double textNormalization = 0;
for (int i = 0; i < 512; i++) { for (int i = 0; i < 512; i++) {
textNormalization += embedding[i] * embedding[i]; textNormalization += embedding[i] * embedding[i];
@ -63,7 +66,6 @@ class OnnxTextEncoder {
inputOrt.release(); inputOrt.release();
runOptions.release(); runOptions.release();
_session?.release();
return (embedding); return (embedding);
} }
} }