diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index 7a8bfa459..21d965e2b 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -53,6 +53,8 @@ class FaceMLDataDB { await db.execute(createClusterPersonTable); await db.execute(createClusterSummaryTable); await db.execute(createNotPersonFeedbackTable); + await db.execute(createFaceClustersTable); + await db.execute(fcClusterIDIndex); } // bulkInsertFaces inserts the faces in the database in batches of 1000. @@ -96,12 +98,10 @@ class FaceMLDataDB { for (final entry in batch) { final faceID = entry.key; final personID = entry.value; - - batchUpdate.update( - facesTable, - {faceClusterId: personID}, - where: '$faceIDColumn = ? AND $faceClusterId IS NULL', - whereArgs: [faceID], + batchUpdate.insert( + faceClustersTable, + {fcClusterID: personID, fcFaceId: faceID}, + conflictAlgorithm: ConflictAlgorithm.replace, ); } @@ -243,12 +243,19 @@ class FaceMLDataDB { } } if (clusterID != null) { - final clusterIDs = [clusterID]; - final List> faceMaps = await db.rawQuery( - 'SELECT * FROM $facesTable where $faceClusterId IN (${clusterIDs.join(",")}) AND $fileIDColumn = $recentFileID ', + final List> faceMaps = await db.query( + faceClustersTable, + columns: [fcFaceId], + where: '$fcClusterID = ?', + whereArgs: [clusterID], ); - if (faceMaps.isNotEmpty) { - return mapRowToFace(faceMaps.first); + final List? faces = await getFacesForGivenFileID(recentFileID); + if (faces != null) { + for (final face in faces) { + if (faceMaps.any((element) => element[fcFaceId] == face.faceID)) { + return face; + } + } } } if (personID == null && clusterID == null) { @@ -296,11 +303,11 @@ class FaceMLDataDB { ) async { final db = await instance.database; final List> maps = await db.rawQuery( - 'SELECT $faceIDColumn, $faceClusterId FROM $facesTable where $faceIDColumn IN (${faceIds.map((id) => "'$id'").join(",")})', + 'SELECT $fcFaceId, $fcClusterID FROM $faceClustersTable where $fcFaceId IN (${faceIds.map((id) => "'$id'").join(",")})', ); final Map result = {}; for (final map in maps) { - result[map[faceIDColumn] as String] = map[faceClusterId] as int?; + result[map[fcFaceId] as String] = map[fcClusterID] as int?; } return result; } @@ -309,13 +316,15 @@ class FaceMLDataDB { final Map> result = {}; final db = await instance.database; final List> maps = await db.rawQuery( - 'SELECT $faceClusterId, $fileIDColumn FROM $facesTable where $faceClusterId IS NOT NULL', + 'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable', ); for (final map in maps) { - final personID = map[faceClusterId] as int; - final fileID = map[fileIDColumn] as int; - result[fileID] = (result[fileID] ?? {})..add(personID); + final clusterID = map[fcClusterID] as int; + final faceID = map[fcFaceId] as String; + final x = faceID.split('_').first; + final fileID = int.parse(x); + result[fileID] = (result[fileID] ?? {})..add(clusterID); } return result; } @@ -331,18 +340,17 @@ class FaceMLDataDB { for (final map in faceIDToPersonID.entries) { final faceID = map.key; final clusterID = map.value; - batch.update( - facesTable, - {faceClusterId: clusterID}, - where: '$faceIDColumn = ?', - whereArgs: [faceID], + batch.insert( + faceClustersTable, + {fcFaceId: faceID, fcClusterID: clusterID}, + conflictAlgorithm: ConflictAlgorithm.replace, ); } // Commit the batch await batch.commit(noResult: true); } - /// Returns a map of faceID to record of faceClusterID and faceEmbeddingBlob + /// Returns a map of faceID to record of clusterId and faceEmbeddingBlob /// /// Only selects faces with score greater than [minScore] and blur score greater than [minClarity] Future> getFaceEmbeddingMap({ @@ -372,10 +380,15 @@ class FaceMLDataDB { if (maps.isEmpty) { break; } + final List faceIds = []; + for (final map in maps) { + faceIds.add(map[faceIDColumn] as String); + } + final faceIdToClusterId = await getFaceIdsToClusterIds(faceIds); for (final map in maps) { final faceID = map[faceIDColumn] as String; result[faceID] = - (map[faceClusterId] as int?, map[faceEmbeddingBlob] as Uint8List); + (faceIdToClusterId[faceID], map[faceEmbeddingBlob] as Uint8List); } if (result.length >= maxFaces) { break; @@ -435,10 +448,9 @@ class FaceMLDataDB { Future resetClusterIDs() async { final db = await instance.database; - await db.update( - facesTable, - {faceClusterId: null}, - ); + await db.rawQuery(dropFaceClustersTable); + await db.rawQuery(createFaceClustersTable); + await db.rawQuery(fcClusterIDIndex); } Future insert(Person p, int cluserID) async { @@ -514,16 +526,17 @@ class FaceMLDataDB { final db = instance.database; return db.then((db) async { final List> maps = await db.rawQuery( - 'SELECT $clusterPersonTable.$cluserIDColumn, $fileIDColumn FROM $facesTable ' + 'SELECT $clusterPersonTable.$cluserIDColumn, $fcFaceId FROM $faceClustersTable ' 'INNER JOIN $clusterPersonTable ' - 'ON $facesTable.$faceClusterId = $clusterPersonTable.$cluserIDColumn ' + 'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$cluserIDColumn ' 'WHERE $clusterPersonTable.$personIdColumn = ?', [personID], ); final Map> result = {}; for (final map in maps) { final clusterID = map[cluserIDColumn] as int; - final fileID = map[fileIDColumn] as int; + final String faceID = map[fcFaceId] as String; + final fileID = int.parse(faceID.split('_').first); result[fileID] = (result[fileID] ?? {})..add(clusterID); } return result; @@ -664,21 +677,24 @@ class FaceMLDataDB { Future removeFilesFromPerson(List files, Person p) async { final db = await instance.database; - final result = await db.rawQuery( - 'SELECT $faceIDColumn FROM $facesTable LEFT JOIN $clusterPersonTable ' - 'ON $facesTable.$faceClusterId = $clusterPersonTable.$cluserIDColumn ' - 'WHERE $clusterPersonTable.$personIdColumn = ? AND $facesTable.$fileIDColumn IN (${files.map((e) => e.uploadedFileID).join(",")})', + final faceIdsResult = await db.rawQuery( + 'SELECT $fcFaceId FROM $faceClustersTable LEFT JOIN $clusterPersonTable ' + 'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$cluserIDColumn ' + 'WHERE $clusterPersonTable.$personIdColumn = ?', [p.remoteID], ); - // get max clusterID - final maxRows = - await db.rawQuery('SELECT max($faceClusterId) from $facesTable'); - int maxClusterID = maxRows.first.values.first as int; + final Set fileIds = {}; + for (final enteFile in files) { + fileIds.add(enteFile.uploadedFileID.toString()); + } + int maxClusterID = DateTime.now().millisecondsSinceEpoch; final Map faceIDToClusterID = {}; - for (final faceRow in result) { - final faceID = faceRow[faceIDColumn] as String; - faceIDToClusterID[faceID] = maxClusterID + 1; - maxClusterID = maxClusterID + 1; + for (final row in faceIdsResult) { + final faceID = row[fcFaceId] as String; + if (fileIds.contains(faceID.split('_').first)) { + maxClusterID += 1; + faceIDToClusterID[faceID] = maxClusterID; + } } await forceUpdateClusterIds(faceIDToClusterID); } @@ -688,17 +704,23 @@ class FaceMLDataDB { int clusterID, ) async { final db = await instance.database; - final result = await db.rawQuery( - 'SELECT $faceIDColumn FROM $facesTable ' - 'WHERE $facesTable.$faceClusterId = ? AND $facesTable.$fileIDColumn IN (${files.map((e) => e.uploadedFileID).join(",")})', + final faceIdsResult = await db.rawQuery( + 'SELECT $fcFaceId FROM $faceClustersTable ' + 'WHERE $faceClustersTable.$fcClusterID = ?', [clusterID], ); - final Map faceIDToClusterID = {}; + final Set fileIds = {}; + for (final enteFile in files) { + fileIds.add(enteFile.uploadedFileID.toString()); + } int maxClusterID = DateTime.now().millisecondsSinceEpoch; - for (final faceRow in result) { - maxClusterID += 1; - final faceID = faceRow[faceIDColumn] as String; - faceIDToClusterID[faceID] = maxClusterID; + final Map faceIDToClusterID = {}; + for (final row in faceIdsResult) { + final faceID = row[fcFaceId] as String; + if (fileIds.contains(faceID.split('_').first)) { + maxClusterID += 1; + faceIDToClusterID[faceID] = maxClusterID; + } } await forceUpdateClusterIds(faceIDToClusterID); } diff --git a/mobile/lib/face/db_fields.dart b/mobile/lib/face/db_fields.dart index 29be20b6a..a4e249a97 100644 --- a/mobile/lib/face/db_fields.dart +++ b/mobile/lib/face/db_fields.dart @@ -35,13 +35,15 @@ const fcFaceId = 'face_id'; // fcClusterId & fcFaceId are the primary keys and fcClusterId is a foreign key to faces table const createFaceClustersTable = ''' CREATE TABLE IF NOT EXISTS $faceClustersTable ( - $fcClusterID INTEGER NOT NULL, $fcFaceId TEXT NOT NULL, - PRIMARY KEY($fcClusterID, $fcFaceId), + $fcClusterID INTEGER NOT NULL, + PRIMARY KEY($fcFaceId), FOREIGN KEY($fcFaceId) REFERENCES $facesTable($faceIDColumn) ); '''; - +// -- Creating a non-unique index on clusterID for query optimization +const fcClusterIDIndex = + '''CREATE INDEX IF NOT EXISTS idx_fcClusterID ON faceClustersTable(fcClusterID);'''; const dropFaceClustersTable = 'DROP TABLE IF EXISTS $faceClustersTable'; //##endregion diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart index d8c003afb..746112b82 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart @@ -80,7 +80,7 @@ class FaceMlService { bool isInitialized = false; bool isImageIndexRunning = false; - int kParallelism = 15; + int kParallelism = 100; Future init({bool initializeImageMlIsolate = false}) async { return _initLock.synchronized(() async { @@ -524,6 +524,7 @@ class FaceMlService { try { final EnteWatch? w = kDebugMode ? EnteWatch("face_em_fetch") : null; w?.start(); + w?.log('starting remote fetch for ${fileIds.length} files'); final res = await RemoteFileMLService.instance.getFilessEmbedding(fileIds); w?.logAndReset('fetched ${res.mlData.length} embeddings'); @@ -1222,5 +1223,4 @@ class FaceMlService { return indexedFileIds.containsKey(id) && indexedFileIds[id]! >= faceMlVersion; } - } diff --git a/mobile/lib/ui/viewer/search_tab/search_tab.dart b/mobile/lib/ui/viewer/search_tab/search_tab.dart index f00c981af..d0db69690 100644 --- a/mobile/lib/ui/viewer/search_tab/search_tab.dart +++ b/mobile/lib/ui/viewer/search_tab/search_tab.dart @@ -169,7 +169,8 @@ class _AllSearchSectionsState extends State { curve: Curves.easeOut, ); } else if (snapshot.hasError) { - _logger.severe('Failed to load sections: ', snapshot.error); + _logger.severe('Failed to load sections: ', snapshot.error, + snapshot.stackTrace,); if (kDebugMode) { return Padding( padding: const EdgeInsets.only(bottom: 72),