|
@@ -1,9 +1,13 @@
|
|
|
+import "dart:async";
|
|
|
import "dart:io";
|
|
|
|
|
|
+import "package:connectivity_plus/connectivity_plus.dart";
|
|
|
import "package:flutter/services.dart";
|
|
|
import "package:logging/logging.dart";
|
|
|
import "package:path/path.dart";
|
|
|
import "package:path_provider/path_provider.dart";
|
|
|
+import "package:photos/core/errors.dart";
|
|
|
+
|
|
|
import "package:photos/core/event_bus.dart";
|
|
|
import "package:photos/core/network/network.dart";
|
|
|
import "package:photos/events/event.dart";
|
|
@@ -12,9 +16,24 @@ abstract class MLFramework {
|
|
|
static const kImageEncoderEnabled = true;
|
|
|
static const kMaximumRetrials = 3;
|
|
|
|
|
|
- InitializationState _state = InitializationState.notInitialized;
|
|
|
+ static final _logger = Logger("MLFramework");
|
|
|
|
|
|
- final _logger = Logger("MLFramework");
|
|
|
+ final bool shouldDownloadOverMobileData;
|
|
|
+
|
|
|
+ InitializationState _state = InitializationState.notInitialized;
|
|
|
+ final _initializationCompleter = Completer<void>();
|
|
|
+
|
|
|
+ MLFramework(this.shouldDownloadOverMobileData) {
|
|
|
+ Connectivity()
|
|
|
+ .onConnectivityChanged
|
|
|
+ .listen((ConnectivityResult result) async {
|
|
|
+ _logger.info("Connectivity changed to $result");
|
|
|
+ if (_state == InitializationState.waitingForNetwork &&
|
|
|
+ await _canDownload()) {
|
|
|
+ unawaited(init());
|
|
|
+ }
|
|
|
+ });
|
|
|
+ }
|
|
|
|
|
|
InitializationState get initializationState => _state;
|
|
|
|
|
@@ -47,8 +66,18 @@ abstract class MLFramework {
|
|
|
/// initialization. For eg. if you wish to load the model from `/assets`
|
|
|
/// instead of a CDN.
|
|
|
Future<void> init() async {
|
|
|
- await Future.wait([_initImageModel(), _initTextModel()]);
|
|
|
+ try {
|
|
|
+ await Future.wait([_initImageModel(), _initTextModel()]);
|
|
|
+ } catch (e, s) {
|
|
|
+ _logger.warning(e, s);
|
|
|
+ if (e is WiFiUnavailableError) {
|
|
|
+ return _initializationCompleter.future;
|
|
|
+ } else {
|
|
|
+ rethrow;
|
|
|
+ }
|
|
|
+ }
|
|
|
_initState = InitializationState.initialized;
|
|
|
+ _initializationCompleter.complete();
|
|
|
}
|
|
|
|
|
|
// Releases any resources held by the framework
|
|
@@ -75,9 +104,9 @@ abstract class MLFramework {
|
|
|
if (!kImageEncoderEnabled) {
|
|
|
return;
|
|
|
}
|
|
|
+ _initState = InitializationState.initializingImageModel;
|
|
|
final path = await _getLocalImageModelPath();
|
|
|
- if (File(path).existsSync()) {
|
|
|
- _initState = InitializationState.initializingImageModel;
|
|
|
+ if (await File(path).exists()) {
|
|
|
await loadImageModel(path);
|
|
|
} else {
|
|
|
_initState = InitializationState.downloadingImageModel;
|
|
@@ -86,12 +115,13 @@ abstract class MLFramework {
|
|
|
await tempFile.rename(path);
|
|
|
await loadImageModel(path);
|
|
|
}
|
|
|
+ _initState = InitializationState.initializedImageModel;
|
|
|
}
|
|
|
|
|
|
Future<void> _initTextModel() async {
|
|
|
final path = await _getLocalTextModelPath();
|
|
|
- if (File(path).existsSync()) {
|
|
|
- _initState = InitializationState.initializingTextModel;
|
|
|
+ _initState = InitializationState.initializingTextModel;
|
|
|
+ if (await File(path).exists()) {
|
|
|
await loadTextModel(path);
|
|
|
} else {
|
|
|
_initState = InitializationState.downloadingTextModel;
|
|
@@ -100,6 +130,7 @@ abstract class MLFramework {
|
|
|
await tempFile.rename(path);
|
|
|
await loadTextModel(path);
|
|
|
}
|
|
|
+ _initState = InitializationState.initializedTextModel;
|
|
|
}
|
|
|
|
|
|
Future<String> _getLocalImageModelPath() async {
|
|
@@ -119,6 +150,10 @@ abstract class MLFramework {
|
|
|
String savePath, {
|
|
|
int trialCount = 1,
|
|
|
}) async {
|
|
|
+ if (!await _canDownload()) {
|
|
|
+ _initState = InitializationState.waitingForNetwork;
|
|
|
+ throw WiFiUnavailableError();
|
|
|
+ }
|
|
|
_logger.info("Downloading " + url);
|
|
|
final existingFile = File(savePath);
|
|
|
if (await existingFile.exists()) {
|
|
@@ -136,6 +171,15 @@ abstract class MLFramework {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ Future<bool> _canDownload() async {
|
|
|
+ final connectivityResult = await (Connectivity().checkConnectivity());
|
|
|
+ bool canDownloadUnderCurrentNetworkConditions = true;
|
|
|
+ if (connectivityResult == ConnectivityResult.mobile) {
|
|
|
+ canDownloadUnderCurrentNetworkConditions = shouldDownloadOverMobileData;
|
|
|
+ }
|
|
|
+ return canDownloadUnderCurrentNetworkConditions;
|
|
|
+ }
|
|
|
+
|
|
|
Future<String> getAccessiblePathForAsset(
|
|
|
String assetPath,
|
|
|
String tempName,
|
|
@@ -158,8 +202,10 @@ enum InitializationState {
|
|
|
notInitialized,
|
|
|
waitingForNetwork,
|
|
|
downloadingImageModel,
|
|
|
- downloadingTextModel,
|
|
|
initializingImageModel,
|
|
|
+ initializedImageModel,
|
|
|
+ downloadingTextModel,
|
|
|
initializingTextModel,
|
|
|
+ initializedTextModel,
|
|
|
initialized,
|
|
|
}
|