[mob] Show progress for downloading of clip models (#1738)

## Description

## Tests
Tested locally, enabled for internal users only.
This commit is contained in:
Neeraj Gupta 2024-05-16 17:09:26 +05:30 committed by GitHub
commit a4ef4ce2c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 74 additions and 3 deletions

View file

@ -1,5 +1,7 @@
import "dart:async";
import "dart:io";
import "package:flutter/foundation.dart";
import "package:logging/logging.dart";
import "package:path_provider/path_provider.dart";
import "package:photos/core/network/network.dart";
@ -8,6 +10,10 @@ class RemoteAssetsService {
static final _logger = Logger("RemoteAssetsService");
RemoteAssetsService._privateConstructor();
final StreamController<(String, int, int)> _progressController =
StreamController<(String, int, int)>.broadcast();
Stream<(String, int, int)> get progressStream => _progressController.stream;
static final RemoteAssetsService instance =
RemoteAssetsService._privateConstructor();
@ -57,7 +63,19 @@ class RemoteAssetsService {
if (await existingFile.exists()) {
await existingFile.delete();
}
await NetworkClient.instance.getDio().download(url, savePath);
await NetworkClient.instance.getDio().download(
url,
savePath,
onReceiveProgress: (received, total) {
if (received > 0 && total > 0) {
_progressController.add((url, received, total));
} else if (kDebugMode) {
debugPrint("$url Received: $received, Total: $total");
}
},
);
_logger.info("Downloaded " + url);
}
}

View file

@ -8,6 +8,7 @@ import "package:photos/generated/l10n.dart";
import "package:photos/service_locator.dart";
import 'package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart';
import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart';
import "package:photos/services/remote_assets_service.dart";
import "package:photos/theme/ente_theme.dart";
import "package:photos/ui/common/loading_widget.dart";
import "package:photos/ui/components/buttons/icon_button_widget.dart";
@ -18,6 +19,7 @@ import "package:photos/ui/components/menu_section_title.dart";
import "package:photos/ui/components/title_bar_title_widget.dart";
import "package:photos/ui/components/title_bar_widget.dart";
import "package:photos/ui/components/toggle_switch_widget.dart";
import "package:photos/utils/data_util.dart";
import "package:photos/utils/local_settings.dart";
class MachineLearningSettingsPage extends StatefulWidget {
@ -176,7 +178,7 @@ class _MachineLearningSettingsPageState
}
}
class ModelLoadingState extends StatelessWidget {
class ModelLoadingState extends StatefulWidget {
final InitializationState state;
const ModelLoadingState(
@ -184,6 +186,38 @@ class ModelLoadingState extends StatelessWidget {
Key? key,
}) : super(key: key);
@override
State<ModelLoadingState> createState() => _ModelLoadingStateState();
}
class _ModelLoadingStateState extends State<ModelLoadingState> {
StreamSubscription<(String, int, int)>? _progressStream;
final Map<String, (int, int)> _progressMap = {};
@override
void initState() {
_progressStream =
RemoteAssetsService.instance.progressStream.listen((event) {
final String url = event.$1;
String title = "";
if (url.contains("clip-image")) {
title = "Image Model";
} else if (url.contains("clip-text")) {
title = "Text Model";
}
if (title.isNotEmpty) {
_progressMap[title] = (event.$2, event.$3);
setState(() {});
}
});
super.initState();
}
@override
void dispose() {
super.dispose();
_progressStream?.cancel();
}
@override
Widget build(BuildContext context) {
return Column(
@ -201,12 +235,31 @@ class ModelLoadingState extends StatelessWidget {
alignCaptionedTextToLeft: true,
isGestureDetectorDisabled: true,
),
// show the progress map if in debug mode
if (flagService.internalUser)
..._progressMap.entries.map((entry) {
return MenuItemWidget(
key: ValueKey(entry.value),
captionedTextWidget: CaptionedTextWidget(
title: entry.key,
),
trailingWidget: Text(
entry.value.$1 == entry.value.$2
? "Done"
: "${formatBytes(entry.value.$1)} / ${formatBytes(entry.value.$2)}",
style: Theme.of(context).textTheme.bodySmall,
),
singleBorderRadius: 8,
alignCaptionedTextToLeft: true,
isGestureDetectorDisabled: true,
);
}).toList(),
],
);
}
String _getTitle(BuildContext context) {
switch (state) {
switch (widget.state) {
case InitializationState.waitingForNetwork:
return S.of(context).waitingForWifi;
default: