diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index 502f48dee..1e5c01343 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -181,6 +181,32 @@ class FaceMLDataDB { return maps.map((e) => e[faceEmbeddingBlob] as Uint8List); } + Future>> getFaceEmbeddingsForClusters( + Iterable clusterIDs, { + int? limit, + }) async { + final db = await instance.database; + final Map> result = {}; + + final selectQuery = ''' + SELECT fc.$fcClusterID, fe.$faceEmbeddingBlob + FROM $faceClustersTable fc + INNER JOIN $facesTable fe ON fc.$fcFaceId = fe.$faceIDColumn + WHERE fc.$fcClusterID IN (${clusterIDs.join(',')}) + ${limit != null ? 'LIMIT $limit' : ''} + '''; + + final List> maps = await db.rawQuery(selectQuery); + + for (final map in maps) { + final clusterID = map[fcClusterID] as int; + final faceEmbedding = map[faceEmbeddingBlob] as Uint8List; + result.putIfAbsent(clusterID, () => []).add(faceEmbedding); + } + + return result; + } + Future getCoverFaceForPerson({ required int recentFileID, String? personID, @@ -668,9 +694,11 @@ class FaceMLDataDB { await db.execute(deletePersonTable); await db.execute(dropClusterPersonTable); await db.execute(dropNotPersonFeedbackTable); + await db.execute(dropClusterSummaryTable); await db.execute(createPersonTable); await db.execute(createClusterPersonTable); await db.execute(createNotPersonFeedbackTable); + await db.execute(createClusterSummaryTable); } Future removeFilesFromPerson(List files, Person p) async { diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index 745c73245..e9c19a6fb 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -410,15 +410,44 @@ class ClusterFeedbackService { allClusterIdsToCountMap[b]!.compareTo(allClusterIdsToCountMap[a]!), ); int indexedInCurrentRun = 0; + final EnteWatch? w = kDebugMode ? EnteWatch("computeAvg") : null; + w?.start(); + w?.log( + 'reading embeddings for $maxClusterInCurrentRun or ${sortedClusterIDs.length} clusters', + ); + final int maxEmbeddingToRead = 10000; + int currentPendingRead = 0; + List clusterIdsToRead = []; for (final clusterID in sortedClusterIDs) { if (maxClusterInCurrentRun-- <= 0) { break; } - indexedInCurrentRun++; + if (currentPendingRead == 0) { + currentPendingRead = allClusterIdsToCountMap[clusterID] ?? 0; + clusterIdsToRead.add(clusterID); + } else { + if ((currentPendingRead + allClusterIdsToCountMap[clusterID]!) < + maxEmbeddingToRead) { + clusterIdsToRead.add(clusterID); + currentPendingRead += allClusterIdsToCountMap[clusterID]!; + } else { + break; + } + } + } + + final Map> clusterEmbeddings = await FaceMLDataDB + .instance + .getFaceEmbeddingsForClusters(clusterIdsToRead); + + w?.logAndReset( + 'read $currentPendingRead embeddings for ${clusterEmbeddings.length} clusters', + ); + + for (final clusterID in clusterEmbeddings.keys) { late List avg; - final Iterable embedings = - await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID); + final Iterable embedings = clusterEmbeddings[clusterID]!; final List sum = List.filled(192, 0); for (final embedding in embedings) { final data = EVector.fromBuffer(embedding).values; @@ -431,12 +460,14 @@ class ClusterFeedbackService { updatesForClusterSummary[clusterID] = (avgEmbeedingBuffer, embedings.length); // store the intermediate updates + indexedInCurrentRun++; if (updatesForClusterSummary.length > 100) { await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary); updatesForClusterSummary.clear(); if (kDebugMode) { _logger.info( - 'getUpdateClusterAvg $indexedInCurrentRun clusters in current one'); + 'getUpdateClusterAvg $indexedInCurrentRun clusters in current one', + ); } } clusterAvg[clusterID] = avg; @@ -444,6 +475,7 @@ class ClusterFeedbackService { if (updatesForClusterSummary.isNotEmpty) { await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary); } + w?.logAndReset('done computing avg '); _logger.info('end getUpdateClusterAvg for ${clusterAvg.length} clusters'); return clusterAvg;