Jelajahi Sumber

[web] Import the scaffolding to sync face embeddings from web_face_v2 (#1402)

This PR cherry picks Neeraj's ML related changes from the web_face_v2
branch.

Similar to https://github.com/ente-io/ente/pull/1399, this gets us one
step closer to integrating ONNX-YOLO with our desktop app. But it is not
currently in a usable state (The web app's functionality remains
untouched).
Manav Rathi 1 tahun lalu
induk
melakukan
4a69e9260c

+ 5 - 1
web/apps/photos/src/pages/gallery/index.tsx

@@ -105,7 +105,7 @@ import { AppContext } from "pages/_app";
 import { clipService } from "services/clip-service";
 import { constructUserIDToEmailMap } from "services/collectionService";
 import downloadManager from "services/download";
-import { syncEmbeddings } from "services/embeddingService";
+import { syncEmbeddings, syncFileEmbeddings } from "services/embeddingService";
 import { syncEntities } from "services/entityService";
 import locationSearchService from "services/locationSearchService";
 import { getLocalTrashedFiles, syncTrash } from "services/trashService";
@@ -702,6 +702,10 @@ export default function Gallery() {
             await syncEntities();
             await syncMapEnabled();
             await syncEmbeddings();
+            const electron = globalThis.electron;
+            if (electron) {
+                await syncFileEmbeddings();
+            }
             if (clipService.isPlatformSupported()) {
                 void clipService.scheduleImageEmbeddingExtraction();
             }

+ 95 - 3
web/apps/photos/src/services/embeddingService.ts

@@ -13,7 +13,11 @@ import type {
     PutEmbeddingRequest,
 } from "types/embedding";
 import { EnteFile } from "types/file";
-import { getLatestVersionEmbeddings } from "utils/embedding";
+import {
+    getLatestVersionEmbeddings,
+    getLatestVersionFileEmbeddings,
+} from "utils/embedding";
+import { FileML } from "utils/machineLearning/mldataMappers";
 import { getLocalCollections } from "./collectionService";
 import { getAllLocalFiles } from "./fileService";
 import { getLocalTrashedFiles } from "./trashService";
@@ -24,6 +28,7 @@ const DIFF_LIMIT = 500;
 
 const EMBEDDINGS_TABLE_V1 = "embeddings";
 const EMBEDDINGS_TABLE = "embeddings_v2";
+const FILE_EMBEDING_TABLE = "file_embeddings";
 const EMBEDDING_SYNC_TIME_TABLE = "embedding_sync_time";
 
 export const getAllLocalEmbeddings = async () => {
@@ -38,6 +43,15 @@ export const getAllLocalEmbeddings = async () => {
     return embeddings;
 };
 
+export const getFileMLEmbeddings = async (): Promise<FileML[]> => {
+    const embeddings: Array<FileML> =
+        await localForage.getItem<FileML[]>(FILE_EMBEDING_TABLE);
+    if (!embeddings) {
+        return [];
+    }
+    return embeddings;
+};
+
 export const getLocalEmbeddings = async () => {
     const embeddings = await getAllLocalEmbeddings();
     return embeddings.filter((embedding) => embedding.model === "onnx-clip");
@@ -140,6 +154,83 @@ export const syncEmbeddings = async () => {
     }
 };
 
+export const syncFileEmbeddings = async () => {
+    const models: EmbeddingModel[] = ["file-ml-clip-face"];
+    try {
+        let allEmbeddings: FileML[] = await getFileMLEmbeddings();
+        const localFiles = await getAllLocalFiles();
+        const hiddenAlbums = await getLocalCollections("hidden");
+        const localTrashFiles = await getLocalTrashedFiles();
+        const fileIdToKeyMap = new Map<number, string>();
+        const allLocalFiles = [...localFiles, ...localTrashFiles];
+        allLocalFiles.forEach((file) => {
+            fileIdToKeyMap.set(file.id, file.key);
+        });
+        await cleanupDeletedEmbeddings(allLocalFiles, allEmbeddings);
+        log.info(`Syncing embeddings localCount: ${allEmbeddings.length}`);
+        for (const model of models) {
+            let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
+            log.info(
+                `Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`,
+            );
+            let response: GetEmbeddingDiffResponse;
+            do {
+                response = await getEmbeddingsDiff(modelLastSinceTime, model);
+                if (!response.diff?.length) {
+                    return;
+                }
+                const newEmbeddings = await Promise.all(
+                    response.diff.map(async (embedding) => {
+                        try {
+                            const worker =
+                                await ComlinkCryptoWorker.getInstance();
+                            const fileKey = fileIdToKeyMap.get(
+                                embedding.fileID,
+                            );
+                            if (!fileKey) {
+                                throw Error(CustomError.FILE_NOT_FOUND);
+                            }
+                            const decryptedData = await worker.decryptMetadata(
+                                embedding.encryptedEmbedding,
+                                embedding.decryptionHeader,
+                                fileIdToKeyMap.get(embedding.fileID),
+                            );
+
+                            return {
+                                ...decryptedData,
+                                updatedAt: embedding.updatedAt,
+                            } as unknown as FileML;
+                        } catch (e) {
+                            let hasHiddenAlbums = false;
+                            if (e.message === CustomError.FILE_NOT_FOUND) {
+                                hasHiddenAlbums = hiddenAlbums?.length > 0;
+                            }
+                            log.error(
+                                `decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`,
+                                e,
+                            );
+                        }
+                    }),
+                );
+                allEmbeddings = getLatestVersionFileEmbeddings([
+                    ...allEmbeddings,
+                    ...newEmbeddings,
+                ]);
+                if (response.diff.length) {
+                    modelLastSinceTime = response.diff.slice(-1)[0].updatedAt;
+                }
+                await localForage.setItem(FILE_EMBEDING_TABLE, allEmbeddings);
+                await setModelEmbeddingSyncTime(model, modelLastSinceTime);
+                log.info(
+                    `Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`,
+                );
+            } while (response.diff.length === DIFF_LIMIT);
+        }
+    } catch (e) {
+        log.error("Sync embeddings failed", e);
+    }
+};
+
 export const getEmbeddingsDiff = async (
     sinceTime: number,
     model: EmbeddingModel,
@@ -173,7 +264,8 @@ export const putEmbedding = async (
     try {
         const token = getToken();
         if (!token) {
-            return;
+            log.info("putEmbedding failed: token not found");
+            throw Error(CustomError.TOKEN_MISSING);
         }
         const resp = await HTTPService.put(
             `${ENDPOINT}/embeddings`,
@@ -192,7 +284,7 @@ export const putEmbedding = async (
 
 export const cleanupDeletedEmbeddings = async (
     allLocalFiles: EnteFile[],
-    allLocalEmbeddings: Embedding[],
+    allLocalEmbeddings: Embedding[] | FileML[],
 ) => {
     const activeFileIds = new Set<number>();
     allLocalFiles.forEach((file) => {

+ 30 - 7
web/apps/photos/src/services/machineLearning/machineLearningService.ts

@@ -1,11 +1,13 @@
 import log from "@/next/log";
 import { APPS } from "@ente/shared/apps/constants";
+import ComlinkCryptoWorker from "@ente/shared/crypto";
 import { CustomError, parseUploadErrorCodes } from "@ente/shared/error";
 import "@tensorflow/tfjs-backend-cpu";
 import "@tensorflow/tfjs-backend-webgl";
 import * as tf from "@tensorflow/tfjs-core";
 import { MAX_ML_SYNC_ERROR_COUNT } from "constants/mlConfig";
 import downloadManager from "services/download";
+import { putEmbedding } from "services/embeddingService";
 import { getLocalFiles } from "services/fileService";
 import { EnteFile } from "types/file";
 import {
@@ -15,6 +17,7 @@ import {
     MlFileData,
 } from "types/machineLearning";
 import { getMLSyncConfig } from "utils/machineLearning/config";
+import { LocalFileMlDataToServerFileMl } from "utils/machineLearning/mldataMappers";
 import mlIDbStorage from "utils/storage/mlIDbStorage";
 import FaceService from "./faceService";
 import { MLFactory } from "./machineLearningFactory";
@@ -215,13 +218,13 @@ class MachineLearningService {
             syncContext,
             [...existingFilesMap.values()],
         );
-        // addLogLine("getUniqueOutOfSyncFiles");
-        // addLogLine(
-        //     "Got unique outOfSyncFiles: ",
-        //     syncContext.outOfSyncFiles.length,
-        //     "for batchSize: ",
-        //     syncContext.config.batchSize,
-        // );
+        log.info("getUniqueOutOfSyncFiles");
+        log.info(
+            "Got unique outOfSyncFiles: ",
+            syncContext.outOfSyncFiles.length,
+            "for batchSize: ",
+            syncContext.config.batchSize,
+        );
     }
 
     private async syncFiles(syncContext: MLSyncContext) {
@@ -415,6 +418,7 @@ class MachineLearningService {
             ]);
             newMlFile.errorCount = 0;
             newMlFile.lastErrorMessage = undefined;
+            await this.persistOnServer(newMlFile, enteFile);
             await this.persistMLFileData(syncContext, newMlFile);
         } catch (e) {
             log.error("ml detection failed", e);
@@ -435,6 +439,25 @@ class MachineLearningService {
         return newMlFile;
     }
 
+    private async persistOnServer(mlFileData: MlFileData, enteFile: EnteFile) {
+        const serverMl = LocalFileMlDataToServerFileMl(mlFileData);
+        log.info(mlFileData);
+
+        const comlinkCryptoWorker = await ComlinkCryptoWorker.getInstance();
+        const { file: encryptedEmbeddingData } =
+            await comlinkCryptoWorker.encryptMetadata(serverMl, enteFile.key);
+        log.info(
+            `putEmbedding embedding to server for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
+        );
+        const res = await putEmbedding({
+            fileID: enteFile.id,
+            encryptedEmbedding: encryptedEmbeddingData.encryptedData,
+            decryptionHeader: encryptedEmbeddingData.decryptionHeader,
+            model: "file-ml-clip-face",
+        });
+        log.info("putEmbedding response: ", res);
+    }
+
     public async init() {
         if (this.initialized) {
             return;

+ 2 - 2
web/apps/photos/src/types/embedding.tsx

@@ -5,7 +5,7 @@
  * embeddings on the server. However, we should be prepared to receive an
  * {@link EncryptedEmbedding} with a model value distinct from one of these.
  */
-export type EmbeddingModel = "onnx-clip";
+export type EmbeddingModel = "onnx-clip" | "file-ml-clip-face";
 
 export interface EncryptedEmbedding {
     fileID: number;
@@ -21,7 +21,7 @@ export interface Embedding
         EncryptedEmbedding,
         "encryptedEmbedding" | "decryptionHeader"
     > {
-    embedding: Float32Array;
+    embedding?: Float32Array;
 }
 
 export interface GetEmbeddingDiffResponse {

+ 18 - 0
web/apps/photos/src/utils/embedding.ts

@@ -1,4 +1,5 @@
 import { Embedding } from "types/embedding";
+import { FileML } from "./machineLearning/mldataMappers";
 
 export const getLatestVersionEmbeddings = (embeddings: Embedding[]) => {
     const latestVersionEntities = new Map<number, Embedding>();
@@ -16,3 +17,20 @@ export const getLatestVersionEmbeddings = (embeddings: Embedding[]) => {
     });
     return Array.from(latestVersionEntities.values());
 };
+
+export const getLatestVersionFileEmbeddings = (embeddings: FileML[]) => {
+    const latestVersionEntities = new Map<number, FileML>();
+    embeddings.forEach((embedding) => {
+        if (!embedding?.fileID) {
+            return;
+        }
+        const existingEmbeddings = latestVersionEntities.get(embedding.fileID);
+        if (
+            !existingEmbeddings ||
+            existingEmbeddings.updatedAt < embedding.updatedAt
+        ) {
+            latestVersionEntities.set(embedding.fileID, embedding);
+        }
+    });
+    return Array.from(latestVersionEntities.values());
+};

+ 265 - 0
web/apps/photos/src/utils/machineLearning/mldataMappers.ts

@@ -0,0 +1,265 @@
+import {
+    Face,
+    FaceDetection,
+    Landmark,
+    MlFileData,
+} from "types/machineLearning";
+import { ClipEmbedding } from "types/machineLearning/data/clip";
+
+export interface FileML extends ServerFileMl {
+    updatedAt: number;
+}
+
+class ServerFileMl {
+    public fileID: number;
+    public height?: number;
+    public width?: number;
+    public faceEmbedding: ServerFaceEmbeddings;
+    public clipEmbedding?: ClipEmbedding;
+
+    public constructor(
+        fileID: number,
+        faceEmbedding: ServerFaceEmbeddings,
+        clipEmbedding?: ClipEmbedding,
+        height?: number,
+        width?: number,
+    ) {
+        this.fileID = fileID;
+        this.height = height;
+        this.width = width;
+        this.faceEmbedding = faceEmbedding;
+        this.clipEmbedding = clipEmbedding;
+    }
+
+    toJson(): string {
+        return JSON.stringify(this);
+    }
+
+    static fromJson(json: string): ServerFileMl {
+        return JSON.parse(json);
+    }
+}
+
+class ServerFaceEmbeddings {
+    public faces: ServerFace[];
+    public version: number;
+    public client?: string;
+    public error?: boolean;
+
+    public constructor(
+        faces: ServerFace[],
+        version: number,
+        client?: string,
+        error?: boolean,
+    ) {
+        this.faces = faces;
+        this.version = version;
+        this.client = client;
+        this.error = error;
+    }
+
+    toJson(): string {
+        return JSON.stringify(this);
+    }
+
+    static fromJson(json: string): ServerFaceEmbeddings {
+        return JSON.parse(json);
+    }
+}
+
+class ServerFace {
+    public fileID: number;
+    public faceID: string;
+    public embeddings: number[];
+    public detection: ServerDetection;
+    public score: number;
+    public blur: number;
+    public fileInfo?: ServerFileInfo;
+
+    public constructor(
+        fileID: number,
+        faceID: string,
+        embeddings: number[],
+        detection: ServerDetection,
+        score: number,
+        blur: number,
+        fileInfo?: ServerFileInfo,
+    ) {
+        this.fileID = fileID;
+        this.faceID = faceID;
+        this.embeddings = embeddings;
+        this.detection = detection;
+        this.score = score;
+        this.blur = blur;
+        this.fileInfo = fileInfo;
+    }
+
+    toJson(): string {
+        return JSON.stringify(this);
+    }
+
+    static fromJson(json: string): ServerFace {
+        return JSON.parse(json);
+    }
+}
+
+class ServerFileInfo {
+    public imageWidth?: number;
+    public imageHeight?: number;
+
+    public constructor(imageWidth?: number, imageHeight?: number) {
+        this.imageWidth = imageWidth;
+        this.imageHeight = imageHeight;
+    }
+}
+
+class ServerDetection {
+    public box: ServerFaceBox;
+    public landmarks: Landmark[];
+
+    public constructor(box: ServerFaceBox, landmarks: Landmark[]) {
+        this.box = box;
+        this.landmarks = landmarks;
+    }
+
+    toJson(): string {
+        return JSON.stringify(this);
+    }
+
+    static fromJson(json: string): ServerDetection {
+        return JSON.parse(json);
+    }
+}
+
+class ServerFaceBox {
+    public xMin: number;
+    public yMin: number;
+    public width: number;
+    public height: number;
+
+    public constructor(
+        xMin: number,
+        yMin: number,
+        width: number,
+        height: number,
+    ) {
+        this.xMin = xMin;
+        this.yMin = yMin;
+        this.width = width;
+        this.height = height;
+    }
+
+    toJson(): string {
+        return JSON.stringify(this);
+    }
+
+    static fromJson(json: string): ServerFaceBox {
+        return JSON.parse(json);
+    }
+}
+
+export function LocalFileMlDataToServerFileMl(
+    localFileMlData: MlFileData,
+): ServerFileMl {
+    if (
+        localFileMlData.errorCount > 0 &&
+        localFileMlData.lastErrorMessage !== undefined
+    ) {
+        return null;
+    }
+    const imageDimensions = localFileMlData.imageDimensions;
+    const fileInfo = new ServerFileInfo(
+        imageDimensions.width,
+        imageDimensions.height,
+    );
+    const faces: ServerFace[] = [];
+    for (let i = 0; i < localFileMlData.faces.length; i++) {
+        const face: Face = localFileMlData.faces[i];
+        const faceID = face.id;
+        const embedding = face.embedding;
+        const score = face.detection.probability;
+        const blur = face.blurValue;
+        const detection: FaceDetection = face.detection;
+        const box = detection.box;
+        const landmarks = detection.landmarks;
+        const newBox = new ServerFaceBox(box.x, box.y, box.width, box.height);
+        const newLandmarks: Landmark[] = [];
+        for (let j = 0; j < landmarks.length; j++) {
+            newLandmarks.push({
+                x: landmarks[j].x,
+                y: landmarks[j].y,
+            } as Landmark);
+        }
+
+        const newFaceObject = new ServerFace(
+            localFileMlData.fileId,
+            faceID,
+            Array.from(embedding),
+            new ServerDetection(newBox, newLandmarks),
+            score,
+            blur,
+            fileInfo,
+        );
+        faces.push(newFaceObject);
+    }
+    const faceEmbeddings = new ServerFaceEmbeddings(
+        faces,
+        1,
+        localFileMlData.lastErrorMessage,
+    );
+    return new ServerFileMl(
+        localFileMlData.fileId,
+        faceEmbeddings,
+        null,
+        imageDimensions.height,
+        imageDimensions.width,
+    );
+}
+
+// // Not sure if this actually works
+// export function ServerFileMlToLocalFileMlData(
+//     serverFileMl: ServerFileMl,
+// ): MlFileData {
+//     const faces: Face[] = [];
+//     const mlVersion: number = serverFileMl.faceEmbeddings.version;
+//     const errorCount = serverFileMl.faceEmbeddings.error ? 1 : 0;
+//     for (let i = 0; i < serverFileMl.faceEmbeddings.faces.length; i++) {
+//         const face = serverFileMl.faceEmbeddings.faces[i];
+//         if(face.detection.landmarks.length === 0) {
+//             continue;
+//         }
+//         const detection = face.detection;
+//         const box = detection.box;
+//         const landmarks = detection.landmarks;
+//         const newBox = new FaceBox(
+//             box.xMin,
+//             box.yMin,
+//             box.width,
+//             box.height,
+//         );
+//         const newLandmarks: Landmark[] = [];
+//         for (let j = 0; j < landmarks.length; j++) {
+//             newLandmarks.push(
+//                 {
+//                  x:   landmarks[j].x,
+//                 y: landmarks[j].y,
+//         } as Landmark
+//             );
+//         }
+//         const newDetection = new Detection(newBox, newLandmarks);
+//         const newFace = {
+
+//         } as Face
+//         faces.push(newFace);
+//     }
+//     return {
+//         fileId: serverFileMl.fileID,
+//         imageDimensions: {
+//             width: serverFileMl.width,
+//             height: serverFileMl.height,
+//         },
+//         faces,
+//         mlVersion,
+//         errorCount,
+//     };
+// }