diff --git a/mobile/lib/services/remote_assets_service.dart b/mobile/lib/services/remote_assets_service.dart index 251ce6c15..1e2cb3b6d 100644 --- a/mobile/lib/services/remote_assets_service.dart +++ b/mobile/lib/services/remote_assets_service.dart @@ -1,5 +1,7 @@ +import "dart:async"; import "dart:io"; +import "package:flutter/foundation.dart"; import "package:logging/logging.dart"; import "package:path_provider/path_provider.dart"; import "package:photos/core/network/network.dart"; @@ -8,6 +10,10 @@ class RemoteAssetsService { static final _logger = Logger("RemoteAssetsService"); RemoteAssetsService._privateConstructor(); + final StreamController<(String, int, int)> _progressController = + StreamController<(String, int, int)>.broadcast(); + + Stream<(String, int, int)> get progressStream => _progressController.stream; static final RemoteAssetsService instance = RemoteAssetsService._privateConstructor(); @@ -57,7 +63,19 @@ class RemoteAssetsService { if (await existingFile.exists()) { await existingFile.delete(); } - await NetworkClient.instance.getDio().download(url, savePath); + + await NetworkClient.instance.getDio().download( + url, + savePath, + onReceiveProgress: (received, total) { + if (received > 0 && total > 0) { + _progressController.add((url, received, total)); + } else if (kDebugMode) { + debugPrint("$url Received: $received, Total: $total"); + } + }, + ); + _logger.info("Downloaded " + url); } } diff --git a/mobile/lib/ui/settings/machine_learning_settings_page.dart b/mobile/lib/ui/settings/machine_learning_settings_page.dart index a0b72ae09..9820d882f 100644 --- a/mobile/lib/ui/settings/machine_learning_settings_page.dart +++ b/mobile/lib/ui/settings/machine_learning_settings_page.dart @@ -8,6 +8,7 @@ import "package:photos/generated/l10n.dart"; import "package:photos/service_locator.dart"; import 'package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart'; import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart'; +import "package:photos/services/remote_assets_service.dart"; import "package:photos/theme/ente_theme.dart"; import "package:photos/ui/common/loading_widget.dart"; import "package:photos/ui/components/buttons/icon_button_widget.dart"; @@ -18,6 +19,7 @@ import "package:photos/ui/components/menu_section_title.dart"; import "package:photos/ui/components/title_bar_title_widget.dart"; import "package:photos/ui/components/title_bar_widget.dart"; import "package:photos/ui/components/toggle_switch_widget.dart"; +import "package:photos/utils/data_util.dart"; import "package:photos/utils/local_settings.dart"; class MachineLearningSettingsPage extends StatefulWidget { @@ -176,7 +178,7 @@ class _MachineLearningSettingsPageState } } -class ModelLoadingState extends StatelessWidget { +class ModelLoadingState extends StatefulWidget { final InitializationState state; const ModelLoadingState( @@ -184,6 +186,38 @@ class ModelLoadingState extends StatelessWidget { Key? key, }) : super(key: key); + @override + State createState() => _ModelLoadingStateState(); +} + +class _ModelLoadingStateState extends State { + StreamSubscription<(String, int, int)>? _progressStream; + final Map _progressMap = {}; + @override + void initState() { + _progressStream = + RemoteAssetsService.instance.progressStream.listen((event) { + final String url = event.$1; + String title = ""; + if (url.contains("clip-image")) { + title = "Image Model"; + } else if (url.contains("clip-text")) { + title = "Text Model"; + } + if (title.isNotEmpty) { + _progressMap[title] = (event.$2, event.$3); + setState(() {}); + } + }); + super.initState(); + } + + @override + void dispose() { + super.dispose(); + _progressStream?.cancel(); + } + @override Widget build(BuildContext context) { return Column( @@ -201,12 +235,31 @@ class ModelLoadingState extends StatelessWidget { alignCaptionedTextToLeft: true, isGestureDetectorDisabled: true, ), + // show the progress map if in debug mode + if (flagService.internalUser) + ..._progressMap.entries.map((entry) { + return MenuItemWidget( + key: ValueKey(entry.value), + captionedTextWidget: CaptionedTextWidget( + title: entry.key, + ), + trailingWidget: Text( + entry.value.$1 == entry.value.$2 + ? "Done" + : "${formatBytes(entry.value.$1)} / ${formatBytes(entry.value.$2)}", + style: Theme.of(context).textTheme.bodySmall, + ), + singleBorderRadius: 8, + alignCaptionedTextToLeft: true, + isGestureDetectorDisabled: true, + ); + }).toList(), ], ); } String _getTitle(BuildContext context) { - switch (state) { + switch (widget.state) { case InitializationState.waitingForNetwork: return S.of(context).waitingForWifi; default: