[mob] Show progress for downloading of clip models (#1738)
## Description ## Tests Tested locally, enabled for internal users only.
This commit is contained in:
commit
a4ef4ce2c1
2 changed files with 74 additions and 3 deletions
|
@ -1,5 +1,7 @@
|
|||
import "dart:async";
|
||||
import "dart:io";
|
||||
|
||||
import "package:flutter/foundation.dart";
|
||||
import "package:logging/logging.dart";
|
||||
import "package:path_provider/path_provider.dart";
|
||||
import "package:photos/core/network/network.dart";
|
||||
|
@ -8,6 +10,10 @@ class RemoteAssetsService {
|
|||
static final _logger = Logger("RemoteAssetsService");
|
||||
|
||||
RemoteAssetsService._privateConstructor();
|
||||
final StreamController<(String, int, int)> _progressController =
|
||||
StreamController<(String, int, int)>.broadcast();
|
||||
|
||||
Stream<(String, int, int)> get progressStream => _progressController.stream;
|
||||
|
||||
static final RemoteAssetsService instance =
|
||||
RemoteAssetsService._privateConstructor();
|
||||
|
@ -57,7 +63,19 @@ class RemoteAssetsService {
|
|||
if (await existingFile.exists()) {
|
||||
await existingFile.delete();
|
||||
}
|
||||
await NetworkClient.instance.getDio().download(url, savePath);
|
||||
|
||||
await NetworkClient.instance.getDio().download(
|
||||
url,
|
||||
savePath,
|
||||
onReceiveProgress: (received, total) {
|
||||
if (received > 0 && total > 0) {
|
||||
_progressController.add((url, received, total));
|
||||
} else if (kDebugMode) {
|
||||
debugPrint("$url Received: $received, Total: $total");
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
_logger.info("Downloaded " + url);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import "package:photos/generated/l10n.dart";
|
|||
import "package:photos/service_locator.dart";
|
||||
import 'package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart';
|
||||
import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart';
|
||||
import "package:photos/services/remote_assets_service.dart";
|
||||
import "package:photos/theme/ente_theme.dart";
|
||||
import "package:photos/ui/common/loading_widget.dart";
|
||||
import "package:photos/ui/components/buttons/icon_button_widget.dart";
|
||||
|
@ -18,6 +19,7 @@ import "package:photos/ui/components/menu_section_title.dart";
|
|||
import "package:photos/ui/components/title_bar_title_widget.dart";
|
||||
import "package:photos/ui/components/title_bar_widget.dart";
|
||||
import "package:photos/ui/components/toggle_switch_widget.dart";
|
||||
import "package:photos/utils/data_util.dart";
|
||||
import "package:photos/utils/local_settings.dart";
|
||||
|
||||
class MachineLearningSettingsPage extends StatefulWidget {
|
||||
|
@ -176,7 +178,7 @@ class _MachineLearningSettingsPageState
|
|||
}
|
||||
}
|
||||
|
||||
class ModelLoadingState extends StatelessWidget {
|
||||
class ModelLoadingState extends StatefulWidget {
|
||||
final InitializationState state;
|
||||
|
||||
const ModelLoadingState(
|
||||
|
@ -184,6 +186,38 @@ class ModelLoadingState extends StatelessWidget {
|
|||
Key? key,
|
||||
}) : super(key: key);
|
||||
|
||||
@override
|
||||
State<ModelLoadingState> createState() => _ModelLoadingStateState();
|
||||
}
|
||||
|
||||
class _ModelLoadingStateState extends State<ModelLoadingState> {
|
||||
StreamSubscription<(String, int, int)>? _progressStream;
|
||||
final Map<String, (int, int)> _progressMap = {};
|
||||
@override
|
||||
void initState() {
|
||||
_progressStream =
|
||||
RemoteAssetsService.instance.progressStream.listen((event) {
|
||||
final String url = event.$1;
|
||||
String title = "";
|
||||
if (url.contains("clip-image")) {
|
||||
title = "Image Model";
|
||||
} else if (url.contains("clip-text")) {
|
||||
title = "Text Model";
|
||||
}
|
||||
if (title.isNotEmpty) {
|
||||
_progressMap[title] = (event.$2, event.$3);
|
||||
setState(() {});
|
||||
}
|
||||
});
|
||||
super.initState();
|
||||
}
|
||||
|
||||
@override
|
||||
void dispose() {
|
||||
super.dispose();
|
||||
_progressStream?.cancel();
|
||||
}
|
||||
|
||||
@override
|
||||
Widget build(BuildContext context) {
|
||||
return Column(
|
||||
|
@ -201,12 +235,31 @@ class ModelLoadingState extends StatelessWidget {
|
|||
alignCaptionedTextToLeft: true,
|
||||
isGestureDetectorDisabled: true,
|
||||
),
|
||||
// show the progress map if in debug mode
|
||||
if (flagService.internalUser)
|
||||
..._progressMap.entries.map((entry) {
|
||||
return MenuItemWidget(
|
||||
key: ValueKey(entry.value),
|
||||
captionedTextWidget: CaptionedTextWidget(
|
||||
title: entry.key,
|
||||
),
|
||||
trailingWidget: Text(
|
||||
entry.value.$1 == entry.value.$2
|
||||
? "Done"
|
||||
: "${formatBytes(entry.value.$1)} / ${formatBytes(entry.value.$2)}",
|
||||
style: Theme.of(context).textTheme.bodySmall,
|
||||
),
|
||||
singleBorderRadius: 8,
|
||||
alignCaptionedTextToLeft: true,
|
||||
isGestureDetectorDisabled: true,
|
||||
);
|
||||
}).toList(),
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
String _getTitle(BuildContext context) {
|
||||
switch (state) {
|
||||
switch (widget.state) {
|
||||
case InitializationState.waitingForNetwork:
|
||||
return S.of(context).waitingForWifi;
|
||||
default:
|
||||
|
|
Loading…
Add table
Reference in a new issue