Clean alignment

This commit is contained in:
Manav Rathi 2024-05-16 11:43:04 +05:30
parent 19f06e6494
commit fbebbd3583
No known key found for this signature in database
5 changed files with 23 additions and 63 deletions

View file

@ -1,31 +1,8 @@
import { Matrix } from "ml-matrix";
import { Point } from "services/face/geom";
import {
FaceAlignment,
FaceAlignmentMethod,
FaceAlignmentService,
FaceDetection,
Versioned,
} from "services/face/types";
import { FaceAlignment, FaceDetection } from "services/face/types";
import { getSimilarityTransformation } from "similarity-transformation";
class ArcfaceAlignmentService implements FaceAlignmentService {
public method: Versioned<FaceAlignmentMethod>;
constructor() {
this.method = {
value: "ArcFace",
version: 1,
};
}
public getFaceAlignment(faceDetection: FaceDetection): FaceAlignment {
return getArcfaceAlignment(faceDetection);
}
}
export default new ArcfaceAlignmentService();
const ARCFACE_LANDMARKS = [
[38.2946, 51.6963],
[73.5318, 51.5014],
@ -43,9 +20,12 @@ const ARC_FACE_5_LANDMARKS = [
[70.7299, 92.2041],
] as Array<[number, number]>;
export function getArcfaceAlignment(
faceDetection: FaceDetection,
): FaceAlignment {
/**
* Compute and return an {@link FaceAlignment} for the given face detection.
*
* @param faceDetection A geometry indicating a face detected in an image.
*/
export const faceAlignment = (faceDetection: FaceDetection): FaceAlignment => {
const landmarkCount = faceDetection.landmarks.length;
return getFaceAlignmentUsingSimilarityTransform(
faceDetection,
@ -54,12 +34,11 @@ export function getArcfaceAlignment(
ARCFACE_LANDMARKS_FACE_SIZE,
),
);
}
};
function getFaceAlignmentUsingSimilarityTransform(
faceDetection: FaceDetection,
alignedLandmarks: Array<[number, number]>,
// alignmentMethod: Versioned<FaceAlignmentMethod>
): FaceAlignment {
const landmarksMat = new Matrix(
faceDetection.landmarks
@ -90,7 +69,6 @@ function getFaceAlignmentUsingSimilarityTransform(
simTransform.rotation.get(0, 1),
simTransform.rotation.get(0, 0),
);
// log.info({ affineMatrix, meanTranslation, centerMat, center, toMean: simTransform.toMean, fromMean: simTransform.fromMean, size });
return {
affineMatrix,

View file

@ -210,7 +210,6 @@ export interface MLSyncContext {
faceDetectionService: FaceDetectionService;
faceCropService: FaceCropService;
faceAlignmentService: FaceAlignmentService;
faceEmbeddingService: FaceEmbeddingService;
blurDetectionService: BlurDetectionService;
faceClusteringService: ClusteringService;
@ -272,11 +271,6 @@ export interface FaceCropService {
): Promise<FaceCrop>;
}
export interface FaceAlignmentService {
method: Versioned<FaceAlignmentMethod>;
getFaceAlignment(faceDetection: FaceDetection): FaceAlignment;
}
export interface FaceEmbeddingService {
method: Versioned<FaceEmbeddingMethod>;
faceSize: number;

View file

@ -8,7 +8,7 @@ import {
Versioned,
} from "services/face/types";
import { cropWithRotation } from "utils/image";
import { getArcfaceAlignment } from "./arcfaceAlignmentService";
import { faceAlignment } from "../face/align";
class ArcFaceCropService implements FaceCropService {
public method: Versioned<FaceCropMethod>;
@ -24,8 +24,8 @@ class ArcFaceCropService implements FaceCropService {
imageBitmap: ImageBitmap,
faceDetection: FaceDetection,
): Promise<FaceCrop> {
const alignedFace = getArcfaceAlignment(faceDetection);
const faceCrop = getFaceCrop(imageBitmap, alignedFace);
const alignment = faceAlignment(faceDetection);
const faceCrop = getFaceCrop(imageBitmap, alignment);
return faceCrop;
}

View file

@ -1,5 +1,6 @@
import { openCache } from "@/next/blob-cache";
import log from "@/next/log";
import { faceAlignment } from "services/face/align";
import mlIDbStorage from "services/face/db";
import {
DetectedFace,
@ -103,12 +104,14 @@ class FaceService {
fileContext: MLSyncFileContext,
): Promise<Float32Array> {
const { oldMlFile, newMlFile } = fileContext;
// TODO-ML(MR):
const method = {
value: "ArcFace",
version: 1,
};
if (
!fileContext.newDetection &&
!isDifferentOrOld(
oldMlFile?.faceAlignmentMethod,
syncContext.faceAlignmentService.method,
) &&
!isDifferentOrOld(oldMlFile?.faceAlignmentMethod, method) &&
areFaceIdsSame(newMlFile.faces, oldMlFile?.faces)
) {
for (const [index, face] of newMlFile.faces.entries()) {
@ -118,7 +121,10 @@ class FaceService {
return;
}
newMlFile.faceAlignmentMethod = syncContext.faceAlignmentService.method;
newMlFile.faceAlignmentMethod = {
value: "ArcFace",
version: 1,
};
fileContext.newAlignment = true;
const imageBitmap =
fileContext.imageBitmap ||
@ -126,9 +132,7 @@ class FaceService {
// Execute the face alignment calculations
for (const face of newMlFile.faces) {
face.alignment = syncContext.faceAlignmentService.getFaceAlignment(
face.detection,
);
face.alignment = faceAlignment(face.detection);
}
// Extract face images and convert to Float32Array
const faceAlignments = newMlFile.faces.map((f) => f.alignment);

View file

@ -17,8 +17,6 @@ import {
ClusteringMethod,
ClusteringService,
Face,
FaceAlignmentMethod,
FaceAlignmentService,
FaceCropMethod,
FaceCropService,
FaceDetection,
@ -38,7 +36,6 @@ import {
import { getLocalFiles } from "services/fileService";
import { EnteFile } from "types/file";
import { isInternalUserForML } from "utils/user";
import arcfaceAlignmentService from "./arcfaceAlignmentService";
import arcfaceCropService from "./arcfaceCropService";
import FaceService from "./faceService";
import hdbscanClusteringService from "./hdbscanClusteringService";
@ -139,16 +136,6 @@ export class MLFactory {
throw Error("Unknon face crop method: " + method);
}
public static getFaceAlignmentService(
method: FaceAlignmentMethod,
): FaceAlignmentService {
if (method === "ArcFace") {
return arcfaceAlignmentService;
}
throw Error("Unknon face alignment method: " + method);
}
public static getBlurDetectionService(
method: BlurDetectionMethod,
): BlurDetectionService {
@ -187,7 +174,6 @@ export class LocalMLSyncContext implements MLSyncContext {
public faceDetectionService: FaceDetectionService;
public faceCropService: FaceCropService;
public faceAlignmentService: FaceAlignmentService;
public blurDetectionService: BlurDetectionService;
public faceEmbeddingService: FaceEmbeddingService;
public faceClusteringService: ClusteringService;
@ -225,8 +211,6 @@ export class LocalMLSyncContext implements MLSyncContext {
this.faceDetectionService =
MLFactory.getFaceDetectionService("YoloFace");
this.faceCropService = MLFactory.getFaceCropService("ArcFace");
this.faceAlignmentService =
MLFactory.getFaceAlignmentService("ArcFace");
this.blurDetectionService =
MLFactory.getBlurDetectionService("Laplacian");
this.faceEmbeddingService =