From b4019f3c50c1bc8f01004980062d9cc0130fa2a4 Mon Sep 17 00:00:00 2001 From: vishnukvmd Date: Mon, 18 Dec 2023 22:21:01 +0530 Subject: [PATCH] Refactor commented code --- .../frameworks/ml_framework.dart | 31 +++++++------------ .../semantic_search_service.dart | 24 +++++++------- 2 files changed, 24 insertions(+), 31 deletions(-) diff --git a/lib/services/semantic_search/frameworks/ml_framework.dart b/lib/services/semantic_search/frameworks/ml_framework.dart index 3f98685b5..379dde9dd 100644 --- a/lib/services/semantic_search/frameworks/ml_framework.dart +++ b/lib/services/semantic_search/frameworks/ml_framework.dart @@ -7,6 +7,8 @@ import "package:path_provider/path_provider.dart"; import "package:photos/core/network/network.dart"; abstract class MLFramework { + static const kImageEncoderEnabled = false; + final _logger = Logger("MLFramework"); /// Returns the name of the framework @@ -57,26 +59,17 @@ abstract class MLFramework { // --- Future _initImageModel() async { - return; - // TODO: remove hardcoding - if (getFrameworkName() == "ggml") { - final path = await _getLocalImageModelPath(); - if (File(path).existsSync()) { - await loadImageModel(path); - } else { - final tempFile = File(path + ".temp"); - await _downloadFile(getImageModelRemotePath(), tempFile.path); - await tempFile.rename(path); - await loadImageModel(path); - } + if (!kImageEncoderEnabled) { + return; + } + final path = await _getLocalImageModelPath(); + if (File(path).existsSync()) { + await loadImageModel(path); } else { - const assetPath = "assets/models/clip/clip-image-vit-32-float32.onnx"; - await loadImageModel( - await getAccessiblePathForAsset( - assetPath, - "clip-image-vit-32-float32.onnx", - ), - ); + final tempFile = File(path + ".temp"); + await _downloadFile(getImageModelRemotePath(), tempFile.path); + await tempFile.rename(path); + await loadImageModel(path); } } diff --git a/lib/services/semantic_search/semantic_search_service.dart b/lib/services/semantic_search/semantic_search_service.dart index c7d530ad9..c646f4928 100644 --- a/lib/services/semantic_search/semantic_search_service.dart +++ b/lib/services/semantic_search/semantic_search_service.dart @@ -16,6 +16,7 @@ import "package:photos/models/embedding.dart"; import "package:photos/models/file/file.dart"; import "package:photos/objectbox.g.dart"; import "package:photos/services/semantic_search/embedding_store.dart"; +import "package:photos/services/semantic_search/frameworks/ml_framework.dart"; import 'package:photos/services/semantic_search/frameworks/onnx/onnx.dart'; import "package:photos/utils/file_util.dart"; import "package:photos/utils/local_settings.dart"; @@ -32,7 +33,6 @@ class SemanticSearchService { static const kModelName = "clip"; static const kEmbeddingLength = 512; static const kScoreThreshold = 0.23; - static const kImageEncoderEnabled = false; static const kShouldPushEmbeddings = false; final _logger = Logger("SemanticSearchService"); @@ -156,7 +156,7 @@ class SemanticSearchService { Future _backFill() async { if (!LocalSettings.instance.hasEnabledMagicSearch() || - !kImageEncoderEnabled) { + !MLFramework.kImageEncoderEnabled) { return; } await _frameworkInitialization.future; @@ -250,7 +250,7 @@ class SemanticSearchService { } Future computeImageEmbedding(EnteFile file) async { - if (!kImageEncoderEnabled) { + if (!MLFramework.kImageEncoderEnabled) { return; } if (!_frameworkInitialization.isCompleted) { @@ -269,15 +269,15 @@ class SemanticSearchService { // dev.log(computeScore(result, pyEmbedding).toString()); // dev.log(computeScore(pyEmbedding, webEmbedding).toString()); - // final embedding = Embedding( - // fileID: file.uploadedFileID!, - // model: _mlFramework.getFrameworkName() + "-" + kModelName, - // embedding: result, - // ); - // await EmbeddingStore.instance.storeEmbedding( - // file, - // embedding, - // ); + final embedding = Embedding( + fileID: file.uploadedFileID!, + model: _mlFramework.getFrameworkName() + "-" + kModelName, + embedding: result, + ); + await EmbeddingStore.instance.storeEmbedding( + file, + embedding, + ); } catch (e, s) { _logger.severe(e, s); }