浏览代码

Clean alignment

Manav Rathi 1 年之前
父节点
当前提交
fbebbd3583

+ 8 - 30
web/apps/photos/src/services/machineLearning/arcfaceAlignmentService.ts → web/apps/photos/src/services/face/align.ts

@@ -1,31 +1,8 @@
 import { Matrix } from "ml-matrix";
 import { Matrix } from "ml-matrix";
 import { Point } from "services/face/geom";
 import { Point } from "services/face/geom";
-import {
-    FaceAlignment,
-    FaceAlignmentMethod,
-    FaceAlignmentService,
-    FaceDetection,
-    Versioned,
-} from "services/face/types";
+import { FaceAlignment, FaceDetection } from "services/face/types";
 import { getSimilarityTransformation } from "similarity-transformation";
 import { getSimilarityTransformation } from "similarity-transformation";
 
 
-class ArcfaceAlignmentService implements FaceAlignmentService {
-    public method: Versioned<FaceAlignmentMethod>;
-
-    constructor() {
-        this.method = {
-            value: "ArcFace",
-            version: 1,
-        };
-    }
-
-    public getFaceAlignment(faceDetection: FaceDetection): FaceAlignment {
-        return getArcfaceAlignment(faceDetection);
-    }
-}
-
-export default new ArcfaceAlignmentService();
-
 const ARCFACE_LANDMARKS = [
 const ARCFACE_LANDMARKS = [
     [38.2946, 51.6963],
     [38.2946, 51.6963],
     [73.5318, 51.5014],
     [73.5318, 51.5014],
@@ -43,9 +20,12 @@ const ARC_FACE_5_LANDMARKS = [
     [70.7299, 92.2041],
     [70.7299, 92.2041],
 ] as Array<[number, number]>;
 ] as Array<[number, number]>;
 
 
-export function getArcfaceAlignment(
-    faceDetection: FaceDetection,
-): FaceAlignment {
+/**
+ * Compute and return an {@link FaceAlignment} for the given face detection.
+ *
+ * @param faceDetection A geometry indicating a face detected in an image.
+ */
+export const faceAlignment = (faceDetection: FaceDetection): FaceAlignment => {
     const landmarkCount = faceDetection.landmarks.length;
     const landmarkCount = faceDetection.landmarks.length;
     return getFaceAlignmentUsingSimilarityTransform(
     return getFaceAlignmentUsingSimilarityTransform(
         faceDetection,
         faceDetection,
@@ -54,12 +34,11 @@ export function getArcfaceAlignment(
             ARCFACE_LANDMARKS_FACE_SIZE,
             ARCFACE_LANDMARKS_FACE_SIZE,
         ),
         ),
     );
     );
-}
+};
 
 
 function getFaceAlignmentUsingSimilarityTransform(
 function getFaceAlignmentUsingSimilarityTransform(
     faceDetection: FaceDetection,
     faceDetection: FaceDetection,
     alignedLandmarks: Array<[number, number]>,
     alignedLandmarks: Array<[number, number]>,
-    // alignmentMethod: Versioned<FaceAlignmentMethod>
 ): FaceAlignment {
 ): FaceAlignment {
     const landmarksMat = new Matrix(
     const landmarksMat = new Matrix(
         faceDetection.landmarks
         faceDetection.landmarks
@@ -90,7 +69,6 @@ function getFaceAlignmentUsingSimilarityTransform(
         simTransform.rotation.get(0, 1),
         simTransform.rotation.get(0, 1),
         simTransform.rotation.get(0, 0),
         simTransform.rotation.get(0, 0),
     );
     );
-    // log.info({ affineMatrix, meanTranslation, centerMat, center, toMean: simTransform.toMean, fromMean: simTransform.fromMean, size });
 
 
     return {
     return {
         affineMatrix,
         affineMatrix,

+ 0 - 6
web/apps/photos/src/services/face/types.ts

@@ -210,7 +210,6 @@ export interface MLSyncContext {
 
 
     faceDetectionService: FaceDetectionService;
     faceDetectionService: FaceDetectionService;
     faceCropService: FaceCropService;
     faceCropService: FaceCropService;
-    faceAlignmentService: FaceAlignmentService;
     faceEmbeddingService: FaceEmbeddingService;
     faceEmbeddingService: FaceEmbeddingService;
     blurDetectionService: BlurDetectionService;
     blurDetectionService: BlurDetectionService;
     faceClusteringService: ClusteringService;
     faceClusteringService: ClusteringService;
@@ -272,11 +271,6 @@ export interface FaceCropService {
     ): Promise<FaceCrop>;
     ): Promise<FaceCrop>;
 }
 }
 
 
-export interface FaceAlignmentService {
-    method: Versioned<FaceAlignmentMethod>;
-    getFaceAlignment(faceDetection: FaceDetection): FaceAlignment;
-}
-
 export interface FaceEmbeddingService {
 export interface FaceEmbeddingService {
     method: Versioned<FaceEmbeddingMethod>;
     method: Versioned<FaceEmbeddingMethod>;
     faceSize: number;
     faceSize: number;

+ 3 - 3
web/apps/photos/src/services/machineLearning/arcfaceCropService.ts

@@ -8,7 +8,7 @@ import {
     Versioned,
     Versioned,
 } from "services/face/types";
 } from "services/face/types";
 import { cropWithRotation } from "utils/image";
 import { cropWithRotation } from "utils/image";
-import { getArcfaceAlignment } from "./arcfaceAlignmentService";
+import { faceAlignment } from "../face/align";
 
 
 class ArcFaceCropService implements FaceCropService {
 class ArcFaceCropService implements FaceCropService {
     public method: Versioned<FaceCropMethod>;
     public method: Versioned<FaceCropMethod>;
@@ -24,8 +24,8 @@ class ArcFaceCropService implements FaceCropService {
         imageBitmap: ImageBitmap,
         imageBitmap: ImageBitmap,
         faceDetection: FaceDetection,
         faceDetection: FaceDetection,
     ): Promise<FaceCrop> {
     ): Promise<FaceCrop> {
-        const alignedFace = getArcfaceAlignment(faceDetection);
-        const faceCrop = getFaceCrop(imageBitmap, alignedFace);
+        const alignment = faceAlignment(faceDetection);
+        const faceCrop = getFaceCrop(imageBitmap, alignment);
 
 
         return faceCrop;
         return faceCrop;
     }
     }

+ 12 - 8
web/apps/photos/src/services/machineLearning/faceService.ts

@@ -1,5 +1,6 @@
 import { openCache } from "@/next/blob-cache";
 import { openCache } from "@/next/blob-cache";
 import log from "@/next/log";
 import log from "@/next/log";
+import { faceAlignment } from "services/face/align";
 import mlIDbStorage from "services/face/db";
 import mlIDbStorage from "services/face/db";
 import {
 import {
     DetectedFace,
     DetectedFace,
@@ -103,12 +104,14 @@ class FaceService {
         fileContext: MLSyncFileContext,
         fileContext: MLSyncFileContext,
     ): Promise<Float32Array> {
     ): Promise<Float32Array> {
         const { oldMlFile, newMlFile } = fileContext;
         const { oldMlFile, newMlFile } = fileContext;
+        // TODO-ML(MR):
+        const method = {
+            value: "ArcFace",
+            version: 1,
+        };
         if (
         if (
             !fileContext.newDetection &&
             !fileContext.newDetection &&
-            !isDifferentOrOld(
-                oldMlFile?.faceAlignmentMethod,
-                syncContext.faceAlignmentService.method,
-            ) &&
+            !isDifferentOrOld(oldMlFile?.faceAlignmentMethod, method) &&
             areFaceIdsSame(newMlFile.faces, oldMlFile?.faces)
             areFaceIdsSame(newMlFile.faces, oldMlFile?.faces)
         ) {
         ) {
             for (const [index, face] of newMlFile.faces.entries()) {
             for (const [index, face] of newMlFile.faces.entries()) {
@@ -118,7 +121,10 @@ class FaceService {
             return;
             return;
         }
         }
 
 
-        newMlFile.faceAlignmentMethod = syncContext.faceAlignmentService.method;
+        newMlFile.faceAlignmentMethod = {
+            value: "ArcFace",
+            version: 1,
+        };
         fileContext.newAlignment = true;
         fileContext.newAlignment = true;
         const imageBitmap =
         const imageBitmap =
             fileContext.imageBitmap ||
             fileContext.imageBitmap ||
@@ -126,9 +132,7 @@ class FaceService {
 
 
         // Execute the face alignment calculations
         // Execute the face alignment calculations
         for (const face of newMlFile.faces) {
         for (const face of newMlFile.faces) {
-            face.alignment = syncContext.faceAlignmentService.getFaceAlignment(
-                face.detection,
-            );
+            face.alignment = faceAlignment(face.detection);
         }
         }
         // Extract face images and convert to Float32Array
         // Extract face images and convert to Float32Array
         const faceAlignments = newMlFile.faces.map((f) => f.alignment);
         const faceAlignments = newMlFile.faces.map((f) => f.alignment);

+ 0 - 16
web/apps/photos/src/services/machineLearning/machineLearningService.ts

@@ -17,8 +17,6 @@ import {
     ClusteringMethod,
     ClusteringMethod,
     ClusteringService,
     ClusteringService,
     Face,
     Face,
-    FaceAlignmentMethod,
-    FaceAlignmentService,
     FaceCropMethod,
     FaceCropMethod,
     FaceCropService,
     FaceCropService,
     FaceDetection,
     FaceDetection,
@@ -38,7 +36,6 @@ import {
 import { getLocalFiles } from "services/fileService";
 import { getLocalFiles } from "services/fileService";
 import { EnteFile } from "types/file";
 import { EnteFile } from "types/file";
 import { isInternalUserForML } from "utils/user";
 import { isInternalUserForML } from "utils/user";
-import arcfaceAlignmentService from "./arcfaceAlignmentService";
 import arcfaceCropService from "./arcfaceCropService";
 import arcfaceCropService from "./arcfaceCropService";
 import FaceService from "./faceService";
 import FaceService from "./faceService";
 import hdbscanClusteringService from "./hdbscanClusteringService";
 import hdbscanClusteringService from "./hdbscanClusteringService";
@@ -139,16 +136,6 @@ export class MLFactory {
         throw Error("Unknon face crop method: " + method);
         throw Error("Unknon face crop method: " + method);
     }
     }
 
 
-    public static getFaceAlignmentService(
-        method: FaceAlignmentMethod,
-    ): FaceAlignmentService {
-        if (method === "ArcFace") {
-            return arcfaceAlignmentService;
-        }
-
-        throw Error("Unknon face alignment method: " + method);
-    }
-
     public static getBlurDetectionService(
     public static getBlurDetectionService(
         method: BlurDetectionMethod,
         method: BlurDetectionMethod,
     ): BlurDetectionService {
     ): BlurDetectionService {
@@ -187,7 +174,6 @@ export class LocalMLSyncContext implements MLSyncContext {
 
 
     public faceDetectionService: FaceDetectionService;
     public faceDetectionService: FaceDetectionService;
     public faceCropService: FaceCropService;
     public faceCropService: FaceCropService;
-    public faceAlignmentService: FaceAlignmentService;
     public blurDetectionService: BlurDetectionService;
     public blurDetectionService: BlurDetectionService;
     public faceEmbeddingService: FaceEmbeddingService;
     public faceEmbeddingService: FaceEmbeddingService;
     public faceClusteringService: ClusteringService;
     public faceClusteringService: ClusteringService;
@@ -225,8 +211,6 @@ export class LocalMLSyncContext implements MLSyncContext {
         this.faceDetectionService =
         this.faceDetectionService =
             MLFactory.getFaceDetectionService("YoloFace");
             MLFactory.getFaceDetectionService("YoloFace");
         this.faceCropService = MLFactory.getFaceCropService("ArcFace");
         this.faceCropService = MLFactory.getFaceCropService("ArcFace");
-        this.faceAlignmentService =
-            MLFactory.getFaceAlignmentService("ArcFace");
         this.blurDetectionService =
         this.blurDetectionService =
             MLFactory.getBlurDetectionService("Laplacian");
             MLFactory.getBlurDetectionService("Laplacian");
         this.faceEmbeddingService =
         this.faceEmbeddingService =