diff --git a/lib/generated/intl/messages_en.dart b/lib/generated/intl/messages_en.dart index 3febd6407..254eacaa9 100644 --- a/lib/generated/intl/messages_en.dart +++ b/lib/generated/intl/messages_en.dart @@ -69,7 +69,7 @@ class MessageLookup extends MessageLookupByLibrary { "Please drop an email to ${supportEmail} from your registered email address"; static String m16(count, storageSaved) => - "Your have cleaned up ${Intl.plural(count, one: '${count} duplicate file', other: '${count} duplicate files')}, saving (${storageSaved}!)"; + "You have cleaned up ${Intl.plural(count, one: '${count} duplicate file', other: '${count} duplicate files')}, saving (${storageSaved}!)"; static String m17(count, formattedSize) => "${count} files, ${formattedSize} each"; @@ -1451,6 +1451,8 @@ class MessageLookup extends MessageLookupByLibrary { "viewer": MessageLookupByLibrary.simpleMessage("Viewer"), "visitWebToManage": MessageLookupByLibrary.simpleMessage( "Please visit web.ente.io to manage your subscription"), + "waitingForWifi": + MessageLookupByLibrary.simpleMessage("Waiting for WiFi..."), "weAreOpenSource": MessageLookupByLibrary.simpleMessage("We are open source!"), "weDontSupportEditingPhotosAndAlbumsThatYouDont": diff --git a/lib/generated/intl/messages_zh.dart b/lib/generated/intl/messages_zh.dart index ecf65f05e..d04c68521 100644 --- a/lib/generated/intl/messages_zh.dart +++ b/lib/generated/intl/messages_zh.dart @@ -704,6 +704,7 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("正在加载 EXIF 数据..."), "loadingGallery": MessageLookupByLibrary.simpleMessage("正在加载图库..."), "loadingMessage": MessageLookupByLibrary.simpleMessage("正在加载您的照片..."), + "loadingModel": MessageLookupByLibrary.simpleMessage("正在下载模型..."), "localGallery": MessageLookupByLibrary.simpleMessage("本地相册"), "location": MessageLookupByLibrary.simpleMessage("地理位置"), "locationName": MessageLookupByLibrary.simpleMessage("地点名称"), diff --git a/lib/generated/l10n.dart b/lib/generated/l10n.dart index 9f3136f8a..e882159ed 100644 --- a/lib/generated/l10n.dart +++ b/lib/generated/l10n.dart @@ -2886,6 +2886,16 @@ class S { ); } + /// `Waiting for WiFi...` + String get waitingForWifi { + return Intl.message( + 'Waiting for WiFi...', + name: 'waitingForWifi', + desc: '', + args: [], + ); + } + /// `Status` String get status { return Intl.message( @@ -3523,10 +3533,10 @@ class S { ); } - /// `Your have cleaned up {count, plural, one{{count} duplicate file} other{{count} duplicate files}}, saving ({storageSaved}!)` + /// `You have cleaned up {count, plural, one{{count} duplicate file} other{{count} duplicate files}}, saving ({storageSaved}!)` String duplicateFileCountWithStorageSaved(int count, String storageSaved) { return Intl.message( - 'Your have cleaned up ${Intl.plural(count, one: '$count duplicate file', other: '$count duplicate files')}, saving ($storageSaved!)', + 'You have cleaned up ${Intl.plural(count, one: '$count duplicate file', other: '$count duplicate files')}, saving ($storageSaved!)', name: 'duplicateFileCountWithStorageSaved', desc: 'The text to display when the user has successfully cleaned up duplicate files', diff --git a/lib/l10n/intl_en.arb b/lib/l10n/intl_en.arb index 873593fe0..ae31c07d6 100644 --- a/lib/l10n/intl_en.arb +++ b/lib/l10n/intl_en.arb @@ -410,6 +410,7 @@ "magicSearch": "Magic search", "magicSearchDescription": "Please note that this will result in a higher bandwidth and battery usage until all items are indexed.", "loadingModel": "Downloading models...", + "waitingForWifi": "Waiting for WiFi...", "status": "Status", "indexedItems": "Indexed items", "pendingItems": "Pending items", diff --git a/lib/services/semantic_search/embedding_store.dart b/lib/services/semantic_search/embedding_store.dart index 6aedf8547..4b665c8af 100644 --- a/lib/services/semantic_search/embedding_store.dart +++ b/lib/services/semantic_search/embedding_store.dart @@ -50,9 +50,15 @@ class EmbeddingStore { Future pushEmbeddings() async { final pendingItems = await EmbeddingsDB.instance.getUnsyncedEmbeddings(); + final fileMap = await FilesDB.instance + .getFilesFromIDs(pendingItems.map((e) => e.fileID).toList()); + _logger.info("Pushing ${pendingItems.length} embeddings"); for (final item in pendingItems) { - final file = await FilesDB.instance.getAnyUploadedFile(item.fileID); - await _pushEmbedding(file!, item); + try { + await _pushEmbedding(fileMap[item.fileID]!, item); + } catch (e, s) { + _logger.severe(e, s); + } } } @@ -67,6 +73,7 @@ class EmbeddingStore { } Future _pushEmbedding(EnteFile file, Embedding embedding) async { + _logger.info("Pushing embedding for $file"); final encryptionKey = getFileKey(file); final embeddingJSON = jsonEncode(embedding.embedding); final encryptedEmbedding = await CryptoUtil.encryptChaCha( @@ -86,7 +93,7 @@ class EmbeddingStore { "decryptionHeader": header, }, ); - final updationTime = response.data["updationTime"]; + final updationTime = response.data["updatedAt"]; embedding.updationTime = updationTime; await EmbeddingsDB.instance.put(embedding); } catch (e, s) { @@ -138,7 +145,10 @@ class EmbeddingStore { for (final embedding in remoteEmbeddings) { final file = fileMap[embedding.fileID]; - final fileKey = getFileKey(file!); + if (file == null) { + continue; + } + final fileKey = getFileKey(file); final input = EmbeddingsDecoderInput(embedding, fileKey); inputs.add(input); } diff --git a/lib/services/semantic_search/frameworks/ggml.dart b/lib/services/semantic_search/frameworks/ggml.dart index e4903091c..6ff862084 100644 --- a/lib/services/semantic_search/frameworks/ggml.dart +++ b/lib/services/semantic_search/frameworks/ggml.dart @@ -11,6 +11,8 @@ class GGML extends MLFramework { final _computer = Computer.shared(); final _logger = Logger("GGML"); + GGML(super.shouldDownloadOverMobileData); + @override String getImageModelRemotePath() { return kModelBucketEndpoint + kImageModel; diff --git a/lib/services/semantic_search/frameworks/ml_framework.dart b/lib/services/semantic_search/frameworks/ml_framework.dart index 2bf415881..3523c721e 100644 --- a/lib/services/semantic_search/frameworks/ml_framework.dart +++ b/lib/services/semantic_search/frameworks/ml_framework.dart @@ -1,16 +1,47 @@ +import "dart:async"; import "dart:io"; +import "package:connectivity_plus/connectivity_plus.dart"; import "package:flutter/services.dart"; import "package:logging/logging.dart"; import "package:path/path.dart"; import "package:path_provider/path_provider.dart"; +import "package:photos/core/errors.dart"; + +import "package:photos/core/event_bus.dart"; import "package:photos/core/network/network.dart"; +import "package:photos/events/event.dart"; abstract class MLFramework { static const kImageEncoderEnabled = true; static const kMaximumRetrials = 3; - final _logger = Logger("MLFramework"); + static final _logger = Logger("MLFramework"); + + final bool shouldDownloadOverMobileData; + + InitializationState _state = InitializationState.notInitialized; + final _initializationCompleter = Completer(); + + MLFramework(this.shouldDownloadOverMobileData) { + Connectivity() + .onConnectivityChanged + .listen((ConnectivityResult result) async { + _logger.info("Connectivity changed to $result"); + if (_state == InitializationState.waitingForNetwork && + await _canDownload()) { + unawaited(init()); + } + }); + } + + InitializationState get initializationState => _state; + + set _initState(InitializationState state) { + Bus.instance.fire(MLFrameworkInitializationUpdateEvent(state)); + _logger.info("Init state is $state"); + _state = state; + } /// Returns the path of the Image Model hosted remotely String getImageModelRemotePath(); @@ -35,8 +66,18 @@ abstract class MLFramework { /// initialization. For eg. if you wish to load the model from `/assets` /// instead of a CDN. Future init() async { - await _initImageModel(); - await _initTextModel(); + try { + await Future.wait([_initImageModel(), _initTextModel()]); + } catch (e, s) { + _logger.warning(e, s); + if (e is WiFiUnavailableError) { + return _initializationCompleter.future; + } else { + rethrow; + } + } + _initState = InitializationState.initialized; + _initializationCompleter.complete(); } // Releases any resources held by the framework @@ -63,27 +104,33 @@ abstract class MLFramework { if (!kImageEncoderEnabled) { return; } + _initState = InitializationState.initializingImageModel; final path = await _getLocalImageModelPath(); - if (File(path).existsSync()) { + if (await File(path).exists()) { await loadImageModel(path); } else { + _initState = InitializationState.downloadingImageModel; final tempFile = File(path + ".temp"); await _downloadFile(getImageModelRemotePath(), tempFile.path); await tempFile.rename(path); await loadImageModel(path); } + _initState = InitializationState.initializedImageModel; } Future _initTextModel() async { final path = await _getLocalTextModelPath(); - if (File(path).existsSync()) { + _initState = InitializationState.initializingTextModel; + if (await File(path).exists()) { await loadTextModel(path); } else { + _initState = InitializationState.downloadingTextModel; final tempFile = File(path + ".temp"); await _downloadFile(getTextModelRemotePath(), tempFile.path); await tempFile.rename(path); await loadTextModel(path); } + _initState = InitializationState.initializedTextModel; } Future _getLocalImageModelPath() async { @@ -103,6 +150,10 @@ abstract class MLFramework { String savePath, { int trialCount = 1, }) async { + if (!await _canDownload()) { + _initState = InitializationState.waitingForNetwork; + throw WiFiUnavailableError(); + } _logger.info("Downloading " + url); final existingFile = File(savePath); if (await existingFile.exists()) { @@ -120,6 +171,12 @@ abstract class MLFramework { } } + Future _canDownload() async { + final connectivityResult = await (Connectivity().checkConnectivity()); + return connectivityResult != ConnectivityResult.mobile || + shouldDownloadOverMobileData; + } + Future getAccessiblePathForAsset( String assetPath, String tempName, @@ -131,3 +188,21 @@ abstract class MLFramework { return file.path; } } + +class MLFrameworkInitializationUpdateEvent extends Event { + final InitializationState state; + + MLFrameworkInitializationUpdateEvent(this.state); +} + +enum InitializationState { + notInitialized, + waitingForNetwork, + downloadingImageModel, + initializingImageModel, + initializedImageModel, + downloadingTextModel, + initializingTextModel, + initializedTextModel, + initialized, +} diff --git a/lib/services/semantic_search/frameworks/onnx/onnx.dart b/lib/services/semantic_search/frameworks/onnx/onnx.dart index 0a18c66e6..00930ccac 100644 --- a/lib/services/semantic_search/frameworks/onnx/onnx.dart +++ b/lib/services/semantic_search/frameworks/onnx/onnx.dart @@ -17,6 +17,8 @@ class ONNX extends MLFramework { int _textEncoderAddress = 0; int _imageEncoderAddress = 0; + ONNX(super.shouldDownloadOverMobileData); + @override String getImageModelRemotePath() { return kModelBucketEndpoint + kImageModel; diff --git a/lib/services/semantic_search/semantic_search_service.dart b/lib/services/semantic_search/semantic_search_service.dart index b233559cf..ddb03d78b 100644 --- a/lib/services/semantic_search/semantic_search_service.dart +++ b/lib/services/semantic_search/semantic_search_service.dart @@ -38,11 +38,11 @@ class SemanticSearchService { final _logger = Logger("SemanticSearchService"); final _queue = Queue(); - final _mlFramework = kCurrentModel == Model.onnxClip ? ONNX() : GGML(); final _frameworkInitialization = Completer(); final _embeddingLoaderDebouncer = Debouncer(kDebounceDuration, executionInterval: kDebounceDuration); + late MLFramework _mlFramework; bool _hasInitialized = false; bool _isComputingEmbeddings = false; bool _isSyncing = false; @@ -61,6 +61,11 @@ class SemanticSearchService { return; } _hasInitialized = true; + final shouldDownloadOverMobileData = + Configuration.instance.shouldBackupOverMobileData(); + _mlFramework = kCurrentModel == Model.onnxClip + ? ONNX(shouldDownloadOverMobileData) + : GGML(shouldDownloadOverMobileData); await EmbeddingsDB.instance.init(); await EmbeddingStore.instance.init(); await _loadEmbeddings(); @@ -145,8 +150,8 @@ class SemanticSearchService { ); } - Future getFrameworkInitializationStatus() { - return _frameworkInitialization.future; + InitializationState getFrameworkInitializationState() { + return _mlFramework.initializationState; } Future clearIndexes() async { diff --git a/lib/ui/settings/machine_learning_settings_page.dart b/lib/ui/settings/machine_learning_settings_page.dart index c64ae5714..6afb3704a 100644 --- a/lib/ui/settings/machine_learning_settings_page.dart +++ b/lib/ui/settings/machine_learning_settings_page.dart @@ -6,6 +6,7 @@ import "package:photos/core/event_bus.dart"; import 'package:photos/events/embedding_updated_event.dart'; import "package:photos/generated/l10n.dart"; import "package:photos/services/feature_flag_service.dart"; +import "package:photos/services/semantic_search/frameworks/ml_framework.dart"; import "package:photos/services/semantic_search/semantic_search_service.dart"; import "package:photos/theme/ente_theme.dart"; import "package:photos/ui/common/loading_widget.dart"; @@ -29,6 +30,32 @@ class MachineLearningSettingsPage extends StatefulWidget { class _MachineLearningSettingsPageState extends State { + late InitializationState _state; + + late StreamSubscription + _eventSubscription; + + @override + void initState() { + super.initState(); + _eventSubscription = + Bus.instance.on().listen((event) { + _fetchState(); + setState(() {}); + }); + _fetchState(); + } + + void _fetchState() { + _state = SemanticSearchService.instance.getFrameworkInitializationState(); + } + + @override + void dispose() { + super.dispose(); + _eventSubscription.cancel(); + } + @override Widget build(BuildContext context) { return Scaffold( @@ -118,17 +145,9 @@ class _MachineLearningSettingsPageState hasEnabled ? Column( children: [ - FutureBuilder( - future: SemanticSearchService.instance - .getFrameworkInitializationStatus(), - builder: (BuildContext context, AsyncSnapshot snapshot) { - if (snapshot.hasData) { - return const MagicSearchIndexStatsWidget(); - } else { - return const ModelLoadingState(); - } - }, - ), + _state == InitializationState.initialized + ? const MagicSearchIndexStatsWidget() + : ModelLoadingState(_state), const SizedBox( height: 12, ), @@ -158,7 +177,12 @@ class _MachineLearningSettingsPageState } class ModelLoadingState extends StatelessWidget { - const ModelLoadingState({super.key}); + final InitializationState state; + + const ModelLoadingState( + this.state, { + Key? key, + }) : super(key: key); @override Widget build(BuildContext context) { @@ -167,7 +191,7 @@ class ModelLoadingState extends StatelessWidget { MenuSectionTitle(title: S.of(context).status), MenuItemWidget( captionedTextWidget: CaptionedTextWidget( - title: S.of(context).loadingModel, + title: _getTitle(context), ), trailingWidget: EnteLoadingWidget( size: 12, @@ -180,6 +204,15 @@ class ModelLoadingState extends StatelessWidget { ], ); } + + String _getTitle(BuildContext context) { + switch (state) { + case InitializationState.waitingForNetwork: + return S.of(context).waitingForWifi; + default: + return S.of(context).loadingModel; + } + } } class MagicSearchIndexStatsWidget extends StatefulWidget { diff --git a/pubspec.yaml b/pubspec.yaml index d38ab111e..e74135aa7 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -12,7 +12,7 @@ description: ente photos application # Read more about iOS versioning at # https://developer.apple.com/library/archive/documentation/General/Reference/InfoPlistKeyReference/Articles/CoreFoundationKeys.html -version: 0.8.33+553 +version: 0.8.35+555 environment: sdk: ">=3.0.0 <4.0.0"