onnx_text_encoder.dart 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import "dart:io";
  2. import "dart:math";
  3. import "dart:typed_data";
  4. import "package:flutter/services.dart";
  5. import "package:logging/logging.dart";
  6. import "package:onnxruntime/onnxruntime.dart";
  7. import "package:photos/services/semantic_search/frameworks/onnx/onnx_text_tokenizer.dart";
  8. class OnnxTextEncoder {
  9. static const vocabFilePath = "assets/models/clip/bpe_simple_vocab_16e6.txt";
  10. final _logger = Logger("OnnxTextEncoder");
  11. final OnnxTextTokenizer _tokenizer = OnnxTextTokenizer();
  12. OnnxTextEncoder() {
  13. OrtEnv.instance.init();
  14. OrtEnv.instance.availableProviders().forEach((element) {
  15. _logger.info('onnx provider=$element');
  16. });
  17. }
  18. Future<void> init() async {
  19. final vocab = await rootBundle.loadString(vocabFilePath);
  20. await _tokenizer.init(vocab);
  21. }
  22. release() {
  23. OrtEnv.instance.release();
  24. }
  25. Future<int> loadModel(Map args) async {
  26. final sessionOptions = OrtSessionOptions()
  27. ..setInterOpNumThreads(1)
  28. ..setIntraOpNumThreads(1)
  29. ..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
  30. try {
  31. _logger.info("Loading text model");
  32. final bytes = File(args["textModelPath"]).readAsBytesSync();
  33. final session = OrtSession.fromBuffer(bytes, sessionOptions);
  34. _logger.info('text model loaded');
  35. return session.address;
  36. } catch (e, s) {
  37. _logger.severe('text model not loaded', e, s);
  38. }
  39. return -1;
  40. }
  41. Future<List<double>> infer(Map args) async {
  42. final text = args["text"];
  43. final address = args["address"] as int;
  44. final runOptions = OrtRunOptions();
  45. final data = List.filled(1, Int32List.fromList(_tokenizer.tokenize(text)));
  46. final inputOrt = OrtValueTensor.createTensorWithDataList(data, [1, 77]);
  47. final inputs = {'input': inputOrt};
  48. final session = OrtSession.fromAddress(address);
  49. final outputs = session.run(runOptions, inputs);
  50. final embedding = (outputs[0]?.value as List<List<double>>)[0];
  51. double textNormalization = 0;
  52. for (int i = 0; i < 512; i++) {
  53. textNormalization += embedding[i] * embedding[i];
  54. }
  55. for (int i = 0; i < 512; i++) {
  56. embedding[i] = embedding[i] / sqrt(textNormalization);
  57. }
  58. inputOrt.release();
  59. runOptions.release();
  60. return (embedding);
  61. }
  62. }