Manav Rathi 1 anno fa
parent
commit
db05afb9ff

+ 32 - 0
web/apps/photos/src/services/face/crop.ts

@@ -0,0 +1,32 @@
+import { Box, enlargeBox } from "services/face/geom";
+import { FaceCrop, FaceDetection } from "services/face/types";
+import { cropWithRotation } from "utils/image";
+import { faceAlignment } from "./align";
+
+export const getFaceCrop = (
+    imageBitmap: ImageBitmap,
+    faceDetection: FaceDetection,
+): FaceCrop => {
+    const alignment = faceAlignment(faceDetection);
+
+    const padding = 0.25;
+    const maxSize = 256;
+
+    const alignmentBox = new Box({
+        x: alignment.center.x - alignment.size / 2,
+        y: alignment.center.y - alignment.size / 2,
+        width: alignment.size,
+        height: alignment.size,
+    }).round();
+    const scaleForPadding = 1 + padding * 2;
+    const paddedBox = enlargeBox(alignmentBox, scaleForPadding).round();
+    const faceImageBitmap = cropWithRotation(imageBitmap, paddedBox, 0, {
+        width: maxSize,
+        height: maxSize,
+    });
+
+    return {
+        image: faceImageBitmap,
+        imageBox: paddedBox,
+    };
+};

+ 1 - 5
web/apps/photos/src/services/face/people.ts

@@ -48,11 +48,7 @@ const syncPeopleFromClusters = async (
         if (personFace && !personFace.crop?.cacheKey) {
         if (personFace && !personFace.crop?.cacheKey) {
             const file = await getLocalFile(personFace.fileId);
             const file = await getLocalFile(personFace.fileId);
             const imageBitmap = await fetchImageBitmap(file);
             const imageBitmap = await fetchImageBitmap(file);
-            await FaceService.saveFaceCrop(
-                imageBitmap,
-                personFace,
-                syncContext,
-            );
+            await FaceService.saveFaceCrop(imageBitmap, personFace);
         }
         }
 
 
         const person: Person = {
         const person: Person = {

+ 0 - 60
web/apps/photos/src/services/machineLearning/arcfaceCropService.ts

@@ -1,60 +0,0 @@
-import { Box, enlargeBox } from "services/face/geom";
-import {
-    FaceAlignment,
-    FaceCrop,
-    FaceCropMethod,
-    FaceCropService,
-    FaceDetection,
-    Versioned,
-} from "services/face/types";
-import { cropWithRotation } from "utils/image";
-import { faceAlignment } from "../face/align";
-
-class ArcFaceCropService implements FaceCropService {
-    public method: Versioned<FaceCropMethod>;
-
-    constructor() {
-        this.method = {
-            value: "ArcFace",
-            version: 1,
-        };
-    }
-
-    public async getFaceCrop(
-        imageBitmap: ImageBitmap,
-        faceDetection: FaceDetection,
-    ): Promise<FaceCrop> {
-        const alignment = faceAlignment(faceDetection);
-        const faceCrop = getFaceCrop(imageBitmap, alignment);
-
-        return faceCrop;
-    }
-}
-
-export default new ArcFaceCropService();
-
-export function getFaceCrop(
-    imageBitmap: ImageBitmap,
-    alignment: FaceAlignment,
-): FaceCrop {
-    const padding = 0.25;
-    const maxSize = 256;
-
-    const alignmentBox = new Box({
-        x: alignment.center.x - alignment.size / 2,
-        y: alignment.center.y - alignment.size / 2,
-        width: alignment.size,
-        height: alignment.size,
-    }).round();
-    const scaleForPadding = 1 + padding * 2;
-    const paddedBox = enlargeBox(alignmentBox, scaleForPadding).round();
-    const faceImageBitmap = cropWithRotation(imageBitmap, paddedBox, 0, {
-        width: maxSize,
-        height: maxSize,
-    });
-
-    return {
-        image: faceImageBitmap,
-        imageBox: paddedBox,
-    };
-}

+ 6 - 15
web/apps/photos/src/services/machineLearning/faceService.ts

@@ -11,6 +11,7 @@ import {
 } from "services/face/types";
 } from "services/face/types";
 import { imageBitmapToBlob, warpAffineFloat32List } from "utils/image";
 import { imageBitmapToBlob, warpAffineFloat32List } from "utils/image";
 import { clusterFaces } from "../face/cluster";
 import { clusterFaces } from "../face/cluster";
+import { getFaceCrop } from "../face/crop";
 import {
 import {
     fetchImageBitmap,
     fetchImageBitmap,
     fetchImageBitmapForContext,
     fetchImageBitmapForContext,
@@ -55,7 +56,7 @@ class FaceService {
         newMlFile.faceCropMethod = syncContext.faceCropService.method;
         newMlFile.faceCropMethod = syncContext.faceCropService.method;
 
 
         for (const face of newMlFile.faces) {
         for (const face of newMlFile.faces) {
-            await this.saveFaceCrop(imageBitmap, face, syncContext);
+            await this.saveFaceCrop(imageBitmap, face);
         }
         }
     }
     }
 
 
@@ -132,15 +133,8 @@ class FaceService {
         }
         }
     }
     }
 
 
-    async saveFaceCrop(
-        imageBitmap: ImageBitmap,
-        face: Face,
-        syncContext: MLSyncContext,
-    ) {
-        const faceCrop = await syncContext.faceCropService.getFaceCrop(
-            imageBitmap,
-            face.detection,
-        );
+    async saveFaceCrop(imageBitmap: ImageBitmap, face: Face) {
+        const faceCrop = getFaceCrop(imageBitmap, face.detection);
 
 
         const blob = await imageBitmapToBlob(faceCrop.image);
         const blob = await imageBitmapToBlob(faceCrop.image);
 
 
@@ -197,10 +191,7 @@ class FaceService {
         // };
         // };
     }
     }
 
 
-    public async regenerateFaceCrop(
-        syncContext: MLSyncContext,
-        faceID: string,
-    ) {
+    public async regenerateFaceCrop(faceID: string) {
         const fileID = Number(faceID.split("-")[0]);
         const fileID = Number(faceID.split("-")[0]);
         const personFace = await mlIDbStorage.getFace(fileID, faceID);
         const personFace = await mlIDbStorage.getFace(fileID, faceID);
         if (!personFace) {
         if (!personFace) {
@@ -209,7 +200,7 @@ class FaceService {
 
 
         const file = await getLocalFile(personFace.fileId);
         const file = await getLocalFile(personFace.fileId);
         const imageBitmap = await fetchImageBitmap(file);
         const imageBitmap = await fetchImageBitmap(file);
-        return await this.saveFaceCrop(imageBitmap, personFace, syncContext);
+        return await this.saveFaceCrop(imageBitmap, personFace);
     }
     }
 }
 }
 
 

+ 1 - 13
web/apps/photos/src/services/machineLearning/machineLearningService.ts

@@ -15,7 +15,6 @@ import {
     BlurDetectionMethod,
     BlurDetectionMethod,
     BlurDetectionService,
     BlurDetectionService,
     Face,
     Face,
-    FaceCropMethod,
     FaceCropService,
     FaceCropService,
     FaceDetection,
     FaceDetection,
     FaceDetectionMethod,
     FaceDetectionMethod,
@@ -34,7 +33,6 @@ import {
 import { getLocalFiles } from "services/fileService";
 import { getLocalFiles } from "services/fileService";
 import { EnteFile } from "types/file";
 import { EnteFile } from "types/file";
 import { isInternalUserForML } from "utils/user";
 import { isInternalUserForML } from "utils/user";
-import arcfaceCropService from "./arcfaceCropService";
 import FaceService from "./faceService";
 import FaceService from "./faceService";
 import laplacianBlurDetectionService from "./laplacianBlurDetectionService";
 import laplacianBlurDetectionService from "./laplacianBlurDetectionService";
 import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService";
 import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService";
@@ -125,14 +123,6 @@ export class MLFactory {
         throw Error("Unknon face detection method: " + method);
         throw Error("Unknon face detection method: " + method);
     }
     }
 
 
-    public static getFaceCropService(method: FaceCropMethod) {
-        if (method === "ArcFace") {
-            return arcfaceCropService;
-        }
-
-        throw Error("Unknon face crop method: " + method);
-    }
-
     public static getBlurDetectionService(
     public static getBlurDetectionService(
         method: BlurDetectionMethod,
         method: BlurDetectionMethod,
     ): BlurDetectionService {
     ): BlurDetectionService {
@@ -189,7 +179,6 @@ export class LocalMLSyncContext implements MLSyncContext {
 
 
         this.faceDetectionService =
         this.faceDetectionService =
             MLFactory.getFaceDetectionService("YoloFace");
             MLFactory.getFaceDetectionService("YoloFace");
-        this.faceCropService = MLFactory.getFaceCropService("ArcFace");
         this.blurDetectionService =
         this.blurDetectionService =
             MLFactory.getBlurDetectionService("Laplacian");
             MLFactory.getBlurDetectionService("Laplacian");
         this.faceEmbeddingService =
         this.faceEmbeddingService =
@@ -288,8 +277,7 @@ class MachineLearningService {
         faceID: string,
         faceID: string,
     ) {
     ) {
         await downloadManager.init(APPS.PHOTOS, { token });
         await downloadManager.init(APPS.PHOTOS, { token });
-        const syncContext = await this.getSyncContext(token, userID);
-        return FaceService.regenerateFaceCrop(syncContext, faceID);
+        return FaceService.regenerateFaceCrop(faceID);
     }
     }
 
 
     private newMlData(fileId: number) {
     private newMlData(fileId: number) {