diff --git a/web/apps/photos/src/services/face/crop.ts b/web/apps/photos/src/services/face/crop.ts new file mode 100644 index 000000000..acd49228e --- /dev/null +++ b/web/apps/photos/src/services/face/crop.ts @@ -0,0 +1,32 @@ +import { Box, enlargeBox } from "services/face/geom"; +import { FaceCrop, FaceDetection } from "services/face/types"; +import { cropWithRotation } from "utils/image"; +import { faceAlignment } from "./align"; + +export const getFaceCrop = ( + imageBitmap: ImageBitmap, + faceDetection: FaceDetection, +): FaceCrop => { + const alignment = faceAlignment(faceDetection); + + const padding = 0.25; + const maxSize = 256; + + const alignmentBox = new Box({ + x: alignment.center.x - alignment.size / 2, + y: alignment.center.y - alignment.size / 2, + width: alignment.size, + height: alignment.size, + }).round(); + const scaleForPadding = 1 + padding * 2; + const paddedBox = enlargeBox(alignmentBox, scaleForPadding).round(); + const faceImageBitmap = cropWithRotation(imageBitmap, paddedBox, 0, { + width: maxSize, + height: maxSize, + }); + + return { + image: faceImageBitmap, + imageBox: paddedBox, + }; +}; diff --git a/web/apps/photos/src/services/face/people.ts b/web/apps/photos/src/services/face/people.ts index e6cea9007..081962935 100644 --- a/web/apps/photos/src/services/face/people.ts +++ b/web/apps/photos/src/services/face/people.ts @@ -48,11 +48,7 @@ const syncPeopleFromClusters = async ( if (personFace && !personFace.crop?.cacheKey) { const file = await getLocalFile(personFace.fileId); const imageBitmap = await fetchImageBitmap(file); - await FaceService.saveFaceCrop( - imageBitmap, - personFace, - syncContext, - ); + await FaceService.saveFaceCrop(imageBitmap, personFace); } const person: Person = { diff --git a/web/apps/photos/src/services/machineLearning/arcfaceCropService.ts b/web/apps/photos/src/services/machineLearning/arcfaceCropService.ts deleted file mode 100644 index 81f2d4de5..000000000 --- a/web/apps/photos/src/services/machineLearning/arcfaceCropService.ts +++ /dev/null @@ -1,60 +0,0 @@ -import { Box, enlargeBox } from "services/face/geom"; -import { - FaceAlignment, - FaceCrop, - FaceCropMethod, - FaceCropService, - FaceDetection, - Versioned, -} from "services/face/types"; -import { cropWithRotation } from "utils/image"; -import { faceAlignment } from "../face/align"; - -class ArcFaceCropService implements FaceCropService { - public method: Versioned; - - constructor() { - this.method = { - value: "ArcFace", - version: 1, - }; - } - - public async getFaceCrop( - imageBitmap: ImageBitmap, - faceDetection: FaceDetection, - ): Promise { - const alignment = faceAlignment(faceDetection); - const faceCrop = getFaceCrop(imageBitmap, alignment); - - return faceCrop; - } -} - -export default new ArcFaceCropService(); - -export function getFaceCrop( - imageBitmap: ImageBitmap, - alignment: FaceAlignment, -): FaceCrop { - const padding = 0.25; - const maxSize = 256; - - const alignmentBox = new Box({ - x: alignment.center.x - alignment.size / 2, - y: alignment.center.y - alignment.size / 2, - width: alignment.size, - height: alignment.size, - }).round(); - const scaleForPadding = 1 + padding * 2; - const paddedBox = enlargeBox(alignmentBox, scaleForPadding).round(); - const faceImageBitmap = cropWithRotation(imageBitmap, paddedBox, 0, { - width: maxSize, - height: maxSize, - }); - - return { - image: faceImageBitmap, - imageBox: paddedBox, - }; -} diff --git a/web/apps/photos/src/services/machineLearning/faceService.ts b/web/apps/photos/src/services/machineLearning/faceService.ts index f5e83b8f5..c6c061af1 100644 --- a/web/apps/photos/src/services/machineLearning/faceService.ts +++ b/web/apps/photos/src/services/machineLearning/faceService.ts @@ -11,6 +11,7 @@ import { } from "services/face/types"; import { imageBitmapToBlob, warpAffineFloat32List } from "utils/image"; import { clusterFaces } from "../face/cluster"; +import { getFaceCrop } from "../face/crop"; import { fetchImageBitmap, fetchImageBitmapForContext, @@ -55,7 +56,7 @@ class FaceService { newMlFile.faceCropMethod = syncContext.faceCropService.method; for (const face of newMlFile.faces) { - await this.saveFaceCrop(imageBitmap, face, syncContext); + await this.saveFaceCrop(imageBitmap, face); } } @@ -132,15 +133,8 @@ class FaceService { } } - async saveFaceCrop( - imageBitmap: ImageBitmap, - face: Face, - syncContext: MLSyncContext, - ) { - const faceCrop = await syncContext.faceCropService.getFaceCrop( - imageBitmap, - face.detection, - ); + async saveFaceCrop(imageBitmap: ImageBitmap, face: Face) { + const faceCrop = getFaceCrop(imageBitmap, face.detection); const blob = await imageBitmapToBlob(faceCrop.image); @@ -197,10 +191,7 @@ class FaceService { // }; } - public async regenerateFaceCrop( - syncContext: MLSyncContext, - faceID: string, - ) { + public async regenerateFaceCrop(faceID: string) { const fileID = Number(faceID.split("-")[0]); const personFace = await mlIDbStorage.getFace(fileID, faceID); if (!personFace) { @@ -209,7 +200,7 @@ class FaceService { const file = await getLocalFile(personFace.fileId); const imageBitmap = await fetchImageBitmap(file); - return await this.saveFaceCrop(imageBitmap, personFace, syncContext); + return await this.saveFaceCrop(imageBitmap, personFace); } } diff --git a/web/apps/photos/src/services/machineLearning/machineLearningService.ts b/web/apps/photos/src/services/machineLearning/machineLearningService.ts index 302a80b53..ec7f97807 100644 --- a/web/apps/photos/src/services/machineLearning/machineLearningService.ts +++ b/web/apps/photos/src/services/machineLearning/machineLearningService.ts @@ -15,7 +15,6 @@ import { BlurDetectionMethod, BlurDetectionService, Face, - FaceCropMethod, FaceCropService, FaceDetection, FaceDetectionMethod, @@ -34,7 +33,6 @@ import { import { getLocalFiles } from "services/fileService"; import { EnteFile } from "types/file"; import { isInternalUserForML } from "utils/user"; -import arcfaceCropService from "./arcfaceCropService"; import FaceService from "./faceService"; import laplacianBlurDetectionService from "./laplacianBlurDetectionService"; import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService"; @@ -125,14 +123,6 @@ export class MLFactory { throw Error("Unknon face detection method: " + method); } - public static getFaceCropService(method: FaceCropMethod) { - if (method === "ArcFace") { - return arcfaceCropService; - } - - throw Error("Unknon face crop method: " + method); - } - public static getBlurDetectionService( method: BlurDetectionMethod, ): BlurDetectionService { @@ -189,7 +179,6 @@ export class LocalMLSyncContext implements MLSyncContext { this.faceDetectionService = MLFactory.getFaceDetectionService("YoloFace"); - this.faceCropService = MLFactory.getFaceCropService("ArcFace"); this.blurDetectionService = MLFactory.getBlurDetectionService("Laplacian"); this.faceEmbeddingService = @@ -288,8 +277,7 @@ class MachineLearningService { faceID: string, ) { await downloadManager.init(APPS.PHOTOS, { token }); - const syncContext = await this.getSyncContext(token, userID); - return FaceService.regenerateFaceCrop(syncContext, faceID); + return FaceService.regenerateFaceCrop(faceID); } private newMlData(fileId: number) {