Abhinav Kumar před 1 rokem
rodič
revize
9fbcf55e4e

+ 8 - 8
apps/photos/src/components/Sidebar/AdvancedSettings.tsx

@@ -52,15 +52,13 @@ export default function AdvancedSettings({ open, onClose, onRootClose }) {
     });
 
     useEffect(() => {
-        ClipService.setOnUpdateHandler(setIndexingStatus);
+        const main = async () => {
+            setIndexingStatus(await ClipService.getIndexingStatus());
+            ClipService.setOnUpdateHandler(setIndexingStatus);
+        };
+        main();
     }, []);
 
-    useEffect(() => {
-        if (open) {
-            ClipService.updateIndexStatus();
-        }
-    }, [open]);
-
     return (
         <EnteDrawer
             transitionDuration={0}
@@ -112,7 +110,9 @@ export default function AdvancedSettings({ open, onClose, onRootClose }) {
 
                         {isElectron() && (
                             <Box>
-                                <MenuSectionTitle title={t('STATUS')} />
+                                <MenuSectionTitle
+                                    title={t('MAGIC_SEARCH_STATUS')}
+                                />
                                 <Stack py={'12px'} px={'12px'} spacing={'24px'}>
                                     <VerticallyCenteredFlex
                                         justifyContent="space-between"

+ 63 - 45
apps/photos/src/services/clipService.ts

@@ -1,8 +1,4 @@
-import {
-    putEmbedding,
-    getLatestEmbeddings,
-    getLocalEmbeddings,
-} from './embeddingService';
+import { putEmbedding, getLocalEmbeddings } from './embeddingService';
 import { getAllLocalFiles, getLocalFiles } from './fileService';
 import downloadManager from './download';
 import { logError } from '@ente/shared/sentry';
@@ -100,14 +96,18 @@ class ClipServiceImpl {
         }
     };
 
-    updateIndexStatus = async () => {
+    getIndexingStatus = async () => {
         try {
-            addLogLine('loading local clip index status');
-            this.clipExtractionStatus = await getClipExtractionStatus();
-            this.onUpdateHandler(this.clipExtractionStatus);
-            addLogLine('loaded local clip index status');
+            if (
+                !this.clipExtractionStatus ||
+                (this.clipExtractionStatus.pending === 0 &&
+                    this.clipExtractionStatus.indexed === 0)
+            ) {
+                this.clipExtractionStatus = await getClipExtractionStatus();
+            }
+            return this.clipExtractionStatus;
         } catch (e) {
-            logError(e, 'failed to load local clip index status');
+            logError(e, 'failed to get clip indexing status');
         }
     };
 
@@ -116,7 +116,9 @@ class ClipServiceImpl {
         handler(this.clipExtractionStatus);
     };
 
-    scheduleImageEmbeddingExtraction = async () => {
+    scheduleImageEmbeddingExtraction = async (
+        model: Model = Model.ONNX_CLIP
+    ) => {
         try {
             if (this.embeddingExtractionInProgress) {
                 addLogLine(
@@ -132,7 +134,7 @@ class ClipServiceImpl {
             const canceller = new AbortController();
             this.embeddingExtractionInProgress = canceller;
             try {
-                await this.runClipEmbeddingExtraction(canceller);
+                await this.runClipEmbeddingExtraction(canceller, model);
             } finally {
                 this.embeddingExtractionInProgress = null;
                 if (!canceller.signal.aborted && this.reRunNeeded) {
@@ -151,9 +153,12 @@ class ClipServiceImpl {
         }
     };
 
-    getTextEmbedding = async (text: string): Promise<Float32Array> => {
+    getTextEmbedding = async (
+        text: string,
+        model: Model = Model.ONNX_CLIP
+    ): Promise<Float32Array> => {
         try {
-            return ElectronAPIs.computeTextEmbedding(text);
+            return ElectronAPIs.computeTextEmbedding(model, text);
         } catch (e) {
             if (e?.message?.includes(CustomError.UNSUPPORTED_PLATFORM)) {
                 this.unsupportedPlatform = true;
@@ -163,7 +168,10 @@ class ClipServiceImpl {
         }
     };
 
-    private runClipEmbeddingExtraction = async (canceller: AbortController) => {
+    private runClipEmbeddingExtraction = async (
+        canceller: AbortController,
+        model: Model
+    ) => {
         try {
             if (this.unsupportedPlatform) {
                 addLogLine(
@@ -176,7 +184,7 @@ class ClipServiceImpl {
                 return;
             }
             const localFiles = getPersonalFiles(await getAllLocalFiles(), user);
-            const existingEmbeddings = await getLatestClipImageEmbeddings();
+            const existingEmbeddings = await getLocalEmbeddings(model);
             const pendingFiles = await getNonClipEmbeddingExtractedFiles(
                 localFiles,
                 existingEmbeddings
@@ -201,11 +209,15 @@ class ClipServiceImpl {
                         throw Error(CustomError.REQUEST_CANCELLED);
                     }
                     const embeddingData =
-                        await this.extractFileClipImageEmbedding(file);
+                        await this.extractFileClipImageEmbedding(model, file);
                     addLogLine(
                         `successfully extracted clip embedding for file: ${file.metadata.title} fileID: ${file.id} embedding length: ${embeddingData?.length}`
                     );
-                    await this.encryptAndUploadEmbedding(file, embeddingData);
+                    await this.encryptAndUploadEmbedding(
+                        model,
+                        file,
+                        embeddingData
+                    );
                     this.onSuccessStatusUpdater();
                     addLogLine(
                         `successfully put clip embedding to server for file: ${file.metadata.title} fileID: ${file.id}`
@@ -238,10 +250,13 @@ class ClipServiceImpl {
         }
     };
 
-    private async runLocalFileClipExtraction(arg: {
-        enteFile: EnteFile;
-        localFile: globalThis.File;
-    }) {
+    private async runLocalFileClipExtraction(
+        arg: {
+            enteFile: EnteFile;
+            localFile: globalThis.File;
+        },
+        model: Model = Model.ONNX_CLIP
+    ) {
         const { enteFile, localFile } = arg;
         addLogLine(
             `clip embedding extraction onFileUploadedHandler file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
@@ -256,9 +271,14 @@ class ClipServiceImpl {
         try {
             await this.liveEmbeddingExtractionQueue.add(async () => {
                 const embedding = await this.extractLocalFileClipImageEmbedding(
+                    model,
                     localFile
                 );
-                await this.encryptAndUploadEmbedding(enteFile, embedding);
+                await this.encryptAndUploadEmbedding(
+                    model,
+                    enteFile,
+                    embedding
+                );
             });
             addLogLine(
                 `successfully extracted clip embedding for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`
@@ -268,15 +288,19 @@ class ClipServiceImpl {
         }
     }
 
-    private extractLocalFileClipImageEmbedding = async (localFile: File) => {
+    private extractLocalFileClipImageEmbedding = async (
+        model: Model,
+        localFile: File
+    ) => {
         const file = await localFile
             .arrayBuffer()
             .then((buffer) => new Uint8Array(buffer));
-        const embedding = await ElectronAPIs.computeImageEmbedding(file);
+        const embedding = await ElectronAPIs.computeImageEmbedding(model, file);
         return embedding;
     };
 
     private encryptAndUploadEmbedding = async (
+        model: Model,
         file: EnteFile,
         embeddingData: Float32Array
     ) => {
@@ -295,7 +319,7 @@ class ClipServiceImpl {
             fileID: file.id,
             encryptedEmbedding: encryptedEmbeddingData.encryptedData,
             decryptionHeader: encryptedEmbeddingData.decryptionHeader,
-            model: Model.GGML_CLIP,
+            model,
         });
     };
 
@@ -306,9 +330,15 @@ class ClipServiceImpl {
         }
     };
 
-    private extractFileClipImageEmbedding = async (file: EnteFile) => {
+    private extractFileClipImageEmbedding = async (
+        model: Model,
+        file: EnteFile
+    ) => {
         const thumb = await downloadManager.getThumbnail(file);
-        const embedding = await ElectronAPIs.computeImageEmbedding(thumb);
+        const embedding = await ElectronAPIs.computeImageEmbedding(
+            model,
+            thumb
+        );
         return embedding;
     };
 
@@ -343,13 +373,6 @@ const getNonClipEmbeddingExtractedFiles = async (
     });
 };
 
-export const getLocalClipImageEmbeddings = async () => {
-    const allEmbeddings = await getLocalEmbeddings();
-    return allEmbeddings.filter(
-        (embedding) => embedding.model === Model.GGML_CLIP
-    );
-};
-
 export const computeClipMatchScore = async (
     imageEmbedding: Float32Array,
     textEmbedding: Float32Array
@@ -377,19 +400,14 @@ export const computeClipMatchScore = async (
     return score;
 };
 
-const getLatestClipImageEmbeddings = async () => {
-    const allEmbeddings = await getLatestEmbeddings();
-    return allEmbeddings.filter(
-        (embedding) => embedding.model === Model.GGML_CLIP
-    );
-};
-
-const getClipExtractionStatus = async (): Promise<ClipExtractionStatus> => {
+const getClipExtractionStatus = async (
+    model: Model = Model.ONNX_CLIP
+): Promise<ClipExtractionStatus> => {
     const user = getData(LS_KEYS.USER);
     if (!user) {
         return;
     }
-    const allEmbeddings = await getLocalClipImageEmbeddings();
+    const allEmbeddings = await getLocalEmbeddings(model);
     const localFiles = getPersonalFiles(await getLocalFiles(), user);
     const pendingFiles = await getNonClipEmbeddingExtractedFiles(
         localFiles,

+ 110 - 78
apps/photos/src/services/embeddingService.ts

@@ -2,6 +2,7 @@ import {
     Embedding,
     EncryptedEmbedding,
     GetEmbeddingDiffResponse,
+    Model,
     PutEmbeddingRequest,
 } from 'types/embedding';
 import ComlinkCryptoWorker from '@ente/shared/crypto';
@@ -16,105 +17,135 @@ import { getLatestVersionEmbeddings } from 'utils/embedding';
 import { getLocalTrashedFiles } from './trashService';
 import { getLocalCollections } from './collectionService';
 import { CustomError } from '@ente/shared/error';
+import { EnteFile } from 'types/file';
 
 const ENDPOINT = getEndpoint();
 
 const DIFF_LIMIT = 500;
 
-const EMBEDDINGS_TABLE = 'embeddings';
+const EMBEDDINGS_TABLE_V1 = 'embeddings';
+const EMBEDDINGS_TABLE = 'embeddings_v2';
 const EMBEDDING_SYNC_TIME_TABLE = 'embedding_sync_time';
 
-export const getLocalEmbeddings = async () => {
-    const embeddings: Array<Embedding> =
-        (await localForage.getItem<Embedding[]>(EMBEDDINGS_TABLE)) || [];
+export const getAllLocalEmbeddings = async () => {
+    const embeddings: Array<Embedding> = await localForage.getItem<Embedding[]>(
+        EMBEDDINGS_TABLE
+    );
+    if (!embeddings) {
+        await localForage.removeItem(EMBEDDINGS_TABLE_V1);
+        await localForage.removeItem(EMBEDDING_SYNC_TIME_TABLE);
+        await localForage.setItem(EMBEDDINGS_TABLE, []);
+        return [];
+    }
     return embeddings;
 };
 
-const getEmbeddingSyncTime = async () => {
-    return (await localForage.getItem<number>(EMBEDDING_SYNC_TIME_TABLE)) ?? 0;
+export const getLocalEmbeddings = async (model: Model) => {
+    const embeddings = await getAllLocalEmbeddings();
+    return embeddings.filter((embedding) => embedding.model === model);
 };
 
-export const getLatestEmbeddings = async () => {
-    await syncEmbeddings();
-    const embeddings = await getLocalEmbeddings();
-    return embeddings;
+const getModelEmbeddingSyncTime = async (model: Model) => {
+    return (
+        (await localForage.getItem<number>(
+            `${model}-${EMBEDDING_SYNC_TIME_TABLE}`
+        )) ?? 0
+    );
 };
 
-export const syncEmbeddings = async () => {
+const setModelEmbeddingSyncTime = async (model: Model, time: number) => {
+    await localForage.setItem(`${model}-${EMBEDDING_SYNC_TIME_TABLE}`, time);
+};
+
+export const syncEmbeddings = async (models: Model[] = [Model.ONNX_CLIP]) => {
     try {
-        let embeddings = await getLocalEmbeddings();
+        let allEmbeddings = await getAllLocalEmbeddings();
         const localFiles = await getAllLocalFiles();
         const hiddenAlbums = await getLocalCollections('hidden');
         const localTrashFiles = await getLocalTrashedFiles();
         const fileIdToKeyMap = new Map<number, string>();
-        [...localFiles, ...localTrashFiles].forEach((file) => {
+        const allLocalFiles = [...localFiles, ...localTrashFiles];
+        allLocalFiles.forEach((file) => {
             fileIdToKeyMap.set(file.id, file.key);
         });
-        addLogLine(`Syncing embeddings localCount: ${embeddings.length}`);
-        let sinceTime = await getEmbeddingSyncTime();
-        addLogLine(`Syncing embeddings sinceTime: ${sinceTime}`);
-        let response: GetEmbeddingDiffResponse;
-        do {
-            response = await getEmbeddingsDiff(sinceTime);
-            if (!response.diff?.length) {
-                return;
-            }
-            const newEmbeddings = await Promise.all(
-                response.diff.map(async (embedding) => {
-                    try {
-                        const {
-                            encryptedEmbedding,
-                            decryptionHeader,
-                            ...rest
-                        } = embedding;
-                        const worker = await ComlinkCryptoWorker.getInstance();
-                        const fileKey = fileIdToKeyMap.get(embedding.fileID);
-                        if (!fileKey) {
-                            throw Error(CustomError.FILE_NOT_FOUND);
-                        }
-                        const decryptedData = await worker.decryptEmbedding(
-                            encryptedEmbedding,
-                            decryptionHeader,
-                            fileIdToKeyMap.get(embedding.fileID)
-                        );
-
-                        return {
-                            ...rest,
-                            embedding: decryptedData,
-                        } as Embedding;
-                    } catch (e) {
-                        let info: Record<string, unknown>;
-                        if (e.message === CustomError.FILE_NOT_FOUND) {
-                            const hasHiddenAlbums = hiddenAlbums?.length > 0;
-                            info = {
-                                hasHiddenAlbums,
-                            };
-                        }
-                        logError(e, 'decryptEmbedding failed for file', info);
-                    }
-                })
-            );
-            embeddings = getLatestVersionEmbeddings([
-                ...embeddings,
-                ...newEmbeddings,
-            ]);
-            if (response.diff.length) {
-                sinceTime = response.diff.slice(-1)[0].updatedAt;
-            }
-            await localForage.setItem(EMBEDDINGS_TABLE, embeddings);
-            await localForage.setItem(EMBEDDING_SYNC_TIME_TABLE, sinceTime);
+        await cleanupDeletedEmbeddings(allLocalFiles, allEmbeddings);
+        addLogLine(`Syncing embeddings localCount: ${allEmbeddings.length}`);
+        for (const model of models) {
+            let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
             addLogLine(
-                `Syncing embeddings syncedEmbeddingsCount: ${newEmbeddings.length}`
+                `Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`
             );
-        } while (response.diff.length === DIFF_LIMIT);
-        void cleanupDeletedEmbeddings();
+            let response: GetEmbeddingDiffResponse;
+            do {
+                response = await getEmbeddingsDiff(modelLastSinceTime, model);
+                if (!response.diff?.length) {
+                    return;
+                }
+                const newEmbeddings = await Promise.all(
+                    response.diff.map(async (embedding) => {
+                        try {
+                            const {
+                                encryptedEmbedding,
+                                decryptionHeader,
+                                ...rest
+                            } = embedding;
+                            const worker =
+                                await ComlinkCryptoWorker.getInstance();
+                            const fileKey = fileIdToKeyMap.get(
+                                embedding.fileID
+                            );
+                            if (!fileKey) {
+                                throw Error(CustomError.FILE_NOT_FOUND);
+                            }
+                            const decryptedData = await worker.decryptEmbedding(
+                                encryptedEmbedding,
+                                decryptionHeader,
+                                fileIdToKeyMap.get(embedding.fileID)
+                            );
+
+                            return {
+                                ...rest,
+                                embedding: decryptedData,
+                            } as Embedding;
+                        } catch (e) {
+                            let info: Record<string, unknown>;
+                            if (e.message === CustomError.FILE_NOT_FOUND) {
+                                const hasHiddenAlbums =
+                                    hiddenAlbums?.length > 0;
+                                info = {
+                                    hasHiddenAlbums,
+                                };
+                            }
+                            logError(
+                                e,
+                                'decryptEmbedding failed for file',
+                                info
+                            );
+                        }
+                    })
+                );
+                allEmbeddings = getLatestVersionEmbeddings([
+                    ...allEmbeddings,
+                    ...newEmbeddings,
+                ]);
+                if (response.diff.length) {
+                    modelLastSinceTime = response.diff.slice(-1)[0].updatedAt;
+                }
+                await localForage.setItem(EMBEDDINGS_TABLE, allEmbeddings);
+                await setModelEmbeddingSyncTime(model, modelLastSinceTime);
+                addLogLine(
+                    `Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`
+                );
+            } while (response.diff.length === DIFF_LIMIT);
+        }
     } catch (e) {
         logError(e, 'Sync embeddings failed');
     }
 };
 
 export const getEmbeddingsDiff = async (
-    sinceTime: number
+    sinceTime: number,
+    model: Model
 ): Promise<GetEmbeddingDiffResponse> => {
     try {
         const token = getToken();
@@ -126,6 +157,7 @@ export const getEmbeddingsDiff = async (
             {
                 sinceTime,
                 limit: DIFF_LIMIT,
+                model,
             },
             {
                 'X-Auth-Token': token,
@@ -161,21 +193,21 @@ export const putEmbedding = async (
     }
 };
 
-export const cleanupDeletedEmbeddings = async () => {
-    const files = await getAllLocalFiles();
-    const trashedFiles = await getLocalTrashedFiles();
+export const cleanupDeletedEmbeddings = async (
+    allLocalFiles: EnteFile[],
+    allLocalEmbeddings: Embedding[]
+) => {
     const activeFileIds = new Set<number>();
-    [...files, ...trashedFiles].forEach((file) => {
+    allLocalFiles.forEach((file) => {
         activeFileIds.add(file.id);
     });
-    const embeddings = await getLocalEmbeddings();
 
-    const remainingEmbeddings = embeddings.filter((embedding) =>
+    const remainingEmbeddings = allLocalEmbeddings.filter((embedding) =>
         activeFileIds.has(embedding.fileID)
     );
-    if (embeddings.length !== remainingEmbeddings.length) {
+    if (allLocalEmbeddings.length !== remainingEmbeddings.length) {
         addLogLine(
-            `cleanupDeletedEmbeddings embeddingsCount: ${embeddings.length} remainingEmbeddingsCount: ${remainingEmbeddings.length}`
+            `cleanupDeletedEmbeddings embeddingsCount: ${allLocalEmbeddings.length} remainingEmbeddingsCount: ${remainingEmbeddings.length}`
         );
         await localForage.setItem(EMBEDDINGS_TABLE, remainingEmbeddings);
     }

+ 4 - 6
apps/photos/src/services/searchService.ts

@@ -27,12 +27,10 @@ import { getLatestEntities } from './entityService';
 import { LocationTag, LocationTagData, EntityType } from 'types/entity';
 import { addLogLine } from '@ente/shared/logging';
 import { FILE_TYPE } from 'constants/file';
-import {
-    ClipService,
-    computeClipMatchScore,
-    getLocalClipImageEmbeddings,
-} from './clipService';
+import { ClipService, computeClipMatchScore } from './clipService';
 import { CustomError } from '@ente/shared/error';
+import { Model } from 'types/embedding';
+import { getLocalEmbeddings } from './embeddingService';
 
 const DIGITS = new Set(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']);
 
@@ -389,7 +387,7 @@ async function searchThing(searchPhrase: string) {
 }
 
 async function searchClip(searchPhrase: string): Promise<ClipSearchScores> {
-    const imageEmbeddings = await getLocalClipImageEmbeddings();
+    const imageEmbeddings = await getLocalEmbeddings(Model.ONNX_CLIP);
     const textEmbedding = await ClipService.getTextEmbedding(searchPhrase);
     const clipSearchResult = new Map<number, number>(
         (

+ 1 - 0
apps/photos/src/types/embedding.tsx

@@ -1,5 +1,6 @@
 export enum Model {
     GGML_CLIP = 'ggml-clip',
+    ONNX_CLIP = 'onnx-clip',
 }
 
 export interface EncryptedEmbedding {

+ 10 - 2
packages/shared/electron/types.ts

@@ -7,6 +7,11 @@ export interface AppUpdateInfo {
     version: string;
 }
 
+export enum Model {
+    GGML_CLIP = 'ggml-clip',
+    ONNX_CLIP = 'onnx-clip',
+}
+
 export interface ElectronAPIsType {
     exists: (path: string) => boolean;
     checkExistsAndCreateDir: (dirPath: string) => Promise<void>;
@@ -97,8 +102,11 @@ export interface ElectronAPIsType {
     deleteFile: (path: string) => void;
     rename: (oldPath: string, newPath: string) => Promise<void>;
     updateOptOutOfCrashReports: (optOut: boolean) => Promise<void>;
-    computeImageEmbedding: (imageData: Uint8Array) => Promise<Float32Array>;
-    computeTextEmbedding: (text: string) => Promise<Float32Array>;
+    computeImageEmbedding: (
+        model: Model,
+        imageData: Uint8Array
+    ) => Promise<Float32Array>;
+    computeTextEmbedding: (model: Model, text: string) => Promise<Float32Array>;
     getPlatform: () => Promise<'mac' | 'windows' | 'linux'>;
     setCustomCacheDirectory: (directory: string) => Promise<void>;
     getCacheDirectory: () => Promise<string>;