onnx_text_encoder.dart 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import "dart:io";
  2. import "dart:math";
  3. import "dart:typed_data";
  4. import "package:flutter/services.dart";
  5. import "package:logging/logging.dart";
  6. import "package:onnxruntime/onnxruntime.dart";
  7. import "package:photos/services/semantic_search/frameworks/onnx/onnx_text_tokenizer.dart";
  8. class OnnxTextEncoder {
  9. static const kVocabFilePath = "assets/models/clip/bpe_simple_vocab_16e6.txt";
  10. final _logger = Logger("OnnxTextEncoder");
  11. final OnnxTextTokenizer _tokenizer = OnnxTextTokenizer();
  12. // Do not run in an isolate since rootBundle can only be accessed in the main isolate
  13. Future<void> initTokenizer() async {
  14. final vocab = await rootBundle.loadString(kVocabFilePath);
  15. await _tokenizer.init(vocab);
  16. }
  17. Future<int> loadModel(Map args) async {
  18. final sessionOptions = OrtSessionOptions()
  19. ..setInterOpNumThreads(1)
  20. ..setIntraOpNumThreads(1)
  21. ..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
  22. try {
  23. _logger.info("Loading text model");
  24. final session =
  25. OrtSession.fromFile(File(args["textModelPath"]), sessionOptions);
  26. _logger.info('text model loaded');
  27. return session.address;
  28. } catch (e, s) {
  29. _logger.severe('text model not loaded', e, s);
  30. }
  31. return -1;
  32. }
  33. Future<List<double>> infer(Map args) async {
  34. final text = args["text"];
  35. final address = args["address"] as int;
  36. final runOptions = OrtRunOptions();
  37. final tokenize = _tokenizer.tokenize(text);
  38. final data = List.filled(1, Int32List.fromList(tokenize));
  39. final inputOrt = OrtValueTensor.createTensorWithDataList(data, [1, 77]);
  40. final inputs = {'input': inputOrt};
  41. final session = OrtSession.fromAddress(address);
  42. final outputs = session.run(runOptions, inputs);
  43. final embedding = (outputs[0]?.value as List<List<double>>)[0];
  44. double textNormalization = 0;
  45. for (int i = 0; i < 512; i++) {
  46. textNormalization += embedding[i] * embedding[i];
  47. }
  48. final double sqrtTextNormalization = sqrt(textNormalization);
  49. for (int i = 0; i < 512; i++) {
  50. embedding[i] = embedding[i] / sqrtTextNormalization;
  51. }
  52. return (embedding);
  53. }
  54. }