|
@@ -13,7 +13,11 @@ import type {
|
|
PutEmbeddingRequest,
|
|
PutEmbeddingRequest,
|
|
} from "types/embedding";
|
|
} from "types/embedding";
|
|
import { EnteFile } from "types/file";
|
|
import { EnteFile } from "types/file";
|
|
-import { getLatestVersionEmbeddings } from "utils/embedding";
|
|
|
|
|
|
+import {
|
|
|
|
+ getLatestVersionEmbeddings,
|
|
|
|
+ getLatestVersionFileEmbeddings,
|
|
|
|
+} from "utils/embedding";
|
|
|
|
+import { FileML } from "utils/machineLearning/mldataMappers";
|
|
import { getLocalCollections } from "./collectionService";
|
|
import { getLocalCollections } from "./collectionService";
|
|
import { getAllLocalFiles } from "./fileService";
|
|
import { getAllLocalFiles } from "./fileService";
|
|
import { getLocalTrashedFiles } from "./trashService";
|
|
import { getLocalTrashedFiles } from "./trashService";
|
|
@@ -24,6 +28,7 @@ const DIFF_LIMIT = 500;
|
|
|
|
|
|
const EMBEDDINGS_TABLE_V1 = "embeddings";
|
|
const EMBEDDINGS_TABLE_V1 = "embeddings";
|
|
const EMBEDDINGS_TABLE = "embeddings_v2";
|
|
const EMBEDDINGS_TABLE = "embeddings_v2";
|
|
|
|
+const FILE_EMBEDING_TABLE = "file_embeddings";
|
|
const EMBEDDING_SYNC_TIME_TABLE = "embedding_sync_time";
|
|
const EMBEDDING_SYNC_TIME_TABLE = "embedding_sync_time";
|
|
|
|
|
|
export const getAllLocalEmbeddings = async () => {
|
|
export const getAllLocalEmbeddings = async () => {
|
|
@@ -38,6 +43,15 @@ export const getAllLocalEmbeddings = async () => {
|
|
return embeddings;
|
|
return embeddings;
|
|
};
|
|
};
|
|
|
|
|
|
|
|
+export const getFileMLEmbeddings = async (): Promise<FileML[]> => {
|
|
|
|
+ const embeddings: Array<FileML> =
|
|
|
|
+ await localForage.getItem<FileML[]>(FILE_EMBEDING_TABLE);
|
|
|
|
+ if (!embeddings) {
|
|
|
|
+ return [];
|
|
|
|
+ }
|
|
|
|
+ return embeddings;
|
|
|
|
+};
|
|
|
|
+
|
|
export const getLocalEmbeddings = async () => {
|
|
export const getLocalEmbeddings = async () => {
|
|
const embeddings = await getAllLocalEmbeddings();
|
|
const embeddings = await getAllLocalEmbeddings();
|
|
return embeddings.filter((embedding) => embedding.model === "onnx-clip");
|
|
return embeddings.filter((embedding) => embedding.model === "onnx-clip");
|
|
@@ -140,6 +154,83 @@ export const syncEmbeddings = async () => {
|
|
}
|
|
}
|
|
};
|
|
};
|
|
|
|
|
|
|
|
+export const syncFileEmbeddings = async () => {
|
|
|
|
+ const models: EmbeddingModel[] = ["file-ml-clip-face"];
|
|
|
|
+ try {
|
|
|
|
+ let allEmbeddings: FileML[] = await getFileMLEmbeddings();
|
|
|
|
+ const localFiles = await getAllLocalFiles();
|
|
|
|
+ const hiddenAlbums = await getLocalCollections("hidden");
|
|
|
|
+ const localTrashFiles = await getLocalTrashedFiles();
|
|
|
|
+ const fileIdToKeyMap = new Map<number, string>();
|
|
|
|
+ const allLocalFiles = [...localFiles, ...localTrashFiles];
|
|
|
|
+ allLocalFiles.forEach((file) => {
|
|
|
|
+ fileIdToKeyMap.set(file.id, file.key);
|
|
|
|
+ });
|
|
|
|
+ await cleanupDeletedEmbeddings(allLocalFiles, allEmbeddings);
|
|
|
|
+ log.info(`Syncing embeddings localCount: ${allEmbeddings.length}`);
|
|
|
|
+ for (const model of models) {
|
|
|
|
+ let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
|
|
|
|
+ log.info(
|
|
|
|
+ `Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`,
|
|
|
|
+ );
|
|
|
|
+ let response: GetEmbeddingDiffResponse;
|
|
|
|
+ do {
|
|
|
|
+ response = await getEmbeddingsDiff(modelLastSinceTime, model);
|
|
|
|
+ if (!response.diff?.length) {
|
|
|
|
+ return;
|
|
|
|
+ }
|
|
|
|
+ const newEmbeddings = await Promise.all(
|
|
|
|
+ response.diff.map(async (embedding) => {
|
|
|
|
+ try {
|
|
|
|
+ const worker =
|
|
|
|
+ await ComlinkCryptoWorker.getInstance();
|
|
|
|
+ const fileKey = fileIdToKeyMap.get(
|
|
|
|
+ embedding.fileID,
|
|
|
|
+ );
|
|
|
|
+ if (!fileKey) {
|
|
|
|
+ throw Error(CustomError.FILE_NOT_FOUND);
|
|
|
|
+ }
|
|
|
|
+ const decryptedData = await worker.decryptMetadata(
|
|
|
|
+ embedding.encryptedEmbedding,
|
|
|
|
+ embedding.decryptionHeader,
|
|
|
|
+ fileIdToKeyMap.get(embedding.fileID),
|
|
|
|
+ );
|
|
|
|
+
|
|
|
|
+ return {
|
|
|
|
+ ...decryptedData,
|
|
|
|
+ updatedAt: embedding.updatedAt,
|
|
|
|
+ } as unknown as FileML;
|
|
|
|
+ } catch (e) {
|
|
|
|
+ let hasHiddenAlbums = false;
|
|
|
|
+ if (e.message === CustomError.FILE_NOT_FOUND) {
|
|
|
|
+ hasHiddenAlbums = hiddenAlbums?.length > 0;
|
|
|
|
+ }
|
|
|
|
+ log.error(
|
|
|
|
+ `decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`,
|
|
|
|
+ e,
|
|
|
|
+ );
|
|
|
|
+ }
|
|
|
|
+ }),
|
|
|
|
+ );
|
|
|
|
+ allEmbeddings = getLatestVersionFileEmbeddings([
|
|
|
|
+ ...allEmbeddings,
|
|
|
|
+ ...newEmbeddings,
|
|
|
|
+ ]);
|
|
|
|
+ if (response.diff.length) {
|
|
|
|
+ modelLastSinceTime = response.diff.slice(-1)[0].updatedAt;
|
|
|
|
+ }
|
|
|
|
+ await localForage.setItem(FILE_EMBEDING_TABLE, allEmbeddings);
|
|
|
|
+ await setModelEmbeddingSyncTime(model, modelLastSinceTime);
|
|
|
|
+ log.info(
|
|
|
|
+ `Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`,
|
|
|
|
+ );
|
|
|
|
+ } while (response.diff.length === DIFF_LIMIT);
|
|
|
|
+ }
|
|
|
|
+ } catch (e) {
|
|
|
|
+ log.error("Sync embeddings failed", e);
|
|
|
|
+ }
|
|
|
|
+};
|
|
|
|
+
|
|
export const getEmbeddingsDiff = async (
|
|
export const getEmbeddingsDiff = async (
|
|
sinceTime: number,
|
|
sinceTime: number,
|
|
model: EmbeddingModel,
|
|
model: EmbeddingModel,
|
|
@@ -173,7 +264,8 @@ export const putEmbedding = async (
|
|
try {
|
|
try {
|
|
const token = getToken();
|
|
const token = getToken();
|
|
if (!token) {
|
|
if (!token) {
|
|
- return;
|
|
|
|
|
|
+ log.info("putEmbedding failed: token not found");
|
|
|
|
+ throw Error(CustomError.TOKEN_MISSING);
|
|
}
|
|
}
|
|
const resp = await HTTPService.put(
|
|
const resp = await HTTPService.put(
|
|
`${ENDPOINT}/embeddings`,
|
|
`${ENDPOINT}/embeddings`,
|
|
@@ -192,7 +284,7 @@ export const putEmbedding = async (
|
|
|
|
|
|
export const cleanupDeletedEmbeddings = async (
|
|
export const cleanupDeletedEmbeddings = async (
|
|
allLocalFiles: EnteFile[],
|
|
allLocalFiles: EnteFile[],
|
|
- allLocalEmbeddings: Embedding[],
|
|
|
|
|
|
+ allLocalEmbeddings: Embedding[] | FileML[],
|
|
) => {
|
|
) => {
|
|
const activeFileIds = new Set<number>();
|
|
const activeFileIds = new Set<number>();
|
|
allLocalFiles.forEach((file) => {
|
|
allLocalFiles.forEach((file) => {
|