[web] Import the scaffolding to sync face embeddings from web_face_v2

This PR cherry picks Neeraj's ML related changes from the web_face_v2 branch.

Similar to https://github.com/ente-io/ente/pull/1399, this gets us one step
closer to integrating ONNX-YOLO with our desktop app. But it is not currently in
a usable state (The web app's functionality remains untouched).
This commit is contained in:
Neeraj Gupta 2024-04-10 16:38:12 +05:30 committed by Manav Rathi
parent 334fd61ea3
commit eefac7fd01
No known key found for this signature in database
6 changed files with 417 additions and 16 deletions

View file

@ -105,7 +105,7 @@ import { AppContext } from "pages/_app";
import { clipService } from "services/clip-service";
import { constructUserIDToEmailMap } from "services/collectionService";
import downloadManager from "services/download";
import { syncEmbeddings } from "services/embeddingService";
import { syncEmbeddings, syncFileEmbeddings } from "services/embeddingService";
import { syncEntities } from "services/entityService";
import locationSearchService from "services/locationSearchService";
import { getLocalTrashedFiles, syncTrash } from "services/trashService";
@ -702,6 +702,10 @@ export default function Gallery() {
await syncEntities();
await syncMapEnabled();
await syncEmbeddings();
const electron = globalThis.electron;
if (electron) {
await syncFileEmbeddings();
}
if (clipService.isPlatformSupported()) {
void clipService.scheduleImageEmbeddingExtraction();
}

View file

@ -13,10 +13,11 @@ import type {
PutEmbeddingRequest,
} from "types/embedding";
import { EnteFile } from "types/file";
import { getLatestVersionEmbeddings } from "utils/embedding";
import { getLatestVersionEmbeddings,getLatestVersionFileEmbeddings } from "utils/embedding";
import { getLocalCollections } from "./collectionService";
import { getAllLocalFiles } from "./fileService";
import { getLocalTrashedFiles } from "./trashService";
import { FileML } from "utils/machineLearning/mldataMappers";
const ENDPOINT = getEndpoint();
@ -24,6 +25,7 @@ const DIFF_LIMIT = 500;
const EMBEDDINGS_TABLE_V1 = "embeddings";
const EMBEDDINGS_TABLE = "embeddings_v2";
const FILE_EMBEDING_TABLE = "file_embeddings";
const EMBEDDING_SYNC_TIME_TABLE = "embedding_sync_time";
export const getAllLocalEmbeddings = async () => {
@ -38,6 +40,15 @@ export const getAllLocalEmbeddings = async () => {
return embeddings;
};
export const getFileMLEmbeddings = async (): Promise<FileML[]> => {
const embeddings: Array<FileML> =
await localForage.getItem<FileML[]>(FILE_EMBEDING_TABLE);
if (!embeddings) {
return [];
}
return embeddings;
};
export const getLocalEmbeddings = async () => {
const embeddings = await getAllLocalEmbeddings();
return embeddings.filter((embedding) => embedding.model === "onnx-clip");
@ -140,6 +151,85 @@ export const syncEmbeddings = async () => {
}
};
export const syncFileEmbeddings = async () => {
const models: EmbeddingModel[] = ["file-ml-clip-face"];
try {
let allEmbeddings :FileML[] = await getFileMLEmbeddings();
const localFiles = await getAllLocalFiles();
const hiddenAlbums = await getLocalCollections("hidden");
const localTrashFiles = await getLocalTrashedFiles();
const fileIdToKeyMap = new Map<number, string>();
const allLocalFiles = [...localFiles, ...localTrashFiles];
allLocalFiles.forEach((file) => {
fileIdToKeyMap.set(file.id, file.key);
});
await cleanupDeletedEmbeddings(allLocalFiles, allEmbeddings);
log.info(`Syncing embeddings localCount: ${allEmbeddings.length}`);
for (const model of models) {
let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
log.info(
`Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`,
);
let response: GetEmbeddingDiffResponse;
do {
response = await getEmbeddingsDiff(modelLastSinceTime, model);
if (!response.diff?.length) {
return;
}
const newEmbeddings = await Promise.all(
response.diff.map(async (embedding) => {
try {
const worker =
await ComlinkCryptoWorker.getInstance();
const fileKey = fileIdToKeyMap.get(
embedding.fileID,
);
if (!fileKey) {
throw Error(CustomError.FILE_NOT_FOUND);
}
const decryptedData = await worker.decryptMetadata(
embedding.encryptedEmbedding,
embedding.decryptionHeader,
fileIdToKeyMap.get(embedding.fileID),
);
return {
...decryptedData,
updatedAt: embedding.updatedAt
} as unknown as FileML;
} catch (e) {
let hasHiddenAlbums = false;
if (e.message === CustomError.FILE_NOT_FOUND) {
hasHiddenAlbums = hiddenAlbums?.length > 0;
}
log.error(
`decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`,
e,
);
}
}),
);
allEmbeddings = getLatestVersionFileEmbeddings([
...allEmbeddings,
...newEmbeddings,
]);
if (response.diff.length) {
modelLastSinceTime = response.diff.slice(-1)[0].updatedAt;
}
await localForage.setItem(FILE_EMBEDING_TABLE, allEmbeddings);
await setModelEmbeddingSyncTime(model, modelLastSinceTime);
log.info(
`Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`,
);
} while (response.diff.length === DIFF_LIMIT);
}
} catch (e) {
log.error("Sync embeddings failed", e);
}
};
export const getEmbeddingsDiff = async (
sinceTime: number,
model: EmbeddingModel,
@ -173,7 +263,8 @@ export const putEmbedding = async (
try {
const token = getToken();
if (!token) {
return;
log.info('putEmbedding failed: token not found');
throw Error(CustomError.TOKEN_MISSING);
}
const resp = await HTTPService.put(
`${ENDPOINT}/embeddings`,
@ -192,7 +283,7 @@ export const putEmbedding = async (
export const cleanupDeletedEmbeddings = async (
allLocalFiles: EnteFile[],
allLocalEmbeddings: Embedding[],
allLocalEmbeddings: Embedding[] | FileML[],
) => {
const activeFileIds = new Set<number>();
allLocalFiles.forEach((file) => {

View file

@ -1,6 +1,7 @@
import log from "@/next/log";
import { APPS } from "@ente/shared/apps/constants";
import { CustomError, parseUploadErrorCodes } from "@ente/shared/error";
import ComlinkCryptoWorker from "@ente/shared/crypto";
import "@tensorflow/tfjs-backend-cpu";
import "@tensorflow/tfjs-backend-webgl";
import * as tf from "@tensorflow/tfjs-core";
@ -21,6 +22,8 @@ import { MLFactory } from "./machineLearningFactory";
import ObjectService from "./objectService";
import PeopleService from "./peopleService";
import ReaderService from "./readerService";
import { LocalFileMlDataToServerFileMl } from "utils/machineLearning/mldataMappers";
import { putEmbedding } from "services/embeddingService";
class MachineLearningService {
private initialized = false;
@ -162,7 +165,7 @@ class MachineLearningService {
log.info("syncLocalFiles", Date.now() - startTime, "ms");
}
private async getOutOfSyncFiles(syncContext: MLSyncContext) {
const startTime = Date.now();
const fileIds = await mlIDbStorage.getFileIds(
@ -210,18 +213,19 @@ class MachineLearningService {
// existingFiles.sort(
// (a, b) => b.metadata.creationTime - a.metadata.creationTime
// );
console.time("getUniqueOutOfSyncFiles");
console.time('getUniqueOutOfSyncFiles');
syncContext.outOfSyncFiles = await this.getUniqueOutOfSyncFilesNoIdx(
syncContext,
[...existingFilesMap.values()],
[...existingFilesMap.values()]
);
// addLogLine("getUniqueOutOfSyncFiles");
// addLogLine(
// "Got unique outOfSyncFiles: ",
// syncContext.outOfSyncFiles.length,
// "for batchSize: ",
// syncContext.config.batchSize,
// );
log.info('getUniqueOutOfSyncFiles');
log.info(
'Got unique outOfSyncFiles: ',
syncContext.outOfSyncFiles.length,
'for batchSize: ',
syncContext.config.batchSize
);
}
private async syncFiles(syncContext: MLSyncContext) {
@ -415,6 +419,7 @@ class MachineLearningService {
]);
newMlFile.errorCount = 0;
newMlFile.lastErrorMessage = undefined;
await this.persistOnServer(newMlFile, enteFile);
await this.persistMLFileData(syncContext, newMlFile);
} catch (e) {
log.error("ml detection failed", e);
@ -435,6 +440,25 @@ class MachineLearningService {
return newMlFile;
}
private async persistOnServer(mlFileData: MlFileData, enteFile: EnteFile) {
const serverMl = LocalFileMlDataToServerFileMl(mlFileData)
log.info(mlFileData);
const comlinkCryptoWorker = await ComlinkCryptoWorker.getInstance();
const { file: encryptedEmbeddingData } =
await comlinkCryptoWorker.encryptMetadata(serverMl, enteFile.key);
log.info(
`putEmbedding embedding to server for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
);
const res = await putEmbedding({
fileID: enteFile.id,
encryptedEmbedding: encryptedEmbeddingData.encryptedData,
decryptionHeader: encryptedEmbeddingData.decryptionHeader,
model: "file-ml-clip-face",
});
log.info("putEmbedding response: ", res);
}
public async init() {
if (this.initialized) {
return;

View file

@ -5,7 +5,7 @@
* embeddings on the server. However, we should be prepared to receive an
* {@link EncryptedEmbedding} with a model value distinct from one of these.
*/
export type EmbeddingModel = "onnx-clip";
export type EmbeddingModel = "onnx-clip" |"file-ml-clip-face";
export interface EncryptedEmbedding {
fileID: number;
@ -21,9 +21,10 @@ export interface Embedding
EncryptedEmbedding,
"encryptedEmbedding" | "decryptionHeader"
> {
embedding: Float32Array;
embedding?: Float32Array;
}
export interface GetEmbeddingDiffResponse {
diff: EncryptedEmbedding[];
}

View file

@ -1,4 +1,5 @@
import { Embedding } from "types/embedding";
import { FileML } from "./machineLearning/mldataMappers";
export const getLatestVersionEmbeddings = (embeddings: Embedding[]) => {
const latestVersionEntities = new Map<number, Embedding>();
@ -16,3 +17,20 @@ export const getLatestVersionEmbeddings = (embeddings: Embedding[]) => {
});
return Array.from(latestVersionEntities.values());
};
export const getLatestVersionFileEmbeddings = (embeddings: FileML[]) => {
const latestVersionEntities = new Map<number, FileML>();
embeddings.forEach((embedding) => {
if (!embedding?.fileID) {
return;
}
const existingEmbeddings = latestVersionEntities.get(embedding.fileID);
if (
!existingEmbeddings ||
existingEmbeddings.updatedAt < embedding.updatedAt
) {
latestVersionEntities.set(embedding.fileID, embedding);
}
});
return Array.from(latestVersionEntities.values());
};

View file

@ -0,0 +1,263 @@
import { Face, FaceDetection, Landmark, MlFileData } from "types/machineLearning";
import { ClipEmbedding } from "types/machineLearning/data/clip";
export interface FileML extends ServerFileMl {
updatedAt: number;
}
class ServerFileMl {
public fileID: number;
public height?: number;
public width?: number;
public faceEmbedding: ServerFaceEmbeddings;
public clipEmbedding?: ClipEmbedding;
public constructor(
fileID: number,
faceEmbedding: ServerFaceEmbeddings,
clipEmbedding?: ClipEmbedding,
height?: number,
width?: number,
) {
this.fileID = fileID;
this.height = height;
this.width = width;
this.faceEmbedding = faceEmbedding;
this.clipEmbedding = clipEmbedding;
}
toJson(): string {
return JSON.stringify(this);
}
static fromJson(json: string): ServerFileMl {
return JSON.parse(json);
}
}
class ServerFaceEmbeddings {
public faces: ServerFace[];
public version: number;
public client?: string;
public error?: boolean;
public constructor(
faces: ServerFace[],
version: number,
client?: string,
error?: boolean,
) {
this.faces = faces;
this.version = version;
this.client = client;
this.error = error;
}
toJson(): string {
return JSON.stringify(this);
}
static fromJson(json: string): ServerFaceEmbeddings {
return JSON.parse(json);
}
}
class ServerFace {
public fileID: number;
public faceID: string;
public embeddings: number[];
public detection: ServerDetection;
public score: number;
public blur: number;
public fileInfo?: ServerFileInfo;
public constructor(
fileID: number,
faceID: string,
embeddings: number[],
detection: ServerDetection,
score: number,
blur: number,
fileInfo?: ServerFileInfo,
) {
this.fileID = fileID;
this.faceID = faceID;
this.embeddings = embeddings;
this.detection = detection;
this.score = score;
this.blur = blur;
this.fileInfo = fileInfo;
}
toJson(): string {
return JSON.stringify(this);
}
static fromJson(json: string): ServerFace {
return JSON.parse(json);
}
}
class ServerFileInfo {
public imageWidth?: number;
public imageHeight?: number;
public constructor(imageWidth?: number, imageHeight?: number) {
this.imageWidth = imageWidth;
this.imageHeight = imageHeight;
}
}
class ServerDetection {
public box: ServerFaceBox;
public landmarks: Landmark[];
public constructor(box: ServerFaceBox, landmarks: Landmark[]) {
this.box = box;
this.landmarks = landmarks;
}
toJson(): string {
return JSON.stringify(this);
}
static fromJson(json: string): ServerDetection {
return JSON.parse(json);
}
}
class ServerFaceBox {
public xMin: number;
public yMin: number;
public width: number;
public height: number;
public constructor(
xMin: number,
yMin: number,
width: number,
height: number,
) {
this.xMin = xMin;
this.yMin = yMin;
this.width = width;
this.height = height;
}
toJson(): string {
return JSON.stringify(this);
}
static fromJson(json: string): ServerFaceBox {
return JSON.parse(json);
}
}
export function LocalFileMlDataToServerFileMl(
localFileMlData: MlFileData,
): ServerFileMl {
if(localFileMlData.errorCount > 0 && localFileMlData.lastErrorMessage !== undefined) {
return null
}
const imageDimensions = localFileMlData.imageDimensions;
const fileInfo = new ServerFileInfo(
imageDimensions.width,
imageDimensions.height,
);
const faces: ServerFace[] = [];
for (let i = 0; i < localFileMlData.faces.length; i++) {
const face: Face = localFileMlData.faces[i];
const faceID = face.id;
const embedding = face.embedding;
const score = face.detection.probability;
const blur = face.blurValue;
const detection: FaceDetection = face.detection;
const box = detection.box;
const landmarks = detection.landmarks;
const newBox = new ServerFaceBox(
box.x,
box.y,
box.width,
box.height,
);
const newLandmarks: Landmark[] = [];
for (let j = 0; j < landmarks.length; j++) {
newLandmarks.push(
{
x: landmarks[j].x,
y: landmarks[j].y,
} as Landmark,
);
}
const newFaceObject = new ServerFace(
localFileMlData.fileId,
faceID,
Array.from(embedding),
new ServerDetection(newBox, newLandmarks),
score,
blur,
fileInfo,
);
faces.push(newFaceObject);
}
const faceEmbeddings = new ServerFaceEmbeddings(faces, 1,localFileMlData.lastErrorMessage);
return new ServerFileMl(
localFileMlData.fileId,
faceEmbeddings,
null,
imageDimensions.height,
imageDimensions.width,
);
}
// // Not sure if this actually works
// export function ServerFileMlToLocalFileMlData(
// serverFileMl: ServerFileMl,
// ): MlFileData {
// const faces: Face[] = [];
// const mlVersion: number = serverFileMl.faceEmbeddings.version;
// const errorCount = serverFileMl.faceEmbeddings.error ? 1 : 0;
// for (let i = 0; i < serverFileMl.faceEmbeddings.faces.length; i++) {
// const face = serverFileMl.faceEmbeddings.faces[i];
// if(face.detection.landmarks.length === 0) {
// continue;
// }
// const detection = face.detection;
// const box = detection.box;
// const landmarks = detection.landmarks;
// const newBox = new FaceBox(
// box.xMin,
// box.yMin,
// box.width,
// box.height,
// );
// const newLandmarks: Landmark[] = [];
// for (let j = 0; j < landmarks.length; j++) {
// newLandmarks.push(
// {
// x: landmarks[j].x,
// y: landmarks[j].y,
// } as Landmark
// );
// }
// const newDetection = new Detection(newBox, newLandmarks);
// const newFace = {
// } as Face
// faces.push(newFace);
// }
// return {
// fileId: serverFileMl.fileID,
// imageDimensions: {
// width: serverFileMl.width,
// height: serverFileMl.height,
// },
// faces,
// mlVersion,
// errorCount,
// };
// }