diff --git a/mobile/lib/face/model/file_ml.dart b/mobile/lib/face/model/file_ml.dart new file mode 100644 index 000000000..41f008c54 --- /dev/null +++ b/mobile/lib/face/model/file_ml.dart @@ -0,0 +1,84 @@ +import "package:photos/face/model/face.dart"; + +class FaceEmbeddings { + final List faces; + final int version; + // Platform: appVersion + final String? client; + final bool? error; + + FaceEmbeddings( + this.faces, + this.version, { + this.client, + this.error, + }); + + // toJson + Map toJson() => { + 'faces': faces.map((x) => x.toJson()).toList(), + 'version': version, + 'client': client, + 'error': error, + }; + // fromJson + factory FaceEmbeddings.fromJson(Map json) { + return FaceEmbeddings( + List.from( + json['faces'].map((x) => Face.fromJson(x as Map)), + ), + json['version'] as int, + client: json['client'] as String?, + error: json['error'] as bool?, + ); + } +} + +class ClipEmbedding { + final int? version; + final String framwork; + final List embedding; + ClipEmbedding(this.embedding, this.framwork, {this.version}); + // toJson + Map toJson() => { + 'version': version, + 'framwork': framwork, + 'embedding': embedding, + }; + // fromJson + factory ClipEmbedding.fromJson(Map json) { + return ClipEmbedding( + List.from(json['embedding'] as List), + json['framwork'] as String, + version: json['version'] as int?, + ); + } +} + +class FileMl { + final int fileID; + final FaceEmbeddings face; + final ClipEmbedding? clip; + final String? last4Hash; + + FileMl(this.fileID, this.face, {this.clip, this.last4Hash}); + + // toJson + Map toJson() => { + 'fileID': fileID, + 'face': face.toJson(), + 'clip': clip?.toJson(), + 'last4Hash': last4Hash, + }; + // fromJson + factory FileMl.fromJson(Map json) { + return FileMl( + json['fileID'] as int, + FaceEmbeddings.fromJson(json['face'] as Map), + clip: json['clip'] == null + ? null + : ClipEmbedding.fromJson(json['clip'] as Map), + last4Hash: json['last4Hash'] as String?, + ); + } +} diff --git a/mobile/lib/main.dart b/mobile/lib/main.dart index a3b38d8eb..f7f8ab223 100644 --- a/mobile/lib/main.dart +++ b/mobile/lib/main.dart @@ -33,6 +33,8 @@ import 'package:photos/services/local_file_update_service.dart'; import 'package:photos/services/local_sync_service.dart'; import "package:photos/services/location_service.dart"; import "package:photos/services/machine_learning/machine_learning_controller.dart"; +import "package:photos/services/machine_learning/remote_embedding_service.dart"; +import "package:photos/services/machine_learning/semantic_search/remote_embedding.dart"; import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart'; import 'package:photos/services/memories_service.dart'; import 'package:photos/services/push_service.dart'; @@ -211,6 +213,7 @@ Future _init(bool isBackground, {String via = ''}) async { LocalFileUpdateService.instance.init(preferences); SearchService.instance.init(); StorageBonusService.instance.init(preferences); + RemoteEmbeddingService.instance.init(preferences); if (!isBackground && Platform.isAndroid && await HomeWidgetService.instance.countHomeWidgets() == 0) { diff --git a/mobile/lib/services/face_ml/face_ml_service.dart b/mobile/lib/services/face_ml/face_ml_service.dart index 1e48bfc1a..875342b7e 100644 --- a/mobile/lib/services/face_ml/face_ml_service.dart +++ b/mobile/lib/services/face_ml/face_ml_service.dart @@ -20,6 +20,7 @@ import "package:photos/face/db.dart"; import "package:photos/face/model/box.dart"; import "package:photos/face/model/detection.dart" as face_detection; import "package:photos/face/model/face.dart"; +import "package:photos/face/model/file_ml.dart"; import "package:photos/face/model/landmark.dart"; import "package:photos/models/file/extensions/file_props.dart"; import "package:photos/models/file/file.dart"; @@ -33,6 +34,7 @@ import "package:photos/services/face_ml/face_embedding/face_embedding_exceptions import 'package:photos/services/face_ml/face_embedding/onnx_face_embedding.dart'; import "package:photos/services/face_ml/face_ml_exceptions.dart"; import "package:photos/services/face_ml/face_ml_result.dart"; +import "package:photos/services/machine_learning/remote_embedding_service.dart"; import "package:photos/services/search_service.dart"; import "package:photos/utils/file_util.dart"; import 'package:photos/utils/image_ml_isolate.dart'; @@ -543,6 +545,13 @@ class FaceMlService { } } _logger.info("inserting ${faces.length} faces for ${result.fileId}"); + await RemoteEmbeddingService.instance.putFaceEmbedding( + enteFile, + FileMl( + enteFile.uploadedFileID!, + FaceEmbeddings(faces, result.mlVersion), + ), + ); await FaceMLDataDB.instance.bulkInsertFaces(faces); } catch (e, s) { _logger.severe( diff --git a/mobile/lib/services/machine_learning/remote_embedding_service.dart b/mobile/lib/services/machine_learning/remote_embedding_service.dart new file mode 100644 index 000000000..7a4cb60c4 --- /dev/null +++ b/mobile/lib/services/machine_learning/remote_embedding_service.dart @@ -0,0 +1,62 @@ +import "dart:async"; +import "dart:convert"; +import "dart:typed_data"; + +import "package:computer/computer.dart"; +import "package:logging/logging.dart"; +import "package:photos/core/network/network.dart"; +import "package:photos/face/model/face.dart"; +import "package:photos/face/model/file_ml.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/utils/crypto_util.dart"; +import "package:photos/utils/file_download_util.dart"; +import "package:shared_preferences/shared_preferences.dart"; + +class RemoteEmbeddingService { + RemoteEmbeddingService._privateConstructor(); + + static final RemoteEmbeddingService instance = + RemoteEmbeddingService._privateConstructor(); + + static const kEmbeddingsSyncTimeKey = "sync_time_embeddings_v2"; + + final _logger = Logger("RemoteEmbeddingService"); + final _dio = NetworkClient.instance.enteDio; + final _computer = Computer.shared(); + + late SharedPreferences _preferences; + + Completer? _syncStatus; + + void init(SharedPreferences prefs) { + _preferences = prefs; + } + + Future putFaceEmbedding(EnteFile file, FileMl fileML) async { + _logger.info("Pushing embedding for $file"); + final encryptionKey = getFileKey(file); + final embeddingJSON = jsonEncode(fileML.toJson()); + final encryptedEmbedding = await CryptoUtil.encryptChaCha( + utf8.encode(embeddingJSON) as Uint8List, + encryptionKey, + ); + final encryptedData = + CryptoUtil.bin2base64(encryptedEmbedding.encryptedData!); + final header = CryptoUtil.bin2base64(encryptedEmbedding.header!); + try { + final response = await _dio.put( + "/embeddings", + data: { + "fileID": file.uploadedFileID!, + "model": 'onnx-yolo5-mobile', + "encryptedEmbedding": encryptedData, + "decryptionHeader": header, + }, + ); + // final updationTime = response.data["updatedAt"]; + } catch (e, s) { + _logger.severe("Failed to put embedding", e, s); + rethrow; + } + } +}