[desktop] Remove GGML (#1394)

This commit is contained in:
Manav Rathi 2024-04-09 21:36:22 +05:30 committed by GitHub
commit eebb90fb40
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 464 additions and 638 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -111,11 +111,11 @@ watcher for the watch folders functionality.
### AI/ML ### AI/ML
- [onnxruntime-node](https://github.com/Microsoft/onnxruntime) - [onnxruntime-node](https://github.com/Microsoft/onnxruntime) is used for
- html-entities is used by the bundled clip-bpe-ts. natural language searches based on CLIP.
- GGML binaries are bundled - html-entities is used by the bundled clip-bpe-ts tokenizer.
- We also use [jpeg-js](https://github.com/jpeg-js/jpeg-js#readme) for - [jpeg-js](https://github.com/jpeg-js/jpeg-js#readme) is used for decoding
conversion of all images to JPEG before processing. JPEG data into raw RGB bytes before passing it to ONNX.
## ZIP ## ZIP

View file

@ -19,7 +19,6 @@ mac:
arch: [universal] arch: [universal]
category: public.app-category.photography category: public.app-category.photography
hardenedRuntime: true hardenedRuntime: true
x64ArchFiles: Contents/Resources/ggmlclip-mac
afterSign: electron-builder-notarize afterSign: electron-builder-notarize
extraFiles: extraFiles:
- from: build - from: build

View file

@ -17,9 +17,9 @@ import {
updateAndRestart, updateAndRestart,
} from "../services/appUpdater"; } from "../services/appUpdater";
import { import {
computeImageEmbedding, clipImageEmbedding,
computeTextEmbedding, clipTextEmbedding,
} from "../services/clipService"; } from "../services/clip-service";
import { runFFmpegCmd } from "../services/ffmpeg"; import { runFFmpegCmd } from "../services/ffmpeg";
import { getDirFiles } from "../services/fs"; import { getDirFiles } from "../services/fs";
import { import {
@ -44,12 +44,7 @@ import {
updateWatchMappingIgnoredFiles, updateWatchMappingIgnoredFiles,
updateWatchMappingSyncedFiles, updateWatchMappingSyncedFiles,
} from "../services/watch"; } from "../services/watch";
import type { import type { ElectronFile, FILE_PATH_TYPE, WatchMapping } from "../types/ipc";
ElectronFile,
FILE_PATH_TYPE,
Model,
WatchMapping,
} from "../types/ipc";
import { import {
selectDirectory, selectDirectory,
showUploadDirsDialog, showUploadDirsDialog,
@ -148,14 +143,12 @@ export const attachIPCHandlers = () => {
// - ML // - ML
ipcMain.handle( ipcMain.handle("clipImageEmbedding", (_, jpegImageData: Uint8Array) =>
"computeImageEmbedding", clipImageEmbedding(jpegImageData),
(_, model: Model, imageData: Uint8Array) =>
computeImageEmbedding(model, imageData),
); );
ipcMain.handle("computeTextEmbedding", (_, model: Model, text: string) => ipcMain.handle("clipTextEmbedding", (_, text: string) =>
computeTextEmbedding(model, text), clipTextEmbedding(text),
); );
// - File selection // - File selection

View file

@ -64,7 +64,10 @@ const logInfo = (...params: any[]) => {
}; };
const logDebug = (param: () => any) => { const logDebug = (param: () => any) => {
if (isDev) console.log(`[debug] ${util.inspect(param())}`); if (isDev) {
const p = param();
console.log(`[debug] ${typeof p == "string" ? p : util.inspect(p)}`);
}
}; };
/** /**

View file

@ -45,7 +45,6 @@ import type {
AppUpdateInfo, AppUpdateInfo,
ElectronFile, ElectronFile,
FILE_PATH_TYPE, FILE_PATH_TYPE,
Model,
WatchMapping, WatchMapping,
} from "./types/ipc"; } from "./types/ipc";
@ -141,17 +140,11 @@ const runFFmpegCmd = (
// - ML // - ML
const computeImageEmbedding = ( const clipImageEmbedding = (jpegImageData: Uint8Array): Promise<Float32Array> =>
model: Model, ipcRenderer.invoke("clipImageEmbedding", jpegImageData);
imageData: Uint8Array,
): Promise<Float32Array> =>
ipcRenderer.invoke("computeImageEmbedding", model, imageData);
const computeTextEmbedding = ( const clipTextEmbedding = (text: string): Promise<Float32Array> =>
model: Model, ipcRenderer.invoke("clipTextEmbedding", text);
text: string,
): Promise<Float32Array> =>
ipcRenderer.invoke("computeTextEmbedding", model, text);
// - File selection // - File selection
@ -332,8 +325,8 @@ contextBridge.exposeInMainWorld("electron", {
runFFmpegCmd, runFFmpegCmd,
// - ML // - ML
computeImageEmbedding, clipImageEmbedding,
computeTextEmbedding, clipTextEmbedding,
// - File selection // - File selection
selectDirectory, selectDirectory,

View file

@ -0,0 +1,288 @@
/**
* @file Compute CLIP embeddings
*
* @see `web/apps/photos/src/services/clip-service.ts` for more details. This
* file implements the Node.js implementation of the actual embedding
* computation. By doing it in the Node.js layer, we can use the binary ONNX
* runtimes which are 10-20x faster than the WASM based web ones.
*
* The embeddings are computed using ONNX runtime. The model itself is not
* shipped with the app but is downloaded on demand.
*/
import { app, net } from "electron/main";
import { existsSync } from "fs";
import fs from "node:fs/promises";
import path from "node:path";
import { writeStream } from "../main/fs";
import log from "../main/log";
import { CustomErrors } from "../types/ipc";
import Tokenizer from "../utils/clip-bpe-ts/mod";
import { generateTempFilePath } from "../utils/temp";
import { deleteTempFile } from "./ffmpeg";
const jpeg = require("jpeg-js");
const ort = require("onnxruntime-node");
const textModelName = "clip-text-vit-32-uint8.onnx";
const textModelByteSize = 64173509; // 61.2 MB
const imageModelName = "clip-image-vit-32-float32.onnx";
const imageModelByteSize = 351468764; // 335.2 MB
/** Return the path where the given {@link modelName} is meant to be saved */
const modelSavePath = (modelName: string) =>
path.join(app.getPath("userData"), "models", modelName);
const downloadModel = async (saveLocation: string, name: string) => {
// `mkdir -p` the directory where we want to save the model.
const saveDir = path.dirname(saveLocation);
await fs.mkdir(saveDir, { recursive: true });
// Download
log.info(`Downloading CLIP model from ${name}`);
const url = `https://models.ente.io/${name}`;
const res = await net.fetch(url);
if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
// Save
await writeStream(saveLocation, res.body);
log.info(`Downloaded CLIP model ${name}`);
};
let activeImageModelDownload: Promise<void> | undefined;
const imageModelPathDownloadingIfNeeded = async () => {
try {
const modelPath = modelSavePath(imageModelName);
if (activeImageModelDownload) {
log.info("Waiting for CLIP image model download to finish");
await activeImageModelDownload;
} else {
if (!existsSync(modelPath)) {
log.info("CLIP image model not found, downloading");
activeImageModelDownload = downloadModel(
modelPath,
imageModelName,
);
await activeImageModelDownload;
} else {
const localFileSize = (await fs.stat(modelPath)).size;
if (localFileSize !== imageModelByteSize) {
log.error(
`CLIP image model size ${localFileSize} does not match the expected size, downloading again`,
);
activeImageModelDownload = downloadModel(
modelPath,
imageModelName,
);
await activeImageModelDownload;
}
}
}
return modelPath;
} finally {
activeImageModelDownload = undefined;
}
};
let textModelDownloadInProgress = false;
const textModelPathDownloadingIfNeeded = async () => {
if (textModelDownloadInProgress)
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
const modelPath = modelSavePath(textModelName);
if (!existsSync(modelPath)) {
log.info("CLIP text model not found, downloading");
textModelDownloadInProgress = true;
downloadModel(modelPath, textModelName)
.catch((e) => {
// log but otherwise ignore
log.error("CLIP text model download failed", e);
})
.finally(() => {
textModelDownloadInProgress = false;
});
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
} else {
const localFileSize = (await fs.stat(modelPath)).size;
if (localFileSize !== textModelByteSize) {
log.error(
`CLIP text model size ${localFileSize} does not match the expected size, downloading again`,
);
textModelDownloadInProgress = true;
downloadModel(modelPath, textModelName)
.catch((e) => {
// log but otherwise ignore
log.error("CLIP text model download failed", e);
})
.finally(() => {
textModelDownloadInProgress = false;
});
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
}
}
return modelPath;
};
const createInferenceSession = async (modelPath: string) => {
return await ort.InferenceSession.create(modelPath, {
intraOpNumThreads: 1,
enableCpuMemArena: false,
});
};
let imageSessionPromise: Promise<any> | undefined;
const onnxImageSession = async () => {
if (!imageSessionPromise) {
imageSessionPromise = (async () => {
const modelPath = await imageModelPathDownloadingIfNeeded();
return createInferenceSession(modelPath);
})();
}
return imageSessionPromise;
};
let _textSession: any = null;
const onnxTextSession = async () => {
if (!_textSession) {
const modelPath = await textModelPathDownloadingIfNeeded();
_textSession = await createInferenceSession(modelPath);
}
return _textSession;
};
export const clipImageEmbedding = async (jpegImageData: Uint8Array) => {
const tempFilePath = await generateTempFilePath("");
const imageStream = new Response(jpegImageData.buffer).body;
await writeStream(tempFilePath, imageStream);
try {
return await clipImageEmbedding_(tempFilePath);
} finally {
await deleteTempFile(tempFilePath);
}
};
const clipImageEmbedding_ = async (jpegFilePath: string) => {
const imageSession = await onnxImageSession();
const t1 = Date.now();
const rgbData = await getRGBData(jpegFilePath);
const feeds = {
input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]),
};
const t2 = Date.now();
const results = await imageSession.run(feeds);
log.debug(
() =>
`CLIP image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
);
const imageEmbedding = results["output"].data; // Float32Array
return normalizeEmbedding(imageEmbedding);
};
const getRGBData = async (jpegFilePath: string) => {
const jpegData = await fs.readFile(jpegFilePath);
const rawImageData = jpeg.decode(jpegData, {
useTArray: true,
formatAsRGBA: false,
});
const nx: number = rawImageData.width;
const ny: number = rawImageData.height;
const inputImage: Uint8Array = rawImageData.data;
const nx2: number = 224;
const ny2: number = 224;
const totalSize: number = 3 * nx2 * ny2;
const result: number[] = Array(totalSize).fill(0);
const scale: number = Math.max(nx, ny) / 224;
const nx3: number = Math.round(nx / scale);
const ny3: number = Math.round(ny / scale);
const mean: number[] = [0.48145466, 0.4578275, 0.40821073];
const std: number[] = [0.26862954, 0.26130258, 0.27577711];
for (let y = 0; y < ny3; y++) {
for (let x = 0; x < nx3; x++) {
for (let c = 0; c < 3; c++) {
// Linear interpolation
const sx: number = (x + 0.5) * scale - 0.5;
const sy: number = (y + 0.5) * scale - 0.5;
const x0: number = Math.max(0, Math.floor(sx));
const y0: number = Math.max(0, Math.floor(sy));
const x1: number = Math.min(x0 + 1, nx - 1);
const y1: number = Math.min(y0 + 1, ny - 1);
const dx: number = sx - x0;
const dy: number = sy - y0;
const j00: number = 3 * (y0 * nx + x0) + c;
const j01: number = 3 * (y0 * nx + x1) + c;
const j10: number = 3 * (y1 * nx + x0) + c;
const j11: number = 3 * (y1 * nx + x1) + c;
const v00: number = inputImage[j00];
const v01: number = inputImage[j01];
const v10: number = inputImage[j10];
const v11: number = inputImage[j11];
const v0: number = v00 * (1 - dx) + v01 * dx;
const v1: number = v10 * (1 - dx) + v11 * dx;
const v: number = v0 * (1 - dy) + v1 * dy;
const v2: number = Math.min(Math.max(Math.round(v), 0), 255);
// createTensorWithDataList is dumb compared to reshape and
// hence has to be given with one channel after another
const i: number = y * nx3 + x + (c % 3) * 224 * 224;
result[i] = (v2 / 255 - mean[c]) / std[c];
}
}
}
return result;
};
const normalizeEmbedding = (embedding: Float32Array) => {
let normalization = 0;
for (let index = 0; index < embedding.length; index++) {
normalization += embedding[index] * embedding[index];
}
const sqrtNormalization = Math.sqrt(normalization);
for (let index = 0; index < embedding.length; index++) {
embedding[index] = embedding[index] / sqrtNormalization;
}
return embedding;
};
let _tokenizer: Tokenizer = null;
const getTokenizer = () => {
if (!_tokenizer) {
_tokenizer = new Tokenizer();
}
return _tokenizer;
};
export const clipTextEmbedding = async (text: string) => {
const imageSession = await onnxTextSession();
const t1 = Date.now();
const tokenizer = getTokenizer();
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
const feeds = {
input: new ort.Tensor("int32", tokenizedText, [1, 77]),
};
const t2 = Date.now();
const results = await imageSession.run(feeds);
log.debug(
() =>
`CLIP text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
);
const textEmbedding = results["output"].data;
return normalizeEmbedding(textEmbedding);
};

View file

@ -1,463 +0,0 @@
import { app, net } from "electron/main";
import { existsSync } from "fs";
import fs from "node:fs/promises";
import path from "node:path";
import { writeStream } from "../main/fs";
import log from "../main/log";
import { execAsync, isDev } from "../main/util";
import { CustomErrors, Model, isModel } from "../types/ipc";
import Tokenizer from "../utils/clip-bpe-ts/mod";
import { getPlatform } from "../utils/common/platform";
import { generateTempFilePath } from "../utils/temp";
import { deleteTempFile } from "./ffmpeg";
const jpeg = require("jpeg-js");
const CLIP_MODEL_PATH_PLACEHOLDER = "CLIP_MODEL";
const GGMLCLIP_PATH_PLACEHOLDER = "GGML_PATH";
const INPUT_PATH_PLACEHOLDER = "INPUT";
const IMAGE_EMBEDDING_EXTRACT_CMD: string[] = [
GGMLCLIP_PATH_PLACEHOLDER,
"-mv",
CLIP_MODEL_PATH_PLACEHOLDER,
"--image",
INPUT_PATH_PLACEHOLDER,
];
const TEXT_EMBEDDING_EXTRACT_CMD: string[] = [
GGMLCLIP_PATH_PLACEHOLDER,
"-mt",
CLIP_MODEL_PATH_PLACEHOLDER,
"--text",
INPUT_PATH_PLACEHOLDER,
];
const ort = require("onnxruntime-node");
const TEXT_MODEL_DOWNLOAD_URL = {
ggml: "https://models.ente.io/clip-vit-base-patch32_ggml-text-model-f16.gguf",
onnx: "https://models.ente.io/clip-text-vit-32-uint8.onnx",
};
const IMAGE_MODEL_DOWNLOAD_URL = {
ggml: "https://models.ente.io/clip-vit-base-patch32_ggml-vision-model-f16.gguf",
onnx: "https://models.ente.io/clip-image-vit-32-float32.onnx",
};
const TEXT_MODEL_NAME = {
ggml: "clip-vit-base-patch32_ggml-text-model-f16.gguf",
onnx: "clip-text-vit-32-uint8.onnx",
};
const IMAGE_MODEL_NAME = {
ggml: "clip-vit-base-patch32_ggml-vision-model-f16.gguf",
onnx: "clip-image-vit-32-float32.onnx",
};
const IMAGE_MODEL_SIZE_IN_BYTES = {
ggml: 175957504, // 167.8 MB
onnx: 351468764, // 335.2 MB
};
const TEXT_MODEL_SIZE_IN_BYTES = {
ggml: 127853440, // 121.9 MB,
onnx: 64173509, // 61.2 MB
};
/** Return the path where the given {@link modelName} is meant to be saved */
const getModelSavePath = (modelName: string) =>
path.join(app.getPath("userData"), "models", modelName);
async function downloadModel(saveLocation: string, url: string) {
// confirm that the save location exists
const saveDir = path.dirname(saveLocation);
await fs.mkdir(saveDir, { recursive: true });
log.info("downloading clip model");
const res = await net.fetch(url);
if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
await writeStream(saveLocation, res.body);
log.info("clip model downloaded");
}
let imageModelDownloadInProgress: Promise<void> = null;
const getClipImageModelPath = async (type: "ggml" | "onnx") => {
try {
const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME[type]);
if (imageModelDownloadInProgress) {
log.info("waiting for image model download to finish");
await imageModelDownloadInProgress;
} else {
if (!existsSync(modelSavePath)) {
log.info("CLIP image model not found, downloading");
imageModelDownloadInProgress = downloadModel(
modelSavePath,
IMAGE_MODEL_DOWNLOAD_URL[type],
);
await imageModelDownloadInProgress;
} else {
const localFileSize = (await fs.stat(modelSavePath)).size;
if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES[type]) {
log.info(
`CLIP image model size mismatch, downloading again got: ${localFileSize}`,
);
imageModelDownloadInProgress = downloadModel(
modelSavePath,
IMAGE_MODEL_DOWNLOAD_URL[type],
);
await imageModelDownloadInProgress;
}
}
}
return modelSavePath;
} finally {
imageModelDownloadInProgress = null;
}
};
let textModelDownloadInProgress: boolean = false;
const getClipTextModelPath = async (type: "ggml" | "onnx") => {
const modelSavePath = getModelSavePath(TEXT_MODEL_NAME[type]);
if (textModelDownloadInProgress) {
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
} else {
if (!existsSync(modelSavePath)) {
log.info("CLIP text model not found, downloading");
textModelDownloadInProgress = true;
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
.catch((e) => {
// log but otherwise ignore
log.error("CLIP text model download failed", e);
})
.finally(() => {
textModelDownloadInProgress = false;
});
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
} else {
const localFileSize = (await fs.stat(modelSavePath)).size;
if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES[type]) {
log.info(
`CLIP text model size mismatch, downloading again got: ${localFileSize}`,
);
textModelDownloadInProgress = true;
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
.catch((e) => {
// log but otherwise ignore
log.error("CLIP text model download failed", e);
})
.finally(() => {
textModelDownloadInProgress = false;
});
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
}
}
}
return modelSavePath;
};
function getGGMLClipPath() {
return isDev
? path.join("./build", `ggmlclip-${getPlatform()}`)
: path.join(process.resourcesPath, `ggmlclip-${getPlatform()}`);
}
async function createOnnxSession(modelPath: string) {
return await ort.InferenceSession.create(modelPath, {
intraOpNumThreads: 1,
enableCpuMemArena: false,
});
}
let onnxImageSessionPromise: Promise<any> = null;
async function getOnnxImageSession() {
if (!onnxImageSessionPromise) {
onnxImageSessionPromise = (async () => {
const clipModelPath = await getClipImageModelPath("onnx");
return createOnnxSession(clipModelPath);
})();
}
return onnxImageSessionPromise;
}
let onnxTextSession: any = null;
async function getOnnxTextSession() {
if (!onnxTextSession) {
const clipModelPath = await getClipTextModelPath("onnx");
onnxTextSession = await createOnnxSession(clipModelPath);
}
return onnxTextSession;
}
let tokenizer: Tokenizer = null;
function getTokenizer() {
if (!tokenizer) {
tokenizer = new Tokenizer();
}
return tokenizer;
}
export const computeImageEmbedding = async (
model: Model,
imageData: Uint8Array,
): Promise<Float32Array> => {
if (!isModel(model)) throw new Error(`Invalid CLIP model ${model}`);
let tempInputFilePath = null;
try {
tempInputFilePath = await generateTempFilePath("");
const imageStream = new Response(imageData.buffer).body;
await writeStream(tempInputFilePath, imageStream);
const embedding = await computeImageEmbedding_(
model,
tempInputFilePath,
);
return embedding;
} catch (err) {
if (isExecError(err)) {
const parsedExecError = parseExecError(err);
throw Error(parsedExecError);
} else {
throw err;
}
} finally {
if (tempInputFilePath) {
await deleteTempFile(tempInputFilePath);
}
}
};
const isExecError = (err: any) => {
return err.message.includes("Command failed:");
};
const parseExecError = (err: any) => {
const errMessage = err.message;
if (errMessage.includes("Bad CPU type in executable")) {
return CustomErrors.UNSUPPORTED_PLATFORM(
process.platform,
process.arch,
);
} else {
return errMessage;
}
};
async function computeImageEmbedding_(
model: Model,
inputFilePath: string,
): Promise<Float32Array> {
if (!existsSync(inputFilePath)) {
throw new Error("Invalid file path");
}
switch (model) {
case "ggml-clip":
return await computeGGMLImageEmbedding(inputFilePath);
case "onnx-clip":
return await computeONNXImageEmbedding(inputFilePath);
}
}
const computeGGMLImageEmbedding = async (
inputFilePath: string,
): Promise<Float32Array> => {
const clipModelPath = await getClipImageModelPath("ggml");
const ggmlclipPath = getGGMLClipPath();
const cmd = IMAGE_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
return ggmlclipPath;
} else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
return clipModelPath;
} else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
return inputFilePath;
} else {
return cmdPart;
}
});
const { stdout } = await execAsync(cmd);
// parse stdout and return embedding
// get the last line of stdout
const lines = stdout.split("\n");
const lastLine = lines[lines.length - 1];
const embedding = JSON.parse(lastLine);
const embeddingArray = new Float32Array(embedding);
return embeddingArray;
};
const computeONNXImageEmbedding = async (
inputFilePath: string,
): Promise<Float32Array> => {
const imageSession = await getOnnxImageSession();
const t1 = Date.now();
const rgbData = await getRGBData(inputFilePath);
const feeds = {
input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]),
};
const t2 = Date.now();
const results = await imageSession.run(feeds);
log.info(
`onnx image embedding time: ${Date.now() - t1} ms (prep:${
t2 - t1
} ms, extraction: ${Date.now() - t2} ms)`,
);
const imageEmbedding = results["output"].data; // Float32Array
return normalizeEmbedding(imageEmbedding);
};
async function getRGBData(inputFilePath: string) {
const jpegData = await fs.readFile(inputFilePath);
const rawImageData = jpeg.decode(jpegData, {
useTArray: true,
formatAsRGBA: false,
});
const nx: number = rawImageData.width;
const ny: number = rawImageData.height;
const inputImage: Uint8Array = rawImageData.data;
const nx2: number = 224;
const ny2: number = 224;
const totalSize: number = 3 * nx2 * ny2;
const result: number[] = Array(totalSize).fill(0);
const scale: number = Math.max(nx, ny) / 224;
const nx3: number = Math.round(nx / scale);
const ny3: number = Math.round(ny / scale);
const mean: number[] = [0.48145466, 0.4578275, 0.40821073];
const std: number[] = [0.26862954, 0.26130258, 0.27577711];
for (let y = 0; y < ny3; y++) {
for (let x = 0; x < nx3; x++) {
for (let c = 0; c < 3; c++) {
// linear interpolation
const sx: number = (x + 0.5) * scale - 0.5;
const sy: number = (y + 0.5) * scale - 0.5;
const x0: number = Math.max(0, Math.floor(sx));
const y0: number = Math.max(0, Math.floor(sy));
const x1: number = Math.min(x0 + 1, nx - 1);
const y1: number = Math.min(y0 + 1, ny - 1);
const dx: number = sx - x0;
const dy: number = sy - y0;
const j00: number = 3 * (y0 * nx + x0) + c;
const j01: number = 3 * (y0 * nx + x1) + c;
const j10: number = 3 * (y1 * nx + x0) + c;
const j11: number = 3 * (y1 * nx + x1) + c;
const v00: number = inputImage[j00];
const v01: number = inputImage[j01];
const v10: number = inputImage[j10];
const v11: number = inputImage[j11];
const v0: number = v00 * (1 - dx) + v01 * dx;
const v1: number = v10 * (1 - dx) + v11 * dx;
const v: number = v0 * (1 - dy) + v1 * dy;
const v2: number = Math.min(Math.max(Math.round(v), 0), 255);
// createTensorWithDataList is dump compared to reshape and hence has to be given with one channel after another
const i: number = y * nx3 + x + (c % 3) * 224 * 224;
result[i] = (v2 / 255 - mean[c]) / std[c];
}
}
}
return result;
}
const normalizeEmbedding = (embedding: Float32Array) => {
let normalization = 0;
for (let index = 0; index < embedding.length; index++) {
normalization += embedding[index] * embedding[index];
}
const sqrtNormalization = Math.sqrt(normalization);
for (let index = 0; index < embedding.length; index++) {
embedding[index] = embedding[index] / sqrtNormalization;
}
return embedding;
};
export async function computeTextEmbedding(
model: Model,
text: string,
): Promise<Float32Array> {
if (!isModel(model)) throw new Error(`Invalid CLIP model ${model}`);
try {
const embedding = computeTextEmbedding_(model, text);
return embedding;
} catch (err) {
if (isExecError(err)) {
const parsedExecError = parseExecError(err);
throw Error(parsedExecError);
} else {
throw err;
}
}
}
async function computeTextEmbedding_(
model: Model,
text: string,
): Promise<Float32Array> {
switch (model) {
case "ggml-clip":
return await computeGGMLTextEmbedding(text);
case "onnx-clip":
return await computeONNXTextEmbedding(text);
}
}
export async function computeGGMLTextEmbedding(
text: string,
): Promise<Float32Array> {
const clipModelPath = await getClipTextModelPath("ggml");
const ggmlclipPath = getGGMLClipPath();
const cmd = TEXT_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
return ggmlclipPath;
} else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
return clipModelPath;
} else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
return text;
} else {
return cmdPart;
}
});
const { stdout } = await execAsync(cmd);
// parse stdout and return embedding
// get the last line of stdout
const lines = stdout.split("\n");
const lastLine = lines[lines.length - 1];
const embedding = JSON.parse(lastLine);
const embeddingArray = new Float32Array(embedding);
return embeddingArray;
}
export async function computeONNXTextEmbedding(
text: string,
): Promise<Float32Array> {
const imageSession = await getOnnxTextSession();
const t1 = Date.now();
const tokenizer = getTokenizer();
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
const feeds = {
input: new ort.Tensor("int32", tokenizedText, [1, 77]),
};
const t2 = Date.now();
const results = await imageSession.run(feeds);
log.info(
`onnx text embedding time: ${Date.now() - t1} ms (prep:${
t2 - t1
} ms, extraction: ${Date.now() - t2} ms)`,
);
const textEmbedding = results["output"].data; // Float32Array
return normalizeEmbedding(textEmbedding);
}

View file

@ -80,7 +80,3 @@ export interface AppUpdateInfo {
autoUpdatable: boolean; autoUpdatable: boolean;
version: string; version: string;
} }
export type Model = "ggml-clip" | "onnx-clip";
export const isModel = (s: unknown) => s == "ggml-clip" || s == "onnx-clip";

View file

@ -14,7 +14,7 @@ import { EnteMenuItem } from "components/Menu/EnteMenuItem";
import { MenuItemGroup } from "components/Menu/MenuItemGroup"; import { MenuItemGroup } from "components/Menu/MenuItemGroup";
import isElectron from "is-electron"; import isElectron from "is-electron";
import { AppContext } from "pages/_app"; import { AppContext } from "pages/_app";
import { ClipExtractionStatus, ClipService } from "services/clipService"; import { CLIPIndexingStatus, clipService } from "services/clip-service";
import { formatNumber } from "utils/number/format"; import { formatNumber } from "utils/number/format";
export default function AdvancedSettings({ open, onClose, onRootClose }) { export default function AdvancedSettings({ open, onClose, onRootClose }) {
@ -44,17 +44,15 @@ export default function AdvancedSettings({ open, onClose, onRootClose }) {
log.error("toggleFasterUpload failed", e); log.error("toggleFasterUpload failed", e);
} }
}; };
const [indexingStatus, setIndexingStatus] = useState<ClipExtractionStatus>({ const [indexingStatus, setIndexingStatus] = useState<CLIPIndexingStatus>({
indexed: 0, indexed: 0,
pending: 0, pending: 0,
}); });
useEffect(() => { useEffect(() => {
const main = async () => { clipService.setOnUpdateHandler(setIndexingStatus);
setIndexingStatus(await ClipService.getIndexingStatus()); clipService.getIndexingStatus().then((st) => setIndexingStatus(st));
ClipService.setOnUpdateHandler(setIndexingStatus); return () => clipService.setOnUpdateHandler(undefined);
};
main();
}, []); }, []);
return ( return (

View file

@ -102,7 +102,7 @@ import {
} from "constants/collection"; } from "constants/collection";
import { SYNC_INTERVAL_IN_MICROSECONDS } from "constants/gallery"; import { SYNC_INTERVAL_IN_MICROSECONDS } from "constants/gallery";
import { AppContext } from "pages/_app"; import { AppContext } from "pages/_app";
import { ClipService } from "services/clipService"; import { clipService } from "services/clip-service";
import { constructUserIDToEmailMap } from "services/collectionService"; import { constructUserIDToEmailMap } from "services/collectionService";
import downloadManager from "services/download"; import downloadManager from "services/download";
import { syncEmbeddings } from "services/embeddingService"; import { syncEmbeddings } from "services/embeddingService";
@ -362,7 +362,7 @@ export default function Gallery() {
syncWithRemote(false, true); syncWithRemote(false, true);
}, SYNC_INTERVAL_IN_MICROSECONDS); }, SYNC_INTERVAL_IN_MICROSECONDS);
if (electron) { if (electron) {
void ClipService.setupOnFileUploadListener(); void clipService.setupOnFileUploadListener();
electron.registerForegroundEventListener(() => { electron.registerForegroundEventListener(() => {
syncWithRemote(false, true); syncWithRemote(false, true);
}); });
@ -373,7 +373,7 @@ export default function Gallery() {
clearInterval(syncInterval.current); clearInterval(syncInterval.current);
if (electron) { if (electron) {
electron.registerForegroundEventListener(() => {}); electron.registerForegroundEventListener(() => {});
ClipService.removeOnFileUploadListener(); clipService.removeOnFileUploadListener();
} }
}; };
}, []); }, []);
@ -704,8 +704,8 @@ export default function Gallery() {
await syncEntities(); await syncEntities();
await syncMapEnabled(); await syncMapEnabled();
await syncEmbeddings(); await syncEmbeddings();
if (ClipService.isPlatformSupported()) { if (clipService.isPlatformSupported()) {
void ClipService.scheduleImageEmbeddingExtraction(); void clipService.scheduleImageEmbeddingExtraction();
} }
} catch (e) { } catch (e) {
switch (e.message) { switch (e.message) {

View file

@ -1,5 +1,6 @@
import { ensureElectron } from "@/next/electron"; import { ensureElectron } from "@/next/electron";
import log from "@/next/log"; import log from "@/next/log";
import type { Electron } from "@/next/types/ipc";
import ComlinkCryptoWorker from "@ente/shared/crypto"; import ComlinkCryptoWorker from "@ente/shared/crypto";
import { CustomError } from "@ente/shared/error"; import { CustomError } from "@ente/shared/error";
import { Events, eventBus } from "@ente/shared/events"; import { Events, eventBus } from "@ente/shared/events";
@ -7,29 +8,71 @@ import { LS_KEYS, getData } from "@ente/shared/storage/localStorage";
import { FILE_TYPE } from "constants/file"; import { FILE_TYPE } from "constants/file";
import isElectron from "is-electron"; import isElectron from "is-electron";
import PQueue from "p-queue"; import PQueue from "p-queue";
import { Embedding, Model } from "types/embedding"; import { Embedding } from "types/embedding";
import { EnteFile } from "types/file"; import { EnteFile } from "types/file";
import { getPersonalFiles } from "utils/file"; import { getPersonalFiles } from "utils/file";
import downloadManager from "./download"; import downloadManager from "./download";
import { getLocalEmbeddings, putEmbedding } from "./embeddingService"; import { getLocalEmbeddings, putEmbedding } from "./embeddingService";
import { getAllLocalFiles, getLocalFiles } from "./fileService"; import { getAllLocalFiles, getLocalFiles } from "./fileService";
const CLIP_EMBEDDING_LENGTH = 512; /** Status of CLIP indexing on the images in the user's local library. */
export interface CLIPIndexingStatus {
export interface ClipExtractionStatus { /** Number of items pending indexing. */
pending: number; pending: number;
/** Number of items that have already been indexed. */
indexed: number; indexed: number;
} }
class ClipServiceImpl { /**
* Use a CLIP based neural network for natural language search.
*
* [Note: CLIP based magic search]
*
* CLIP (Contrastive Language-Image Pretraining) is a neural network trained on
* (image, text) pairs. It can be thought of as two separate (but jointly
* trained) encoders - one for images, and one for text - that both map to the
* same embedding space.
*
* We use this for natural language search within the app (aka "magic search"):
*
* 1. Pre-compute an embedding for each image.
*
* 2. When the user searches, compute an embedding for the search term.
*
* 3. Use cosine similarity to find the find the image (embedding) closest to
* the text (embedding).
*
* More details are in our [blog
* post](https://ente.io/blog/image-search-with-clip-ggml/) that describes the
* initial launch of this feature using the GGML runtime.
*
* Since the initial launch, we've switched over to another runtime,
* [ONNX](https://onnxruntime.ai).
*
* Note that we don't train the neural network - we only use one of the publicly
* available pre-trained neural networks for inference. These neural networks
* are wholly defined by their connectivity and weights. ONNX, our ML runtimes,
* loads these weights and instantiates a running network that we can use to
* compute the embeddings.
*
* Theoretically, the same CLIP model can be loaded by different frameworks /
* runtimes, but in practice each runtime has its own preferred format, and
* there are also quantization tradeoffs. So there is a specific model (a binary
* encoding of weights) tied to our current runtime that we use.
*
* To ensure that the embeddings, for the most part, can be shared, whenever
* possible we try to ensure that all the preprocessing steps, and the model
* itself, is the same across clients - web and mobile.
*/
class CLIPService {
private electron: Electron;
private embeddingExtractionInProgress: AbortController | null = null; private embeddingExtractionInProgress: AbortController | null = null;
private reRunNeeded = false; private reRunNeeded = false;
private clipExtractionStatus: ClipExtractionStatus = { private indexingStatus: CLIPIndexingStatus = {
pending: 0, pending: 0,
indexed: 0, indexed: 0,
}; };
private onUpdateHandler: ((status: ClipExtractionStatus) => void) | null = private onUpdateHandler: ((status: CLIPIndexingStatus) => void) | undefined;
null;
private liveEmbeddingExtractionQueue: PQueue; private liveEmbeddingExtractionQueue: PQueue;
private onFileUploadedHandler: private onFileUploadedHandler:
| ((arg: { enteFile: EnteFile; localFile: globalThis.File }) => void) | ((arg: { enteFile: EnteFile; localFile: globalThis.File }) => void)
@ -37,6 +80,7 @@ class ClipServiceImpl {
private unsupportedPlatform = false; private unsupportedPlatform = false;
constructor() { constructor() {
this.electron = ensureElectron();
this.liveEmbeddingExtractionQueue = new PQueue({ this.liveEmbeddingExtractionQueue = new PQueue({
concurrency: 1, concurrency: 1,
}); });
@ -96,28 +140,23 @@ class ClipServiceImpl {
}; };
getIndexingStatus = async () => { getIndexingStatus = async () => {
try {
if ( if (
!this.clipExtractionStatus || this.indexingStatus.pending === 0 &&
(this.clipExtractionStatus.pending === 0 && this.indexingStatus.indexed === 0
this.clipExtractionStatus.indexed === 0)
) { ) {
this.clipExtractionStatus = await getClipExtractionStatus(); this.indexingStatus = await initialIndexingStatus();
}
return this.clipExtractionStatus;
} catch (e) {
log.error("failed to get clip indexing status", e);
} }
return this.indexingStatus;
}; };
setOnUpdateHandler = (handler: (status: ClipExtractionStatus) => void) => { /**
* Set the {@link handler} to invoke whenever our indexing status changes.
*/
setOnUpdateHandler = (handler?: (status: CLIPIndexingStatus) => void) => {
this.onUpdateHandler = handler; this.onUpdateHandler = handler;
handler(this.clipExtractionStatus);
}; };
scheduleImageEmbeddingExtraction = async ( scheduleImageEmbeddingExtraction = async () => {
model: Model = Model.ONNX_CLIP,
) => {
try { try {
if (this.embeddingExtractionInProgress) { if (this.embeddingExtractionInProgress) {
log.info( log.info(
@ -133,7 +172,7 @@ class ClipServiceImpl {
const canceller = new AbortController(); const canceller = new AbortController();
this.embeddingExtractionInProgress = canceller; this.embeddingExtractionInProgress = canceller;
try { try {
await this.runClipEmbeddingExtraction(canceller, model); await this.runClipEmbeddingExtraction(canceller);
} finally { } finally {
this.embeddingExtractionInProgress = null; this.embeddingExtractionInProgress = null;
if (!canceller.signal.aborted && this.reRunNeeded) { if (!canceller.signal.aborted && this.reRunNeeded) {
@ -152,25 +191,19 @@ class ClipServiceImpl {
} }
}; };
getTextEmbedding = async ( getTextEmbedding = async (text: string): Promise<Float32Array> => {
text: string,
model: Model = Model.ONNX_CLIP,
): Promise<Float32Array> => {
try { try {
return ensureElectron().computeTextEmbedding(model, text); return electron.clipTextEmbedding(text);
} catch (e) { } catch (e) {
if (e?.message?.includes(CustomError.UNSUPPORTED_PLATFORM)) { if (e?.message?.includes(CustomError.UNSUPPORTED_PLATFORM)) {
this.unsupportedPlatform = true; this.unsupportedPlatform = true;
} }
log.error("failed to compute text embedding", e); log.error("Failed to compute CLIP text embedding", e);
throw e; throw e;
} }
}; };
private runClipEmbeddingExtraction = async ( private runClipEmbeddingExtraction = async (canceller: AbortController) => {
canceller: AbortController,
model: Model,
) => {
try { try {
if (this.unsupportedPlatform) { if (this.unsupportedPlatform) {
log.info( log.info(
@ -183,12 +216,12 @@ class ClipServiceImpl {
return; return;
} }
const localFiles = getPersonalFiles(await getAllLocalFiles(), user); const localFiles = getPersonalFiles(await getAllLocalFiles(), user);
const existingEmbeddings = await getLocalEmbeddings(model); const existingEmbeddings = await getLocalEmbeddings();
const pendingFiles = await getNonClipEmbeddingExtractedFiles( const pendingFiles = await getNonClipEmbeddingExtractedFiles(
localFiles, localFiles,
existingEmbeddings, existingEmbeddings,
); );
this.updateClipEmbeddingExtractionStatus({ this.updateIndexingStatus({
indexed: existingEmbeddings.length, indexed: existingEmbeddings.length,
pending: pendingFiles.length, pending: pendingFiles.length,
}); });
@ -208,15 +241,11 @@ class ClipServiceImpl {
throw Error(CustomError.REQUEST_CANCELLED); throw Error(CustomError.REQUEST_CANCELLED);
} }
const embeddingData = const embeddingData =
await this.extractFileClipImageEmbedding(model, file); await this.extractFileClipImageEmbedding(file);
log.info( log.info(
`successfully extracted clip embedding for file: ${file.metadata.title} fileID: ${file.id} embedding length: ${embeddingData?.length}`, `successfully extracted clip embedding for file: ${file.metadata.title} fileID: ${file.id} embedding length: ${embeddingData?.length}`,
); );
await this.encryptAndUploadEmbedding( await this.encryptAndUploadEmbedding(file, embeddingData);
model,
file,
embeddingData,
);
this.onSuccessStatusUpdater(); this.onSuccessStatusUpdater();
log.info( log.info(
`successfully put clip embedding to server for file: ${file.metadata.title} fileID: ${file.id}`, `successfully put clip embedding to server for file: ${file.metadata.title} fileID: ${file.id}`,
@ -249,13 +278,10 @@ class ClipServiceImpl {
} }
}; };
private async runLocalFileClipExtraction( private async runLocalFileClipExtraction(arg: {
arg: {
enteFile: EnteFile; enteFile: EnteFile;
localFile: globalThis.File; localFile: globalThis.File;
}, }) {
model: Model = Model.ONNX_CLIP,
) {
const { enteFile, localFile } = arg; const { enteFile, localFile } = arg;
log.info( log.info(
`clip embedding extraction onFileUploadedHandler file: ${enteFile.metadata.title} fileID: ${enteFile.id}`, `clip embedding extraction onFileUploadedHandler file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
@ -279,15 +305,9 @@ class ClipServiceImpl {
); );
try { try {
await this.liveEmbeddingExtractionQueue.add(async () => { await this.liveEmbeddingExtractionQueue.add(async () => {
const embedding = await this.extractLocalFileClipImageEmbedding( const embedding =
model, await this.extractLocalFileClipImageEmbedding(localFile);
localFile, await this.encryptAndUploadEmbedding(enteFile, embedding);
);
await this.encryptAndUploadEmbedding(
model,
enteFile,
embedding,
);
}); });
log.info( log.info(
`successfully extracted clip embedding for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`, `successfully extracted clip embedding for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
@ -297,26 +317,18 @@ class ClipServiceImpl {
} }
} }
private extractLocalFileClipImageEmbedding = async ( private extractLocalFileClipImageEmbedding = async (localFile: File) => {
model: Model,
localFile: File,
) => {
const file = await localFile const file = await localFile
.arrayBuffer() .arrayBuffer()
.then((buffer) => new Uint8Array(buffer)); .then((buffer) => new Uint8Array(buffer));
const embedding = await ensureElectron().computeImageEmbedding( return await electron.clipImageEmbedding(file);
model,
file,
);
return embedding;
}; };
private encryptAndUploadEmbedding = async ( private encryptAndUploadEmbedding = async (
model: Model,
file: EnteFile, file: EnteFile,
embeddingData: Float32Array, embeddingData: Float32Array,
) => { ) => {
if (embeddingData?.length !== CLIP_EMBEDDING_LENGTH) { if (embeddingData?.length !== 512) {
throw Error( throw Error(
`invalid length embedding data length: ${embeddingData?.length}`, `invalid length embedding data length: ${embeddingData?.length}`,
); );
@ -331,38 +343,31 @@ class ClipServiceImpl {
fileID: file.id, fileID: file.id,
encryptedEmbedding: encryptedEmbeddingData.encryptedData, encryptedEmbedding: encryptedEmbeddingData.encryptedData,
decryptionHeader: encryptedEmbeddingData.decryptionHeader, decryptionHeader: encryptedEmbeddingData.decryptionHeader,
model, model: "onnx-clip",
}); });
}; };
updateClipEmbeddingExtractionStatus = (status: ClipExtractionStatus) => { private updateIndexingStatus = (status: CLIPIndexingStatus) => {
this.clipExtractionStatus = status; this.indexingStatus = status;
if (this.onUpdateHandler) { const handler = this.onUpdateHandler;
this.onUpdateHandler(status); if (handler) handler(status);
}
}; };
private extractFileClipImageEmbedding = async ( private extractFileClipImageEmbedding = async (file: EnteFile) => {
model: Model,
file: EnteFile,
) => {
const thumb = await downloadManager.getThumbnail(file); const thumb = await downloadManager.getThumbnail(file);
const embedding = await ensureElectron().computeImageEmbedding( const embedding = await ensureElectron().clipImageEmbedding(thumb);
model,
thumb,
);
return embedding; return embedding;
}; };
private onSuccessStatusUpdater = () => { private onSuccessStatusUpdater = () => {
this.updateClipEmbeddingExtractionStatus({ this.updateIndexingStatus({
pending: this.clipExtractionStatus.pending - 1, pending: this.indexingStatus.pending - 1,
indexed: this.clipExtractionStatus.indexed + 1, indexed: this.indexingStatus.indexed + 1,
}); });
}; };
} }
export const ClipService = new ClipServiceImpl(); export const clipService = new CLIPService();
const getNonClipEmbeddingExtractedFiles = async ( const getNonClipEmbeddingExtractedFiles = async (
files: EnteFile[], files: EnteFile[],
@ -412,14 +417,10 @@ export const computeClipMatchScore = async (
return score; return score;
}; };
const getClipExtractionStatus = async ( const initialIndexingStatus = async (): Promise<CLIPIndexingStatus> => {
model: Model = Model.ONNX_CLIP,
): Promise<ClipExtractionStatus> => {
const user = getData(LS_KEYS.USER); const user = getData(LS_KEYS.USER);
if (!user) { if (!user) throw new Error("Orphan CLIP indexing without a login");
return; const allEmbeddings = await getLocalEmbeddings();
}
const allEmbeddings = await getLocalEmbeddings(model);
const localFiles = getPersonalFiles(await getLocalFiles(), user); const localFiles = getPersonalFiles(await getLocalFiles(), user);
const pendingFiles = await getNonClipEmbeddingExtractedFiles( const pendingFiles = await getNonClipEmbeddingExtractedFiles(
localFiles, localFiles,

View file

@ -5,11 +5,11 @@ import HTTPService from "@ente/shared/network/HTTPService";
import { getEndpoint } from "@ente/shared/network/api"; import { getEndpoint } from "@ente/shared/network/api";
import localForage from "@ente/shared/storage/localForage"; import localForage from "@ente/shared/storage/localForage";
import { getToken } from "@ente/shared/storage/localStorage/helpers"; import { getToken } from "@ente/shared/storage/localStorage/helpers";
import { import type {
Embedding, Embedding,
EmbeddingModel,
EncryptedEmbedding, EncryptedEmbedding,
GetEmbeddingDiffResponse, GetEmbeddingDiffResponse,
Model,
PutEmbeddingRequest, PutEmbeddingRequest,
} from "types/embedding"; } from "types/embedding";
import { EnteFile } from "types/file"; import { EnteFile } from "types/file";
@ -38,12 +38,12 @@ export const getAllLocalEmbeddings = async () => {
return embeddings; return embeddings;
}; };
export const getLocalEmbeddings = async (model: Model) => { export const getLocalEmbeddings = async () => {
const embeddings = await getAllLocalEmbeddings(); const embeddings = await getAllLocalEmbeddings();
return embeddings.filter((embedding) => embedding.model === model); return embeddings.filter((embedding) => embedding.model === "onnx-clip");
}; };
const getModelEmbeddingSyncTime = async (model: Model) => { const getModelEmbeddingSyncTime = async (model: EmbeddingModel) => {
return ( return (
(await localForage.getItem<number>( (await localForage.getItem<number>(
`${model}-${EMBEDDING_SYNC_TIME_TABLE}`, `${model}-${EMBEDDING_SYNC_TIME_TABLE}`,
@ -51,11 +51,15 @@ const getModelEmbeddingSyncTime = async (model: Model) => {
); );
}; };
const setModelEmbeddingSyncTime = async (model: Model, time: number) => { const setModelEmbeddingSyncTime = async (
model: EmbeddingModel,
time: number,
) => {
await localForage.setItem(`${model}-${EMBEDDING_SYNC_TIME_TABLE}`, time); await localForage.setItem(`${model}-${EMBEDDING_SYNC_TIME_TABLE}`, time);
}; };
export const syncEmbeddings = async (models: Model[] = [Model.ONNX_CLIP]) => { export const syncEmbeddings = async () => {
const models: EmbeddingModel[] = ["onnx-clip"];
try { try {
let allEmbeddings = await getAllLocalEmbeddings(); let allEmbeddings = await getAllLocalEmbeddings();
const localFiles = await getAllLocalFiles(); const localFiles = await getAllLocalFiles();
@ -138,7 +142,7 @@ export const syncEmbeddings = async (models: Model[] = [Model.ONNX_CLIP]) => {
export const getEmbeddingsDiff = async ( export const getEmbeddingsDiff = async (
sinceTime: number, sinceTime: number,
model: Model, model: EmbeddingModel,
): Promise<GetEmbeddingDiffResponse> => { ): Promise<GetEmbeddingDiffResponse> => {
try { try {
const token = getToken(); const token = getToken();

View file

@ -4,7 +4,6 @@ import * as chrono from "chrono-node";
import { FILE_TYPE } from "constants/file"; import { FILE_TYPE } from "constants/file";
import { t } from "i18next"; import { t } from "i18next";
import { Collection } from "types/collection"; import { Collection } from "types/collection";
import { Model } from "types/embedding";
import { EntityType, LocationTag, LocationTagData } from "types/entity"; import { EntityType, LocationTag, LocationTagData } from "types/entity";
import { EnteFile } from "types/file"; import { EnteFile } from "types/file";
import { Person, Thing } from "types/machineLearning"; import { Person, Thing } from "types/machineLearning";
@ -22,7 +21,7 @@ import { getAllPeople } from "utils/machineLearning";
import { getMLSyncConfig } from "utils/machineLearning/config"; import { getMLSyncConfig } from "utils/machineLearning/config";
import { getFormattedDate } from "utils/search"; import { getFormattedDate } from "utils/search";
import mlIDbStorage from "utils/storage/mlIDbStorage"; import mlIDbStorage from "utils/storage/mlIDbStorage";
import { ClipService, computeClipMatchScore } from "./clipService"; import { clipService, computeClipMatchScore } from "./clip-service";
import { getLocalEmbeddings } from "./embeddingService"; import { getLocalEmbeddings } from "./embeddingService";
import { getLatestEntities } from "./entityService"; import { getLatestEntities } from "./entityService";
import locationSearchService, { City } from "./locationSearchService"; import locationSearchService, { City } from "./locationSearchService";
@ -305,7 +304,7 @@ async function getThingSuggestion(searchPhrase: string): Promise<Suggestion[]> {
async function getClipSuggestion(searchPhrase: string): Promise<Suggestion> { async function getClipSuggestion(searchPhrase: string): Promise<Suggestion> {
try { try {
if (!ClipService.isPlatformSupported()) { if (!clipService.isPlatformSupported()) {
return null; return null;
} }
@ -396,8 +395,8 @@ async function searchThing(searchPhrase: string) {
} }
async function searchClip(searchPhrase: string): Promise<ClipSearchScores> { async function searchClip(searchPhrase: string): Promise<ClipSearchScores> {
const imageEmbeddings = await getLocalEmbeddings(Model.ONNX_CLIP); const imageEmbeddings = await getLocalEmbeddings();
const textEmbedding = await ClipService.getTextEmbedding(searchPhrase); const textEmbedding = await clipService.getTextEmbedding(searchPhrase);
const clipSearchResult = new Map<number, number>( const clipSearchResult = new Map<number, number>(
( (
await Promise.all( await Promise.all(

View file

@ -1,11 +1,16 @@
export enum Model { /**
GGML_CLIP = "ggml-clip", * The embeddings models that we support.
ONNX_CLIP = "onnx-clip", *
} * This is an exhaustive set of values we pass when PUT-ting encrypted
* embeddings on the server. However, we should be prepared to receive an
* {@link EncryptedEmbedding} with a model value distinct from one of these.
*/
export type EmbeddingModel = "onnx-clip";
export interface EncryptedEmbedding { export interface EncryptedEmbedding {
fileID: number; fileID: number;
model: Model; /** @see {@link EmbeddingModel} */
model: string;
encryptedEmbedding: string; encryptedEmbedding: string;
decryptionHeader: string; decryptionHeader: string;
updatedAt: number; updatedAt: number;
@ -25,7 +30,7 @@ export interface GetEmbeddingDiffResponse {
export interface PutEmbeddingRequest { export interface PutEmbeddingRequest {
fileID: number; fileID: number;
model: Model; model: EmbeddingModel;
encryptedEmbedding: string; encryptedEmbedding: string;
decryptionHeader: string; decryptionHeader: string;
} }

View file

@ -10,11 +10,6 @@ export interface AppUpdateInfo {
version: string; version: string;
} }
export enum Model {
GGML_CLIP = "ggml-clip",
ONNX_CLIP = "onnx-clip",
}
export enum FILE_PATH_TYPE { export enum FILE_PATH_TYPE {
FILES = "files", FILES = "files",
ZIPS = "zips", ZIPS = "zips",
@ -147,12 +142,27 @@ export interface Electron {
// - ML // - ML
computeImageEmbedding: ( /**
model: Model, * Compute and return a CLIP embedding of the given image.
imageData: Uint8Array, *
) => Promise<Float32Array>; * See: [Note: CLIP based magic search]
*
* @param jpegImageData The raw bytes of the image encoded as an JPEG.
*
* @returns A CLIP embedding.
*/
clipImageEmbedding: (jpegImageData: Uint8Array) => Promise<Float32Array>;
computeTextEmbedding: (model: Model, text: string) => Promise<Float32Array>; /**
* Compute and return a CLIP embedding of the given image.
*
* See: [Note: CLIP based magic search]
*
* @param text The string whose embedding we want to compute.
*
* @returns A CLIP embedding.
*/
clipTextEmbedding: (text: string) => Promise<Float32Array>;
// - File selection // - File selection
// TODO: Deprecated - use dialogs on the renderer process itself // TODO: Deprecated - use dialogs on the renderer process itself