[web] Import the scaffolding to sync face embeddings from web_face_v2 (#1402)
This PR cherry picks Neeraj's ML related changes from the web_face_v2 branch. Similar to https://github.com/ente-io/ente/pull/1399, this gets us one step closer to integrating ONNX-YOLO with our desktop app. But it is not currently in a usable state (The web app's functionality remains untouched).
This commit is contained in:
commit
4a69e9260c
6 changed files with 415 additions and 13 deletions
|
@ -105,7 +105,7 @@ import { AppContext } from "pages/_app";
|
|||
import { clipService } from "services/clip-service";
|
||||
import { constructUserIDToEmailMap } from "services/collectionService";
|
||||
import downloadManager from "services/download";
|
||||
import { syncEmbeddings } from "services/embeddingService";
|
||||
import { syncEmbeddings, syncFileEmbeddings } from "services/embeddingService";
|
||||
import { syncEntities } from "services/entityService";
|
||||
import locationSearchService from "services/locationSearchService";
|
||||
import { getLocalTrashedFiles, syncTrash } from "services/trashService";
|
||||
|
@ -702,6 +702,10 @@ export default function Gallery() {
|
|||
await syncEntities();
|
||||
await syncMapEnabled();
|
||||
await syncEmbeddings();
|
||||
const electron = globalThis.electron;
|
||||
if (electron) {
|
||||
await syncFileEmbeddings();
|
||||
}
|
||||
if (clipService.isPlatformSupported()) {
|
||||
void clipService.scheduleImageEmbeddingExtraction();
|
||||
}
|
||||
|
|
|
@ -13,7 +13,11 @@ import type {
|
|||
PutEmbeddingRequest,
|
||||
} from "types/embedding";
|
||||
import { EnteFile } from "types/file";
|
||||
import { getLatestVersionEmbeddings } from "utils/embedding";
|
||||
import {
|
||||
getLatestVersionEmbeddings,
|
||||
getLatestVersionFileEmbeddings,
|
||||
} from "utils/embedding";
|
||||
import { FileML } from "utils/machineLearning/mldataMappers";
|
||||
import { getLocalCollections } from "./collectionService";
|
||||
import { getAllLocalFiles } from "./fileService";
|
||||
import { getLocalTrashedFiles } from "./trashService";
|
||||
|
@ -24,6 +28,7 @@ const DIFF_LIMIT = 500;
|
|||
|
||||
const EMBEDDINGS_TABLE_V1 = "embeddings";
|
||||
const EMBEDDINGS_TABLE = "embeddings_v2";
|
||||
const FILE_EMBEDING_TABLE = "file_embeddings";
|
||||
const EMBEDDING_SYNC_TIME_TABLE = "embedding_sync_time";
|
||||
|
||||
export const getAllLocalEmbeddings = async () => {
|
||||
|
@ -38,6 +43,15 @@ export const getAllLocalEmbeddings = async () => {
|
|||
return embeddings;
|
||||
};
|
||||
|
||||
export const getFileMLEmbeddings = async (): Promise<FileML[]> => {
|
||||
const embeddings: Array<FileML> =
|
||||
await localForage.getItem<FileML[]>(FILE_EMBEDING_TABLE);
|
||||
if (!embeddings) {
|
||||
return [];
|
||||
}
|
||||
return embeddings;
|
||||
};
|
||||
|
||||
export const getLocalEmbeddings = async () => {
|
||||
const embeddings = await getAllLocalEmbeddings();
|
||||
return embeddings.filter((embedding) => embedding.model === "onnx-clip");
|
||||
|
@ -140,6 +154,83 @@ export const syncEmbeddings = async () => {
|
|||
}
|
||||
};
|
||||
|
||||
export const syncFileEmbeddings = async () => {
|
||||
const models: EmbeddingModel[] = ["file-ml-clip-face"];
|
||||
try {
|
||||
let allEmbeddings: FileML[] = await getFileMLEmbeddings();
|
||||
const localFiles = await getAllLocalFiles();
|
||||
const hiddenAlbums = await getLocalCollections("hidden");
|
||||
const localTrashFiles = await getLocalTrashedFiles();
|
||||
const fileIdToKeyMap = new Map<number, string>();
|
||||
const allLocalFiles = [...localFiles, ...localTrashFiles];
|
||||
allLocalFiles.forEach((file) => {
|
||||
fileIdToKeyMap.set(file.id, file.key);
|
||||
});
|
||||
await cleanupDeletedEmbeddings(allLocalFiles, allEmbeddings);
|
||||
log.info(`Syncing embeddings localCount: ${allEmbeddings.length}`);
|
||||
for (const model of models) {
|
||||
let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
|
||||
log.info(
|
||||
`Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`,
|
||||
);
|
||||
let response: GetEmbeddingDiffResponse;
|
||||
do {
|
||||
response = await getEmbeddingsDiff(modelLastSinceTime, model);
|
||||
if (!response.diff?.length) {
|
||||
return;
|
||||
}
|
||||
const newEmbeddings = await Promise.all(
|
||||
response.diff.map(async (embedding) => {
|
||||
try {
|
||||
const worker =
|
||||
await ComlinkCryptoWorker.getInstance();
|
||||
const fileKey = fileIdToKeyMap.get(
|
||||
embedding.fileID,
|
||||
);
|
||||
if (!fileKey) {
|
||||
throw Error(CustomError.FILE_NOT_FOUND);
|
||||
}
|
||||
const decryptedData = await worker.decryptMetadata(
|
||||
embedding.encryptedEmbedding,
|
||||
embedding.decryptionHeader,
|
||||
fileIdToKeyMap.get(embedding.fileID),
|
||||
);
|
||||
|
||||
return {
|
||||
...decryptedData,
|
||||
updatedAt: embedding.updatedAt,
|
||||
} as unknown as FileML;
|
||||
} catch (e) {
|
||||
let hasHiddenAlbums = false;
|
||||
if (e.message === CustomError.FILE_NOT_FOUND) {
|
||||
hasHiddenAlbums = hiddenAlbums?.length > 0;
|
||||
}
|
||||
log.error(
|
||||
`decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`,
|
||||
e,
|
||||
);
|
||||
}
|
||||
}),
|
||||
);
|
||||
allEmbeddings = getLatestVersionFileEmbeddings([
|
||||
...allEmbeddings,
|
||||
...newEmbeddings,
|
||||
]);
|
||||
if (response.diff.length) {
|
||||
modelLastSinceTime = response.diff.slice(-1)[0].updatedAt;
|
||||
}
|
||||
await localForage.setItem(FILE_EMBEDING_TABLE, allEmbeddings);
|
||||
await setModelEmbeddingSyncTime(model, modelLastSinceTime);
|
||||
log.info(
|
||||
`Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`,
|
||||
);
|
||||
} while (response.diff.length === DIFF_LIMIT);
|
||||
}
|
||||
} catch (e) {
|
||||
log.error("Sync embeddings failed", e);
|
||||
}
|
||||
};
|
||||
|
||||
export const getEmbeddingsDiff = async (
|
||||
sinceTime: number,
|
||||
model: EmbeddingModel,
|
||||
|
@ -173,7 +264,8 @@ export const putEmbedding = async (
|
|||
try {
|
||||
const token = getToken();
|
||||
if (!token) {
|
||||
return;
|
||||
log.info("putEmbedding failed: token not found");
|
||||
throw Error(CustomError.TOKEN_MISSING);
|
||||
}
|
||||
const resp = await HTTPService.put(
|
||||
`${ENDPOINT}/embeddings`,
|
||||
|
@ -192,7 +284,7 @@ export const putEmbedding = async (
|
|||
|
||||
export const cleanupDeletedEmbeddings = async (
|
||||
allLocalFiles: EnteFile[],
|
||||
allLocalEmbeddings: Embedding[],
|
||||
allLocalEmbeddings: Embedding[] | FileML[],
|
||||
) => {
|
||||
const activeFileIds = new Set<number>();
|
||||
allLocalFiles.forEach((file) => {
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import log from "@/next/log";
|
||||
import { APPS } from "@ente/shared/apps/constants";
|
||||
import ComlinkCryptoWorker from "@ente/shared/crypto";
|
||||
import { CustomError, parseUploadErrorCodes } from "@ente/shared/error";
|
||||
import "@tensorflow/tfjs-backend-cpu";
|
||||
import "@tensorflow/tfjs-backend-webgl";
|
||||
import * as tf from "@tensorflow/tfjs-core";
|
||||
import { MAX_ML_SYNC_ERROR_COUNT } from "constants/mlConfig";
|
||||
import downloadManager from "services/download";
|
||||
import { putEmbedding } from "services/embeddingService";
|
||||
import { getLocalFiles } from "services/fileService";
|
||||
import { EnteFile } from "types/file";
|
||||
import {
|
||||
|
@ -15,6 +17,7 @@ import {
|
|||
MlFileData,
|
||||
} from "types/machineLearning";
|
||||
import { getMLSyncConfig } from "utils/machineLearning/config";
|
||||
import { LocalFileMlDataToServerFileMl } from "utils/machineLearning/mldataMappers";
|
||||
import mlIDbStorage from "utils/storage/mlIDbStorage";
|
||||
import FaceService from "./faceService";
|
||||
import { MLFactory } from "./machineLearningFactory";
|
||||
|
@ -215,13 +218,13 @@ class MachineLearningService {
|
|||
syncContext,
|
||||
[...existingFilesMap.values()],
|
||||
);
|
||||
// addLogLine("getUniqueOutOfSyncFiles");
|
||||
// addLogLine(
|
||||
// "Got unique outOfSyncFiles: ",
|
||||
// syncContext.outOfSyncFiles.length,
|
||||
// "for batchSize: ",
|
||||
// syncContext.config.batchSize,
|
||||
// );
|
||||
log.info("getUniqueOutOfSyncFiles");
|
||||
log.info(
|
||||
"Got unique outOfSyncFiles: ",
|
||||
syncContext.outOfSyncFiles.length,
|
||||
"for batchSize: ",
|
||||
syncContext.config.batchSize,
|
||||
);
|
||||
}
|
||||
|
||||
private async syncFiles(syncContext: MLSyncContext) {
|
||||
|
@ -415,6 +418,7 @@ class MachineLearningService {
|
|||
]);
|
||||
newMlFile.errorCount = 0;
|
||||
newMlFile.lastErrorMessage = undefined;
|
||||
await this.persistOnServer(newMlFile, enteFile);
|
||||
await this.persistMLFileData(syncContext, newMlFile);
|
||||
} catch (e) {
|
||||
log.error("ml detection failed", e);
|
||||
|
@ -435,6 +439,25 @@ class MachineLearningService {
|
|||
return newMlFile;
|
||||
}
|
||||
|
||||
private async persistOnServer(mlFileData: MlFileData, enteFile: EnteFile) {
|
||||
const serverMl = LocalFileMlDataToServerFileMl(mlFileData);
|
||||
log.info(mlFileData);
|
||||
|
||||
const comlinkCryptoWorker = await ComlinkCryptoWorker.getInstance();
|
||||
const { file: encryptedEmbeddingData } =
|
||||
await comlinkCryptoWorker.encryptMetadata(serverMl, enteFile.key);
|
||||
log.info(
|
||||
`putEmbedding embedding to server for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
|
||||
);
|
||||
const res = await putEmbedding({
|
||||
fileID: enteFile.id,
|
||||
encryptedEmbedding: encryptedEmbeddingData.encryptedData,
|
||||
decryptionHeader: encryptedEmbeddingData.decryptionHeader,
|
||||
model: "file-ml-clip-face",
|
||||
});
|
||||
log.info("putEmbedding response: ", res);
|
||||
}
|
||||
|
||||
public async init() {
|
||||
if (this.initialized) {
|
||||
return;
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
* embeddings on the server. However, we should be prepared to receive an
|
||||
* {@link EncryptedEmbedding} with a model value distinct from one of these.
|
||||
*/
|
||||
export type EmbeddingModel = "onnx-clip";
|
||||
export type EmbeddingModel = "onnx-clip" | "file-ml-clip-face";
|
||||
|
||||
export interface EncryptedEmbedding {
|
||||
fileID: number;
|
||||
|
@ -21,7 +21,7 @@ export interface Embedding
|
|||
EncryptedEmbedding,
|
||||
"encryptedEmbedding" | "decryptionHeader"
|
||||
> {
|
||||
embedding: Float32Array;
|
||||
embedding?: Float32Array;
|
||||
}
|
||||
|
||||
export interface GetEmbeddingDiffResponse {
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import { Embedding } from "types/embedding";
|
||||
import { FileML } from "./machineLearning/mldataMappers";
|
||||
|
||||
export const getLatestVersionEmbeddings = (embeddings: Embedding[]) => {
|
||||
const latestVersionEntities = new Map<number, Embedding>();
|
||||
|
@ -16,3 +17,20 @@ export const getLatestVersionEmbeddings = (embeddings: Embedding[]) => {
|
|||
});
|
||||
return Array.from(latestVersionEntities.values());
|
||||
};
|
||||
|
||||
export const getLatestVersionFileEmbeddings = (embeddings: FileML[]) => {
|
||||
const latestVersionEntities = new Map<number, FileML>();
|
||||
embeddings.forEach((embedding) => {
|
||||
if (!embedding?.fileID) {
|
||||
return;
|
||||
}
|
||||
const existingEmbeddings = latestVersionEntities.get(embedding.fileID);
|
||||
if (
|
||||
!existingEmbeddings ||
|
||||
existingEmbeddings.updatedAt < embedding.updatedAt
|
||||
) {
|
||||
latestVersionEntities.set(embedding.fileID, embedding);
|
||||
}
|
||||
});
|
||||
return Array.from(latestVersionEntities.values());
|
||||
};
|
||||
|
|
265
web/apps/photos/src/utils/machineLearning/mldataMappers.ts
Normal file
265
web/apps/photos/src/utils/machineLearning/mldataMappers.ts
Normal file
|
@ -0,0 +1,265 @@
|
|||
import {
|
||||
Face,
|
||||
FaceDetection,
|
||||
Landmark,
|
||||
MlFileData,
|
||||
} from "types/machineLearning";
|
||||
import { ClipEmbedding } from "types/machineLearning/data/clip";
|
||||
|
||||
export interface FileML extends ServerFileMl {
|
||||
updatedAt: number;
|
||||
}
|
||||
|
||||
class ServerFileMl {
|
||||
public fileID: number;
|
||||
public height?: number;
|
||||
public width?: number;
|
||||
public faceEmbedding: ServerFaceEmbeddings;
|
||||
public clipEmbedding?: ClipEmbedding;
|
||||
|
||||
public constructor(
|
||||
fileID: number,
|
||||
faceEmbedding: ServerFaceEmbeddings,
|
||||
clipEmbedding?: ClipEmbedding,
|
||||
height?: number,
|
||||
width?: number,
|
||||
) {
|
||||
this.fileID = fileID;
|
||||
this.height = height;
|
||||
this.width = width;
|
||||
this.faceEmbedding = faceEmbedding;
|
||||
this.clipEmbedding = clipEmbedding;
|
||||
}
|
||||
|
||||
toJson(): string {
|
||||
return JSON.stringify(this);
|
||||
}
|
||||
|
||||
static fromJson(json: string): ServerFileMl {
|
||||
return JSON.parse(json);
|
||||
}
|
||||
}
|
||||
|
||||
class ServerFaceEmbeddings {
|
||||
public faces: ServerFace[];
|
||||
public version: number;
|
||||
public client?: string;
|
||||
public error?: boolean;
|
||||
|
||||
public constructor(
|
||||
faces: ServerFace[],
|
||||
version: number,
|
||||
client?: string,
|
||||
error?: boolean,
|
||||
) {
|
||||
this.faces = faces;
|
||||
this.version = version;
|
||||
this.client = client;
|
||||
this.error = error;
|
||||
}
|
||||
|
||||
toJson(): string {
|
||||
return JSON.stringify(this);
|
||||
}
|
||||
|
||||
static fromJson(json: string): ServerFaceEmbeddings {
|
||||
return JSON.parse(json);
|
||||
}
|
||||
}
|
||||
|
||||
class ServerFace {
|
||||
public fileID: number;
|
||||
public faceID: string;
|
||||
public embeddings: number[];
|
||||
public detection: ServerDetection;
|
||||
public score: number;
|
||||
public blur: number;
|
||||
public fileInfo?: ServerFileInfo;
|
||||
|
||||
public constructor(
|
||||
fileID: number,
|
||||
faceID: string,
|
||||
embeddings: number[],
|
||||
detection: ServerDetection,
|
||||
score: number,
|
||||
blur: number,
|
||||
fileInfo?: ServerFileInfo,
|
||||
) {
|
||||
this.fileID = fileID;
|
||||
this.faceID = faceID;
|
||||
this.embeddings = embeddings;
|
||||
this.detection = detection;
|
||||
this.score = score;
|
||||
this.blur = blur;
|
||||
this.fileInfo = fileInfo;
|
||||
}
|
||||
|
||||
toJson(): string {
|
||||
return JSON.stringify(this);
|
||||
}
|
||||
|
||||
static fromJson(json: string): ServerFace {
|
||||
return JSON.parse(json);
|
||||
}
|
||||
}
|
||||
|
||||
class ServerFileInfo {
|
||||
public imageWidth?: number;
|
||||
public imageHeight?: number;
|
||||
|
||||
public constructor(imageWidth?: number, imageHeight?: number) {
|
||||
this.imageWidth = imageWidth;
|
||||
this.imageHeight = imageHeight;
|
||||
}
|
||||
}
|
||||
|
||||
class ServerDetection {
|
||||
public box: ServerFaceBox;
|
||||
public landmarks: Landmark[];
|
||||
|
||||
public constructor(box: ServerFaceBox, landmarks: Landmark[]) {
|
||||
this.box = box;
|
||||
this.landmarks = landmarks;
|
||||
}
|
||||
|
||||
toJson(): string {
|
||||
return JSON.stringify(this);
|
||||
}
|
||||
|
||||
static fromJson(json: string): ServerDetection {
|
||||
return JSON.parse(json);
|
||||
}
|
||||
}
|
||||
|
||||
class ServerFaceBox {
|
||||
public xMin: number;
|
||||
public yMin: number;
|
||||
public width: number;
|
||||
public height: number;
|
||||
|
||||
public constructor(
|
||||
xMin: number,
|
||||
yMin: number,
|
||||
width: number,
|
||||
height: number,
|
||||
) {
|
||||
this.xMin = xMin;
|
||||
this.yMin = yMin;
|
||||
this.width = width;
|
||||
this.height = height;
|
||||
}
|
||||
|
||||
toJson(): string {
|
||||
return JSON.stringify(this);
|
||||
}
|
||||
|
||||
static fromJson(json: string): ServerFaceBox {
|
||||
return JSON.parse(json);
|
||||
}
|
||||
}
|
||||
|
||||
export function LocalFileMlDataToServerFileMl(
|
||||
localFileMlData: MlFileData,
|
||||
): ServerFileMl {
|
||||
if (
|
||||
localFileMlData.errorCount > 0 &&
|
||||
localFileMlData.lastErrorMessage !== undefined
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
const imageDimensions = localFileMlData.imageDimensions;
|
||||
const fileInfo = new ServerFileInfo(
|
||||
imageDimensions.width,
|
||||
imageDimensions.height,
|
||||
);
|
||||
const faces: ServerFace[] = [];
|
||||
for (let i = 0; i < localFileMlData.faces.length; i++) {
|
||||
const face: Face = localFileMlData.faces[i];
|
||||
const faceID = face.id;
|
||||
const embedding = face.embedding;
|
||||
const score = face.detection.probability;
|
||||
const blur = face.blurValue;
|
||||
const detection: FaceDetection = face.detection;
|
||||
const box = detection.box;
|
||||
const landmarks = detection.landmarks;
|
||||
const newBox = new ServerFaceBox(box.x, box.y, box.width, box.height);
|
||||
const newLandmarks: Landmark[] = [];
|
||||
for (let j = 0; j < landmarks.length; j++) {
|
||||
newLandmarks.push({
|
||||
x: landmarks[j].x,
|
||||
y: landmarks[j].y,
|
||||
} as Landmark);
|
||||
}
|
||||
|
||||
const newFaceObject = new ServerFace(
|
||||
localFileMlData.fileId,
|
||||
faceID,
|
||||
Array.from(embedding),
|
||||
new ServerDetection(newBox, newLandmarks),
|
||||
score,
|
||||
blur,
|
||||
fileInfo,
|
||||
);
|
||||
faces.push(newFaceObject);
|
||||
}
|
||||
const faceEmbeddings = new ServerFaceEmbeddings(
|
||||
faces,
|
||||
1,
|
||||
localFileMlData.lastErrorMessage,
|
||||
);
|
||||
return new ServerFileMl(
|
||||
localFileMlData.fileId,
|
||||
faceEmbeddings,
|
||||
null,
|
||||
imageDimensions.height,
|
||||
imageDimensions.width,
|
||||
);
|
||||
}
|
||||
|
||||
// // Not sure if this actually works
|
||||
// export function ServerFileMlToLocalFileMlData(
|
||||
// serverFileMl: ServerFileMl,
|
||||
// ): MlFileData {
|
||||
// const faces: Face[] = [];
|
||||
// const mlVersion: number = serverFileMl.faceEmbeddings.version;
|
||||
// const errorCount = serverFileMl.faceEmbeddings.error ? 1 : 0;
|
||||
// for (let i = 0; i < serverFileMl.faceEmbeddings.faces.length; i++) {
|
||||
// const face = serverFileMl.faceEmbeddings.faces[i];
|
||||
// if(face.detection.landmarks.length === 0) {
|
||||
// continue;
|
||||
// }
|
||||
// const detection = face.detection;
|
||||
// const box = detection.box;
|
||||
// const landmarks = detection.landmarks;
|
||||
// const newBox = new FaceBox(
|
||||
// box.xMin,
|
||||
// box.yMin,
|
||||
// box.width,
|
||||
// box.height,
|
||||
// );
|
||||
// const newLandmarks: Landmark[] = [];
|
||||
// for (let j = 0; j < landmarks.length; j++) {
|
||||
// newLandmarks.push(
|
||||
// {
|
||||
// x: landmarks[j].x,
|
||||
// y: landmarks[j].y,
|
||||
// } as Landmark
|
||||
// );
|
||||
// }
|
||||
// const newDetection = new Detection(newBox, newLandmarks);
|
||||
// const newFace = {
|
||||
|
||||
// } as Face
|
||||
// faces.push(newFace);
|
||||
// }
|
||||
// return {
|
||||
// fileId: serverFileMl.fileID,
|
||||
// imageDimensions: {
|
||||
// width: serverFileMl.width,
|
||||
// height: serverFileMl.height,
|
||||
// },
|
||||
// faces,
|
||||
// mlVersion,
|
||||
// errorCount,
|
||||
// };
|
||||
// }
|
Loading…
Add table
Reference in a new issue