diff --git a/web/apps/photos/src/services/machineLearning/arcfaceAlignmentService.ts b/web/apps/photos/src/services/face/align.ts similarity index 73% rename from web/apps/photos/src/services/machineLearning/arcfaceAlignmentService.ts rename to web/apps/photos/src/services/face/align.ts index 749da9591..7a3bf7a04 100644 --- a/web/apps/photos/src/services/machineLearning/arcfaceAlignmentService.ts +++ b/web/apps/photos/src/services/face/align.ts @@ -1,31 +1,8 @@ import { Matrix } from "ml-matrix"; import { Point } from "services/face/geom"; -import { - FaceAlignment, - FaceAlignmentMethod, - FaceAlignmentService, - FaceDetection, - Versioned, -} from "services/face/types"; +import { FaceAlignment, FaceDetection } from "services/face/types"; import { getSimilarityTransformation } from "similarity-transformation"; -class ArcfaceAlignmentService implements FaceAlignmentService { - public method: Versioned; - - constructor() { - this.method = { - value: "ArcFace", - version: 1, - }; - } - - public getFaceAlignment(faceDetection: FaceDetection): FaceAlignment { - return getArcfaceAlignment(faceDetection); - } -} - -export default new ArcfaceAlignmentService(); - const ARCFACE_LANDMARKS = [ [38.2946, 51.6963], [73.5318, 51.5014], @@ -43,9 +20,12 @@ const ARC_FACE_5_LANDMARKS = [ [70.7299, 92.2041], ] as Array<[number, number]>; -export function getArcfaceAlignment( - faceDetection: FaceDetection, -): FaceAlignment { +/** + * Compute and return an {@link FaceAlignment} for the given face detection. + * + * @param faceDetection A geometry indicating a face detected in an image. + */ +export const faceAlignment = (faceDetection: FaceDetection): FaceAlignment => { const landmarkCount = faceDetection.landmarks.length; return getFaceAlignmentUsingSimilarityTransform( faceDetection, @@ -54,12 +34,11 @@ export function getArcfaceAlignment( ARCFACE_LANDMARKS_FACE_SIZE, ), ); -} +}; function getFaceAlignmentUsingSimilarityTransform( faceDetection: FaceDetection, alignedLandmarks: Array<[number, number]>, - // alignmentMethod: Versioned ): FaceAlignment { const landmarksMat = new Matrix( faceDetection.landmarks @@ -90,7 +69,6 @@ function getFaceAlignmentUsingSimilarityTransform( simTransform.rotation.get(0, 1), simTransform.rotation.get(0, 0), ); - // log.info({ affineMatrix, meanTranslation, centerMat, center, toMean: simTransform.toMean, fromMean: simTransform.fromMean, size }); return { affineMatrix, diff --git a/web/apps/photos/src/services/face/types.ts b/web/apps/photos/src/services/face/types.ts index 1a5e75a1b..b652657dd 100644 --- a/web/apps/photos/src/services/face/types.ts +++ b/web/apps/photos/src/services/face/types.ts @@ -210,7 +210,6 @@ export interface MLSyncContext { faceDetectionService: FaceDetectionService; faceCropService: FaceCropService; - faceAlignmentService: FaceAlignmentService; faceEmbeddingService: FaceEmbeddingService; blurDetectionService: BlurDetectionService; faceClusteringService: ClusteringService; @@ -272,11 +271,6 @@ export interface FaceCropService { ): Promise; } -export interface FaceAlignmentService { - method: Versioned; - getFaceAlignment(faceDetection: FaceDetection): FaceAlignment; -} - export interface FaceEmbeddingService { method: Versioned; faceSize: number; diff --git a/web/apps/photos/src/services/machineLearning/arcfaceCropService.ts b/web/apps/photos/src/services/machineLearning/arcfaceCropService.ts index 9aff4b606..81f2d4de5 100644 --- a/web/apps/photos/src/services/machineLearning/arcfaceCropService.ts +++ b/web/apps/photos/src/services/machineLearning/arcfaceCropService.ts @@ -8,7 +8,7 @@ import { Versioned, } from "services/face/types"; import { cropWithRotation } from "utils/image"; -import { getArcfaceAlignment } from "./arcfaceAlignmentService"; +import { faceAlignment } from "../face/align"; class ArcFaceCropService implements FaceCropService { public method: Versioned; @@ -24,8 +24,8 @@ class ArcFaceCropService implements FaceCropService { imageBitmap: ImageBitmap, faceDetection: FaceDetection, ): Promise { - const alignedFace = getArcfaceAlignment(faceDetection); - const faceCrop = getFaceCrop(imageBitmap, alignedFace); + const alignment = faceAlignment(faceDetection); + const faceCrop = getFaceCrop(imageBitmap, alignment); return faceCrop; } diff --git a/web/apps/photos/src/services/machineLearning/faceService.ts b/web/apps/photos/src/services/machineLearning/faceService.ts index 99c6bd99e..7183db1f1 100644 --- a/web/apps/photos/src/services/machineLearning/faceService.ts +++ b/web/apps/photos/src/services/machineLearning/faceService.ts @@ -1,5 +1,6 @@ import { openCache } from "@/next/blob-cache"; import log from "@/next/log"; +import { faceAlignment } from "services/face/align"; import mlIDbStorage from "services/face/db"; import { DetectedFace, @@ -103,12 +104,14 @@ class FaceService { fileContext: MLSyncFileContext, ): Promise { const { oldMlFile, newMlFile } = fileContext; + // TODO-ML(MR): + const method = { + value: "ArcFace", + version: 1, + }; if ( !fileContext.newDetection && - !isDifferentOrOld( - oldMlFile?.faceAlignmentMethod, - syncContext.faceAlignmentService.method, - ) && + !isDifferentOrOld(oldMlFile?.faceAlignmentMethod, method) && areFaceIdsSame(newMlFile.faces, oldMlFile?.faces) ) { for (const [index, face] of newMlFile.faces.entries()) { @@ -118,7 +121,10 @@ class FaceService { return; } - newMlFile.faceAlignmentMethod = syncContext.faceAlignmentService.method; + newMlFile.faceAlignmentMethod = { + value: "ArcFace", + version: 1, + }; fileContext.newAlignment = true; const imageBitmap = fileContext.imageBitmap || @@ -126,9 +132,7 @@ class FaceService { // Execute the face alignment calculations for (const face of newMlFile.faces) { - face.alignment = syncContext.faceAlignmentService.getFaceAlignment( - face.detection, - ); + face.alignment = faceAlignment(face.detection); } // Extract face images and convert to Float32Array const faceAlignments = newMlFile.faces.map((f) => f.alignment); diff --git a/web/apps/photos/src/services/machineLearning/machineLearningService.ts b/web/apps/photos/src/services/machineLearning/machineLearningService.ts index efc470c4d..eb3d50558 100644 --- a/web/apps/photos/src/services/machineLearning/machineLearningService.ts +++ b/web/apps/photos/src/services/machineLearning/machineLearningService.ts @@ -17,8 +17,6 @@ import { ClusteringMethod, ClusteringService, Face, - FaceAlignmentMethod, - FaceAlignmentService, FaceCropMethod, FaceCropService, FaceDetection, @@ -38,7 +36,6 @@ import { import { getLocalFiles } from "services/fileService"; import { EnteFile } from "types/file"; import { isInternalUserForML } from "utils/user"; -import arcfaceAlignmentService from "./arcfaceAlignmentService"; import arcfaceCropService from "./arcfaceCropService"; import FaceService from "./faceService"; import hdbscanClusteringService from "./hdbscanClusteringService"; @@ -139,16 +136,6 @@ export class MLFactory { throw Error("Unknon face crop method: " + method); } - public static getFaceAlignmentService( - method: FaceAlignmentMethod, - ): FaceAlignmentService { - if (method === "ArcFace") { - return arcfaceAlignmentService; - } - - throw Error("Unknon face alignment method: " + method); - } - public static getBlurDetectionService( method: BlurDetectionMethod, ): BlurDetectionService { @@ -187,7 +174,6 @@ export class LocalMLSyncContext implements MLSyncContext { public faceDetectionService: FaceDetectionService; public faceCropService: FaceCropService; - public faceAlignmentService: FaceAlignmentService; public blurDetectionService: BlurDetectionService; public faceEmbeddingService: FaceEmbeddingService; public faceClusteringService: ClusteringService; @@ -225,8 +211,6 @@ export class LocalMLSyncContext implements MLSyncContext { this.faceDetectionService = MLFactory.getFaceDetectionService("YoloFace"); this.faceCropService = MLFactory.getFaceCropService("ArcFace"); - this.faceAlignmentService = - MLFactory.getFaceAlignmentService("ArcFace"); this.blurDetectionService = MLFactory.getBlurDetectionService("Laplacian"); this.faceEmbeddingService =