Onnx clip UI (#1543)

This commit is contained in:
Abhinav Kumar 2024-01-16 11:45:59 +05:30 committed by GitHub
commit 9fbcf55e4e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 196 additions and 139 deletions

View file

@ -52,15 +52,13 @@ export default function AdvancedSettings({ open, onClose, onRootClose }) {
});
useEffect(() => {
ClipService.setOnUpdateHandler(setIndexingStatus);
const main = async () => {
setIndexingStatus(await ClipService.getIndexingStatus());
ClipService.setOnUpdateHandler(setIndexingStatus);
};
main();
}, []);
useEffect(() => {
if (open) {
ClipService.updateIndexStatus();
}
}, [open]);
return (
<EnteDrawer
transitionDuration={0}
@ -112,7 +110,9 @@ export default function AdvancedSettings({ open, onClose, onRootClose }) {
{isElectron() && (
<Box>
<MenuSectionTitle title={t('STATUS')} />
<MenuSectionTitle
title={t('MAGIC_SEARCH_STATUS')}
/>
<Stack py={'12px'} px={'12px'} spacing={'24px'}>
<VerticallyCenteredFlex
justifyContent="space-between"

View file

@ -1,8 +1,4 @@
import {
putEmbedding,
getLatestEmbeddings,
getLocalEmbeddings,
} from './embeddingService';
import { putEmbedding, getLocalEmbeddings } from './embeddingService';
import { getAllLocalFiles, getLocalFiles } from './fileService';
import downloadManager from './download';
import { logError } from '@ente/shared/sentry';
@ -100,14 +96,18 @@ class ClipServiceImpl {
}
};
updateIndexStatus = async () => {
getIndexingStatus = async () => {
try {
addLogLine('loading local clip index status');
this.clipExtractionStatus = await getClipExtractionStatus();
this.onUpdateHandler(this.clipExtractionStatus);
addLogLine('loaded local clip index status');
if (
!this.clipExtractionStatus ||
(this.clipExtractionStatus.pending === 0 &&
this.clipExtractionStatus.indexed === 0)
) {
this.clipExtractionStatus = await getClipExtractionStatus();
}
return this.clipExtractionStatus;
} catch (e) {
logError(e, 'failed to load local clip index status');
logError(e, 'failed to get clip indexing status');
}
};
@ -116,7 +116,9 @@ class ClipServiceImpl {
handler(this.clipExtractionStatus);
};
scheduleImageEmbeddingExtraction = async () => {
scheduleImageEmbeddingExtraction = async (
model: Model = Model.ONNX_CLIP
) => {
try {
if (this.embeddingExtractionInProgress) {
addLogLine(
@ -132,7 +134,7 @@ class ClipServiceImpl {
const canceller = new AbortController();
this.embeddingExtractionInProgress = canceller;
try {
await this.runClipEmbeddingExtraction(canceller);
await this.runClipEmbeddingExtraction(canceller, model);
} finally {
this.embeddingExtractionInProgress = null;
if (!canceller.signal.aborted && this.reRunNeeded) {
@ -151,9 +153,12 @@ class ClipServiceImpl {
}
};
getTextEmbedding = async (text: string): Promise<Float32Array> => {
getTextEmbedding = async (
text: string,
model: Model = Model.ONNX_CLIP
): Promise<Float32Array> => {
try {
return ElectronAPIs.computeTextEmbedding(text);
return ElectronAPIs.computeTextEmbedding(model, text);
} catch (e) {
if (e?.message?.includes(CustomError.UNSUPPORTED_PLATFORM)) {
this.unsupportedPlatform = true;
@ -163,7 +168,10 @@ class ClipServiceImpl {
}
};
private runClipEmbeddingExtraction = async (canceller: AbortController) => {
private runClipEmbeddingExtraction = async (
canceller: AbortController,
model: Model
) => {
try {
if (this.unsupportedPlatform) {
addLogLine(
@ -176,7 +184,7 @@ class ClipServiceImpl {
return;
}
const localFiles = getPersonalFiles(await getAllLocalFiles(), user);
const existingEmbeddings = await getLatestClipImageEmbeddings();
const existingEmbeddings = await getLocalEmbeddings(model);
const pendingFiles = await getNonClipEmbeddingExtractedFiles(
localFiles,
existingEmbeddings
@ -201,11 +209,15 @@ class ClipServiceImpl {
throw Error(CustomError.REQUEST_CANCELLED);
}
const embeddingData =
await this.extractFileClipImageEmbedding(file);
await this.extractFileClipImageEmbedding(model, file);
addLogLine(
`successfully extracted clip embedding for file: ${file.metadata.title} fileID: ${file.id} embedding length: ${embeddingData?.length}`
);
await this.encryptAndUploadEmbedding(file, embeddingData);
await this.encryptAndUploadEmbedding(
model,
file,
embeddingData
);
this.onSuccessStatusUpdater();
addLogLine(
`successfully put clip embedding to server for file: ${file.metadata.title} fileID: ${file.id}`
@ -238,10 +250,13 @@ class ClipServiceImpl {
}
};
private async runLocalFileClipExtraction(arg: {
enteFile: EnteFile;
localFile: globalThis.File;
}) {
private async runLocalFileClipExtraction(
arg: {
enteFile: EnteFile;
localFile: globalThis.File;
},
model: Model = Model.ONNX_CLIP
) {
const { enteFile, localFile } = arg;
addLogLine(
`clip embedding extraction onFileUploadedHandler file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
@ -256,9 +271,14 @@ class ClipServiceImpl {
try {
await this.liveEmbeddingExtractionQueue.add(async () => {
const embedding = await this.extractLocalFileClipImageEmbedding(
model,
localFile
);
await this.encryptAndUploadEmbedding(enteFile, embedding);
await this.encryptAndUploadEmbedding(
model,
enteFile,
embedding
);
});
addLogLine(
`successfully extracted clip embedding for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`
@ -268,15 +288,19 @@ class ClipServiceImpl {
}
}
private extractLocalFileClipImageEmbedding = async (localFile: File) => {
private extractLocalFileClipImageEmbedding = async (
model: Model,
localFile: File
) => {
const file = await localFile
.arrayBuffer()
.then((buffer) => new Uint8Array(buffer));
const embedding = await ElectronAPIs.computeImageEmbedding(file);
const embedding = await ElectronAPIs.computeImageEmbedding(model, file);
return embedding;
};
private encryptAndUploadEmbedding = async (
model: Model,
file: EnteFile,
embeddingData: Float32Array
) => {
@ -295,7 +319,7 @@ class ClipServiceImpl {
fileID: file.id,
encryptedEmbedding: encryptedEmbeddingData.encryptedData,
decryptionHeader: encryptedEmbeddingData.decryptionHeader,
model: Model.GGML_CLIP,
model,
});
};
@ -306,9 +330,15 @@ class ClipServiceImpl {
}
};
private extractFileClipImageEmbedding = async (file: EnteFile) => {
private extractFileClipImageEmbedding = async (
model: Model,
file: EnteFile
) => {
const thumb = await downloadManager.getThumbnail(file);
const embedding = await ElectronAPIs.computeImageEmbedding(thumb);
const embedding = await ElectronAPIs.computeImageEmbedding(
model,
thumb
);
return embedding;
};
@ -343,13 +373,6 @@ const getNonClipEmbeddingExtractedFiles = async (
});
};
export const getLocalClipImageEmbeddings = async () => {
const allEmbeddings = await getLocalEmbeddings();
return allEmbeddings.filter(
(embedding) => embedding.model === Model.GGML_CLIP
);
};
export const computeClipMatchScore = async (
imageEmbedding: Float32Array,
textEmbedding: Float32Array
@ -377,19 +400,14 @@ export const computeClipMatchScore = async (
return score;
};
const getLatestClipImageEmbeddings = async () => {
const allEmbeddings = await getLatestEmbeddings();
return allEmbeddings.filter(
(embedding) => embedding.model === Model.GGML_CLIP
);
};
const getClipExtractionStatus = async (): Promise<ClipExtractionStatus> => {
const getClipExtractionStatus = async (
model: Model = Model.ONNX_CLIP
): Promise<ClipExtractionStatus> => {
const user = getData(LS_KEYS.USER);
if (!user) {
return;
}
const allEmbeddings = await getLocalClipImageEmbeddings();
const allEmbeddings = await getLocalEmbeddings(model);
const localFiles = getPersonalFiles(await getLocalFiles(), user);
const pendingFiles = await getNonClipEmbeddingExtractedFiles(
localFiles,

View file

@ -2,6 +2,7 @@ import {
Embedding,
EncryptedEmbedding,
GetEmbeddingDiffResponse,
Model,
PutEmbeddingRequest,
} from 'types/embedding';
import ComlinkCryptoWorker from '@ente/shared/crypto';
@ -16,105 +17,135 @@ import { getLatestVersionEmbeddings } from 'utils/embedding';
import { getLocalTrashedFiles } from './trashService';
import { getLocalCollections } from './collectionService';
import { CustomError } from '@ente/shared/error';
import { EnteFile } from 'types/file';
const ENDPOINT = getEndpoint();
const DIFF_LIMIT = 500;
const EMBEDDINGS_TABLE = 'embeddings';
const EMBEDDINGS_TABLE_V1 = 'embeddings';
const EMBEDDINGS_TABLE = 'embeddings_v2';
const EMBEDDING_SYNC_TIME_TABLE = 'embedding_sync_time';
export const getLocalEmbeddings = async () => {
const embeddings: Array<Embedding> =
(await localForage.getItem<Embedding[]>(EMBEDDINGS_TABLE)) || [];
export const getAllLocalEmbeddings = async () => {
const embeddings: Array<Embedding> = await localForage.getItem<Embedding[]>(
EMBEDDINGS_TABLE
);
if (!embeddings) {
await localForage.removeItem(EMBEDDINGS_TABLE_V1);
await localForage.removeItem(EMBEDDING_SYNC_TIME_TABLE);
await localForage.setItem(EMBEDDINGS_TABLE, []);
return [];
}
return embeddings;
};
const getEmbeddingSyncTime = async () => {
return (await localForage.getItem<number>(EMBEDDING_SYNC_TIME_TABLE)) ?? 0;
export const getLocalEmbeddings = async (model: Model) => {
const embeddings = await getAllLocalEmbeddings();
return embeddings.filter((embedding) => embedding.model === model);
};
export const getLatestEmbeddings = async () => {
await syncEmbeddings();
const embeddings = await getLocalEmbeddings();
return embeddings;
const getModelEmbeddingSyncTime = async (model: Model) => {
return (
(await localForage.getItem<number>(
`${model}-${EMBEDDING_SYNC_TIME_TABLE}`
)) ?? 0
);
};
export const syncEmbeddings = async () => {
const setModelEmbeddingSyncTime = async (model: Model, time: number) => {
await localForage.setItem(`${model}-${EMBEDDING_SYNC_TIME_TABLE}`, time);
};
export const syncEmbeddings = async (models: Model[] = [Model.ONNX_CLIP]) => {
try {
let embeddings = await getLocalEmbeddings();
let allEmbeddings = await getAllLocalEmbeddings();
const localFiles = await getAllLocalFiles();
const hiddenAlbums = await getLocalCollections('hidden');
const localTrashFiles = await getLocalTrashedFiles();
const fileIdToKeyMap = new Map<number, string>();
[...localFiles, ...localTrashFiles].forEach((file) => {
const allLocalFiles = [...localFiles, ...localTrashFiles];
allLocalFiles.forEach((file) => {
fileIdToKeyMap.set(file.id, file.key);
});
addLogLine(`Syncing embeddings localCount: ${embeddings.length}`);
let sinceTime = await getEmbeddingSyncTime();
addLogLine(`Syncing embeddings sinceTime: ${sinceTime}`);
let response: GetEmbeddingDiffResponse;
do {
response = await getEmbeddingsDiff(sinceTime);
if (!response.diff?.length) {
return;
}
const newEmbeddings = await Promise.all(
response.diff.map(async (embedding) => {
try {
const {
encryptedEmbedding,
decryptionHeader,
...rest
} = embedding;
const worker = await ComlinkCryptoWorker.getInstance();
const fileKey = fileIdToKeyMap.get(embedding.fileID);
if (!fileKey) {
throw Error(CustomError.FILE_NOT_FOUND);
}
const decryptedData = await worker.decryptEmbedding(
encryptedEmbedding,
decryptionHeader,
fileIdToKeyMap.get(embedding.fileID)
);
return {
...rest,
embedding: decryptedData,
} as Embedding;
} catch (e) {
let info: Record<string, unknown>;
if (e.message === CustomError.FILE_NOT_FOUND) {
const hasHiddenAlbums = hiddenAlbums?.length > 0;
info = {
hasHiddenAlbums,
};
}
logError(e, 'decryptEmbedding failed for file', info);
}
})
);
embeddings = getLatestVersionEmbeddings([
...embeddings,
...newEmbeddings,
]);
if (response.diff.length) {
sinceTime = response.diff.slice(-1)[0].updatedAt;
}
await localForage.setItem(EMBEDDINGS_TABLE, embeddings);
await localForage.setItem(EMBEDDING_SYNC_TIME_TABLE, sinceTime);
await cleanupDeletedEmbeddings(allLocalFiles, allEmbeddings);
addLogLine(`Syncing embeddings localCount: ${allEmbeddings.length}`);
for (const model of models) {
let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
addLogLine(
`Syncing embeddings syncedEmbeddingsCount: ${newEmbeddings.length}`
`Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`
);
} while (response.diff.length === DIFF_LIMIT);
void cleanupDeletedEmbeddings();
let response: GetEmbeddingDiffResponse;
do {
response = await getEmbeddingsDiff(modelLastSinceTime, model);
if (!response.diff?.length) {
return;
}
const newEmbeddings = await Promise.all(
response.diff.map(async (embedding) => {
try {
const {
encryptedEmbedding,
decryptionHeader,
...rest
} = embedding;
const worker =
await ComlinkCryptoWorker.getInstance();
const fileKey = fileIdToKeyMap.get(
embedding.fileID
);
if (!fileKey) {
throw Error(CustomError.FILE_NOT_FOUND);
}
const decryptedData = await worker.decryptEmbedding(
encryptedEmbedding,
decryptionHeader,
fileIdToKeyMap.get(embedding.fileID)
);
return {
...rest,
embedding: decryptedData,
} as Embedding;
} catch (e) {
let info: Record<string, unknown>;
if (e.message === CustomError.FILE_NOT_FOUND) {
const hasHiddenAlbums =
hiddenAlbums?.length > 0;
info = {
hasHiddenAlbums,
};
}
logError(
e,
'decryptEmbedding failed for file',
info
);
}
})
);
allEmbeddings = getLatestVersionEmbeddings([
...allEmbeddings,
...newEmbeddings,
]);
if (response.diff.length) {
modelLastSinceTime = response.diff.slice(-1)[0].updatedAt;
}
await localForage.setItem(EMBEDDINGS_TABLE, allEmbeddings);
await setModelEmbeddingSyncTime(model, modelLastSinceTime);
addLogLine(
`Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`
);
} while (response.diff.length === DIFF_LIMIT);
}
} catch (e) {
logError(e, 'Sync embeddings failed');
}
};
export const getEmbeddingsDiff = async (
sinceTime: number
sinceTime: number,
model: Model
): Promise<GetEmbeddingDiffResponse> => {
try {
const token = getToken();
@ -126,6 +157,7 @@ export const getEmbeddingsDiff = async (
{
sinceTime,
limit: DIFF_LIMIT,
model,
},
{
'X-Auth-Token': token,
@ -161,21 +193,21 @@ export const putEmbedding = async (
}
};
export const cleanupDeletedEmbeddings = async () => {
const files = await getAllLocalFiles();
const trashedFiles = await getLocalTrashedFiles();
export const cleanupDeletedEmbeddings = async (
allLocalFiles: EnteFile[],
allLocalEmbeddings: Embedding[]
) => {
const activeFileIds = new Set<number>();
[...files, ...trashedFiles].forEach((file) => {
allLocalFiles.forEach((file) => {
activeFileIds.add(file.id);
});
const embeddings = await getLocalEmbeddings();
const remainingEmbeddings = embeddings.filter((embedding) =>
const remainingEmbeddings = allLocalEmbeddings.filter((embedding) =>
activeFileIds.has(embedding.fileID)
);
if (embeddings.length !== remainingEmbeddings.length) {
if (allLocalEmbeddings.length !== remainingEmbeddings.length) {
addLogLine(
`cleanupDeletedEmbeddings embeddingsCount: ${embeddings.length} remainingEmbeddingsCount: ${remainingEmbeddings.length}`
`cleanupDeletedEmbeddings embeddingsCount: ${allLocalEmbeddings.length} remainingEmbeddingsCount: ${remainingEmbeddings.length}`
);
await localForage.setItem(EMBEDDINGS_TABLE, remainingEmbeddings);
}

View file

@ -27,12 +27,10 @@ import { getLatestEntities } from './entityService';
import { LocationTag, LocationTagData, EntityType } from 'types/entity';
import { addLogLine } from '@ente/shared/logging';
import { FILE_TYPE } from 'constants/file';
import {
ClipService,
computeClipMatchScore,
getLocalClipImageEmbeddings,
} from './clipService';
import { ClipService, computeClipMatchScore } from './clipService';
import { CustomError } from '@ente/shared/error';
import { Model } from 'types/embedding';
import { getLocalEmbeddings } from './embeddingService';
const DIGITS = new Set(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']);
@ -389,7 +387,7 @@ async function searchThing(searchPhrase: string) {
}
async function searchClip(searchPhrase: string): Promise<ClipSearchScores> {
const imageEmbeddings = await getLocalClipImageEmbeddings();
const imageEmbeddings = await getLocalEmbeddings(Model.ONNX_CLIP);
const textEmbedding = await ClipService.getTextEmbedding(searchPhrase);
const clipSearchResult = new Map<number, number>(
(

View file

@ -1,5 +1,6 @@
export enum Model {
GGML_CLIP = 'ggml-clip',
ONNX_CLIP = 'onnx-clip',
}
export interface EncryptedEmbedding {

View file

@ -7,6 +7,11 @@ export interface AppUpdateInfo {
version: string;
}
export enum Model {
GGML_CLIP = 'ggml-clip',
ONNX_CLIP = 'onnx-clip',
}
export interface ElectronAPIsType {
exists: (path: string) => boolean;
checkExistsAndCreateDir: (dirPath: string) => Promise<void>;
@ -97,8 +102,11 @@ export interface ElectronAPIsType {
deleteFile: (path: string) => void;
rename: (oldPath: string, newPath: string) => Promise<void>;
updateOptOutOfCrashReports: (optOut: boolean) => Promise<void>;
computeImageEmbedding: (imageData: Uint8Array) => Promise<Float32Array>;
computeTextEmbedding: (text: string) => Promise<Float32Array>;
computeImageEmbedding: (
model: Model,
imageData: Uint8Array
) => Promise<Float32Array>;
computeTextEmbedding: (model: Model, text: string) => Promise<Float32Array>;
getPlatform: () => Promise<'mac' | 'windows' | 'linux'>;
setCustomCacheDirectory: (directory: string) => Promise<void>;
getCacheDirectory: () => Promise<string>;