diff --git a/web/apps/photos/src/services/machineLearning/arcfaceAlignmentService.ts b/web/apps/photos/src/services/machineLearning/arcfaceAlignmentService.ts index 1d7deac5e..f23a065c8 100644 --- a/web/apps/photos/src/services/machineLearning/arcfaceAlignmentService.ts +++ b/web/apps/photos/src/services/machineLearning/arcfaceAlignmentService.ts @@ -1,3 +1,5 @@ +import { Matrix } from "ml-matrix"; +import { Point } from "services/ml/geom"; import { FaceAlignment, FaceAlignmentMethod, @@ -5,7 +7,7 @@ import { FaceDetection, Versioned, } from "services/ml/types"; -import { getArcfaceAlignment } from "utils/machineLearning/faceAlign"; +import { getSimilarityTransformation } from "similarity-transformation"; class ArcfaceAlignmentService implements FaceAlignmentService { public method: Versioned; @@ -23,3 +25,86 @@ class ArcfaceAlignmentService implements FaceAlignmentService { } export default new ArcfaceAlignmentService(); + +const ARCFACE_LANDMARKS = [ + [38.2946, 51.6963], + [73.5318, 51.5014], + [56.0252, 71.7366], + [56.1396, 92.2848], +] as Array<[number, number]>; + +const ARCFACE_LANDMARKS_FACE_SIZE = 112; + +const ARC_FACE_5_LANDMARKS = [ + [38.2946, 51.6963], + [73.5318, 51.5014], + [56.0252, 71.7366], + [41.5493, 92.3655], + [70.7299, 92.2041], +] as Array<[number, number]>; + +export function getArcfaceAlignment( + faceDetection: FaceDetection, +): FaceAlignment { + const landmarkCount = faceDetection.landmarks.length; + return getFaceAlignmentUsingSimilarityTransform( + faceDetection, + normalizeLandmarks( + landmarkCount === 5 ? ARC_FACE_5_LANDMARKS : ARCFACE_LANDMARKS, + ARCFACE_LANDMARKS_FACE_SIZE, + ), + ); +} + +function getFaceAlignmentUsingSimilarityTransform( + faceDetection: FaceDetection, + alignedLandmarks: Array<[number, number]>, + // alignmentMethod: Versioned +): FaceAlignment { + const landmarksMat = new Matrix( + faceDetection.landmarks + .map((p) => [p.x, p.y]) + .slice(0, alignedLandmarks.length), + ).transpose(); + const alignedLandmarksMat = new Matrix(alignedLandmarks).transpose(); + + const simTransform = getSimilarityTransformation( + landmarksMat, + alignedLandmarksMat, + ); + + const RS = Matrix.mul(simTransform.rotation, simTransform.scale); + const TR = simTransform.translation; + + const affineMatrix = [ + [RS.get(0, 0), RS.get(0, 1), TR.get(0, 0)], + [RS.get(1, 0), RS.get(1, 1), TR.get(1, 0)], + [0, 0, 1], + ]; + + const size = 1 / simTransform.scale; + const meanTranslation = simTransform.toMean.sub(0.5).mul(size); + const centerMat = simTransform.fromMean.sub(meanTranslation); + const center = new Point(centerMat.get(0, 0), centerMat.get(1, 0)); + const rotation = -Math.atan2( + simTransform.rotation.get(0, 1), + simTransform.rotation.get(0, 0), + ); + // log.info({ affineMatrix, meanTranslation, centerMat, center, toMean: simTransform.toMean, fromMean: simTransform.fromMean, size }); + + return { + affineMatrix, + center, + size, + rotation, + }; +} + +function normalizeLandmarks( + landmarks: Array<[number, number]>, + faceSize: number, +): Array<[number, number]> { + return landmarks.map((landmark) => + landmark.map((p) => p / faceSize), + ) as Array<[number, number]>; +} diff --git a/web/apps/photos/src/services/machineLearning/arcfaceCropService.ts b/web/apps/photos/src/services/machineLearning/arcfaceCropService.ts index 77ec03b6b..2075d6acf 100644 --- a/web/apps/photos/src/services/machineLearning/arcfaceCropService.ts +++ b/web/apps/photos/src/services/machineLearning/arcfaceCropService.ts @@ -1,4 +1,4 @@ -import { Box } from "services/ml/geom"; +import { Box, enlargeBox } from "services/ml/geom"; import { FaceAlignment, FaceCrop, @@ -9,8 +9,7 @@ import { Versioned, } from "services/ml/types"; import { cropWithRotation } from "utils/image"; -import { getArcfaceAlignment } from "utils/machineLearning/faceAlign"; -import { enlargeBox } from "services/ml/geom"; +import { getArcfaceAlignment } from "./arcfaceAlignmentService"; class ArcFaceCropService implements FaceCropService { public method: Versioned; diff --git a/web/apps/photos/src/utils/machineLearning/faceAlign.ts b/web/apps/photos/src/utils/machineLearning/faceAlign.ts index 3e6846ff8..e69de29bb 100644 --- a/web/apps/photos/src/utils/machineLearning/faceAlign.ts +++ b/web/apps/photos/src/utils/machineLearning/faceAlign.ts @@ -1,87 +0,0 @@ -import { Matrix } from "ml-matrix"; -import { Point } from "services/ml/geom"; -import { FaceAlignment, FaceDetection } from "services/ml/types"; -import { getSimilarityTransformation } from "similarity-transformation"; - -const ARCFACE_LANDMARKS = [ - [38.2946, 51.6963], - [73.5318, 51.5014], - [56.0252, 71.7366], - [56.1396, 92.2848], -] as Array<[number, number]>; - -const ARCFACE_LANDMARKS_FACE_SIZE = 112; - -const ARC_FACE_5_LANDMARKS = [ - [38.2946, 51.6963], - [73.5318, 51.5014], - [56.0252, 71.7366], - [41.5493, 92.3655], - [70.7299, 92.2041], -] as Array<[number, number]>; - -export function getArcfaceAlignment( - faceDetection: FaceDetection, -): FaceAlignment { - const landmarkCount = faceDetection.landmarks.length; - return getFaceAlignmentUsingSimilarityTransform( - faceDetection, - normalizeLandmarks( - landmarkCount === 5 ? ARC_FACE_5_LANDMARKS : ARCFACE_LANDMARKS, - ARCFACE_LANDMARKS_FACE_SIZE, - ), - ); -} - -function getFaceAlignmentUsingSimilarityTransform( - faceDetection: FaceDetection, - alignedLandmarks: Array<[number, number]>, - // alignmentMethod: Versioned -): FaceAlignment { - const landmarksMat = new Matrix( - faceDetection.landmarks - .map((p) => [p.x, p.y]) - .slice(0, alignedLandmarks.length), - ).transpose(); - const alignedLandmarksMat = new Matrix(alignedLandmarks).transpose(); - - const simTransform = getSimilarityTransformation( - landmarksMat, - alignedLandmarksMat, - ); - - const RS = Matrix.mul(simTransform.rotation, simTransform.scale); - const TR = simTransform.translation; - - const affineMatrix = [ - [RS.get(0, 0), RS.get(0, 1), TR.get(0, 0)], - [RS.get(1, 0), RS.get(1, 1), TR.get(1, 0)], - [0, 0, 1], - ]; - - const size = 1 / simTransform.scale; - const meanTranslation = simTransform.toMean.sub(0.5).mul(size); - const centerMat = simTransform.fromMean.sub(meanTranslation); - const center = new Point(centerMat.get(0, 0), centerMat.get(1, 0)); - const rotation = -Math.atan2( - simTransform.rotation.get(0, 1), - simTransform.rotation.get(0, 0), - ); - // log.info({ affineMatrix, meanTranslation, centerMat, center, toMean: simTransform.toMean, fromMean: simTransform.fromMean, size }); - - return { - affineMatrix, - center, - size, - rotation, - }; -} - -function normalizeLandmarks( - landmarks: Array<[number, number]>, - faceSize: number, -): Array<[number, number]> { - return landmarks.map((landmark) => - landmark.map((p) => p / faceSize), - ) as Array<[number, number]>; -}