ml_framework.dart 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import "dart:async";
  2. import "dart:io";
  3. import "package:connectivity_plus/connectivity_plus.dart";
  4. import "package:flutter/services.dart";
  5. import "package:logging/logging.dart";
  6. import "package:path/path.dart";
  7. import "package:path_provider/path_provider.dart";
  8. import "package:photos/core/errors.dart";
  9. import "package:photos/core/event_bus.dart";
  10. import "package:photos/core/network/network.dart";
  11. import "package:photos/events/event.dart";
  12. abstract class MLFramework {
  13. static const kImageEncoderEnabled = true;
  14. static const kMaximumRetrials = 3;
  15. static final _logger = Logger("MLFramework");
  16. final bool shouldDownloadOverMobileData;
  17. InitializationState _state = InitializationState.notInitialized;
  18. final _initializationCompleter = Completer<void>();
  19. MLFramework(this.shouldDownloadOverMobileData) {
  20. Connectivity()
  21. .onConnectivityChanged
  22. .listen((ConnectivityResult result) async {
  23. _logger.info("Connectivity changed to $result");
  24. if (_state == InitializationState.waitingForNetwork &&
  25. await _canDownload()) {
  26. unawaited(init());
  27. }
  28. });
  29. }
  30. InitializationState get initializationState => _state;
  31. set _initState(InitializationState state) {
  32. Bus.instance.fire(MLFrameworkInitializationUpdateEvent(state));
  33. _logger.info("Init state is $state");
  34. _state = state;
  35. }
  36. /// Returns the path of the Image Model hosted remotely
  37. String getImageModelRemotePath();
  38. /// Returns the path of the Text Model hosted remotely
  39. String getTextModelRemotePath();
  40. /// Loads the Image Model stored at [path] into the framework
  41. Future<void> loadImageModel(String path);
  42. /// Loads the Text Model stored at [path] into the framework
  43. Future<void> loadTextModel(String path);
  44. /// Returns the Image Embedding for a file stored at [imagePath]
  45. Future<List<double>> getImageEmbedding(String imagePath);
  46. /// Returns the Text Embedding for [text]
  47. Future<List<double>> getTextEmbedding(String text);
  48. /// Downloads the models from remote, caches them and loads them into the
  49. /// framework. Override this method if you would like to control the
  50. /// initialization. For eg. if you wish to load the model from `/assets`
  51. /// instead of a CDN.
  52. Future<void> init() async {
  53. try {
  54. await Future.wait([_initImageModel(), _initTextModel()]);
  55. } catch (e, s) {
  56. _logger.warning(e, s);
  57. if (e is WiFiUnavailableError) {
  58. return _initializationCompleter.future;
  59. } else {
  60. rethrow;
  61. }
  62. }
  63. _initState = InitializationState.initialized;
  64. _initializationCompleter.complete();
  65. }
  66. // Releases any resources held by the framework
  67. Future<void> release() async {}
  68. /// Returns the cosine similarity between [imageEmbedding] and [textEmbedding]
  69. double computeScore(List<double> imageEmbedding, List<double> textEmbedding) {
  70. assert(
  71. imageEmbedding.length == textEmbedding.length,
  72. "The two embeddings should have the same length",
  73. );
  74. double score = 0;
  75. for (int index = 0; index < imageEmbedding.length; index++) {
  76. score += imageEmbedding[index] * textEmbedding[index];
  77. }
  78. return score;
  79. }
  80. // ---
  81. // Private methods
  82. // ---
  83. Future<void> _initImageModel() async {
  84. if (!kImageEncoderEnabled) {
  85. return;
  86. }
  87. _initState = InitializationState.initializingImageModel;
  88. final path = await _getLocalImageModelPath();
  89. if (await File(path).exists()) {
  90. await loadImageModel(path);
  91. } else {
  92. _initState = InitializationState.downloadingImageModel;
  93. final tempFile = File(path + ".temp");
  94. await _downloadFile(getImageModelRemotePath(), tempFile.path);
  95. await tempFile.rename(path);
  96. await loadImageModel(path);
  97. }
  98. _initState = InitializationState.initializedImageModel;
  99. }
  100. Future<void> _initTextModel() async {
  101. final path = await _getLocalTextModelPath();
  102. _initState = InitializationState.initializingTextModel;
  103. if (await File(path).exists()) {
  104. await loadTextModel(path);
  105. } else {
  106. _initState = InitializationState.downloadingTextModel;
  107. final tempFile = File(path + ".temp");
  108. await _downloadFile(getTextModelRemotePath(), tempFile.path);
  109. await tempFile.rename(path);
  110. await loadTextModel(path);
  111. }
  112. _initState = InitializationState.initializedTextModel;
  113. }
  114. Future<String> _getLocalImageModelPath() async {
  115. return (await getTemporaryDirectory()).path +
  116. "/models/" +
  117. basename(getImageModelRemotePath());
  118. }
  119. Future<String> _getLocalTextModelPath() async {
  120. return (await getTemporaryDirectory()).path +
  121. "/models/" +
  122. basename(getTextModelRemotePath());
  123. }
  124. Future<void> _downloadFile(
  125. String url,
  126. String savePath, {
  127. int trialCount = 1,
  128. }) async {
  129. if (!await _canDownload()) {
  130. _initState = InitializationState.waitingForNetwork;
  131. throw WiFiUnavailableError();
  132. }
  133. _logger.info("Downloading " + url);
  134. final existingFile = File(savePath);
  135. if (await existingFile.exists()) {
  136. await existingFile.delete();
  137. }
  138. try {
  139. await NetworkClient.instance.getDio().download(url, savePath);
  140. } catch (e, s) {
  141. _logger.severe(e, s);
  142. if (trialCount < kMaximumRetrials) {
  143. return _downloadFile(url, savePath, trialCount: trialCount + 1);
  144. } else {
  145. rethrow;
  146. }
  147. }
  148. }
  149. Future<bool> _canDownload() async {
  150. final connectivityResult = await (Connectivity().checkConnectivity());
  151. return connectivityResult != ConnectivityResult.mobile ||
  152. shouldDownloadOverMobileData;
  153. }
  154. Future<String> getAccessiblePathForAsset(
  155. String assetPath,
  156. String tempName,
  157. ) async {
  158. final byteData = await rootBundle.load(assetPath);
  159. final tempDir = await getTemporaryDirectory();
  160. final file = await File('${tempDir.path}/$tempName')
  161. .writeAsBytes(byteData.buffer.asUint8List());
  162. return file.path;
  163. }
  164. }
  165. class MLFrameworkInitializationUpdateEvent extends Event {
  166. final InitializationState state;
  167. MLFrameworkInitializationUpdateEvent(this.state);
  168. }
  169. enum InitializationState {
  170. notInitialized,
  171. waitingForNetwork,
  172. downloadingImageModel,
  173. initializingImageModel,
  174. initializedImageModel,
  175. downloadingTextModel,
  176. initializingTextModel,
  177. initializedTextModel,
  178. initialized,
  179. }