소스 검색

embeddings

Manav Rathi 1 년 전
부모
커밋
43a3df5bbf

+ 3 - 3
desktop/src/main/ipc.ts

@@ -46,7 +46,7 @@ import {
     clipImageEmbedding,
     clipTextEmbeddingIfAvailable,
 } from "./services/ml-clip";
-import { detectFaces, faceEmbedding } from "./services/ml-face";
+import { detectFaces, faceEmbeddings } from "./services/ml-face";
 import { encryptionKey, saveEncryptionKey } from "./services/store";
 import {
     clearPendingUploads,
@@ -182,8 +182,8 @@ export const attachIPCHandlers = () => {
         detectFaces(input),
     );
 
-    ipcMain.handle("faceEmbedding", (_, input: Float32Array) =>
-        faceEmbedding(input),
+    ipcMain.handle("faceEmbeddings", (_, input: Float32Array) =>
+        faceEmbeddings(input),
     );
 
     ipcMain.handle("legacyFaceCrop", (_, faceID: string) =>

+ 1 - 1
desktop/src/main/services/ml-face.ts

@@ -32,7 +32,7 @@ const cachedFaceEmbeddingSession = makeCachedInferenceSession(
     5286998 /* 5 MB */,
 );
 
-export const faceEmbedding = async (input: Float32Array) => {
+export const faceEmbeddings = async (input: Float32Array) => {
     // Dimension of each face (alias)
     const mobileFaceNetFaceSize = 112;
     // Smaller alias

+ 3 - 3
desktop/src/preload.ts

@@ -162,8 +162,8 @@ const clipTextEmbeddingIfAvailable = (text: string) =>
 const detectFaces = (input: Float32Array) =>
     ipcRenderer.invoke("detectFaces", input);
 
-const faceEmbedding = (input: Float32Array) =>
-    ipcRenderer.invoke("faceEmbedding", input);
+const faceEmbeddings = (input: Float32Array) =>
+    ipcRenderer.invoke("faceEmbeddings", input);
 
 const legacyFaceCrop = (faceID: string) =>
     ipcRenderer.invoke("legacyFaceCrop", faceID);
@@ -343,7 +343,7 @@ contextBridge.exposeInMainWorld("electron", {
     clipImageEmbedding,
     clipTextEmbeddingIfAvailable,
     detectFaces,
-    faceEmbedding,
+    faceEmbeddings,
     legacyFaceCrop,
 
     // - Watch

+ 1 - 1
web/apps/photos/src/services/face/blur.ts

@@ -1,6 +1,6 @@
 import { Face } from "services/face/types";
 import { createGrayscaleIntMatrixFromNormalized2List } from "utils/image";
-import { mobileFaceNetFaceSize } from "../machineLearning/mobileFaceNetEmbeddingService";
+import { mobileFaceNetFaceSize } from "../machineLearning/embed";
 
 /**
  * Laplacian blur detection.

+ 1 - 1
web/apps/photos/src/services/face/detect.ts

@@ -24,7 +24,7 @@ import {
 /**
  * Detect faces in the given {@link imageBitmap}.
  *
- * The ML model used is YOLO, running in an ONNX runtime.
+ * The model used is YOLO, running in an ONNX runtime.
  */
 export const detectFaces = async (
     imageBitmap: ImageBitmap,

+ 26 - 0
web/apps/photos/src/services/machineLearning/embed.ts

@@ -0,0 +1,26 @@
+import { workerBridge } from "@/next/worker/worker-bridge";
+import { FaceEmbedding } from "services/face/types";
+
+export const mobileFaceNetFaceSize = 112;
+
+/**
+ * Compute embeddings for the given {@link faceData}.
+ *
+ * The model used is MobileFaceNet, running in an ONNX runtime.
+ */
+export const getFaceEmbeddings = async (
+    faceData: Float32Array,
+): Promise<Array<FaceEmbedding>> => {
+    const outputData = await workerBridge.faceEmbeddings(faceData);
+
+    const embeddingSize = 192;
+    const embeddings = new Array<FaceEmbedding>(
+        outputData.length / embeddingSize,
+    );
+    for (let i = 0; i < embeddings.length; i++) {
+        embeddings[i] = new Float32Array(
+            outputData.slice(i * embeddingSize, (i + 1) * embeddingSize),
+        );
+    }
+    return embeddings;
+};

+ 1 - 1
web/apps/photos/src/services/machineLearning/machineLearningService.ts

@@ -31,8 +31,8 @@ import { EnteFile } from "types/file";
 import { isInternalUserForML } from "utils/user";
 import { fetchImageBitmapForContext } from "../face/image";
 import { syncPeopleIndex } from "../face/people";
+import mobileFaceNetEmbeddingService from "./embed";
 import FaceService from "./faceService";
-import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService";
 
 /**
  * TODO-ML(MR): What and why.

+ 0 - 41
web/apps/photos/src/services/machineLearning/mobileFaceNetEmbeddingService.ts

@@ -1,41 +0,0 @@
-import { workerBridge } from "@/next/worker/worker-bridge";
-import {
-    FaceEmbedding,
-    FaceEmbeddingMethod,
-    FaceEmbeddingService,
-    Versioned,
-} from "services/face/types";
-
-export const mobileFaceNetFaceSize = 112;
-
-class MobileFaceNetEmbeddingService implements FaceEmbeddingService {
-    public method: Versioned<FaceEmbeddingMethod>;
-    public faceSize: number;
-
-    public constructor() {
-        this.method = {
-            value: "MobileFaceNet",
-            version: 2,
-        };
-        this.faceSize = mobileFaceNetFaceSize;
-    }
-
-    public async getFaceEmbeddings(
-        faceData: Float32Array,
-    ): Promise<Array<FaceEmbedding>> {
-        const outputData = await workerBridge.faceEmbedding(faceData);
-
-        const embeddingSize = 192;
-        const embeddings = new Array<FaceEmbedding>(
-            outputData.length / embeddingSize,
-        );
-        for (let i = 0; i < embeddings.length; i++) {
-            embeddings[i] = new Float32Array(
-                outputData.slice(i * embeddingSize, (i + 1) * embeddingSize),
-            );
-        }
-        return embeddings;
-    }
-}
-
-export default new MobileFaceNetEmbeddingService();

+ 2 - 2
web/packages/next/types/ipc.ts

@@ -332,12 +332,12 @@ export interface Electron {
     detectFaces: (input: Float32Array) => Promise<Float32Array>;
 
     /**
-     * Return a MobileFaceNet embedding for the given face data.
+     * Return a MobileFaceNet embeddings for the given faces.
      *
      * Both the input and output are opaque binary data whose internal structure
      * is specific to our implementation and the model (MobileFaceNet) we use.
      */
-    faceEmbedding: (input: Float32Array) => Promise<Float32Array>;
+    faceEmbeddings: (input: Float32Array) => Promise<Float32Array>;
 
     /**
      * Return a face crop stored by a previous version of ML.

+ 2 - 2
web/packages/next/worker/comlink-worker.ts

@@ -47,8 +47,8 @@ const workerBridge = {
     convertToJPEG: (imageData: Uint8Array) =>
         ensureElectron().convertToJPEG(imageData),
     detectFaces: (input: Float32Array) => ensureElectron().detectFaces(input),
-    faceEmbedding: (input: Float32Array) =>
-        ensureElectron().faceEmbedding(input),
+    faceEmbeddings: (input: Float32Array) =>
+        ensureElectron().faceEmbeddings(input),
 };
 
 export type WorkerBridge = typeof workerBridge;