Only ONNX, desktop
This commit is contained in:
parent
b76b57d07e
commit
4327cfdb23
9 changed files with 304 additions and 507 deletions
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -116,7 +116,6 @@ watcher for the watch folders functionality.
|
|||
- html-entities is used by the bundled clip-bpe-ts tokenizer.
|
||||
- [jpeg-js](https://github.com/jpeg-js/jpeg-js#readme) is used for decoding
|
||||
JPEG data into raw RGB bytes before passing it to ONNX.
|
||||
- GGML binaries are bundled.
|
||||
|
||||
## ZIP
|
||||
|
||||
|
|
|
@ -17,9 +17,9 @@ import {
|
|||
updateAndRestart,
|
||||
} from "../services/appUpdater";
|
||||
import {
|
||||
computeImageEmbedding,
|
||||
computeTextEmbedding,
|
||||
} from "../services/clipService";
|
||||
clipImageEmbedding,
|
||||
clipTextEmbedding,
|
||||
} from "../services/clip-service";
|
||||
import { runFFmpegCmd } from "../services/ffmpeg";
|
||||
import { getDirFiles } from "../services/fs";
|
||||
import {
|
||||
|
@ -44,12 +44,7 @@ import {
|
|||
updateWatchMappingIgnoredFiles,
|
||||
updateWatchMappingSyncedFiles,
|
||||
} from "../services/watch";
|
||||
import type {
|
||||
ElectronFile,
|
||||
FILE_PATH_TYPE,
|
||||
Model,
|
||||
WatchMapping,
|
||||
} from "../types/ipc";
|
||||
import type { ElectronFile, FILE_PATH_TYPE, WatchMapping } from "../types/ipc";
|
||||
import {
|
||||
selectDirectory,
|
||||
showUploadDirsDialog,
|
||||
|
@ -148,14 +143,12 @@ export const attachIPCHandlers = () => {
|
|||
|
||||
// - ML
|
||||
|
||||
ipcMain.handle(
|
||||
"computeImageEmbedding",
|
||||
(_, model: Model, imageData: Uint8Array) =>
|
||||
computeImageEmbedding(model, imageData),
|
||||
ipcMain.handle("clipImageEmbedding", (_, jpegImageData: Uint8Array) =>
|
||||
clipImageEmbedding(jpegImageData),
|
||||
);
|
||||
|
||||
ipcMain.handle("computeTextEmbedding", (_, model: Model, text: string) =>
|
||||
computeTextEmbedding(model, text),
|
||||
ipcMain.handle("clipTextEmbedding", (_, text: string) =>
|
||||
clipTextEmbedding(text),
|
||||
);
|
||||
|
||||
// - File selection
|
||||
|
|
|
@ -45,7 +45,6 @@ import type {
|
|||
AppUpdateInfo,
|
||||
ElectronFile,
|
||||
FILE_PATH_TYPE,
|
||||
Model,
|
||||
WatchMapping,
|
||||
} from "./types/ipc";
|
||||
|
||||
|
@ -141,17 +140,11 @@ const runFFmpegCmd = (
|
|||
|
||||
// - ML
|
||||
|
||||
const computeImageEmbedding = (
|
||||
model: Model,
|
||||
imageData: Uint8Array,
|
||||
): Promise<Float32Array> =>
|
||||
ipcRenderer.invoke("computeImageEmbedding", model, imageData);
|
||||
const clipImageEmbedding = (jpegImageData: Uint8Array): Promise<Float32Array> =>
|
||||
ipcRenderer.invoke("clipImageEmbedding", jpegImageData);
|
||||
|
||||
const computeTextEmbedding = (
|
||||
model: Model,
|
||||
text: string,
|
||||
): Promise<Float32Array> =>
|
||||
ipcRenderer.invoke("computeTextEmbedding", model, text);
|
||||
const clipTextEmbedding = (text: string): Promise<Float32Array> =>
|
||||
ipcRenderer.invoke("clipTextEmbedding", text);
|
||||
|
||||
// - File selection
|
||||
|
||||
|
@ -332,8 +325,8 @@ contextBridge.exposeInMainWorld("electron", {
|
|||
runFFmpegCmd,
|
||||
|
||||
// - ML
|
||||
computeImageEmbedding,
|
||||
computeTextEmbedding,
|
||||
clipImageEmbedding,
|
||||
clipTextEmbedding,
|
||||
|
||||
// - File selection
|
||||
selectDirectory,
|
||||
|
|
290
desktop/src/services/clip-service.ts
Normal file
290
desktop/src/services/clip-service.ts
Normal file
|
@ -0,0 +1,290 @@
|
|||
/**
|
||||
* @file Compute CLIP embeddings
|
||||
*
|
||||
* @see `web/apps/photos/src/services/clip-service.ts` for more details. This
|
||||
* file implements the Node.js implementation of the actual embedding
|
||||
* computation. By doing it in the Node.js layer, we can use the binary ONNX
|
||||
* runtimes which are 10-20x faster than the WASM based web ones.
|
||||
*
|
||||
* The embeddings are computed using ONNX runtime. The model itself is not
|
||||
* shipped with the app but is downloaded on demand.
|
||||
*/
|
||||
import { app, net } from "electron/main";
|
||||
import { existsSync } from "fs";
|
||||
import fs from "node:fs/promises";
|
||||
import path from "node:path";
|
||||
import { writeStream } from "../main/fs";
|
||||
import log from "../main/log";
|
||||
import { CustomErrors } from "../types/ipc";
|
||||
import Tokenizer from "../utils/clip-bpe-ts/mod";
|
||||
import { generateTempFilePath } from "../utils/temp";
|
||||
import { deleteTempFile } from "./ffmpeg";
|
||||
const jpeg = require("jpeg-js");
|
||||
const ort = require("onnxruntime-node");
|
||||
|
||||
const textModelName = "clip-text-vit-32-uint8.onnx";
|
||||
const textModelByteSize = 64173509; // 61.2 MB
|
||||
|
||||
const imageModelName = "clip-image-vit-32-float32.onnx";
|
||||
const imageModelByteSize = 351468764; // 335.2 MB
|
||||
|
||||
/** Return the path where the given {@link modelName} is meant to be saved */
|
||||
const modelSavePath = (modelName: string) =>
|
||||
path.join(app.getPath("userData"), "models", modelName);
|
||||
|
||||
const downloadModel = async (saveLocation: string, name: string) => {
|
||||
// `mkdir -p` the directory where we want to save the model.
|
||||
const saveDir = path.dirname(saveLocation);
|
||||
await fs.mkdir(saveDir, { recursive: true });
|
||||
// Download
|
||||
log.info(`Downloading CLIP model from ${name}`);
|
||||
const url = `https://models.ente.io/${name}`;
|
||||
const res = await net.fetch(url);
|
||||
if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
|
||||
// Save
|
||||
await writeStream(saveLocation, res.body);
|
||||
log.info(`Downloaded CLIP model ${name}`);
|
||||
};
|
||||
|
||||
let activeImageModelDownload: Promise<void> | undefined;
|
||||
|
||||
const imageModelPathDownloadingIfNeeded = async () => {
|
||||
try {
|
||||
const modelPath = modelSavePath(imageModelName);
|
||||
if (activeImageModelDownload) {
|
||||
log.info("Waiting for CLIP image model download to finish");
|
||||
await activeImageModelDownload;
|
||||
} else {
|
||||
if (!existsSync(modelPath)) {
|
||||
log.info("CLIP image model not found, downloading");
|
||||
activeImageModelDownload = downloadModel(
|
||||
modelPath,
|
||||
imageModelName,
|
||||
);
|
||||
await activeImageModelDownload;
|
||||
} else {
|
||||
const localFileSize = (await fs.stat(modelPath)).size;
|
||||
if (localFileSize !== imageModelByteSize) {
|
||||
log.error(
|
||||
`CLIP image model size ${localFileSize} does not match the expected size, downloading again`,
|
||||
);
|
||||
activeImageModelDownload = downloadModel(
|
||||
modelPath,
|
||||
imageModelName,
|
||||
);
|
||||
await activeImageModelDownload;
|
||||
}
|
||||
}
|
||||
}
|
||||
return modelPath;
|
||||
} finally {
|
||||
activeImageModelDownload = undefined;
|
||||
}
|
||||
};
|
||||
|
||||
let textModelDownloadInProgress = false;
|
||||
|
||||
const textModelPathDownloadingIfNeeded = async () => {
|
||||
if (textModelDownloadInProgress)
|
||||
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
||||
|
||||
const modelPath = modelSavePath(textModelName);
|
||||
if (!existsSync(modelPath)) {
|
||||
log.info("CLIP text model not found, downloading");
|
||||
textModelDownloadInProgress = true;
|
||||
downloadModel(modelPath, textModelName)
|
||||
.catch((e) => {
|
||||
// log but otherwise ignore
|
||||
log.error("CLIP text model download failed", e);
|
||||
})
|
||||
.finally(() => {
|
||||
textModelDownloadInProgress = false;
|
||||
});
|
||||
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
||||
} else {
|
||||
const localFileSize = (await fs.stat(modelPath)).size;
|
||||
if (localFileSize !== textModelByteSize) {
|
||||
log.error(
|
||||
`CLIP text model size ${localFileSize} does not match the expected size, downloading again`,
|
||||
);
|
||||
textModelDownloadInProgress = true;
|
||||
downloadModel(modelPath, textModelName)
|
||||
.catch((e) => {
|
||||
// log but otherwise ignore
|
||||
log.error("CLIP text model download failed", e);
|
||||
})
|
||||
.finally(() => {
|
||||
textModelDownloadInProgress = false;
|
||||
});
|
||||
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
||||
}
|
||||
}
|
||||
|
||||
return modelPath;
|
||||
};
|
||||
|
||||
const createInferenceSession = async (modelPath: string) => {
|
||||
return await ort.InferenceSession.create(modelPath, {
|
||||
intraOpNumThreads: 1,
|
||||
enableCpuMemArena: false,
|
||||
});
|
||||
};
|
||||
|
||||
let imageSessionPromise: Promise<any> | undefined;
|
||||
|
||||
const onnxImageSession = async () => {
|
||||
if (!imageSessionPromise) {
|
||||
imageSessionPromise = (async () => {
|
||||
const modelPath = await imageModelPathDownloadingIfNeeded();
|
||||
return createInferenceSession(modelPath);
|
||||
})();
|
||||
}
|
||||
return imageSessionPromise;
|
||||
};
|
||||
|
||||
let _textSession: any = null;
|
||||
|
||||
const onnxTextSession = async () => {
|
||||
if (!_textSession) {
|
||||
const modelPath = await textModelPathDownloadingIfNeeded();
|
||||
_textSession = await createInferenceSession(modelPath);
|
||||
}
|
||||
return _textSession;
|
||||
};
|
||||
|
||||
export const clipImageEmbedding = async (jpegImageData: Uint8Array) => {
|
||||
const tempFilePath = await generateTempFilePath("");
|
||||
const imageStream = new Response(jpegImageData.buffer).body;
|
||||
await writeStream(tempFilePath, imageStream);
|
||||
try {
|
||||
return await clipImageEmbedding_(tempFilePath);
|
||||
} finally {
|
||||
await deleteTempFile(tempFilePath);
|
||||
}
|
||||
};
|
||||
|
||||
const clipImageEmbedding_ = async (jpegFilePath: string) => {
|
||||
const imageSession = await onnxImageSession();
|
||||
const t1 = Date.now();
|
||||
const rgbData = await getRGBData(jpegFilePath);
|
||||
const feeds = {
|
||||
input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]),
|
||||
};
|
||||
const t2 = Date.now();
|
||||
const results = await imageSession.run(feeds);
|
||||
log.info(
|
||||
`onnx image embedding time: ${Date.now() - t1} ms (prep:${
|
||||
t2 - t1
|
||||
} ms, extraction: ${Date.now() - t2} ms)`,
|
||||
);
|
||||
const imageEmbedding = results["output"].data; // Float32Array
|
||||
return normalizeEmbedding(imageEmbedding);
|
||||
};
|
||||
|
||||
const getRGBData = async (jpegFilePath: string) => {
|
||||
const jpegData = await fs.readFile(jpegFilePath);
|
||||
const rawImageData = jpeg.decode(jpegData, {
|
||||
useTArray: true,
|
||||
formatAsRGBA: false,
|
||||
});
|
||||
|
||||
const nx: number = rawImageData.width;
|
||||
const ny: number = rawImageData.height;
|
||||
const inputImage: Uint8Array = rawImageData.data;
|
||||
|
||||
const nx2: number = 224;
|
||||
const ny2: number = 224;
|
||||
const totalSize: number = 3 * nx2 * ny2;
|
||||
|
||||
const result: number[] = Array(totalSize).fill(0);
|
||||
const scale: number = Math.max(nx, ny) / 224;
|
||||
|
||||
const nx3: number = Math.round(nx / scale);
|
||||
const ny3: number = Math.round(ny / scale);
|
||||
|
||||
const mean: number[] = [0.48145466, 0.4578275, 0.40821073];
|
||||
const std: number[] = [0.26862954, 0.26130258, 0.27577711];
|
||||
|
||||
for (let y = 0; y < ny3; y++) {
|
||||
for (let x = 0; x < nx3; x++) {
|
||||
for (let c = 0; c < 3; c++) {
|
||||
// Linear interpolation
|
||||
const sx: number = (x + 0.5) * scale - 0.5;
|
||||
const sy: number = (y + 0.5) * scale - 0.5;
|
||||
|
||||
const x0: number = Math.max(0, Math.floor(sx));
|
||||
const y0: number = Math.max(0, Math.floor(sy));
|
||||
|
||||
const x1: number = Math.min(x0 + 1, nx - 1);
|
||||
const y1: number = Math.min(y0 + 1, ny - 1);
|
||||
|
||||
const dx: number = sx - x0;
|
||||
const dy: number = sy - y0;
|
||||
|
||||
const j00: number = 3 * (y0 * nx + x0) + c;
|
||||
const j01: number = 3 * (y0 * nx + x1) + c;
|
||||
const j10: number = 3 * (y1 * nx + x0) + c;
|
||||
const j11: number = 3 * (y1 * nx + x1) + c;
|
||||
|
||||
const v00: number = inputImage[j00];
|
||||
const v01: number = inputImage[j01];
|
||||
const v10: number = inputImage[j10];
|
||||
const v11: number = inputImage[j11];
|
||||
|
||||
const v0: number = v00 * (1 - dx) + v01 * dx;
|
||||
const v1: number = v10 * (1 - dx) + v11 * dx;
|
||||
|
||||
const v: number = v0 * (1 - dy) + v1 * dy;
|
||||
|
||||
const v2: number = Math.min(Math.max(Math.round(v), 0), 255);
|
||||
|
||||
// createTensorWithDataList is dumb compared to reshape and
|
||||
// hence has to be given with one channel after another
|
||||
const i: number = y * nx3 + x + (c % 3) * 224 * 224;
|
||||
|
||||
result[i] = (v2 / 255 - mean[c]) / std[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
};
|
||||
|
||||
const normalizeEmbedding = (embedding: Float32Array) => {
|
||||
let normalization = 0;
|
||||
for (let index = 0; index < embedding.length; index++) {
|
||||
normalization += embedding[index] * embedding[index];
|
||||
}
|
||||
const sqrtNormalization = Math.sqrt(normalization);
|
||||
for (let index = 0; index < embedding.length; index++) {
|
||||
embedding[index] = embedding[index] / sqrtNormalization;
|
||||
}
|
||||
return embedding;
|
||||
};
|
||||
|
||||
let _tokenizer: Tokenizer = null;
|
||||
const getTokenizer = () => {
|
||||
if (!_tokenizer) {
|
||||
_tokenizer = new Tokenizer();
|
||||
}
|
||||
return _tokenizer;
|
||||
};
|
||||
|
||||
export const clipTextEmbedding = async (text: string) => {
|
||||
const imageSession = await onnxTextSession();
|
||||
const t1 = Date.now();
|
||||
const tokenizer = getTokenizer();
|
||||
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
|
||||
const feeds = {
|
||||
input: new ort.Tensor("int32", tokenizedText, [1, 77]),
|
||||
};
|
||||
const t2 = Date.now();
|
||||
const results = await imageSession.run(feeds);
|
||||
log.info(
|
||||
`onnx text embedding time: ${Date.now() - t1} ms (prep:${
|
||||
t2 - t1
|
||||
} ms, extraction: ${Date.now() - t2} ms)`,
|
||||
);
|
||||
const textEmbedding = results["output"].data;
|
||||
return normalizeEmbedding(textEmbedding);
|
||||
};
|
|
@ -1,474 +0,0 @@
|
|||
/**
|
||||
* @file Compute CLIP embeddings
|
||||
*
|
||||
* @see `web/apps/photos/services/clip-service.ts` for more details. This file
|
||||
* implements the Node.js implementation of the actual embedding computation. By
|
||||
* doing it in the Node.js layer, we can use the binary ONNX runtimes which are
|
||||
* 10-20x faster than the WASM based web ones.
|
||||
*
|
||||
* The embeddings are computed using ONNX runtime. The model itself is not
|
||||
* shipped with the app but is downloaded on demand.
|
||||
*/
|
||||
import { app, net } from "electron/main";
|
||||
import { existsSync } from "fs";
|
||||
import fs from "node:fs/promises";
|
||||
import path from "node:path";
|
||||
import { writeStream } from "../main/fs";
|
||||
import log from "../main/log";
|
||||
import { execAsync, isDev } from "../main/util";
|
||||
import { CustomErrors, Model, isModel } from "../types/ipc";
|
||||
import Tokenizer from "../utils/clip-bpe-ts/mod";
|
||||
import { getPlatform } from "../utils/common/platform";
|
||||
import { generateTempFilePath } from "../utils/temp";
|
||||
import { deleteTempFile } from "./ffmpeg";
|
||||
const jpeg = require("jpeg-js");
|
||||
|
||||
const CLIP_MODEL_PATH_PLACEHOLDER = "CLIP_MODEL";
|
||||
const GGMLCLIP_PATH_PLACEHOLDER = "GGML_PATH";
|
||||
const INPUT_PATH_PLACEHOLDER = "INPUT";
|
||||
|
||||
const IMAGE_EMBEDDING_EXTRACT_CMD: string[] = [
|
||||
GGMLCLIP_PATH_PLACEHOLDER,
|
||||
"-mv",
|
||||
CLIP_MODEL_PATH_PLACEHOLDER,
|
||||
"--image",
|
||||
INPUT_PATH_PLACEHOLDER,
|
||||
];
|
||||
|
||||
const TEXT_EMBEDDING_EXTRACT_CMD: string[] = [
|
||||
GGMLCLIP_PATH_PLACEHOLDER,
|
||||
"-mt",
|
||||
CLIP_MODEL_PATH_PLACEHOLDER,
|
||||
"--text",
|
||||
INPUT_PATH_PLACEHOLDER,
|
||||
];
|
||||
const ort = require("onnxruntime-node");
|
||||
|
||||
const TEXT_MODEL_DOWNLOAD_URL = {
|
||||
ggml: "https://models.ente.io/clip-vit-base-patch32_ggml-text-model-f16.gguf",
|
||||
onnx: "https://models.ente.io/clip-text-vit-32-uint8.onnx",
|
||||
};
|
||||
const IMAGE_MODEL_DOWNLOAD_URL = {
|
||||
ggml: "https://models.ente.io/clip-vit-base-patch32_ggml-vision-model-f16.gguf",
|
||||
onnx: "https://models.ente.io/clip-image-vit-32-float32.onnx",
|
||||
};
|
||||
|
||||
const TEXT_MODEL_NAME = {
|
||||
ggml: "clip-vit-base-patch32_ggml-text-model-f16.gguf",
|
||||
onnx: "clip-text-vit-32-uint8.onnx",
|
||||
};
|
||||
const IMAGE_MODEL_NAME = {
|
||||
ggml: "clip-vit-base-patch32_ggml-vision-model-f16.gguf",
|
||||
onnx: "clip-image-vit-32-float32.onnx",
|
||||
};
|
||||
|
||||
const IMAGE_MODEL_SIZE_IN_BYTES = {
|
||||
ggml: 175957504, // 167.8 MB
|
||||
onnx: 351468764, // 335.2 MB
|
||||
};
|
||||
const TEXT_MODEL_SIZE_IN_BYTES = {
|
||||
ggml: 127853440, // 121.9 MB,
|
||||
onnx: 64173509, // 61.2 MB
|
||||
};
|
||||
|
||||
/** Return the path where the given {@link modelName} is meant to be saved */
|
||||
const getModelSavePath = (modelName: string) =>
|
||||
path.join(app.getPath("userData"), "models", modelName);
|
||||
|
||||
async function downloadModel(saveLocation: string, url: string) {
|
||||
// confirm that the save location exists
|
||||
const saveDir = path.dirname(saveLocation);
|
||||
await fs.mkdir(saveDir, { recursive: true });
|
||||
log.info("downloading clip model");
|
||||
const res = await net.fetch(url);
|
||||
if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
|
||||
await writeStream(saveLocation, res.body);
|
||||
log.info("clip model downloaded");
|
||||
}
|
||||
|
||||
let imageModelDownloadInProgress: Promise<void> = null;
|
||||
|
||||
const getClipImageModelPath = async (type: "ggml" | "onnx") => {
|
||||
try {
|
||||
const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME[type]);
|
||||
if (imageModelDownloadInProgress) {
|
||||
log.info("waiting for image model download to finish");
|
||||
await imageModelDownloadInProgress;
|
||||
} else {
|
||||
if (!existsSync(modelSavePath)) {
|
||||
log.info("CLIP image model not found, downloading");
|
||||
imageModelDownloadInProgress = downloadModel(
|
||||
modelSavePath,
|
||||
IMAGE_MODEL_DOWNLOAD_URL[type],
|
||||
);
|
||||
await imageModelDownloadInProgress;
|
||||
} else {
|
||||
const localFileSize = (await fs.stat(modelSavePath)).size;
|
||||
if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES[type]) {
|
||||
log.info(
|
||||
`CLIP image model size mismatch, downloading again got: ${localFileSize}`,
|
||||
);
|
||||
imageModelDownloadInProgress = downloadModel(
|
||||
modelSavePath,
|
||||
IMAGE_MODEL_DOWNLOAD_URL[type],
|
||||
);
|
||||
await imageModelDownloadInProgress;
|
||||
}
|
||||
}
|
||||
}
|
||||
return modelSavePath;
|
||||
} finally {
|
||||
imageModelDownloadInProgress = null;
|
||||
}
|
||||
};
|
||||
|
||||
let textModelDownloadInProgress: boolean = false;
|
||||
|
||||
const getClipTextModelPath = async (type: "ggml" | "onnx") => {
|
||||
const modelSavePath = getModelSavePath(TEXT_MODEL_NAME[type]);
|
||||
if (textModelDownloadInProgress) {
|
||||
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
||||
} else {
|
||||
if (!existsSync(modelSavePath)) {
|
||||
log.info("CLIP text model not found, downloading");
|
||||
textModelDownloadInProgress = true;
|
||||
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
|
||||
.catch((e) => {
|
||||
// log but otherwise ignore
|
||||
log.error("CLIP text model download failed", e);
|
||||
})
|
||||
.finally(() => {
|
||||
textModelDownloadInProgress = false;
|
||||
});
|
||||
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
||||
} else {
|
||||
const localFileSize = (await fs.stat(modelSavePath)).size;
|
||||
if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES[type]) {
|
||||
log.info(
|
||||
`CLIP text model size mismatch, downloading again got: ${localFileSize}`,
|
||||
);
|
||||
textModelDownloadInProgress = true;
|
||||
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
|
||||
.catch((e) => {
|
||||
// log but otherwise ignore
|
||||
log.error("CLIP text model download failed", e);
|
||||
})
|
||||
.finally(() => {
|
||||
textModelDownloadInProgress = false;
|
||||
});
|
||||
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
||||
}
|
||||
}
|
||||
}
|
||||
return modelSavePath;
|
||||
};
|
||||
|
||||
function getGGMLClipPath() {
|
||||
return isDev
|
||||
? path.join("./build", `ggmlclip-${getPlatform()}`)
|
||||
: path.join(process.resourcesPath, `ggmlclip-${getPlatform()}`);
|
||||
}
|
||||
|
||||
async function createOnnxSession(modelPath: string) {
|
||||
return await ort.InferenceSession.create(modelPath, {
|
||||
intraOpNumThreads: 1,
|
||||
enableCpuMemArena: false,
|
||||
});
|
||||
}
|
||||
|
||||
let onnxImageSessionPromise: Promise<any> = null;
|
||||
|
||||
async function getOnnxImageSession() {
|
||||
if (!onnxImageSessionPromise) {
|
||||
onnxImageSessionPromise = (async () => {
|
||||
const clipModelPath = await getClipImageModelPath("onnx");
|
||||
return createOnnxSession(clipModelPath);
|
||||
})();
|
||||
}
|
||||
return onnxImageSessionPromise;
|
||||
}
|
||||
|
||||
let onnxTextSession: any = null;
|
||||
|
||||
async function getOnnxTextSession() {
|
||||
if (!onnxTextSession) {
|
||||
const clipModelPath = await getClipTextModelPath("onnx");
|
||||
onnxTextSession = await createOnnxSession(clipModelPath);
|
||||
}
|
||||
return onnxTextSession;
|
||||
}
|
||||
|
||||
let tokenizer: Tokenizer = null;
|
||||
function getTokenizer() {
|
||||
if (!tokenizer) {
|
||||
tokenizer = new Tokenizer();
|
||||
}
|
||||
return tokenizer;
|
||||
}
|
||||
|
||||
export const computeImageEmbedding = async (
|
||||
model: Model,
|
||||
imageData: Uint8Array,
|
||||
): Promise<Float32Array> => {
|
||||
if (!isModel(model)) throw new Error(`Invalid CLIP model ${model}`);
|
||||
|
||||
let tempInputFilePath = null;
|
||||
try {
|
||||
tempInputFilePath = await generateTempFilePath("");
|
||||
const imageStream = new Response(imageData.buffer).body;
|
||||
await writeStream(tempInputFilePath, imageStream);
|
||||
const embedding = await computeImageEmbedding_(
|
||||
model,
|
||||
tempInputFilePath,
|
||||
);
|
||||
return embedding;
|
||||
} catch (err) {
|
||||
if (isExecError(err)) {
|
||||
const parsedExecError = parseExecError(err);
|
||||
throw Error(parsedExecError);
|
||||
} else {
|
||||
throw err;
|
||||
}
|
||||
} finally {
|
||||
if (tempInputFilePath) {
|
||||
await deleteTempFile(tempInputFilePath);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const isExecError = (err: any) => {
|
||||
return err.message.includes("Command failed:");
|
||||
};
|
||||
|
||||
const parseExecError = (err: any) => {
|
||||
const errMessage = err.message;
|
||||
if (errMessage.includes("Bad CPU type in executable")) {
|
||||
return CustomErrors.UNSUPPORTED_PLATFORM(
|
||||
process.platform,
|
||||
process.arch,
|
||||
);
|
||||
} else {
|
||||
return errMessage;
|
||||
}
|
||||
};
|
||||
|
||||
async function computeImageEmbedding_(
|
||||
model: Model,
|
||||
inputFilePath: string,
|
||||
): Promise<Float32Array> {
|
||||
if (!existsSync(inputFilePath)) {
|
||||
throw new Error("Invalid file path");
|
||||
}
|
||||
switch (model) {
|
||||
case "ggml-clip":
|
||||
return await computeGGMLImageEmbedding(inputFilePath);
|
||||
case "onnx-clip":
|
||||
return await computeONNXImageEmbedding(inputFilePath);
|
||||
}
|
||||
}
|
||||
|
||||
const computeGGMLImageEmbedding = async (
|
||||
inputFilePath: string,
|
||||
): Promise<Float32Array> => {
|
||||
const clipModelPath = await getClipImageModelPath("ggml");
|
||||
const ggmlclipPath = getGGMLClipPath();
|
||||
const cmd = IMAGE_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
|
||||
if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
|
||||
return ggmlclipPath;
|
||||
} else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
|
||||
return clipModelPath;
|
||||
} else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
|
||||
return inputFilePath;
|
||||
} else {
|
||||
return cmdPart;
|
||||
}
|
||||
});
|
||||
|
||||
const { stdout } = await execAsync(cmd);
|
||||
// parse stdout and return embedding
|
||||
// get the last line of stdout
|
||||
const lines = stdout.split("\n");
|
||||
const lastLine = lines[lines.length - 1];
|
||||
const embedding = JSON.parse(lastLine);
|
||||
const embeddingArray = new Float32Array(embedding);
|
||||
return embeddingArray;
|
||||
};
|
||||
|
||||
const computeONNXImageEmbedding = async (
|
||||
inputFilePath: string,
|
||||
): Promise<Float32Array> => {
|
||||
const imageSession = await getOnnxImageSession();
|
||||
const t1 = Date.now();
|
||||
const rgbData = await getRGBData(inputFilePath);
|
||||
const feeds = {
|
||||
input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]),
|
||||
};
|
||||
const t2 = Date.now();
|
||||
const results = await imageSession.run(feeds);
|
||||
log.info(
|
||||
`onnx image embedding time: ${Date.now() - t1} ms (prep:${
|
||||
t2 - t1
|
||||
} ms, extraction: ${Date.now() - t2} ms)`,
|
||||
);
|
||||
const imageEmbedding = results["output"].data; // Float32Array
|
||||
return normalizeEmbedding(imageEmbedding);
|
||||
};
|
||||
|
||||
async function getRGBData(inputFilePath: string) {
|
||||
const jpegData = await fs.readFile(inputFilePath);
|
||||
const rawImageData = jpeg.decode(jpegData, {
|
||||
useTArray: true,
|
||||
formatAsRGBA: false,
|
||||
});
|
||||
|
||||
const nx: number = rawImageData.width;
|
||||
const ny: number = rawImageData.height;
|
||||
const inputImage: Uint8Array = rawImageData.data;
|
||||
|
||||
const nx2: number = 224;
|
||||
const ny2: number = 224;
|
||||
const totalSize: number = 3 * nx2 * ny2;
|
||||
|
||||
const result: number[] = Array(totalSize).fill(0);
|
||||
const scale: number = Math.max(nx, ny) / 224;
|
||||
|
||||
const nx3: number = Math.round(nx / scale);
|
||||
const ny3: number = Math.round(ny / scale);
|
||||
|
||||
const mean: number[] = [0.48145466, 0.4578275, 0.40821073];
|
||||
const std: number[] = [0.26862954, 0.26130258, 0.27577711];
|
||||
|
||||
for (let y = 0; y < ny3; y++) {
|
||||
for (let x = 0; x < nx3; x++) {
|
||||
for (let c = 0; c < 3; c++) {
|
||||
// linear interpolation
|
||||
const sx: number = (x + 0.5) * scale - 0.5;
|
||||
const sy: number = (y + 0.5) * scale - 0.5;
|
||||
|
||||
const x0: number = Math.max(0, Math.floor(sx));
|
||||
const y0: number = Math.max(0, Math.floor(sy));
|
||||
|
||||
const x1: number = Math.min(x0 + 1, nx - 1);
|
||||
const y1: number = Math.min(y0 + 1, ny - 1);
|
||||
|
||||
const dx: number = sx - x0;
|
||||
const dy: number = sy - y0;
|
||||
|
||||
const j00: number = 3 * (y0 * nx + x0) + c;
|
||||
const j01: number = 3 * (y0 * nx + x1) + c;
|
||||
const j10: number = 3 * (y1 * nx + x0) + c;
|
||||
const j11: number = 3 * (y1 * nx + x1) + c;
|
||||
|
||||
const v00: number = inputImage[j00];
|
||||
const v01: number = inputImage[j01];
|
||||
const v10: number = inputImage[j10];
|
||||
const v11: number = inputImage[j11];
|
||||
|
||||
const v0: number = v00 * (1 - dx) + v01 * dx;
|
||||
const v1: number = v10 * (1 - dx) + v11 * dx;
|
||||
|
||||
const v: number = v0 * (1 - dy) + v1 * dy;
|
||||
|
||||
const v2: number = Math.min(Math.max(Math.round(v), 0), 255);
|
||||
|
||||
// createTensorWithDataList is dump compared to reshape and hence has to be given with one channel after another
|
||||
const i: number = y * nx3 + x + (c % 3) * 224 * 224;
|
||||
|
||||
result[i] = (v2 / 255 - mean[c]) / std[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
const normalizeEmbedding = (embedding: Float32Array) => {
|
||||
let normalization = 0;
|
||||
for (let index = 0; index < embedding.length; index++) {
|
||||
normalization += embedding[index] * embedding[index];
|
||||
}
|
||||
const sqrtNormalization = Math.sqrt(normalization);
|
||||
for (let index = 0; index < embedding.length; index++) {
|
||||
embedding[index] = embedding[index] / sqrtNormalization;
|
||||
}
|
||||
return embedding;
|
||||
};
|
||||
|
||||
export async function computeTextEmbedding(
|
||||
model: Model,
|
||||
text: string,
|
||||
): Promise<Float32Array> {
|
||||
if (!isModel(model)) throw new Error(`Invalid CLIP model ${model}`);
|
||||
|
||||
try {
|
||||
const embedding = computeTextEmbedding_(model, text);
|
||||
return embedding;
|
||||
} catch (err) {
|
||||
if (isExecError(err)) {
|
||||
const parsedExecError = parseExecError(err);
|
||||
throw Error(parsedExecError);
|
||||
} else {
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function computeTextEmbedding_(
|
||||
model: Model,
|
||||
text: string,
|
||||
): Promise<Float32Array> {
|
||||
switch (model) {
|
||||
case "ggml-clip":
|
||||
return await computeGGMLTextEmbedding(text);
|
||||
case "onnx-clip":
|
||||
return await computeONNXTextEmbedding(text);
|
||||
}
|
||||
}
|
||||
|
||||
export async function computeGGMLTextEmbedding(
|
||||
text: string,
|
||||
): Promise<Float32Array> {
|
||||
const clipModelPath = await getClipTextModelPath("ggml");
|
||||
const ggmlclipPath = getGGMLClipPath();
|
||||
const cmd = TEXT_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
|
||||
if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
|
||||
return ggmlclipPath;
|
||||
} else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
|
||||
return clipModelPath;
|
||||
} else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
|
||||
return text;
|
||||
} else {
|
||||
return cmdPart;
|
||||
}
|
||||
});
|
||||
|
||||
const { stdout } = await execAsync(cmd);
|
||||
// parse stdout and return embedding
|
||||
// get the last line of stdout
|
||||
const lines = stdout.split("\n");
|
||||
const lastLine = lines[lines.length - 1];
|
||||
const embedding = JSON.parse(lastLine);
|
||||
const embeddingArray = new Float32Array(embedding);
|
||||
return embeddingArray;
|
||||
}
|
||||
|
||||
export async function computeONNXTextEmbedding(
|
||||
text: string,
|
||||
): Promise<Float32Array> {
|
||||
const imageSession = await getOnnxTextSession();
|
||||
const t1 = Date.now();
|
||||
const tokenizer = getTokenizer();
|
||||
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
|
||||
const feeds = {
|
||||
input: new ort.Tensor("int32", tokenizedText, [1, 77]),
|
||||
};
|
||||
const t2 = Date.now();
|
||||
const results = await imageSession.run(feeds);
|
||||
log.info(
|
||||
`onnx text embedding time: ${Date.now() - t1} ms (prep:${
|
||||
t2 - t1
|
||||
} ms, extraction: ${Date.now() - t2} ms)`,
|
||||
);
|
||||
const textEmbedding = results["output"].data; // Float32Array
|
||||
return normalizeEmbedding(textEmbedding);
|
||||
}
|
|
@ -80,7 +80,3 @@ export interface AppUpdateInfo {
|
|||
autoUpdatable: boolean;
|
||||
version: string;
|
||||
}
|
||||
|
||||
export type Model = "ggml-clip" | "onnx-clip";
|
||||
|
||||
export const isModel = (s: unknown) => s == "ggml-clip" || s == "onnx-clip";
|
||||
|
|
Loading…
Reference in a new issue