|
@@ -1,5 +1,6 @@
|
|
import { ensureElectron } from "@/next/electron";
|
|
import { ensureElectron } from "@/next/electron";
|
|
import log from "@/next/log";
|
|
import log from "@/next/log";
|
|
|
|
+import type { Electron } from "@/next/types/ipc";
|
|
import ComlinkCryptoWorker from "@ente/shared/crypto";
|
|
import ComlinkCryptoWorker from "@ente/shared/crypto";
|
|
import { CustomError } from "@ente/shared/error";
|
|
import { CustomError } from "@ente/shared/error";
|
|
import { Events, eventBus } from "@ente/shared/events";
|
|
import { Events, eventBus } from "@ente/shared/events";
|
|
@@ -7,29 +8,71 @@ import { LS_KEYS, getData } from "@ente/shared/storage/localStorage";
|
|
import { FILE_TYPE } from "constants/file";
|
|
import { FILE_TYPE } from "constants/file";
|
|
import isElectron from "is-electron";
|
|
import isElectron from "is-electron";
|
|
import PQueue from "p-queue";
|
|
import PQueue from "p-queue";
|
|
-import { Embedding, Model } from "types/embedding";
|
|
|
|
|
|
+import { Embedding } from "types/embedding";
|
|
import { EnteFile } from "types/file";
|
|
import { EnteFile } from "types/file";
|
|
import { getPersonalFiles } from "utils/file";
|
|
import { getPersonalFiles } from "utils/file";
|
|
import downloadManager from "./download";
|
|
import downloadManager from "./download";
|
|
import { getLocalEmbeddings, putEmbedding } from "./embeddingService";
|
|
import { getLocalEmbeddings, putEmbedding } from "./embeddingService";
|
|
import { getAllLocalFiles, getLocalFiles } from "./fileService";
|
|
import { getAllLocalFiles, getLocalFiles } from "./fileService";
|
|
|
|
|
|
-const CLIP_EMBEDDING_LENGTH = 512;
|
|
|
|
-
|
|
|
|
-export interface ClipExtractionStatus {
|
|
|
|
|
|
+/** Status of CLIP indexing on the images in the user's local library. */
|
|
|
|
+export interface CLIPIndexingStatus {
|
|
|
|
+ /** Number of items pending indexing. */
|
|
pending: number;
|
|
pending: number;
|
|
|
|
+ /** Number of items that have already been indexed. */
|
|
indexed: number;
|
|
indexed: number;
|
|
}
|
|
}
|
|
|
|
|
|
-class ClipServiceImpl {
|
|
|
|
|
|
+/**
|
|
|
|
+ * Use a CLIP based neural network for natural language search.
|
|
|
|
+ *
|
|
|
|
+ * [Note: CLIP based magic search]
|
|
|
|
+ *
|
|
|
|
+ * CLIP (Contrastive Language-Image Pretraining) is a neural network trained on
|
|
|
|
+ * (image, text) pairs. It can be thought of as two separate (but jointly
|
|
|
|
+ * trained) encoders - one for images, and one for text - that both map to the
|
|
|
|
+ * same embedding space.
|
|
|
|
+ *
|
|
|
|
+ * We use this for natural language search within the app (aka "magic search"):
|
|
|
|
+ *
|
|
|
|
+ * 1. Pre-compute an embedding for each image.
|
|
|
|
+ *
|
|
|
|
+ * 2. When the user searches, compute an embedding for the search term.
|
|
|
|
+ *
|
|
|
|
+ * 3. Use cosine similarity to find the find the image (embedding) closest to
|
|
|
|
+ * the text (embedding).
|
|
|
|
+ *
|
|
|
|
+ * More details are in our [blog
|
|
|
|
+ * post](https://ente.io/blog/image-search-with-clip-ggml/) that describes the
|
|
|
|
+ * initial launch of this feature using the GGML runtime.
|
|
|
|
+ *
|
|
|
|
+ * Since the initial launch, we've switched over to another runtime,
|
|
|
|
+ * [ONNX](https://onnxruntime.ai).
|
|
|
|
+ *
|
|
|
|
+ * Note that we don't train the neural network - we only use one of the publicly
|
|
|
|
+ * available pre-trained neural networks for inference. These neural networks
|
|
|
|
+ * are wholly defined by their connectivity and weights. ONNX, our ML runtimes,
|
|
|
|
+ * loads these weights and instantiates a running network that we can use to
|
|
|
|
+ * compute the embeddings.
|
|
|
|
+ *
|
|
|
|
+ * Theoretically, the same CLIP model can be loaded by different frameworks /
|
|
|
|
+ * runtimes, but in practice each runtime has its own preferred format, and
|
|
|
|
+ * there are also quantization tradeoffs. So there is a specific model (a binary
|
|
|
|
+ * encoding of weights) tied to our current runtime that we use.
|
|
|
|
+ *
|
|
|
|
+ * To ensure that the embeddings, for the most part, can be shared, whenever
|
|
|
|
+ * possible we try to ensure that all the preprocessing steps, and the model
|
|
|
|
+ * itself, is the same across clients - web and mobile.
|
|
|
|
+ */
|
|
|
|
+class CLIPService {
|
|
|
|
+ private electron: Electron;
|
|
private embeddingExtractionInProgress: AbortController | null = null;
|
|
private embeddingExtractionInProgress: AbortController | null = null;
|
|
private reRunNeeded = false;
|
|
private reRunNeeded = false;
|
|
- private clipExtractionStatus: ClipExtractionStatus = {
|
|
|
|
|
|
+ private indexingStatus: CLIPIndexingStatus = {
|
|
pending: 0,
|
|
pending: 0,
|
|
indexed: 0,
|
|
indexed: 0,
|
|
};
|
|
};
|
|
- private onUpdateHandler: ((status: ClipExtractionStatus) => void) | null =
|
|
|
|
- null;
|
|
|
|
|
|
+ private onUpdateHandler: ((status: CLIPIndexingStatus) => void) | undefined;
|
|
private liveEmbeddingExtractionQueue: PQueue;
|
|
private liveEmbeddingExtractionQueue: PQueue;
|
|
private onFileUploadedHandler:
|
|
private onFileUploadedHandler:
|
|
| ((arg: { enteFile: EnteFile; localFile: globalThis.File }) => void)
|
|
| ((arg: { enteFile: EnteFile; localFile: globalThis.File }) => void)
|
|
@@ -37,6 +80,7 @@ class ClipServiceImpl {
|
|
private unsupportedPlatform = false;
|
|
private unsupportedPlatform = false;
|
|
|
|
|
|
constructor() {
|
|
constructor() {
|
|
|
|
+ this.electron = ensureElectron();
|
|
this.liveEmbeddingExtractionQueue = new PQueue({
|
|
this.liveEmbeddingExtractionQueue = new PQueue({
|
|
concurrency: 1,
|
|
concurrency: 1,
|
|
});
|
|
});
|
|
@@ -96,28 +140,23 @@ class ClipServiceImpl {
|
|
};
|
|
};
|
|
|
|
|
|
getIndexingStatus = async () => {
|
|
getIndexingStatus = async () => {
|
|
- try {
|
|
|
|
- if (
|
|
|
|
- !this.clipExtractionStatus ||
|
|
|
|
- (this.clipExtractionStatus.pending === 0 &&
|
|
|
|
- this.clipExtractionStatus.indexed === 0)
|
|
|
|
- ) {
|
|
|
|
- this.clipExtractionStatus = await getClipExtractionStatus();
|
|
|
|
- }
|
|
|
|
- return this.clipExtractionStatus;
|
|
|
|
- } catch (e) {
|
|
|
|
- log.error("failed to get clip indexing status", e);
|
|
|
|
|
|
+ if (
|
|
|
|
+ this.indexingStatus.pending === 0 &&
|
|
|
|
+ this.indexingStatus.indexed === 0
|
|
|
|
+ ) {
|
|
|
|
+ this.indexingStatus = await initialIndexingStatus();
|
|
}
|
|
}
|
|
|
|
+ return this.indexingStatus;
|
|
};
|
|
};
|
|
|
|
|
|
- setOnUpdateHandler = (handler: (status: ClipExtractionStatus) => void) => {
|
|
|
|
|
|
+ /**
|
|
|
|
+ * Set the {@link handler} to invoke whenever our indexing status changes.
|
|
|
|
+ */
|
|
|
|
+ setOnUpdateHandler = (handler?: (status: CLIPIndexingStatus) => void) => {
|
|
this.onUpdateHandler = handler;
|
|
this.onUpdateHandler = handler;
|
|
- handler(this.clipExtractionStatus);
|
|
|
|
};
|
|
};
|
|
|
|
|
|
- scheduleImageEmbeddingExtraction = async (
|
|
|
|
- model: Model = Model.ONNX_CLIP,
|
|
|
|
- ) => {
|
|
|
|
|
|
+ scheduleImageEmbeddingExtraction = async () => {
|
|
try {
|
|
try {
|
|
if (this.embeddingExtractionInProgress) {
|
|
if (this.embeddingExtractionInProgress) {
|
|
log.info(
|
|
log.info(
|
|
@@ -133,7 +172,7 @@ class ClipServiceImpl {
|
|
const canceller = new AbortController();
|
|
const canceller = new AbortController();
|
|
this.embeddingExtractionInProgress = canceller;
|
|
this.embeddingExtractionInProgress = canceller;
|
|
try {
|
|
try {
|
|
- await this.runClipEmbeddingExtraction(canceller, model);
|
|
|
|
|
|
+ await this.runClipEmbeddingExtraction(canceller);
|
|
} finally {
|
|
} finally {
|
|
this.embeddingExtractionInProgress = null;
|
|
this.embeddingExtractionInProgress = null;
|
|
if (!canceller.signal.aborted && this.reRunNeeded) {
|
|
if (!canceller.signal.aborted && this.reRunNeeded) {
|
|
@@ -152,25 +191,19 @@ class ClipServiceImpl {
|
|
}
|
|
}
|
|
};
|
|
};
|
|
|
|
|
|
- getTextEmbedding = async (
|
|
|
|
- text: string,
|
|
|
|
- model: Model = Model.ONNX_CLIP,
|
|
|
|
- ): Promise<Float32Array> => {
|
|
|
|
|
|
+ getTextEmbedding = async (text: string): Promise<Float32Array> => {
|
|
try {
|
|
try {
|
|
- return ensureElectron().computeTextEmbedding(model, text);
|
|
|
|
|
|
+ return electron.clipTextEmbedding(text);
|
|
} catch (e) {
|
|
} catch (e) {
|
|
if (e?.message?.includes(CustomError.UNSUPPORTED_PLATFORM)) {
|
|
if (e?.message?.includes(CustomError.UNSUPPORTED_PLATFORM)) {
|
|
this.unsupportedPlatform = true;
|
|
this.unsupportedPlatform = true;
|
|
}
|
|
}
|
|
- log.error("failed to compute text embedding", e);
|
|
|
|
|
|
+ log.error("Failed to compute CLIP text embedding", e);
|
|
throw e;
|
|
throw e;
|
|
}
|
|
}
|
|
};
|
|
};
|
|
|
|
|
|
- private runClipEmbeddingExtraction = async (
|
|
|
|
- canceller: AbortController,
|
|
|
|
- model: Model,
|
|
|
|
- ) => {
|
|
|
|
|
|
+ private runClipEmbeddingExtraction = async (canceller: AbortController) => {
|
|
try {
|
|
try {
|
|
if (this.unsupportedPlatform) {
|
|
if (this.unsupportedPlatform) {
|
|
log.info(
|
|
log.info(
|
|
@@ -183,12 +216,12 @@ class ClipServiceImpl {
|
|
return;
|
|
return;
|
|
}
|
|
}
|
|
const localFiles = getPersonalFiles(await getAllLocalFiles(), user);
|
|
const localFiles = getPersonalFiles(await getAllLocalFiles(), user);
|
|
- const existingEmbeddings = await getLocalEmbeddings(model);
|
|
|
|
|
|
+ const existingEmbeddings = await getLocalEmbeddings();
|
|
const pendingFiles = await getNonClipEmbeddingExtractedFiles(
|
|
const pendingFiles = await getNonClipEmbeddingExtractedFiles(
|
|
localFiles,
|
|
localFiles,
|
|
existingEmbeddings,
|
|
existingEmbeddings,
|
|
);
|
|
);
|
|
- this.updateClipEmbeddingExtractionStatus({
|
|
|
|
|
|
+ this.updateIndexingStatus({
|
|
indexed: existingEmbeddings.length,
|
|
indexed: existingEmbeddings.length,
|
|
pending: pendingFiles.length,
|
|
pending: pendingFiles.length,
|
|
});
|
|
});
|
|
@@ -208,15 +241,11 @@ class ClipServiceImpl {
|
|
throw Error(CustomError.REQUEST_CANCELLED);
|
|
throw Error(CustomError.REQUEST_CANCELLED);
|
|
}
|
|
}
|
|
const embeddingData =
|
|
const embeddingData =
|
|
- await this.extractFileClipImageEmbedding(model, file);
|
|
|
|
|
|
+ await this.extractFileClipImageEmbedding(file);
|
|
log.info(
|
|
log.info(
|
|
`successfully extracted clip embedding for file: ${file.metadata.title} fileID: ${file.id} embedding length: ${embeddingData?.length}`,
|
|
`successfully extracted clip embedding for file: ${file.metadata.title} fileID: ${file.id} embedding length: ${embeddingData?.length}`,
|
|
);
|
|
);
|
|
- await this.encryptAndUploadEmbedding(
|
|
|
|
- model,
|
|
|
|
- file,
|
|
|
|
- embeddingData,
|
|
|
|
- );
|
|
|
|
|
|
+ await this.encryptAndUploadEmbedding(file, embeddingData);
|
|
this.onSuccessStatusUpdater();
|
|
this.onSuccessStatusUpdater();
|
|
log.info(
|
|
log.info(
|
|
`successfully put clip embedding to server for file: ${file.metadata.title} fileID: ${file.id}`,
|
|
`successfully put clip embedding to server for file: ${file.metadata.title} fileID: ${file.id}`,
|
|
@@ -249,13 +278,10 @@ class ClipServiceImpl {
|
|
}
|
|
}
|
|
};
|
|
};
|
|
|
|
|
|
- private async runLocalFileClipExtraction(
|
|
|
|
- arg: {
|
|
|
|
- enteFile: EnteFile;
|
|
|
|
- localFile: globalThis.File;
|
|
|
|
- },
|
|
|
|
- model: Model = Model.ONNX_CLIP,
|
|
|
|
- ) {
|
|
|
|
|
|
+ private async runLocalFileClipExtraction(arg: {
|
|
|
|
+ enteFile: EnteFile;
|
|
|
|
+ localFile: globalThis.File;
|
|
|
|
+ }) {
|
|
const { enteFile, localFile } = arg;
|
|
const { enteFile, localFile } = arg;
|
|
log.info(
|
|
log.info(
|
|
`clip embedding extraction onFileUploadedHandler file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
|
|
`clip embedding extraction onFileUploadedHandler file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
|
|
@@ -279,15 +305,9 @@ class ClipServiceImpl {
|
|
);
|
|
);
|
|
try {
|
|
try {
|
|
await this.liveEmbeddingExtractionQueue.add(async () => {
|
|
await this.liveEmbeddingExtractionQueue.add(async () => {
|
|
- const embedding = await this.extractLocalFileClipImageEmbedding(
|
|
|
|
- model,
|
|
|
|
- localFile,
|
|
|
|
- );
|
|
|
|
- await this.encryptAndUploadEmbedding(
|
|
|
|
- model,
|
|
|
|
- enteFile,
|
|
|
|
- embedding,
|
|
|
|
- );
|
|
|
|
|
|
+ const embedding =
|
|
|
|
+ await this.extractLocalFileClipImageEmbedding(localFile);
|
|
|
|
+ await this.encryptAndUploadEmbedding(enteFile, embedding);
|
|
});
|
|
});
|
|
log.info(
|
|
log.info(
|
|
`successfully extracted clip embedding for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
|
|
`successfully extracted clip embedding for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
|
|
@@ -297,26 +317,18 @@ class ClipServiceImpl {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
- private extractLocalFileClipImageEmbedding = async (
|
|
|
|
- model: Model,
|
|
|
|
- localFile: File,
|
|
|
|
- ) => {
|
|
|
|
|
|
+ private extractLocalFileClipImageEmbedding = async (localFile: File) => {
|
|
const file = await localFile
|
|
const file = await localFile
|
|
.arrayBuffer()
|
|
.arrayBuffer()
|
|
.then((buffer) => new Uint8Array(buffer));
|
|
.then((buffer) => new Uint8Array(buffer));
|
|
- const embedding = await ensureElectron().computeImageEmbedding(
|
|
|
|
- model,
|
|
|
|
- file,
|
|
|
|
- );
|
|
|
|
- return embedding;
|
|
|
|
|
|
+ return await electron.clipImageEmbedding(file);
|
|
};
|
|
};
|
|
|
|
|
|
private encryptAndUploadEmbedding = async (
|
|
private encryptAndUploadEmbedding = async (
|
|
- model: Model,
|
|
|
|
file: EnteFile,
|
|
file: EnteFile,
|
|
embeddingData: Float32Array,
|
|
embeddingData: Float32Array,
|
|
) => {
|
|
) => {
|
|
- if (embeddingData?.length !== CLIP_EMBEDDING_LENGTH) {
|
|
|
|
|
|
+ if (embeddingData?.length !== 512) {
|
|
throw Error(
|
|
throw Error(
|
|
`invalid length embedding data length: ${embeddingData?.length}`,
|
|
`invalid length embedding data length: ${embeddingData?.length}`,
|
|
);
|
|
);
|
|
@@ -331,38 +343,31 @@ class ClipServiceImpl {
|
|
fileID: file.id,
|
|
fileID: file.id,
|
|
encryptedEmbedding: encryptedEmbeddingData.encryptedData,
|
|
encryptedEmbedding: encryptedEmbeddingData.encryptedData,
|
|
decryptionHeader: encryptedEmbeddingData.decryptionHeader,
|
|
decryptionHeader: encryptedEmbeddingData.decryptionHeader,
|
|
- model,
|
|
|
|
|
|
+ model: "onnx-clip",
|
|
});
|
|
});
|
|
};
|
|
};
|
|
|
|
|
|
- updateClipEmbeddingExtractionStatus = (status: ClipExtractionStatus) => {
|
|
|
|
- this.clipExtractionStatus = status;
|
|
|
|
- if (this.onUpdateHandler) {
|
|
|
|
- this.onUpdateHandler(status);
|
|
|
|
- }
|
|
|
|
|
|
+ private updateIndexingStatus = (status: CLIPIndexingStatus) => {
|
|
|
|
+ this.indexingStatus = status;
|
|
|
|
+ const handler = this.onUpdateHandler;
|
|
|
|
+ if (handler) handler(status);
|
|
};
|
|
};
|
|
|
|
|
|
- private extractFileClipImageEmbedding = async (
|
|
|
|
- model: Model,
|
|
|
|
- file: EnteFile,
|
|
|
|
- ) => {
|
|
|
|
|
|
+ private extractFileClipImageEmbedding = async (file: EnteFile) => {
|
|
const thumb = await downloadManager.getThumbnail(file);
|
|
const thumb = await downloadManager.getThumbnail(file);
|
|
- const embedding = await ensureElectron().computeImageEmbedding(
|
|
|
|
- model,
|
|
|
|
- thumb,
|
|
|
|
- );
|
|
|
|
|
|
+ const embedding = await ensureElectron().clipImageEmbedding(thumb);
|
|
return embedding;
|
|
return embedding;
|
|
};
|
|
};
|
|
|
|
|
|
private onSuccessStatusUpdater = () => {
|
|
private onSuccessStatusUpdater = () => {
|
|
- this.updateClipEmbeddingExtractionStatus({
|
|
|
|
- pending: this.clipExtractionStatus.pending - 1,
|
|
|
|
- indexed: this.clipExtractionStatus.indexed + 1,
|
|
|
|
|
|
+ this.updateIndexingStatus({
|
|
|
|
+ pending: this.indexingStatus.pending - 1,
|
|
|
|
+ indexed: this.indexingStatus.indexed + 1,
|
|
});
|
|
});
|
|
};
|
|
};
|
|
}
|
|
}
|
|
|
|
|
|
-export const ClipService = new ClipServiceImpl();
|
|
|
|
|
|
+export const clipService = new CLIPService();
|
|
|
|
|
|
const getNonClipEmbeddingExtractedFiles = async (
|
|
const getNonClipEmbeddingExtractedFiles = async (
|
|
files: EnteFile[],
|
|
files: EnteFile[],
|
|
@@ -412,14 +417,10 @@ export const computeClipMatchScore = async (
|
|
return score;
|
|
return score;
|
|
};
|
|
};
|
|
|
|
|
|
-const getClipExtractionStatus = async (
|
|
|
|
- model: Model = Model.ONNX_CLIP,
|
|
|
|
-): Promise<ClipExtractionStatus> => {
|
|
|
|
|
|
+const initialIndexingStatus = async (): Promise<CLIPIndexingStatus> => {
|
|
const user = getData(LS_KEYS.USER);
|
|
const user = getData(LS_KEYS.USER);
|
|
- if (!user) {
|
|
|
|
- return;
|
|
|
|
- }
|
|
|
|
- const allEmbeddings = await getLocalEmbeddings(model);
|
|
|
|
|
|
+ if (!user) throw new Error("Orphan CLIP indexing without a login");
|
|
|
|
+ const allEmbeddings = await getLocalEmbeddings();
|
|
const localFiles = getPersonalFiles(await getLocalFiles(), user);
|
|
const localFiles = getPersonalFiles(await getLocalFiles(), user);
|
|
const pendingFiles = await getNonClipEmbeddingExtractedFiles(
|
|
const pendingFiles = await getNonClipEmbeddingExtractedFiles(
|
|
localFiles,
|
|
localFiles,
|