|
@@ -1,16 +1,47 @@
|
|
|
|
+import "dart:async";
|
|
import "dart:io";
|
|
import "dart:io";
|
|
|
|
|
|
|
|
+import "package:connectivity_plus/connectivity_plus.dart";
|
|
import "package:flutter/services.dart";
|
|
import "package:flutter/services.dart";
|
|
import "package:logging/logging.dart";
|
|
import "package:logging/logging.dart";
|
|
import "package:path/path.dart";
|
|
import "package:path/path.dart";
|
|
import "package:path_provider/path_provider.dart";
|
|
import "package:path_provider/path_provider.dart";
|
|
|
|
+import "package:photos/core/errors.dart";
|
|
|
|
+
|
|
|
|
+import "package:photos/core/event_bus.dart";
|
|
import "package:photos/core/network/network.dart";
|
|
import "package:photos/core/network/network.dart";
|
|
|
|
+import "package:photos/events/event.dart";
|
|
|
|
|
|
abstract class MLFramework {
|
|
abstract class MLFramework {
|
|
static const kImageEncoderEnabled = true;
|
|
static const kImageEncoderEnabled = true;
|
|
static const kMaximumRetrials = 3;
|
|
static const kMaximumRetrials = 3;
|
|
|
|
|
|
- final _logger = Logger("MLFramework");
|
|
|
|
|
|
+ static final _logger = Logger("MLFramework");
|
|
|
|
+
|
|
|
|
+ final bool shouldDownloadOverMobileData;
|
|
|
|
+
|
|
|
|
+ InitializationState _state = InitializationState.notInitialized;
|
|
|
|
+ final _initializationCompleter = Completer<void>();
|
|
|
|
+
|
|
|
|
+ MLFramework(this.shouldDownloadOverMobileData) {
|
|
|
|
+ Connectivity()
|
|
|
|
+ .onConnectivityChanged
|
|
|
|
+ .listen((ConnectivityResult result) async {
|
|
|
|
+ _logger.info("Connectivity changed to $result");
|
|
|
|
+ if (_state == InitializationState.waitingForNetwork &&
|
|
|
|
+ await _canDownload()) {
|
|
|
|
+ unawaited(init());
|
|
|
|
+ }
|
|
|
|
+ });
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ InitializationState get initializationState => _state;
|
|
|
|
+
|
|
|
|
+ set _initState(InitializationState state) {
|
|
|
|
+ Bus.instance.fire(MLFrameworkInitializationUpdateEvent(state));
|
|
|
|
+ _logger.info("Init state is $state");
|
|
|
|
+ _state = state;
|
|
|
|
+ }
|
|
|
|
|
|
/// Returns the path of the Image Model hosted remotely
|
|
/// Returns the path of the Image Model hosted remotely
|
|
String getImageModelRemotePath();
|
|
String getImageModelRemotePath();
|
|
@@ -35,8 +66,18 @@ abstract class MLFramework {
|
|
/// initialization. For eg. if you wish to load the model from `/assets`
|
|
/// initialization. For eg. if you wish to load the model from `/assets`
|
|
/// instead of a CDN.
|
|
/// instead of a CDN.
|
|
Future<void> init() async {
|
|
Future<void> init() async {
|
|
- await _initImageModel();
|
|
|
|
- await _initTextModel();
|
|
|
|
|
|
+ try {
|
|
|
|
+ await Future.wait([_initImageModel(), _initTextModel()]);
|
|
|
|
+ } catch (e, s) {
|
|
|
|
+ _logger.warning(e, s);
|
|
|
|
+ if (e is WiFiUnavailableError) {
|
|
|
|
+ return _initializationCompleter.future;
|
|
|
|
+ } else {
|
|
|
|
+ rethrow;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ _initState = InitializationState.initialized;
|
|
|
|
+ _initializationCompleter.complete();
|
|
}
|
|
}
|
|
|
|
|
|
// Releases any resources held by the framework
|
|
// Releases any resources held by the framework
|
|
@@ -63,27 +104,33 @@ abstract class MLFramework {
|
|
if (!kImageEncoderEnabled) {
|
|
if (!kImageEncoderEnabled) {
|
|
return;
|
|
return;
|
|
}
|
|
}
|
|
|
|
+ _initState = InitializationState.initializingImageModel;
|
|
final path = await _getLocalImageModelPath();
|
|
final path = await _getLocalImageModelPath();
|
|
- if (File(path).existsSync()) {
|
|
|
|
|
|
+ if (await File(path).exists()) {
|
|
await loadImageModel(path);
|
|
await loadImageModel(path);
|
|
} else {
|
|
} else {
|
|
|
|
+ _initState = InitializationState.downloadingImageModel;
|
|
final tempFile = File(path + ".temp");
|
|
final tempFile = File(path + ".temp");
|
|
await _downloadFile(getImageModelRemotePath(), tempFile.path);
|
|
await _downloadFile(getImageModelRemotePath(), tempFile.path);
|
|
await tempFile.rename(path);
|
|
await tempFile.rename(path);
|
|
await loadImageModel(path);
|
|
await loadImageModel(path);
|
|
}
|
|
}
|
|
|
|
+ _initState = InitializationState.initializedImageModel;
|
|
}
|
|
}
|
|
|
|
|
|
Future<void> _initTextModel() async {
|
|
Future<void> _initTextModel() async {
|
|
final path = await _getLocalTextModelPath();
|
|
final path = await _getLocalTextModelPath();
|
|
- if (File(path).existsSync()) {
|
|
|
|
|
|
+ _initState = InitializationState.initializingTextModel;
|
|
|
|
+ if (await File(path).exists()) {
|
|
await loadTextModel(path);
|
|
await loadTextModel(path);
|
|
} else {
|
|
} else {
|
|
|
|
+ _initState = InitializationState.downloadingTextModel;
|
|
final tempFile = File(path + ".temp");
|
|
final tempFile = File(path + ".temp");
|
|
await _downloadFile(getTextModelRemotePath(), tempFile.path);
|
|
await _downloadFile(getTextModelRemotePath(), tempFile.path);
|
|
await tempFile.rename(path);
|
|
await tempFile.rename(path);
|
|
await loadTextModel(path);
|
|
await loadTextModel(path);
|
|
}
|
|
}
|
|
|
|
+ _initState = InitializationState.initializedTextModel;
|
|
}
|
|
}
|
|
|
|
|
|
Future<String> _getLocalImageModelPath() async {
|
|
Future<String> _getLocalImageModelPath() async {
|
|
@@ -103,6 +150,10 @@ abstract class MLFramework {
|
|
String savePath, {
|
|
String savePath, {
|
|
int trialCount = 1,
|
|
int trialCount = 1,
|
|
}) async {
|
|
}) async {
|
|
|
|
+ if (!await _canDownload()) {
|
|
|
|
+ _initState = InitializationState.waitingForNetwork;
|
|
|
|
+ throw WiFiUnavailableError();
|
|
|
|
+ }
|
|
_logger.info("Downloading " + url);
|
|
_logger.info("Downloading " + url);
|
|
final existingFile = File(savePath);
|
|
final existingFile = File(savePath);
|
|
if (await existingFile.exists()) {
|
|
if (await existingFile.exists()) {
|
|
@@ -120,6 +171,12 @@ abstract class MLFramework {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ Future<bool> _canDownload() async {
|
|
|
|
+ final connectivityResult = await (Connectivity().checkConnectivity());
|
|
|
|
+ return connectivityResult != ConnectivityResult.mobile ||
|
|
|
|
+ shouldDownloadOverMobileData;
|
|
|
|
+ }
|
|
|
|
+
|
|
Future<String> getAccessiblePathForAsset(
|
|
Future<String> getAccessiblePathForAsset(
|
|
String assetPath,
|
|
String assetPath,
|
|
String tempName,
|
|
String tempName,
|
|
@@ -131,3 +188,21 @@ abstract class MLFramework {
|
|
return file.path;
|
|
return file.path;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+class MLFrameworkInitializationUpdateEvent extends Event {
|
|
|
|
+ final InitializationState state;
|
|
|
|
+
|
|
|
|
+ MLFrameworkInitializationUpdateEvent(this.state);
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+enum InitializationState {
|
|
|
|
+ notInitialized,
|
|
|
|
+ waitingForNetwork,
|
|
|
|
+ downloadingImageModel,
|
|
|
|
+ initializingImageModel,
|
|
|
|
+ initializedImageModel,
|
|
|
|
+ downloadingTextModel,
|
|
|
|
+ initializingTextModel,
|
|
|
|
+ initializedTextModel,
|
|
|
|
+ initialized,
|
|
|
|
+}
|