Explorar o código

Merge branch 'main' into memories_redesign

ashilkn hai 1 ano
pai
achega
62ea6cdf2c

+ 3 - 1
lib/generated/intl/messages_en.dart

@@ -69,7 +69,7 @@ class MessageLookup extends MessageLookupByLibrary {
       "Please drop an email to ${supportEmail} from your registered email address";
 
   static String m16(count, storageSaved) =>
-      "Your have cleaned up ${Intl.plural(count, one: '${count} duplicate file', other: '${count} duplicate files')}, saving (${storageSaved}!)";
+      "You have cleaned up ${Intl.plural(count, one: '${count} duplicate file', other: '${count} duplicate files')}, saving (${storageSaved}!)";
 
   static String m17(count, formattedSize) =>
       "${count} files, ${formattedSize} each";
@@ -1451,6 +1451,8 @@ class MessageLookup extends MessageLookupByLibrary {
         "viewer": MessageLookupByLibrary.simpleMessage("Viewer"),
         "visitWebToManage": MessageLookupByLibrary.simpleMessage(
             "Please visit web.ente.io to manage your subscription"),
+        "waitingForWifi":
+            MessageLookupByLibrary.simpleMessage("Waiting for WiFi..."),
         "weAreOpenSource":
             MessageLookupByLibrary.simpleMessage("We are open source!"),
         "weDontSupportEditingPhotosAndAlbumsThatYouDont":

+ 1 - 0
lib/generated/intl/messages_zh.dart

@@ -704,6 +704,7 @@ class MessageLookup extends MessageLookupByLibrary {
             MessageLookupByLibrary.simpleMessage("正在加载 EXIF 数据..."),
         "loadingGallery": MessageLookupByLibrary.simpleMessage("正在加载图库..."),
         "loadingMessage": MessageLookupByLibrary.simpleMessage("正在加载您的照片..."),
+        "loadingModel": MessageLookupByLibrary.simpleMessage("正在下载模型..."),
         "localGallery": MessageLookupByLibrary.simpleMessage("本地相册"),
         "location": MessageLookupByLibrary.simpleMessage("地理位置"),
         "locationName": MessageLookupByLibrary.simpleMessage("地点名称"),

+ 12 - 2
lib/generated/l10n.dart

@@ -2886,6 +2886,16 @@ class S {
     );
   }
 
+  /// `Waiting for WiFi...`
+  String get waitingForWifi {
+    return Intl.message(
+      'Waiting for WiFi...',
+      name: 'waitingForWifi',
+      desc: '',
+      args: [],
+    );
+  }
+
   /// `Status`
   String get status {
     return Intl.message(
@@ -3523,10 +3533,10 @@ class S {
     );
   }
 
-  /// `Your have cleaned up {count, plural, one{{count} duplicate file} other{{count} duplicate files}}, saving ({storageSaved}!)`
+  /// `You have cleaned up {count, plural, one{{count} duplicate file} other{{count} duplicate files}}, saving ({storageSaved}!)`
   String duplicateFileCountWithStorageSaved(int count, String storageSaved) {
     return Intl.message(
-      'Your have cleaned up ${Intl.plural(count, one: '$count duplicate file', other: '$count duplicate files')}, saving ($storageSaved!)',
+      'You have cleaned up ${Intl.plural(count, one: '$count duplicate file', other: '$count duplicate files')}, saving ($storageSaved!)',
       name: 'duplicateFileCountWithStorageSaved',
       desc:
           'The text to display when the user has successfully cleaned up duplicate files',

+ 1 - 0
lib/l10n/intl_en.arb

@@ -410,6 +410,7 @@
   "magicSearch": "Magic search",
   "magicSearchDescription": "Please note that this will result in a higher bandwidth and battery usage until all items are indexed.",
   "loadingModel": "Downloading models...",
+  "waitingForWifi": "Waiting for WiFi...",
   "status": "Status",
   "indexedItems": "Indexed items",
   "pendingItems": "Pending items",

+ 14 - 4
lib/services/semantic_search/embedding_store.dart

@@ -50,9 +50,15 @@ class EmbeddingStore {
 
   Future<void> pushEmbeddings() async {
     final pendingItems = await EmbeddingsDB.instance.getUnsyncedEmbeddings();
+    final fileMap = await FilesDB.instance
+        .getFilesFromIDs(pendingItems.map((e) => e.fileID).toList());
+    _logger.info("Pushing ${pendingItems.length} embeddings");
     for (final item in pendingItems) {
-      final file = await FilesDB.instance.getAnyUploadedFile(item.fileID);
-      await _pushEmbedding(file!, item);
+      try {
+        await _pushEmbedding(fileMap[item.fileID]!, item);
+      } catch (e, s) {
+        _logger.severe(e, s);
+      }
     }
   }
 
@@ -67,6 +73,7 @@ class EmbeddingStore {
   }
 
   Future<void> _pushEmbedding(EnteFile file, Embedding embedding) async {
+    _logger.info("Pushing embedding for $file");
     final encryptionKey = getFileKey(file);
     final embeddingJSON = jsonEncode(embedding.embedding);
     final encryptedEmbedding = await CryptoUtil.encryptChaCha(
@@ -86,7 +93,7 @@ class EmbeddingStore {
           "decryptionHeader": header,
         },
       );
-      final updationTime = response.data["updationTime"];
+      final updationTime = response.data["updatedAt"];
       embedding.updationTime = updationTime;
       await EmbeddingsDB.instance.put(embedding);
     } catch (e, s) {
@@ -138,7 +145,10 @@ class EmbeddingStore {
 
     for (final embedding in remoteEmbeddings) {
       final file = fileMap[embedding.fileID];
-      final fileKey = getFileKey(file!);
+      if (file == null) {
+        continue;
+      }
+      final fileKey = getFileKey(file);
       final input = EmbeddingsDecoderInput(embedding, fileKey);
       inputs.add(input);
     }

+ 2 - 0
lib/services/semantic_search/frameworks/ggml.dart

@@ -11,6 +11,8 @@ class GGML extends MLFramework {
   final _computer = Computer.shared();
   final _logger = Logger("GGML");
 
+  GGML(super.shouldDownloadOverMobileData);
+
   @override
   String getImageModelRemotePath() {
     return kModelBucketEndpoint + kImageModel;

+ 80 - 5
lib/services/semantic_search/frameworks/ml_framework.dart

@@ -1,16 +1,47 @@
+import "dart:async";
 import "dart:io";
 
+import "package:connectivity_plus/connectivity_plus.dart";
 import "package:flutter/services.dart";
 import "package:logging/logging.dart";
 import "package:path/path.dart";
 import "package:path_provider/path_provider.dart";
+import "package:photos/core/errors.dart";
+
+import "package:photos/core/event_bus.dart";
 import "package:photos/core/network/network.dart";
+import "package:photos/events/event.dart";
 
 abstract class MLFramework {
   static const kImageEncoderEnabled = true;
   static const kMaximumRetrials = 3;
 
-  final _logger = Logger("MLFramework");
+  static final _logger = Logger("MLFramework");
+
+  final bool shouldDownloadOverMobileData;
+
+  InitializationState _state = InitializationState.notInitialized;
+  final _initializationCompleter = Completer<void>();
+
+  MLFramework(this.shouldDownloadOverMobileData) {
+    Connectivity()
+        .onConnectivityChanged
+        .listen((ConnectivityResult result) async {
+      _logger.info("Connectivity changed to $result");
+      if (_state == InitializationState.waitingForNetwork &&
+          await _canDownload()) {
+        unawaited(init());
+      }
+    });
+  }
+
+  InitializationState get initializationState => _state;
+
+  set _initState(InitializationState state) {
+    Bus.instance.fire(MLFrameworkInitializationUpdateEvent(state));
+    _logger.info("Init state is $state");
+    _state = state;
+  }
 
   /// Returns the path of the Image Model hosted remotely
   String getImageModelRemotePath();
@@ -35,8 +66,18 @@ abstract class MLFramework {
   /// initialization. For eg. if you wish to load the model from `/assets`
   /// instead of a CDN.
   Future<void> init() async {
-    await _initImageModel();
-    await _initTextModel();
+    try {
+      await Future.wait([_initImageModel(), _initTextModel()]);
+    } catch (e, s) {
+      _logger.warning(e, s);
+      if (e is WiFiUnavailableError) {
+        return _initializationCompleter.future;
+      } else {
+        rethrow;
+      }
+    }
+    _initState = InitializationState.initialized;
+    _initializationCompleter.complete();
   }
 
   // Releases any resources held by the framework
@@ -63,27 +104,33 @@ abstract class MLFramework {
     if (!kImageEncoderEnabled) {
       return;
     }
+    _initState = InitializationState.initializingImageModel;
     final path = await _getLocalImageModelPath();
-    if (File(path).existsSync()) {
+    if (await File(path).exists()) {
       await loadImageModel(path);
     } else {
+      _initState = InitializationState.downloadingImageModel;
       final tempFile = File(path + ".temp");
       await _downloadFile(getImageModelRemotePath(), tempFile.path);
       await tempFile.rename(path);
       await loadImageModel(path);
     }
+    _initState = InitializationState.initializedImageModel;
   }
 
   Future<void> _initTextModel() async {
     final path = await _getLocalTextModelPath();
-    if (File(path).existsSync()) {
+    _initState = InitializationState.initializingTextModel;
+    if (await File(path).exists()) {
       await loadTextModel(path);
     } else {
+      _initState = InitializationState.downloadingTextModel;
       final tempFile = File(path + ".temp");
       await _downloadFile(getTextModelRemotePath(), tempFile.path);
       await tempFile.rename(path);
       await loadTextModel(path);
     }
+    _initState = InitializationState.initializedTextModel;
   }
 
   Future<String> _getLocalImageModelPath() async {
@@ -103,6 +150,10 @@ abstract class MLFramework {
     String savePath, {
     int trialCount = 1,
   }) async {
+    if (!await _canDownload()) {
+      _initState = InitializationState.waitingForNetwork;
+      throw WiFiUnavailableError();
+    }
     _logger.info("Downloading " + url);
     final existingFile = File(savePath);
     if (await existingFile.exists()) {
@@ -120,6 +171,12 @@ abstract class MLFramework {
     }
   }
 
+  Future<bool> _canDownload() async {
+    final connectivityResult = await (Connectivity().checkConnectivity());
+    return connectivityResult != ConnectivityResult.mobile ||
+        shouldDownloadOverMobileData;
+  }
+
   Future<String> getAccessiblePathForAsset(
     String assetPath,
     String tempName,
@@ -131,3 +188,21 @@ abstract class MLFramework {
     return file.path;
   }
 }
+
+class MLFrameworkInitializationUpdateEvent extends Event {
+  final InitializationState state;
+
+  MLFrameworkInitializationUpdateEvent(this.state);
+}
+
+enum InitializationState {
+  notInitialized,
+  waitingForNetwork,
+  downloadingImageModel,
+  initializingImageModel,
+  initializedImageModel,
+  downloadingTextModel,
+  initializingTextModel,
+  initializedTextModel,
+  initialized,
+}

+ 2 - 0
lib/services/semantic_search/frameworks/onnx/onnx.dart

@@ -17,6 +17,8 @@ class ONNX extends MLFramework {
   int _textEncoderAddress = 0;
   int _imageEncoderAddress = 0;
 
+  ONNX(super.shouldDownloadOverMobileData);
+
   @override
   String getImageModelRemotePath() {
     return kModelBucketEndpoint + kImageModel;

+ 8 - 3
lib/services/semantic_search/semantic_search_service.dart

@@ -38,11 +38,11 @@ class SemanticSearchService {
 
   final _logger = Logger("SemanticSearchService");
   final _queue = Queue<EnteFile>();
-  final _mlFramework = kCurrentModel == Model.onnxClip ? ONNX() : GGML();
   final _frameworkInitialization = Completer<bool>();
   final _embeddingLoaderDebouncer =
       Debouncer(kDebounceDuration, executionInterval: kDebounceDuration);
 
+  late MLFramework _mlFramework;
   bool _hasInitialized = false;
   bool _isComputingEmbeddings = false;
   bool _isSyncing = false;
@@ -61,6 +61,11 @@ class SemanticSearchService {
       return;
     }
     _hasInitialized = true;
+    final shouldDownloadOverMobileData =
+        Configuration.instance.shouldBackupOverMobileData();
+    _mlFramework = kCurrentModel == Model.onnxClip
+        ? ONNX(shouldDownloadOverMobileData)
+        : GGML(shouldDownloadOverMobileData);
     await EmbeddingsDB.instance.init();
     await EmbeddingStore.instance.init();
     await _loadEmbeddings();
@@ -145,8 +150,8 @@ class SemanticSearchService {
     );
   }
 
-  Future<bool> getFrameworkInitializationStatus() {
-    return _frameworkInitialization.future;
+  InitializationState getFrameworkInitializationState() {
+    return _mlFramework.initializationState;
   }
 
   Future<void> clearIndexes() async {

+ 46 - 13
lib/ui/settings/machine_learning_settings_page.dart

@@ -6,6 +6,7 @@ import "package:photos/core/event_bus.dart";
 import 'package:photos/events/embedding_updated_event.dart';
 import "package:photos/generated/l10n.dart";
 import "package:photos/services/feature_flag_service.dart";
+import "package:photos/services/semantic_search/frameworks/ml_framework.dart";
 import "package:photos/services/semantic_search/semantic_search_service.dart";
 import "package:photos/theme/ente_theme.dart";
 import "package:photos/ui/common/loading_widget.dart";
@@ -29,6 +30,32 @@ class MachineLearningSettingsPage extends StatefulWidget {
 
 class _MachineLearningSettingsPageState
     extends State<MachineLearningSettingsPage> {
+  late InitializationState _state;
+
+  late StreamSubscription<MLFrameworkInitializationUpdateEvent>
+      _eventSubscription;
+
+  @override
+  void initState() {
+    super.initState();
+    _eventSubscription =
+        Bus.instance.on<MLFrameworkInitializationUpdateEvent>().listen((event) {
+      _fetchState();
+      setState(() {});
+    });
+    _fetchState();
+  }
+
+  void _fetchState() {
+    _state = SemanticSearchService.instance.getFrameworkInitializationState();
+  }
+
+  @override
+  void dispose() {
+    super.dispose();
+    _eventSubscription.cancel();
+  }
+
   @override
   Widget build(BuildContext context) {
     return Scaffold(
@@ -118,17 +145,9 @@ class _MachineLearningSettingsPageState
         hasEnabled
             ? Column(
                 children: [
-                  FutureBuilder(
-                    future: SemanticSearchService.instance
-                        .getFrameworkInitializationStatus(),
-                    builder: (BuildContext context, AsyncSnapshot snapshot) {
-                      if (snapshot.hasData) {
-                        return const MagicSearchIndexStatsWidget();
-                      } else {
-                        return const ModelLoadingState();
-                      }
-                    },
-                  ),
+                  _state == InitializationState.initialized
+                      ? const MagicSearchIndexStatsWidget()
+                      : ModelLoadingState(_state),
                   const SizedBox(
                     height: 12,
                   ),
@@ -158,7 +177,12 @@ class _MachineLearningSettingsPageState
 }
 
 class ModelLoadingState extends StatelessWidget {
-  const ModelLoadingState({super.key});
+  final InitializationState state;
+
+  const ModelLoadingState(
+    this.state, {
+    Key? key,
+  }) : super(key: key);
 
   @override
   Widget build(BuildContext context) {
@@ -167,7 +191,7 @@ class ModelLoadingState extends StatelessWidget {
         MenuSectionTitle(title: S.of(context).status),
         MenuItemWidget(
           captionedTextWidget: CaptionedTextWidget(
-            title: S.of(context).loadingModel,
+            title: _getTitle(context),
           ),
           trailingWidget: EnteLoadingWidget(
             size: 12,
@@ -180,6 +204,15 @@ class ModelLoadingState extends StatelessWidget {
       ],
     );
   }
+
+  String _getTitle(BuildContext context) {
+    switch (state) {
+      case InitializationState.waitingForNetwork:
+        return S.of(context).waitingForWifi;
+      default:
+        return S.of(context).loadingModel;
+    }
+  }
 }
 
 class MagicSearchIndexStatsWidget extends StatefulWidget {

+ 1 - 1
pubspec.yaml

@@ -12,7 +12,7 @@ description: ente photos application
 # Read more about iOS versioning at
 # https://developer.apple.com/library/archive/documentation/General/Reference/InfoPlistKeyReference/Articles/CoreFoundationKeys.html
 
-version: 0.8.33+553
+version: 0.8.35+555
 
 environment:
   sdk: ">=3.0.0 <4.0.0"