diff --git a/mobile/lib/services/machine_learning/semantic_search/frameworks/ml_framework.dart b/mobile/lib/services/machine_learning/semantic_search/frameworks/ml_framework.dart index 35616b1b1..94f480583 100644 --- a/mobile/lib/services/machine_learning/semantic_search/frameworks/ml_framework.dart +++ b/mobile/lib/services/machine_learning/semantic_search/frameworks/ml_framework.dart @@ -6,7 +6,6 @@ import "package:logging/logging.dart"; import "package:photos/core/errors.dart"; import "package:photos/core/event_bus.dart"; -import "package:photos/core/network/network.dart"; import "package:photos/events/event.dart"; import "package:photos/services/remote_assets_service.dart"; @@ -103,40 +102,35 @@ abstract class MLFramework { return; } _initState = InitializationState.initializingImageModel; - final imageModel = - await RemoteAssetsService.instance.getAsset(getImageModelRemotePath()); + final imageModel = await _getModel(getImageModelRemotePath()); await loadImageModel(imageModel.path); _initState = InitializationState.initializedImageModel; } Future _initTextModel() async { _initState = InitializationState.initializingTextModel; - final textModel = - await RemoteAssetsService.instance.getAsset(getTextModelRemotePath()); + final textModel = await _getModel(getTextModelRemotePath()); await loadTextModel(textModel.path); _initState = InitializationState.initializedTextModel; } - Future _downloadFile( - String url, - String savePath, { + Future _getModel( + String url, { int trialCount = 1, }) async { + if (await RemoteAssetsService.instance.hasAsset(url)) { + return RemoteAssetsService.instance.getAsset(url); + } if (!await _canDownload()) { _initState = InitializationState.waitingForNetwork; throw WiFiUnavailableError(); } - _logger.info("Downloading " + url); - final existingFile = File(savePath); - if (await existingFile.exists()) { - await existingFile.delete(); - } try { - await NetworkClient.instance.getDio().download(url, savePath); + return RemoteAssetsService.instance.getAsset(url); } catch (e, s) { _logger.severe(e, s); if (trialCount < kMaximumRetrials) { - return _downloadFile(url, savePath, trialCount: trialCount + 1); + return _getModel(url, trialCount: trialCount + 1); } else { rethrow; } diff --git a/mobile/lib/services/remote_assets_service.dart b/mobile/lib/services/remote_assets_service.dart index 0e75b983d..251ce6c15 100644 --- a/mobile/lib/services/remote_assets_service.dart +++ b/mobile/lib/services/remote_assets_service.dart @@ -26,6 +26,11 @@ class RemoteAssetsService { } } + Future hasAsset(String remotePath) async { + final path = await _getLocalPath(remotePath); + return File(path).exists(); + } + Future _getLocalPath(String remotePath) async { return (await getApplicationSupportDirectory()).path + "/assets/" +