瀏覽代碼

The ML code runs in workers

Manav Rathi 1 年之前
父節點
當前提交
ff66a2f44c

+ 2 - 2
web/apps/photos/src/services/machineLearning/mobileFaceNetEmbeddingService.ts

@@ -1,4 +1,4 @@
-import { ensureElectron } from "@/next/electron";
+import { workerBridge } from "@/next/worker/worker-bridge";
 import {
 import {
     FaceEmbedding,
     FaceEmbedding,
     FaceEmbeddingMethod,
     FaceEmbeddingMethod,
@@ -23,7 +23,7 @@ class MobileFaceNetEmbeddingService implements FaceEmbeddingService {
     public async getFaceEmbeddings(
     public async getFaceEmbeddings(
         faceData: Float32Array,
         faceData: Float32Array,
     ): Promise<Array<FaceEmbedding>> {
     ): Promise<Array<FaceEmbedding>> {
-        const outputData = await ensureElectron().faceEmbedding(faceData);
+        const outputData = await workerBridge.faceEmbedding(faceData);
 
 
         const embeddingSize = 192;
         const embeddingSize = 192;
         const embeddings = new Array<FaceEmbedding>(
         const embeddings = new Array<FaceEmbedding>(

+ 2 - 2
web/apps/photos/src/services/machineLearning/yoloFaceDetectionService.ts

@@ -1,4 +1,3 @@
-import { ensureElectron } from "@/next/electron";
 import { euclidean } from "hdbscan";
 import { euclidean } from "hdbscan";
 import {
 import {
     Matrix,
     Matrix,
@@ -21,6 +20,7 @@ import {
 } from "utils/image";
 } from "utils/image";
 import { newBox } from "utils/machineLearning";
 import { newBox } from "utils/machineLearning";
 import { Box, Point } from "../../../thirdparty/face-api/classes";
 import { Box, Point } from "../../../thirdparty/face-api/classes";
+import { workerBridge } from "@/next/worker/worker-bridge";
 
 
 class YoloFaceDetectionService implements FaceDetectionService {
 class YoloFaceDetectionService implements FaceDetectionService {
     public method: Versioned<FaceDetectionMethod>;
     public method: Versioned<FaceDetectionMethod>;
@@ -45,7 +45,7 @@ class YoloFaceDetectionService implements FaceDetectionService {
             );
             );
         const data = preprocessResult.data;
         const data = preprocessResult.data;
         const resized = preprocessResult.newSize;
         const resized = preprocessResult.newSize;
-        const outputData = await ensureElectron().detectFaces(data);
+        const outputData = await workerBridge.detectFaces(data);
         const faces = this.getFacesFromYoloOutput(
         const faces = this.getFacesFromYoloOutput(
             outputData as Float32Array,
             outputData as Float32Array,
             0.7,
             0.7,

+ 1 - 0
web/apps/photos/src/utils/storage/mlIDbStorage.ts

@@ -124,6 +124,7 @@ class MLIDbStorage {
                         .add(DEFAULT_ML_SEARCH_CONFIG, ML_SEARCH_CONFIG_NAME);
                         .add(DEFAULT_ML_SEARCH_CONFIG, ML_SEARCH_CONFIG_NAME);
                 }
                 }
                 if (oldVersion < 4) {
                 if (oldVersion < 4) {
+                    // TODO(MR): This loses the user's settings.
                     db.deleteObjectStore("configs");
                     db.deleteObjectStore("configs");
                     db.createObjectStore("configs");
                     db.createObjectStore("configs");
 
 

+ 3 - 0
web/packages/next/worker/comlink-worker.ts

@@ -46,6 +46,9 @@ const workerBridge = {
     logToDisk,
     logToDisk,
     convertToJPEG: (inputFileData: Uint8Array, filename: string) =>
     convertToJPEG: (inputFileData: Uint8Array, filename: string) =>
         ensureElectron().convertToJPEG(inputFileData, filename),
         ensureElectron().convertToJPEG(inputFileData, filename),
+    detectFaces: (input: Float32Array) => ensureElectron().detectFaces(input),
+    faceEmbedding: (input: Float32Array) =>
+        ensureElectron().faceEmbedding(input),
 };
 };
 
 
 export type WorkerBridge = typeof workerBridge;
 export type WorkerBridge = typeof workerBridge;