ソースを参照

Download models on mobile data only if enabled

vishnukvmd 1 年間 前
コミット
d79ad9f02d

+ 2 - 0
lib/services/semantic_search/frameworks/ggml.dart

@@ -11,6 +11,8 @@ class GGML extends MLFramework {
   final _computer = Computer.shared();
   final _logger = Logger("GGML");
 
+  GGML(super.shouldDownloadOverMobileData);
+
   @override
   String getImageModelRemotePath() {
     return kModelBucketEndpoint + kImageModel;

+ 54 - 8
lib/services/semantic_search/frameworks/ml_framework.dart

@@ -1,9 +1,13 @@
+import "dart:async";
 import "dart:io";
 
+import "package:connectivity_plus/connectivity_plus.dart";
 import "package:flutter/services.dart";
 import "package:logging/logging.dart";
 import "package:path/path.dart";
 import "package:path_provider/path_provider.dart";
+import "package:photos/core/errors.dart";
+
 import "package:photos/core/event_bus.dart";
 import "package:photos/core/network/network.dart";
 import "package:photos/events/event.dart";
@@ -12,9 +16,24 @@ abstract class MLFramework {
   static const kImageEncoderEnabled = true;
   static const kMaximumRetrials = 3;
 
-  InitializationState _state = InitializationState.notInitialized;
+  static final _logger = Logger("MLFramework");
 
-  final _logger = Logger("MLFramework");
+  final bool shouldDownloadOverMobileData;
+
+  InitializationState _state = InitializationState.notInitialized;
+  final _initializationCompleter = Completer<void>();
+
+  MLFramework(this.shouldDownloadOverMobileData) {
+    Connectivity()
+        .onConnectivityChanged
+        .listen((ConnectivityResult result) async {
+      _logger.info("Connectivity changed to $result");
+      if (_state == InitializationState.waitingForNetwork &&
+          await _canDownload()) {
+        unawaited(init());
+      }
+    });
+  }
 
   InitializationState get initializationState => _state;
 
@@ -47,8 +66,18 @@ abstract class MLFramework {
   /// initialization. For eg. if you wish to load the model from `/assets`
   /// instead of a CDN.
   Future<void> init() async {
-    await Future.wait([_initImageModel(), _initTextModel()]);
+    try {
+      await Future.wait([_initImageModel(), _initTextModel()]);
+    } catch (e, s) {
+      _logger.warning(e, s);
+      if (e is WiFiUnavailableError) {
+        return _initializationCompleter.future;
+      } else {
+        rethrow;
+      }
+    }
     _initState = InitializationState.initialized;
+    _initializationCompleter.complete();
   }
 
   // Releases any resources held by the framework
@@ -75,9 +104,9 @@ abstract class MLFramework {
     if (!kImageEncoderEnabled) {
       return;
     }
+    _initState = InitializationState.initializingImageModel;
     final path = await _getLocalImageModelPath();
-    if (File(path).existsSync()) {
-      _initState = InitializationState.initializingImageModel;
+    if (await File(path).exists()) {
       await loadImageModel(path);
     } else {
       _initState = InitializationState.downloadingImageModel;
@@ -86,12 +115,13 @@ abstract class MLFramework {
       await tempFile.rename(path);
       await loadImageModel(path);
     }
+    _initState = InitializationState.initializedImageModel;
   }
 
   Future<void> _initTextModel() async {
     final path = await _getLocalTextModelPath();
-    if (File(path).existsSync()) {
-      _initState = InitializationState.initializingTextModel;
+    _initState = InitializationState.initializingTextModel;
+    if (await File(path).exists()) {
       await loadTextModel(path);
     } else {
       _initState = InitializationState.downloadingTextModel;
@@ -100,6 +130,7 @@ abstract class MLFramework {
       await tempFile.rename(path);
       await loadTextModel(path);
     }
+    _initState = InitializationState.initializedTextModel;
   }
 
   Future<String> _getLocalImageModelPath() async {
@@ -119,6 +150,10 @@ abstract class MLFramework {
     String savePath, {
     int trialCount = 1,
   }) async {
+    if (!await _canDownload()) {
+      _initState = InitializationState.waitingForNetwork;
+      throw WiFiUnavailableError();
+    }
     _logger.info("Downloading " + url);
     final existingFile = File(savePath);
     if (await existingFile.exists()) {
@@ -136,6 +171,15 @@ abstract class MLFramework {
     }
   }
 
+  Future<bool> _canDownload() async {
+    final connectivityResult = await (Connectivity().checkConnectivity());
+    bool canDownloadUnderCurrentNetworkConditions = true;
+    if (connectivityResult == ConnectivityResult.mobile) {
+      canDownloadUnderCurrentNetworkConditions = shouldDownloadOverMobileData;
+    }
+    return canDownloadUnderCurrentNetworkConditions;
+  }
+
   Future<String> getAccessiblePathForAsset(
     String assetPath,
     String tempName,
@@ -158,8 +202,10 @@ enum InitializationState {
   notInitialized,
   waitingForNetwork,
   downloadingImageModel,
-  downloadingTextModel,
   initializingImageModel,
+  initializedImageModel,
+  downloadingTextModel,
   initializingTextModel,
+  initializedTextModel,
   initialized,
 }

+ 2 - 0
lib/services/semantic_search/frameworks/onnx/onnx.dart

@@ -17,6 +17,8 @@ class ONNX extends MLFramework {
   int _textEncoderAddress = 0;
   int _imageEncoderAddress = 0;
 
+  ONNX(super.shouldDownloadOverMobileData);
+
   @override
   String getImageModelRemotePath() {
     return kModelBucketEndpoint + kImageModel;

+ 6 - 1
lib/services/semantic_search/semantic_search_service.dart

@@ -38,11 +38,11 @@ class SemanticSearchService {
 
   final _logger = Logger("SemanticSearchService");
   final _queue = Queue<EnteFile>();
-  final _mlFramework = kCurrentModel == Model.onnxClip ? ONNX() : GGML();
   final _frameworkInitialization = Completer<bool>();
   final _embeddingLoaderDebouncer =
       Debouncer(kDebounceDuration, executionInterval: kDebounceDuration);
 
+  late MLFramework _mlFramework;
   bool _hasInitialized = false;
   bool _isComputingEmbeddings = false;
   bool _isSyncing = false;
@@ -61,6 +61,11 @@ class SemanticSearchService {
       return;
     }
     _hasInitialized = true;
+    final shouldDownloadOverMobileData =
+        Configuration.instance.shouldBackupOverMobileData();
+    _mlFramework = kCurrentModel == Model.onnxClip
+        ? ONNX(shouldDownloadOverMobileData)
+        : GGML(shouldDownloadOverMobileData);
     await EmbeddingsDB.instance.init();
     await EmbeddingStore.instance.init();
     await _loadEmbeddings();