[web] Import the scaffolding to sync face embeddings from web_face_v2 (#1402)

This PR cherry picks Neeraj's ML related changes from the web_face_v2
branch.

Similar to https://github.com/ente-io/ente/pull/1399, this gets us one
step closer to integrating ONNX-YOLO with our desktop app. But it is not
currently in a usable state (The web app's functionality remains
untouched).
This commit is contained in:
Manav Rathi 2024-04-10 16:53:48 +05:30 committed by GitHub
commit 4a69e9260c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 415 additions and 13 deletions

View file

@ -105,7 +105,7 @@ import { AppContext } from "pages/_app";
import { clipService } from "services/clip-service";
import { constructUserIDToEmailMap } from "services/collectionService";
import downloadManager from "services/download";
import { syncEmbeddings } from "services/embeddingService";
import { syncEmbeddings, syncFileEmbeddings } from "services/embeddingService";
import { syncEntities } from "services/entityService";
import locationSearchService from "services/locationSearchService";
import { getLocalTrashedFiles, syncTrash } from "services/trashService";
@ -702,6 +702,10 @@ export default function Gallery() {
await syncEntities();
await syncMapEnabled();
await syncEmbeddings();
const electron = globalThis.electron;
if (electron) {
await syncFileEmbeddings();
}
if (clipService.isPlatformSupported()) {
void clipService.scheduleImageEmbeddingExtraction();
}

View file

@ -13,7 +13,11 @@ import type {
PutEmbeddingRequest,
} from "types/embedding";
import { EnteFile } from "types/file";
import { getLatestVersionEmbeddings } from "utils/embedding";
import {
getLatestVersionEmbeddings,
getLatestVersionFileEmbeddings,
} from "utils/embedding";
import { FileML } from "utils/machineLearning/mldataMappers";
import { getLocalCollections } from "./collectionService";
import { getAllLocalFiles } from "./fileService";
import { getLocalTrashedFiles } from "./trashService";
@ -24,6 +28,7 @@ const DIFF_LIMIT = 500;
const EMBEDDINGS_TABLE_V1 = "embeddings";
const EMBEDDINGS_TABLE = "embeddings_v2";
const FILE_EMBEDING_TABLE = "file_embeddings";
const EMBEDDING_SYNC_TIME_TABLE = "embedding_sync_time";
export const getAllLocalEmbeddings = async () => {
@ -38,6 +43,15 @@ export const getAllLocalEmbeddings = async () => {
return embeddings;
};
export const getFileMLEmbeddings = async (): Promise<FileML[]> => {
const embeddings: Array<FileML> =
await localForage.getItem<FileML[]>(FILE_EMBEDING_TABLE);
if (!embeddings) {
return [];
}
return embeddings;
};
export const getLocalEmbeddings = async () => {
const embeddings = await getAllLocalEmbeddings();
return embeddings.filter((embedding) => embedding.model === "onnx-clip");
@ -140,6 +154,83 @@ export const syncEmbeddings = async () => {
}
};
export const syncFileEmbeddings = async () => {
const models: EmbeddingModel[] = ["file-ml-clip-face"];
try {
let allEmbeddings: FileML[] = await getFileMLEmbeddings();
const localFiles = await getAllLocalFiles();
const hiddenAlbums = await getLocalCollections("hidden");
const localTrashFiles = await getLocalTrashedFiles();
const fileIdToKeyMap = new Map<number, string>();
const allLocalFiles = [...localFiles, ...localTrashFiles];
allLocalFiles.forEach((file) => {
fileIdToKeyMap.set(file.id, file.key);
});
await cleanupDeletedEmbeddings(allLocalFiles, allEmbeddings);
log.info(`Syncing embeddings localCount: ${allEmbeddings.length}`);
for (const model of models) {
let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
log.info(
`Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`,
);
let response: GetEmbeddingDiffResponse;
do {
response = await getEmbeddingsDiff(modelLastSinceTime, model);
if (!response.diff?.length) {
return;
}
const newEmbeddings = await Promise.all(
response.diff.map(async (embedding) => {
try {
const worker =
await ComlinkCryptoWorker.getInstance();
const fileKey = fileIdToKeyMap.get(
embedding.fileID,
);
if (!fileKey) {
throw Error(CustomError.FILE_NOT_FOUND);
}
const decryptedData = await worker.decryptMetadata(
embedding.encryptedEmbedding,
embedding.decryptionHeader,
fileIdToKeyMap.get(embedding.fileID),
);
return {
...decryptedData,
updatedAt: embedding.updatedAt,
} as unknown as FileML;
} catch (e) {
let hasHiddenAlbums = false;
if (e.message === CustomError.FILE_NOT_FOUND) {
hasHiddenAlbums = hiddenAlbums?.length > 0;
}
log.error(
`decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`,
e,
);
}
}),
);
allEmbeddings = getLatestVersionFileEmbeddings([
...allEmbeddings,
...newEmbeddings,
]);
if (response.diff.length) {
modelLastSinceTime = response.diff.slice(-1)[0].updatedAt;
}
await localForage.setItem(FILE_EMBEDING_TABLE, allEmbeddings);
await setModelEmbeddingSyncTime(model, modelLastSinceTime);
log.info(
`Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`,
);
} while (response.diff.length === DIFF_LIMIT);
}
} catch (e) {
log.error("Sync embeddings failed", e);
}
};
export const getEmbeddingsDiff = async (
sinceTime: number,
model: EmbeddingModel,
@ -173,7 +264,8 @@ export const putEmbedding = async (
try {
const token = getToken();
if (!token) {
return;
log.info("putEmbedding failed: token not found");
throw Error(CustomError.TOKEN_MISSING);
}
const resp = await HTTPService.put(
`${ENDPOINT}/embeddings`,
@ -192,7 +284,7 @@ export const putEmbedding = async (
export const cleanupDeletedEmbeddings = async (
allLocalFiles: EnteFile[],
allLocalEmbeddings: Embedding[],
allLocalEmbeddings: Embedding[] | FileML[],
) => {
const activeFileIds = new Set<number>();
allLocalFiles.forEach((file) => {

View file

@ -1,11 +1,13 @@
import log from "@/next/log";
import { APPS } from "@ente/shared/apps/constants";
import ComlinkCryptoWorker from "@ente/shared/crypto";
import { CustomError, parseUploadErrorCodes } from "@ente/shared/error";
import "@tensorflow/tfjs-backend-cpu";
import "@tensorflow/tfjs-backend-webgl";
import * as tf from "@tensorflow/tfjs-core";
import { MAX_ML_SYNC_ERROR_COUNT } from "constants/mlConfig";
import downloadManager from "services/download";
import { putEmbedding } from "services/embeddingService";
import { getLocalFiles } from "services/fileService";
import { EnteFile } from "types/file";
import {
@ -15,6 +17,7 @@ import {
MlFileData,
} from "types/machineLearning";
import { getMLSyncConfig } from "utils/machineLearning/config";
import { LocalFileMlDataToServerFileMl } from "utils/machineLearning/mldataMappers";
import mlIDbStorage from "utils/storage/mlIDbStorage";
import FaceService from "./faceService";
import { MLFactory } from "./machineLearningFactory";
@ -215,13 +218,13 @@ class MachineLearningService {
syncContext,
[...existingFilesMap.values()],
);
// addLogLine("getUniqueOutOfSyncFiles");
// addLogLine(
// "Got unique outOfSyncFiles: ",
// syncContext.outOfSyncFiles.length,
// "for batchSize: ",
// syncContext.config.batchSize,
// );
log.info("getUniqueOutOfSyncFiles");
log.info(
"Got unique outOfSyncFiles: ",
syncContext.outOfSyncFiles.length,
"for batchSize: ",
syncContext.config.batchSize,
);
}
private async syncFiles(syncContext: MLSyncContext) {
@ -415,6 +418,7 @@ class MachineLearningService {
]);
newMlFile.errorCount = 0;
newMlFile.lastErrorMessage = undefined;
await this.persistOnServer(newMlFile, enteFile);
await this.persistMLFileData(syncContext, newMlFile);
} catch (e) {
log.error("ml detection failed", e);
@ -435,6 +439,25 @@ class MachineLearningService {
return newMlFile;
}
private async persistOnServer(mlFileData: MlFileData, enteFile: EnteFile) {
const serverMl = LocalFileMlDataToServerFileMl(mlFileData);
log.info(mlFileData);
const comlinkCryptoWorker = await ComlinkCryptoWorker.getInstance();
const { file: encryptedEmbeddingData } =
await comlinkCryptoWorker.encryptMetadata(serverMl, enteFile.key);
log.info(
`putEmbedding embedding to server for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
);
const res = await putEmbedding({
fileID: enteFile.id,
encryptedEmbedding: encryptedEmbeddingData.encryptedData,
decryptionHeader: encryptedEmbeddingData.decryptionHeader,
model: "file-ml-clip-face",
});
log.info("putEmbedding response: ", res);
}
public async init() {
if (this.initialized) {
return;

View file

@ -5,7 +5,7 @@
* embeddings on the server. However, we should be prepared to receive an
* {@link EncryptedEmbedding} with a model value distinct from one of these.
*/
export type EmbeddingModel = "onnx-clip";
export type EmbeddingModel = "onnx-clip" | "file-ml-clip-face";
export interface EncryptedEmbedding {
fileID: number;
@ -21,7 +21,7 @@ export interface Embedding
EncryptedEmbedding,
"encryptedEmbedding" | "decryptionHeader"
> {
embedding: Float32Array;
embedding?: Float32Array;
}
export interface GetEmbeddingDiffResponse {

View file

@ -1,4 +1,5 @@
import { Embedding } from "types/embedding";
import { FileML } from "./machineLearning/mldataMappers";
export const getLatestVersionEmbeddings = (embeddings: Embedding[]) => {
const latestVersionEntities = new Map<number, Embedding>();
@ -16,3 +17,20 @@ export const getLatestVersionEmbeddings = (embeddings: Embedding[]) => {
});
return Array.from(latestVersionEntities.values());
};
export const getLatestVersionFileEmbeddings = (embeddings: FileML[]) => {
const latestVersionEntities = new Map<number, FileML>();
embeddings.forEach((embedding) => {
if (!embedding?.fileID) {
return;
}
const existingEmbeddings = latestVersionEntities.get(embedding.fileID);
if (
!existingEmbeddings ||
existingEmbeddings.updatedAt < embedding.updatedAt
) {
latestVersionEntities.set(embedding.fileID, embedding);
}
});
return Array.from(latestVersionEntities.values());
};

View file

@ -0,0 +1,265 @@
import {
Face,
FaceDetection,
Landmark,
MlFileData,
} from "types/machineLearning";
import { ClipEmbedding } from "types/machineLearning/data/clip";
export interface FileML extends ServerFileMl {
updatedAt: number;
}
class ServerFileMl {
public fileID: number;
public height?: number;
public width?: number;
public faceEmbedding: ServerFaceEmbeddings;
public clipEmbedding?: ClipEmbedding;
public constructor(
fileID: number,
faceEmbedding: ServerFaceEmbeddings,
clipEmbedding?: ClipEmbedding,
height?: number,
width?: number,
) {
this.fileID = fileID;
this.height = height;
this.width = width;
this.faceEmbedding = faceEmbedding;
this.clipEmbedding = clipEmbedding;
}
toJson(): string {
return JSON.stringify(this);
}
static fromJson(json: string): ServerFileMl {
return JSON.parse(json);
}
}
class ServerFaceEmbeddings {
public faces: ServerFace[];
public version: number;
public client?: string;
public error?: boolean;
public constructor(
faces: ServerFace[],
version: number,
client?: string,
error?: boolean,
) {
this.faces = faces;
this.version = version;
this.client = client;
this.error = error;
}
toJson(): string {
return JSON.stringify(this);
}
static fromJson(json: string): ServerFaceEmbeddings {
return JSON.parse(json);
}
}
class ServerFace {
public fileID: number;
public faceID: string;
public embeddings: number[];
public detection: ServerDetection;
public score: number;
public blur: number;
public fileInfo?: ServerFileInfo;
public constructor(
fileID: number,
faceID: string,
embeddings: number[],
detection: ServerDetection,
score: number,
blur: number,
fileInfo?: ServerFileInfo,
) {
this.fileID = fileID;
this.faceID = faceID;
this.embeddings = embeddings;
this.detection = detection;
this.score = score;
this.blur = blur;
this.fileInfo = fileInfo;
}
toJson(): string {
return JSON.stringify(this);
}
static fromJson(json: string): ServerFace {
return JSON.parse(json);
}
}
class ServerFileInfo {
public imageWidth?: number;
public imageHeight?: number;
public constructor(imageWidth?: number, imageHeight?: number) {
this.imageWidth = imageWidth;
this.imageHeight = imageHeight;
}
}
class ServerDetection {
public box: ServerFaceBox;
public landmarks: Landmark[];
public constructor(box: ServerFaceBox, landmarks: Landmark[]) {
this.box = box;
this.landmarks = landmarks;
}
toJson(): string {
return JSON.stringify(this);
}
static fromJson(json: string): ServerDetection {
return JSON.parse(json);
}
}
class ServerFaceBox {
public xMin: number;
public yMin: number;
public width: number;
public height: number;
public constructor(
xMin: number,
yMin: number,
width: number,
height: number,
) {
this.xMin = xMin;
this.yMin = yMin;
this.width = width;
this.height = height;
}
toJson(): string {
return JSON.stringify(this);
}
static fromJson(json: string): ServerFaceBox {
return JSON.parse(json);
}
}
export function LocalFileMlDataToServerFileMl(
localFileMlData: MlFileData,
): ServerFileMl {
if (
localFileMlData.errorCount > 0 &&
localFileMlData.lastErrorMessage !== undefined
) {
return null;
}
const imageDimensions = localFileMlData.imageDimensions;
const fileInfo = new ServerFileInfo(
imageDimensions.width,
imageDimensions.height,
);
const faces: ServerFace[] = [];
for (let i = 0; i < localFileMlData.faces.length; i++) {
const face: Face = localFileMlData.faces[i];
const faceID = face.id;
const embedding = face.embedding;
const score = face.detection.probability;
const blur = face.blurValue;
const detection: FaceDetection = face.detection;
const box = detection.box;
const landmarks = detection.landmarks;
const newBox = new ServerFaceBox(box.x, box.y, box.width, box.height);
const newLandmarks: Landmark[] = [];
for (let j = 0; j < landmarks.length; j++) {
newLandmarks.push({
x: landmarks[j].x,
y: landmarks[j].y,
} as Landmark);
}
const newFaceObject = new ServerFace(
localFileMlData.fileId,
faceID,
Array.from(embedding),
new ServerDetection(newBox, newLandmarks),
score,
blur,
fileInfo,
);
faces.push(newFaceObject);
}
const faceEmbeddings = new ServerFaceEmbeddings(
faces,
1,
localFileMlData.lastErrorMessage,
);
return new ServerFileMl(
localFileMlData.fileId,
faceEmbeddings,
null,
imageDimensions.height,
imageDimensions.width,
);
}
// // Not sure if this actually works
// export function ServerFileMlToLocalFileMlData(
// serverFileMl: ServerFileMl,
// ): MlFileData {
// const faces: Face[] = [];
// const mlVersion: number = serverFileMl.faceEmbeddings.version;
// const errorCount = serverFileMl.faceEmbeddings.error ? 1 : 0;
// for (let i = 0; i < serverFileMl.faceEmbeddings.faces.length; i++) {
// const face = serverFileMl.faceEmbeddings.faces[i];
// if(face.detection.landmarks.length === 0) {
// continue;
// }
// const detection = face.detection;
// const box = detection.box;
// const landmarks = detection.landmarks;
// const newBox = new FaceBox(
// box.xMin,
// box.yMin,
// box.width,
// box.height,
// );
// const newLandmarks: Landmark[] = [];
// for (let j = 0; j < landmarks.length; j++) {
// newLandmarks.push(
// {
// x: landmarks[j].x,
// y: landmarks[j].y,
// } as Landmark
// );
// }
// const newDetection = new Detection(newBox, newLandmarks);
// const newFace = {
// } as Face
// faces.push(newFace);
// }
// return {
// fileId: serverFileMl.fileID,
// imageDimensions: {
// width: serverFileMl.width,
// height: serverFileMl.height,
// },
// faces,
// mlVersion,
// errorCount,
// };
// }