Clean alignment
This commit is contained in:
parent
19f06e6494
commit
fbebbd3583
5 changed files with 23 additions and 63 deletions
|
@ -1,31 +1,8 @@
|
|||
import { Matrix } from "ml-matrix";
|
||||
import { Point } from "services/face/geom";
|
||||
import {
|
||||
FaceAlignment,
|
||||
FaceAlignmentMethod,
|
||||
FaceAlignmentService,
|
||||
FaceDetection,
|
||||
Versioned,
|
||||
} from "services/face/types";
|
||||
import { FaceAlignment, FaceDetection } from "services/face/types";
|
||||
import { getSimilarityTransformation } from "similarity-transformation";
|
||||
|
||||
class ArcfaceAlignmentService implements FaceAlignmentService {
|
||||
public method: Versioned<FaceAlignmentMethod>;
|
||||
|
||||
constructor() {
|
||||
this.method = {
|
||||
value: "ArcFace",
|
||||
version: 1,
|
||||
};
|
||||
}
|
||||
|
||||
public getFaceAlignment(faceDetection: FaceDetection): FaceAlignment {
|
||||
return getArcfaceAlignment(faceDetection);
|
||||
}
|
||||
}
|
||||
|
||||
export default new ArcfaceAlignmentService();
|
||||
|
||||
const ARCFACE_LANDMARKS = [
|
||||
[38.2946, 51.6963],
|
||||
[73.5318, 51.5014],
|
||||
|
@ -43,9 +20,12 @@ const ARC_FACE_5_LANDMARKS = [
|
|||
[70.7299, 92.2041],
|
||||
] as Array<[number, number]>;
|
||||
|
||||
export function getArcfaceAlignment(
|
||||
faceDetection: FaceDetection,
|
||||
): FaceAlignment {
|
||||
/**
|
||||
* Compute and return an {@link FaceAlignment} for the given face detection.
|
||||
*
|
||||
* @param faceDetection A geometry indicating a face detected in an image.
|
||||
*/
|
||||
export const faceAlignment = (faceDetection: FaceDetection): FaceAlignment => {
|
||||
const landmarkCount = faceDetection.landmarks.length;
|
||||
return getFaceAlignmentUsingSimilarityTransform(
|
||||
faceDetection,
|
||||
|
@ -54,12 +34,11 @@ export function getArcfaceAlignment(
|
|||
ARCFACE_LANDMARKS_FACE_SIZE,
|
||||
),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
function getFaceAlignmentUsingSimilarityTransform(
|
||||
faceDetection: FaceDetection,
|
||||
alignedLandmarks: Array<[number, number]>,
|
||||
// alignmentMethod: Versioned<FaceAlignmentMethod>
|
||||
): FaceAlignment {
|
||||
const landmarksMat = new Matrix(
|
||||
faceDetection.landmarks
|
||||
|
@ -90,7 +69,6 @@ function getFaceAlignmentUsingSimilarityTransform(
|
|||
simTransform.rotation.get(0, 1),
|
||||
simTransform.rotation.get(0, 0),
|
||||
);
|
||||
// log.info({ affineMatrix, meanTranslation, centerMat, center, toMean: simTransform.toMean, fromMean: simTransform.fromMean, size });
|
||||
|
||||
return {
|
||||
affineMatrix,
|
|
@ -210,7 +210,6 @@ export interface MLSyncContext {
|
|||
|
||||
faceDetectionService: FaceDetectionService;
|
||||
faceCropService: FaceCropService;
|
||||
faceAlignmentService: FaceAlignmentService;
|
||||
faceEmbeddingService: FaceEmbeddingService;
|
||||
blurDetectionService: BlurDetectionService;
|
||||
faceClusteringService: ClusteringService;
|
||||
|
@ -272,11 +271,6 @@ export interface FaceCropService {
|
|||
): Promise<FaceCrop>;
|
||||
}
|
||||
|
||||
export interface FaceAlignmentService {
|
||||
method: Versioned<FaceAlignmentMethod>;
|
||||
getFaceAlignment(faceDetection: FaceDetection): FaceAlignment;
|
||||
}
|
||||
|
||||
export interface FaceEmbeddingService {
|
||||
method: Versioned<FaceEmbeddingMethod>;
|
||||
faceSize: number;
|
||||
|
|
|
@ -8,7 +8,7 @@ import {
|
|||
Versioned,
|
||||
} from "services/face/types";
|
||||
import { cropWithRotation } from "utils/image";
|
||||
import { getArcfaceAlignment } from "./arcfaceAlignmentService";
|
||||
import { faceAlignment } from "../face/align";
|
||||
|
||||
class ArcFaceCropService implements FaceCropService {
|
||||
public method: Versioned<FaceCropMethod>;
|
||||
|
@ -24,8 +24,8 @@ class ArcFaceCropService implements FaceCropService {
|
|||
imageBitmap: ImageBitmap,
|
||||
faceDetection: FaceDetection,
|
||||
): Promise<FaceCrop> {
|
||||
const alignedFace = getArcfaceAlignment(faceDetection);
|
||||
const faceCrop = getFaceCrop(imageBitmap, alignedFace);
|
||||
const alignment = faceAlignment(faceDetection);
|
||||
const faceCrop = getFaceCrop(imageBitmap, alignment);
|
||||
|
||||
return faceCrop;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import { openCache } from "@/next/blob-cache";
|
||||
import log from "@/next/log";
|
||||
import { faceAlignment } from "services/face/align";
|
||||
import mlIDbStorage from "services/face/db";
|
||||
import {
|
||||
DetectedFace,
|
||||
|
@ -103,12 +104,14 @@ class FaceService {
|
|||
fileContext: MLSyncFileContext,
|
||||
): Promise<Float32Array> {
|
||||
const { oldMlFile, newMlFile } = fileContext;
|
||||
// TODO-ML(MR):
|
||||
const method = {
|
||||
value: "ArcFace",
|
||||
version: 1,
|
||||
};
|
||||
if (
|
||||
!fileContext.newDetection &&
|
||||
!isDifferentOrOld(
|
||||
oldMlFile?.faceAlignmentMethod,
|
||||
syncContext.faceAlignmentService.method,
|
||||
) &&
|
||||
!isDifferentOrOld(oldMlFile?.faceAlignmentMethod, method) &&
|
||||
areFaceIdsSame(newMlFile.faces, oldMlFile?.faces)
|
||||
) {
|
||||
for (const [index, face] of newMlFile.faces.entries()) {
|
||||
|
@ -118,7 +121,10 @@ class FaceService {
|
|||
return;
|
||||
}
|
||||
|
||||
newMlFile.faceAlignmentMethod = syncContext.faceAlignmentService.method;
|
||||
newMlFile.faceAlignmentMethod = {
|
||||
value: "ArcFace",
|
||||
version: 1,
|
||||
};
|
||||
fileContext.newAlignment = true;
|
||||
const imageBitmap =
|
||||
fileContext.imageBitmap ||
|
||||
|
@ -126,9 +132,7 @@ class FaceService {
|
|||
|
||||
// Execute the face alignment calculations
|
||||
for (const face of newMlFile.faces) {
|
||||
face.alignment = syncContext.faceAlignmentService.getFaceAlignment(
|
||||
face.detection,
|
||||
);
|
||||
face.alignment = faceAlignment(face.detection);
|
||||
}
|
||||
// Extract face images and convert to Float32Array
|
||||
const faceAlignments = newMlFile.faces.map((f) => f.alignment);
|
||||
|
|
|
@ -17,8 +17,6 @@ import {
|
|||
ClusteringMethod,
|
||||
ClusteringService,
|
||||
Face,
|
||||
FaceAlignmentMethod,
|
||||
FaceAlignmentService,
|
||||
FaceCropMethod,
|
||||
FaceCropService,
|
||||
FaceDetection,
|
||||
|
@ -38,7 +36,6 @@ import {
|
|||
import { getLocalFiles } from "services/fileService";
|
||||
import { EnteFile } from "types/file";
|
||||
import { isInternalUserForML } from "utils/user";
|
||||
import arcfaceAlignmentService from "./arcfaceAlignmentService";
|
||||
import arcfaceCropService from "./arcfaceCropService";
|
||||
import FaceService from "./faceService";
|
||||
import hdbscanClusteringService from "./hdbscanClusteringService";
|
||||
|
@ -139,16 +136,6 @@ export class MLFactory {
|
|||
throw Error("Unknon face crop method: " + method);
|
||||
}
|
||||
|
||||
public static getFaceAlignmentService(
|
||||
method: FaceAlignmentMethod,
|
||||
): FaceAlignmentService {
|
||||
if (method === "ArcFace") {
|
||||
return arcfaceAlignmentService;
|
||||
}
|
||||
|
||||
throw Error("Unknon face alignment method: " + method);
|
||||
}
|
||||
|
||||
public static getBlurDetectionService(
|
||||
method: BlurDetectionMethod,
|
||||
): BlurDetectionService {
|
||||
|
@ -187,7 +174,6 @@ export class LocalMLSyncContext implements MLSyncContext {
|
|||
|
||||
public faceDetectionService: FaceDetectionService;
|
||||
public faceCropService: FaceCropService;
|
||||
public faceAlignmentService: FaceAlignmentService;
|
||||
public blurDetectionService: BlurDetectionService;
|
||||
public faceEmbeddingService: FaceEmbeddingService;
|
||||
public faceClusteringService: ClusteringService;
|
||||
|
@ -225,8 +211,6 @@ export class LocalMLSyncContext implements MLSyncContext {
|
|||
this.faceDetectionService =
|
||||
MLFactory.getFaceDetectionService("YoloFace");
|
||||
this.faceCropService = MLFactory.getFaceCropService("ArcFace");
|
||||
this.faceAlignmentService =
|
||||
MLFactory.getFaceAlignmentService("ArcFace");
|
||||
this.blurDetectionService =
|
||||
MLFactory.getBlurDetectionService("Laplacian");
|
||||
this.faceEmbeddingService =
|
||||
|
|
Loading…
Add table
Reference in a new issue