ml_framework.dart 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import "dart:io";
  2. import "package:flutter/services.dart";
  3. import "package:logging/logging.dart";
  4. import "package:path/path.dart";
  5. import "package:path_provider/path_provider.dart";
  6. import "package:photos/core/network/network.dart";
  7. abstract class MLFramework {
  8. static const kImageEncoderEnabled = true;
  9. final _logger = Logger("MLFramework");
  10. /// Returns the path of the Image Model hosted remotely
  11. String getImageModelRemotePath();
  12. /// Returns the path of the Text Model hosted remotely
  13. String getTextModelRemotePath();
  14. /// Loads the Image Model stored at [path] into the framework
  15. Future<void> loadImageModel(String path);
  16. /// Loads the Text Model stored at [path] into the framework
  17. Future<void> loadTextModel(String path);
  18. /// Returns the Image Embedding for a file stored at [imagePath]
  19. Future<List<double>> getImageEmbedding(String imagePath);
  20. /// Returns the Text Embedding for [text]
  21. Future<List<double>> getTextEmbedding(String text);
  22. /// Downloads the models from remote, caches them and loads them into the
  23. /// framework. Override this method if you would like to control the
  24. /// initialization. For eg. if you wish to load the model from `/assets`
  25. /// instead of a CDN.
  26. Future<void> init() async {
  27. await _initImageModel();
  28. await _initTextModel();
  29. }
  30. // Releases any resources held by the framework
  31. Future<void> release() async {}
  32. /// Returns the cosine similarity between [imageEmbedding] and [textEmbedding]
  33. double computeScore(List<double> imageEmbedding, List<double> textEmbedding) {
  34. assert(
  35. imageEmbedding.length == textEmbedding.length,
  36. "The two embeddings should have the same length",
  37. );
  38. double score = 0;
  39. for (int index = 0; index < imageEmbedding.length; index++) {
  40. score += imageEmbedding[index] * textEmbedding[index];
  41. }
  42. return score;
  43. }
  44. // ---
  45. // Private methods
  46. // ---
  47. Future<void> _initImageModel() async {
  48. if (!kImageEncoderEnabled) {
  49. return;
  50. }
  51. final path = await _getLocalImageModelPath();
  52. if (File(path).existsSync()) {
  53. await loadImageModel(path);
  54. } else {
  55. final tempFile = File(path + ".temp");
  56. await _downloadFile(getImageModelRemotePath(), tempFile.path);
  57. await tempFile.rename(path);
  58. await loadImageModel(path);
  59. }
  60. }
  61. Future<void> _initTextModel() async {
  62. final path = await _getLocalTextModelPath();
  63. if (File(path).existsSync()) {
  64. await loadTextModel(path);
  65. } else {
  66. final tempFile = File(path + ".temp");
  67. await _downloadFile(getTextModelRemotePath(), tempFile.path);
  68. await tempFile.rename(path);
  69. await loadTextModel(path);
  70. }
  71. }
  72. Future<String> _getLocalImageModelPath() async {
  73. return (await getTemporaryDirectory()).path +
  74. "/models/" +
  75. basename(getImageModelRemotePath());
  76. }
  77. Future<String> _getLocalTextModelPath() async {
  78. return (await getTemporaryDirectory()).path +
  79. "/models/" +
  80. basename(getTextModelRemotePath());
  81. }
  82. Future<void> _downloadFile(String url, String savePath) async {
  83. _logger.info("Downloading " + url);
  84. final existingFile = File(savePath);
  85. if (await existingFile.exists()) {
  86. await existingFile.delete();
  87. }
  88. await NetworkClient.instance.getDio().download(url, savePath);
  89. }
  90. Future<String> getAccessiblePathForAsset(
  91. String assetPath,
  92. String tempName,
  93. ) async {
  94. final byteData = await rootBundle.load(assetPath);
  95. final tempDir = await getTemporaryDirectory();
  96. final file = await File('${tempDir.path}/$tempName')
  97. .writeAsBytes(byteData.buffer.asUint8List());
  98. return file.path;
  99. }
  100. }