|
@@ -2,6 +2,7 @@ import {
|
|
|
Embedding,
|
|
|
EncryptedEmbedding,
|
|
|
GetEmbeddingDiffResponse,
|
|
|
+ Model,
|
|
|
PutEmbeddingRequest,
|
|
|
} from 'types/embedding';
|
|
|
import ComlinkCryptoWorker from '@ente/shared/crypto';
|
|
@@ -16,105 +17,135 @@ import { getLatestVersionEmbeddings } from 'utils/embedding';
|
|
|
import { getLocalTrashedFiles } from './trashService';
|
|
|
import { getLocalCollections } from './collectionService';
|
|
|
import { CustomError } from '@ente/shared/error';
|
|
|
+import { EnteFile } from 'types/file';
|
|
|
|
|
|
const ENDPOINT = getEndpoint();
|
|
|
|
|
|
const DIFF_LIMIT = 500;
|
|
|
|
|
|
-const EMBEDDINGS_TABLE = 'embeddings';
|
|
|
+const EMBEDDINGS_TABLE_V1 = 'embeddings';
|
|
|
+const EMBEDDINGS_TABLE = 'embeddings_v2';
|
|
|
const EMBEDDING_SYNC_TIME_TABLE = 'embedding_sync_time';
|
|
|
|
|
|
-export const getLocalEmbeddings = async () => {
|
|
|
- const embeddings: Array<Embedding> =
|
|
|
- (await localForage.getItem<Embedding[]>(EMBEDDINGS_TABLE)) || [];
|
|
|
+export const getAllLocalEmbeddings = async () => {
|
|
|
+ const embeddings: Array<Embedding> = await localForage.getItem<Embedding[]>(
|
|
|
+ EMBEDDINGS_TABLE
|
|
|
+ );
|
|
|
+ if (!embeddings) {
|
|
|
+ await localForage.removeItem(EMBEDDINGS_TABLE_V1);
|
|
|
+ await localForage.removeItem(EMBEDDING_SYNC_TIME_TABLE);
|
|
|
+ await localForage.setItem(EMBEDDINGS_TABLE, []);
|
|
|
+ return [];
|
|
|
+ }
|
|
|
return embeddings;
|
|
|
};
|
|
|
|
|
|
-const getEmbeddingSyncTime = async () => {
|
|
|
- return (await localForage.getItem<number>(EMBEDDING_SYNC_TIME_TABLE)) ?? 0;
|
|
|
+export const getLocalEmbeddings = async (model: Model) => {
|
|
|
+ const embeddings = await getAllLocalEmbeddings();
|
|
|
+ return embeddings.filter((embedding) => embedding.model === model);
|
|
|
};
|
|
|
|
|
|
-export const getLatestEmbeddings = async () => {
|
|
|
- await syncEmbeddings();
|
|
|
- const embeddings = await getLocalEmbeddings();
|
|
|
- return embeddings;
|
|
|
+const getModelEmbeddingSyncTime = async (model: Model) => {
|
|
|
+ return (
|
|
|
+ (await localForage.getItem<number>(
|
|
|
+ `${model}-${EMBEDDING_SYNC_TIME_TABLE}`
|
|
|
+ )) ?? 0
|
|
|
+ );
|
|
|
};
|
|
|
|
|
|
-export const syncEmbeddings = async () => {
|
|
|
+const setModelEmbeddingSyncTime = async (model: Model, time: number) => {
|
|
|
+ await localForage.setItem(`${model}-${EMBEDDING_SYNC_TIME_TABLE}`, time);
|
|
|
+};
|
|
|
+
|
|
|
+export const syncEmbeddings = async (models: Model[] = [Model.ONNX_CLIP]) => {
|
|
|
try {
|
|
|
- let embeddings = await getLocalEmbeddings();
|
|
|
+ let allEmbeddings = await getAllLocalEmbeddings();
|
|
|
const localFiles = await getAllLocalFiles();
|
|
|
const hiddenAlbums = await getLocalCollections('hidden');
|
|
|
const localTrashFiles = await getLocalTrashedFiles();
|
|
|
const fileIdToKeyMap = new Map<number, string>();
|
|
|
- [...localFiles, ...localTrashFiles].forEach((file) => {
|
|
|
+ const allLocalFiles = [...localFiles, ...localTrashFiles];
|
|
|
+ allLocalFiles.forEach((file) => {
|
|
|
fileIdToKeyMap.set(file.id, file.key);
|
|
|
});
|
|
|
- addLogLine(`Syncing embeddings localCount: ${embeddings.length}`);
|
|
|
- let sinceTime = await getEmbeddingSyncTime();
|
|
|
- addLogLine(`Syncing embeddings sinceTime: ${sinceTime}`);
|
|
|
- let response: GetEmbeddingDiffResponse;
|
|
|
- do {
|
|
|
- response = await getEmbeddingsDiff(sinceTime);
|
|
|
- if (!response.diff?.length) {
|
|
|
- return;
|
|
|
- }
|
|
|
- const newEmbeddings = await Promise.all(
|
|
|
- response.diff.map(async (embedding) => {
|
|
|
- try {
|
|
|
- const {
|
|
|
- encryptedEmbedding,
|
|
|
- decryptionHeader,
|
|
|
- ...rest
|
|
|
- } = embedding;
|
|
|
- const worker = await ComlinkCryptoWorker.getInstance();
|
|
|
- const fileKey = fileIdToKeyMap.get(embedding.fileID);
|
|
|
- if (!fileKey) {
|
|
|
- throw Error(CustomError.FILE_NOT_FOUND);
|
|
|
- }
|
|
|
- const decryptedData = await worker.decryptEmbedding(
|
|
|
- encryptedEmbedding,
|
|
|
- decryptionHeader,
|
|
|
- fileIdToKeyMap.get(embedding.fileID)
|
|
|
- );
|
|
|
-
|
|
|
- return {
|
|
|
- ...rest,
|
|
|
- embedding: decryptedData,
|
|
|
- } as Embedding;
|
|
|
- } catch (e) {
|
|
|
- let info: Record<string, unknown>;
|
|
|
- if (e.message === CustomError.FILE_NOT_FOUND) {
|
|
|
- const hasHiddenAlbums = hiddenAlbums?.length > 0;
|
|
|
- info = {
|
|
|
- hasHiddenAlbums,
|
|
|
- };
|
|
|
- }
|
|
|
- logError(e, 'decryptEmbedding failed for file', info);
|
|
|
- }
|
|
|
- })
|
|
|
- );
|
|
|
- embeddings = getLatestVersionEmbeddings([
|
|
|
- ...embeddings,
|
|
|
- ...newEmbeddings,
|
|
|
- ]);
|
|
|
- if (response.diff.length) {
|
|
|
- sinceTime = response.diff.slice(-1)[0].updatedAt;
|
|
|
- }
|
|
|
- await localForage.setItem(EMBEDDINGS_TABLE, embeddings);
|
|
|
- await localForage.setItem(EMBEDDING_SYNC_TIME_TABLE, sinceTime);
|
|
|
+ await cleanupDeletedEmbeddings(allLocalFiles, allEmbeddings);
|
|
|
+ addLogLine(`Syncing embeddings localCount: ${allEmbeddings.length}`);
|
|
|
+ for (const model of models) {
|
|
|
+ let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
|
|
|
addLogLine(
|
|
|
- `Syncing embeddings syncedEmbeddingsCount: ${newEmbeddings.length}`
|
|
|
+ `Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`
|
|
|
);
|
|
|
- } while (response.diff.length === DIFF_LIMIT);
|
|
|
- void cleanupDeletedEmbeddings();
|
|
|
+ let response: GetEmbeddingDiffResponse;
|
|
|
+ do {
|
|
|
+ response = await getEmbeddingsDiff(modelLastSinceTime, model);
|
|
|
+ if (!response.diff?.length) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ const newEmbeddings = await Promise.all(
|
|
|
+ response.diff.map(async (embedding) => {
|
|
|
+ try {
|
|
|
+ const {
|
|
|
+ encryptedEmbedding,
|
|
|
+ decryptionHeader,
|
|
|
+ ...rest
|
|
|
+ } = embedding;
|
|
|
+ const worker =
|
|
|
+ await ComlinkCryptoWorker.getInstance();
|
|
|
+ const fileKey = fileIdToKeyMap.get(
|
|
|
+ embedding.fileID
|
|
|
+ );
|
|
|
+ if (!fileKey) {
|
|
|
+ throw Error(CustomError.FILE_NOT_FOUND);
|
|
|
+ }
|
|
|
+ const decryptedData = await worker.decryptEmbedding(
|
|
|
+ encryptedEmbedding,
|
|
|
+ decryptionHeader,
|
|
|
+ fileIdToKeyMap.get(embedding.fileID)
|
|
|
+ );
|
|
|
+
|
|
|
+ return {
|
|
|
+ ...rest,
|
|
|
+ embedding: decryptedData,
|
|
|
+ } as Embedding;
|
|
|
+ } catch (e) {
|
|
|
+ let info: Record<string, unknown>;
|
|
|
+ if (e.message === CustomError.FILE_NOT_FOUND) {
|
|
|
+ const hasHiddenAlbums =
|
|
|
+ hiddenAlbums?.length > 0;
|
|
|
+ info = {
|
|
|
+ hasHiddenAlbums,
|
|
|
+ };
|
|
|
+ }
|
|
|
+ logError(
|
|
|
+ e,
|
|
|
+ 'decryptEmbedding failed for file',
|
|
|
+ info
|
|
|
+ );
|
|
|
+ }
|
|
|
+ })
|
|
|
+ );
|
|
|
+ allEmbeddings = getLatestVersionEmbeddings([
|
|
|
+ ...allEmbeddings,
|
|
|
+ ...newEmbeddings,
|
|
|
+ ]);
|
|
|
+ if (response.diff.length) {
|
|
|
+ modelLastSinceTime = response.diff.slice(-1)[0].updatedAt;
|
|
|
+ }
|
|
|
+ await localForage.setItem(EMBEDDINGS_TABLE, allEmbeddings);
|
|
|
+ await setModelEmbeddingSyncTime(model, modelLastSinceTime);
|
|
|
+ addLogLine(
|
|
|
+ `Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`
|
|
|
+ );
|
|
|
+ } while (response.diff.length === DIFF_LIMIT);
|
|
|
+ }
|
|
|
} catch (e) {
|
|
|
logError(e, 'Sync embeddings failed');
|
|
|
}
|
|
|
};
|
|
|
|
|
|
export const getEmbeddingsDiff = async (
|
|
|
- sinceTime: number
|
|
|
+ sinceTime: number,
|
|
|
+ model: Model
|
|
|
): Promise<GetEmbeddingDiffResponse> => {
|
|
|
try {
|
|
|
const token = getToken();
|
|
@@ -126,6 +157,7 @@ export const getEmbeddingsDiff = async (
|
|
|
{
|
|
|
sinceTime,
|
|
|
limit: DIFF_LIMIT,
|
|
|
+ model,
|
|
|
},
|
|
|
{
|
|
|
'X-Auth-Token': token,
|
|
@@ -161,21 +193,21 @@ export const putEmbedding = async (
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-export const cleanupDeletedEmbeddings = async () => {
|
|
|
- const files = await getAllLocalFiles();
|
|
|
- const trashedFiles = await getLocalTrashedFiles();
|
|
|
+export const cleanupDeletedEmbeddings = async (
|
|
|
+ allLocalFiles: EnteFile[],
|
|
|
+ allLocalEmbeddings: Embedding[]
|
|
|
+) => {
|
|
|
const activeFileIds = new Set<number>();
|
|
|
- [...files, ...trashedFiles].forEach((file) => {
|
|
|
+ allLocalFiles.forEach((file) => {
|
|
|
activeFileIds.add(file.id);
|
|
|
});
|
|
|
- const embeddings = await getLocalEmbeddings();
|
|
|
|
|
|
- const remainingEmbeddings = embeddings.filter((embedding) =>
|
|
|
+ const remainingEmbeddings = allLocalEmbeddings.filter((embedding) =>
|
|
|
activeFileIds.has(embedding.fileID)
|
|
|
);
|
|
|
- if (embeddings.length !== remainingEmbeddings.length) {
|
|
|
+ if (allLocalEmbeddings.length !== remainingEmbeddings.length) {
|
|
|
addLogLine(
|
|
|
- `cleanupDeletedEmbeddings embeddingsCount: ${embeddings.length} remainingEmbeddingsCount: ${remainingEmbeddings.length}`
|
|
|
+ `cleanupDeletedEmbeddings embeddingsCount: ${allLocalEmbeddings.length} remainingEmbeddingsCount: ${remainingEmbeddings.length}`
|
|
|
);
|
|
|
await localForage.setItem(EMBEDDINGS_TABLE, remainingEmbeddings);
|
|
|
}
|