|
@@ -2,11 +2,10 @@ import { app, net } from "electron/main";
|
|
|
import { existsSync } from "fs";
|
|
|
import fs from "node:fs/promises";
|
|
|
import path from "node:path";
|
|
|
-import { CustomErrors } from "../constants/errors";
|
|
|
import { writeStream } from "../main/fs";
|
|
|
-import log, { logErrorSentry } from "../main/log";
|
|
|
+import log from "../main/log";
|
|
|
import { execAsync, isDev } from "../main/util";
|
|
|
-import { Model } from "../types/ipc";
|
|
|
+import { CustomErrors, Model, isModel } from "../types/ipc";
|
|
|
import Tokenizer from "../utils/clip-bpe-ts/mod";
|
|
|
import { getPlatform } from "../utils/common/platform";
|
|
|
import { generateTempFilePath } from "../utils/temp";
|
|
@@ -78,7 +77,7 @@ async function downloadModel(saveLocation: string, url: string) {
|
|
|
|
|
|
let imageModelDownloadInProgress: Promise<void> = null;
|
|
|
|
|
|
-export async function getClipImageModelPath(type: "ggml" | "onnx") {
|
|
|
+const getClipImageModelPath = async (type: "ggml" | "onnx") => {
|
|
|
try {
|
|
|
const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME[type]);
|
|
|
if (imageModelDownloadInProgress) {
|
|
@@ -86,7 +85,7 @@ export async function getClipImageModelPath(type: "ggml" | "onnx") {
|
|
|
await imageModelDownloadInProgress;
|
|
|
} else {
|
|
|
if (!existsSync(modelSavePath)) {
|
|
|
- log.info("clip image model not found, downloading");
|
|
|
+ log.info("CLIP image model not found, downloading");
|
|
|
imageModelDownloadInProgress = downloadModel(
|
|
|
modelSavePath,
|
|
|
IMAGE_MODEL_DOWNLOAD_URL[type],
|
|
@@ -96,7 +95,7 @@ export async function getClipImageModelPath(type: "ggml" | "onnx") {
|
|
|
const localFileSize = (await fs.stat(modelSavePath)).size;
|
|
|
if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES[type]) {
|
|
|
log.info(
|
|
|
- `clip image model size mismatch, downloading again got: ${localFileSize}`,
|
|
|
+ `CLIP image model size mismatch, downloading again got: ${localFileSize}`,
|
|
|
);
|
|
|
imageModelDownloadInProgress = downloadModel(
|
|
|
modelSavePath,
|
|
@@ -110,21 +109,22 @@ export async function getClipImageModelPath(type: "ggml" | "onnx") {
|
|
|
} finally {
|
|
|
imageModelDownloadInProgress = null;
|
|
|
}
|
|
|
-}
|
|
|
+};
|
|
|
|
|
|
let textModelDownloadInProgress: boolean = false;
|
|
|
|
|
|
-export async function getClipTextModelPath(type: "ggml" | "onnx") {
|
|
|
+const getClipTextModelPath = async (type: "ggml" | "onnx") => {
|
|
|
const modelSavePath = getModelSavePath(TEXT_MODEL_NAME[type]);
|
|
|
if (textModelDownloadInProgress) {
|
|
|
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
|
|
} else {
|
|
|
if (!existsSync(modelSavePath)) {
|
|
|
- log.info("clip text model not found, downloading");
|
|
|
+ log.info("CLIP text model not found, downloading");
|
|
|
textModelDownloadInProgress = true;
|
|
|
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
|
|
|
- .catch(() => {
|
|
|
- // ignore
|
|
|
+ .catch((e) => {
|
|
|
+ // log but otherwise ignore
|
|
|
+ log.error("CLIP text model download failed", e);
|
|
|
})
|
|
|
.finally(() => {
|
|
|
textModelDownloadInProgress = false;
|
|
@@ -134,12 +134,13 @@ export async function getClipTextModelPath(type: "ggml" | "onnx") {
|
|
|
const localFileSize = (await fs.stat(modelSavePath)).size;
|
|
|
if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES[type]) {
|
|
|
log.info(
|
|
|
- `clip text model size mismatch, downloading again got: ${localFileSize}`,
|
|
|
+ `CLIP text model size mismatch, downloading again got: ${localFileSize}`,
|
|
|
);
|
|
|
textModelDownloadInProgress = true;
|
|
|
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
|
|
|
- .catch(() => {
|
|
|
- // ignore
|
|
|
+ .catch((e) => {
|
|
|
+ // log but otherwise ignore
|
|
|
+ log.error("CLIP text model download failed", e);
|
|
|
})
|
|
|
.finally(() => {
|
|
|
textModelDownloadInProgress = false;
|
|
@@ -149,7 +150,7 @@ export async function getClipTextModelPath(type: "ggml" | "onnx") {
|
|
|
}
|
|
|
}
|
|
|
return modelSavePath;
|
|
|
-}
|
|
|
+};
|
|
|
|
|
|
function getGGMLClipPath() {
|
|
|
return isDev
|
|
@@ -198,6 +199,8 @@ export const computeImageEmbedding = async (
|
|
|
model: Model,
|
|
|
imageData: Uint8Array,
|
|
|
): Promise<Float32Array> => {
|
|
|
+ if (!isModel(model)) throw new Error(`Invalid CLIP model ${model}`);
|
|
|
+
|
|
|
let tempInputFilePath = null;
|
|
|
try {
|
|
|
tempInputFilePath = await generateTempFilePath("");
|
|
@@ -243,180 +246,69 @@ async function computeImageEmbedding_(
|
|
|
inputFilePath: string,
|
|
|
): Promise<Float32Array> {
|
|
|
if (!existsSync(inputFilePath)) {
|
|
|
- throw Error(CustomErrors.INVALID_FILE_PATH);
|
|
|
- }
|
|
|
- if (model === Model.GGML_CLIP) {
|
|
|
- return await computeGGMLImageEmbedding(inputFilePath);
|
|
|
- } else if (model === Model.ONNX_CLIP) {
|
|
|
- return await computeONNXImageEmbedding(inputFilePath);
|
|
|
- } else {
|
|
|
- throw Error(CustomErrors.INVALID_CLIP_MODEL(model));
|
|
|
+ throw new Error("Invalid file path");
|
|
|
}
|
|
|
-}
|
|
|
-
|
|
|
-export async function computeGGMLImageEmbedding(
|
|
|
- inputFilePath: string,
|
|
|
-): Promise<Float32Array> {
|
|
|
- try {
|
|
|
- const clipModelPath = await getClipImageModelPath("ggml");
|
|
|
- const ggmlclipPath = getGGMLClipPath();
|
|
|
- const cmd = IMAGE_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
|
|
|
- if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
|
|
|
- return ggmlclipPath;
|
|
|
- } else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
|
|
|
- return clipModelPath;
|
|
|
- } else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
|
|
|
- return inputFilePath;
|
|
|
- } else {
|
|
|
- return cmdPart;
|
|
|
- }
|
|
|
- });
|
|
|
-
|
|
|
- const { stdout } = await execAsync(cmd);
|
|
|
- // parse stdout and return embedding
|
|
|
- // get the last line of stdout
|
|
|
- const lines = stdout.split("\n");
|
|
|
- const lastLine = lines[lines.length - 1];
|
|
|
- const embedding = JSON.parse(lastLine);
|
|
|
- const embeddingArray = new Float32Array(embedding);
|
|
|
- return embeddingArray;
|
|
|
- } catch (err) {
|
|
|
- log.error("Failed to compute GGML image embedding", err);
|
|
|
- throw err;
|
|
|
+ switch (model) {
|
|
|
+ case "ggml-clip":
|
|
|
+ return await computeGGMLImageEmbedding(inputFilePath);
|
|
|
+ case "onnx-clip":
|
|
|
+ return await computeONNXImageEmbedding(inputFilePath);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-export async function computeONNXImageEmbedding(
|
|
|
+const computeGGMLImageEmbedding = async (
|
|
|
inputFilePath: string,
|
|
|
-): Promise<Float32Array> {
|
|
|
- try {
|
|
|
- const imageSession = await getOnnxImageSession();
|
|
|
- const t1 = Date.now();
|
|
|
- const rgbData = await getRGBData(inputFilePath);
|
|
|
- const feeds = {
|
|
|
- input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]),
|
|
|
- };
|
|
|
- const t2 = Date.now();
|
|
|
- const results = await imageSession.run(feeds);
|
|
|
- log.info(
|
|
|
- `onnx image embedding time: ${Date.now() - t1} ms (prep:${
|
|
|
- t2 - t1
|
|
|
- } ms, extraction: ${Date.now() - t2} ms)`,
|
|
|
- );
|
|
|
- const imageEmbedding = results["output"].data; // Float32Array
|
|
|
- return normalizeEmbedding(imageEmbedding);
|
|
|
- } catch (err) {
|
|
|
- log.error("Failed to compute ONNX image embedding", err);
|
|
|
- throw err;
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-export async function computeTextEmbedding(
|
|
|
- model: Model,
|
|
|
- text: string,
|
|
|
-): Promise<Float32Array> {
|
|
|
- try {
|
|
|
- const embedding = computeTextEmbedding_(model, text);
|
|
|
- return embedding;
|
|
|
- } catch (err) {
|
|
|
- if (isExecError(err)) {
|
|
|
- const parsedExecError = parseExecError(err);
|
|
|
- throw Error(parsedExecError);
|
|
|
+): Promise<Float32Array> => {
|
|
|
+ const clipModelPath = await getClipImageModelPath("ggml");
|
|
|
+ const ggmlclipPath = getGGMLClipPath();
|
|
|
+ const cmd = IMAGE_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
|
|
|
+ if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
|
|
|
+ return ggmlclipPath;
|
|
|
+ } else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
|
|
|
+ return clipModelPath;
|
|
|
+ } else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
|
|
|
+ return inputFilePath;
|
|
|
} else {
|
|
|
- throw err;
|
|
|
+ return cmdPart;
|
|
|
}
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-async function computeTextEmbedding_(
|
|
|
- model: Model,
|
|
|
- text: string,
|
|
|
-): Promise<Float32Array> {
|
|
|
- if (model === Model.GGML_CLIP) {
|
|
|
- return await computeGGMLTextEmbedding(text);
|
|
|
- } else {
|
|
|
- return await computeONNXTextEmbedding(text);
|
|
|
- }
|
|
|
-}
|
|
|
+ });
|
|
|
|
|
|
-export async function computeGGMLTextEmbedding(
|
|
|
- text: string,
|
|
|
-): Promise<Float32Array> {
|
|
|
- try {
|
|
|
- const clipModelPath = await getClipTextModelPath("ggml");
|
|
|
- const ggmlclipPath = getGGMLClipPath();
|
|
|
- const cmd = TEXT_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
|
|
|
- if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
|
|
|
- return ggmlclipPath;
|
|
|
- } else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
|
|
|
- return clipModelPath;
|
|
|
- } else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
|
|
|
- return text;
|
|
|
- } else {
|
|
|
- return cmdPart;
|
|
|
- }
|
|
|
- });
|
|
|
-
|
|
|
- const { stdout } = await execAsync(cmd);
|
|
|
- // parse stdout and return embedding
|
|
|
- // get the last line of stdout
|
|
|
- const lines = stdout.split("\n");
|
|
|
- const lastLine = lines[lines.length - 1];
|
|
|
- const embedding = JSON.parse(lastLine);
|
|
|
- const embeddingArray = new Float32Array(embedding);
|
|
|
- return embeddingArray;
|
|
|
- } catch (err) {
|
|
|
- if (err.message === CustomErrors.MODEL_DOWNLOAD_PENDING) {
|
|
|
- log.info(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
|
|
- } else {
|
|
|
- log.error("Failed to compute GGML text embedding", err);
|
|
|
- }
|
|
|
- throw err;
|
|
|
- }
|
|
|
-}
|
|
|
+ const { stdout } = await execAsync(cmd);
|
|
|
+ // parse stdout and return embedding
|
|
|
+ // get the last line of stdout
|
|
|
+ const lines = stdout.split("\n");
|
|
|
+ const lastLine = lines[lines.length - 1];
|
|
|
+ const embedding = JSON.parse(lastLine);
|
|
|
+ const embeddingArray = new Float32Array(embedding);
|
|
|
+ return embeddingArray;
|
|
|
+};
|
|
|
|
|
|
-export async function computeONNXTextEmbedding(
|
|
|
- text: string,
|
|
|
-): Promise<Float32Array> {
|
|
|
- try {
|
|
|
- const imageSession = await getOnnxTextSession();
|
|
|
- const t1 = Date.now();
|
|
|
- const tokenizer = getTokenizer();
|
|
|
- const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
|
|
|
- const feeds = {
|
|
|
- input: new ort.Tensor("int32", tokenizedText, [1, 77]),
|
|
|
- };
|
|
|
- const t2 = Date.now();
|
|
|
- const results = await imageSession.run(feeds);
|
|
|
- log.info(
|
|
|
- `onnx text embedding time: ${Date.now() - t1} ms (prep:${
|
|
|
- t2 - t1
|
|
|
- } ms, extraction: ${Date.now() - t2} ms)`,
|
|
|
- );
|
|
|
- const textEmbedding = results["output"].data; // Float32Array
|
|
|
- return normalizeEmbedding(textEmbedding);
|
|
|
- } catch (err) {
|
|
|
- if (err.message === CustomErrors.MODEL_DOWNLOAD_PENDING) {
|
|
|
- log.info(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
|
|
- } else {
|
|
|
- logErrorSentry(err, "Error in computeONNXTextEmbedding");
|
|
|
- }
|
|
|
- throw err;
|
|
|
- }
|
|
|
-}
|
|
|
+const computeONNXImageEmbedding = async (
|
|
|
+ inputFilePath: string,
|
|
|
+): Promise<Float32Array> => {
|
|
|
+ const imageSession = await getOnnxImageSession();
|
|
|
+ const t1 = Date.now();
|
|
|
+ const rgbData = await getRGBData(inputFilePath);
|
|
|
+ const feeds = {
|
|
|
+ input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]),
|
|
|
+ };
|
|
|
+ const t2 = Date.now();
|
|
|
+ const results = await imageSession.run(feeds);
|
|
|
+ log.info(
|
|
|
+ `onnx image embedding time: ${Date.now() - t1} ms (prep:${
|
|
|
+ t2 - t1
|
|
|
+ } ms, extraction: ${Date.now() - t2} ms)`,
|
|
|
+ );
|
|
|
+ const imageEmbedding = results["output"].data; // Float32Array
|
|
|
+ return normalizeEmbedding(imageEmbedding);
|
|
|
+};
|
|
|
|
|
|
async function getRGBData(inputFilePath: string) {
|
|
|
const jpegData = await fs.readFile(inputFilePath);
|
|
|
- let rawImageData;
|
|
|
- try {
|
|
|
- rawImageData = jpeg.decode(jpegData, {
|
|
|
- useTArray: true,
|
|
|
- formatAsRGBA: false,
|
|
|
- });
|
|
|
- } catch (err) {
|
|
|
- logErrorSentry(err, "JPEG decode error");
|
|
|
- throw err;
|
|
|
- }
|
|
|
+ const rawImageData = jpeg.decode(jpegData, {
|
|
|
+ useTArray: true,
|
|
|
+ formatAsRGBA: false,
|
|
|
+ });
|
|
|
|
|
|
const nx: number = rawImageData.width;
|
|
|
const ny: number = rawImageData.height;
|
|
@@ -479,21 +371,7 @@ async function getRGBData(inputFilePath: string) {
|
|
|
return result;
|
|
|
}
|
|
|
|
|
|
-export const computeClipMatchScore = async (
|
|
|
- imageEmbedding: Float32Array,
|
|
|
- textEmbedding: Float32Array,
|
|
|
-) => {
|
|
|
- if (imageEmbedding.length !== textEmbedding.length) {
|
|
|
- throw Error("imageEmbedding and textEmbedding length mismatch");
|
|
|
- }
|
|
|
- let score = 0;
|
|
|
- for (let index = 0; index < imageEmbedding.length; index++) {
|
|
|
- score += imageEmbedding[index] * textEmbedding[index];
|
|
|
- }
|
|
|
- return score;
|
|
|
-};
|
|
|
-
|
|
|
-export const normalizeEmbedding = (embedding: Float32Array) => {
|
|
|
+const normalizeEmbedding = (embedding: Float32Array) => {
|
|
|
let normalization = 0;
|
|
|
for (let index = 0; index < embedding.length; index++) {
|
|
|
normalization += embedding[index] * embedding[index];
|
|
@@ -504,3 +382,82 @@ export const normalizeEmbedding = (embedding: Float32Array) => {
|
|
|
}
|
|
|
return embedding;
|
|
|
};
|
|
|
+
|
|
|
+export async function computeTextEmbedding(
|
|
|
+ model: Model,
|
|
|
+ text: string,
|
|
|
+): Promise<Float32Array> {
|
|
|
+ if (!isModel(model)) throw new Error(`Invalid CLIP model ${model}`);
|
|
|
+
|
|
|
+ try {
|
|
|
+ const embedding = computeTextEmbedding_(model, text);
|
|
|
+ return embedding;
|
|
|
+ } catch (err) {
|
|
|
+ if (isExecError(err)) {
|
|
|
+ const parsedExecError = parseExecError(err);
|
|
|
+ throw Error(parsedExecError);
|
|
|
+ } else {
|
|
|
+ throw err;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+async function computeTextEmbedding_(
|
|
|
+ model: Model,
|
|
|
+ text: string,
|
|
|
+): Promise<Float32Array> {
|
|
|
+ switch (model) {
|
|
|
+ case "ggml-clip":
|
|
|
+ return await computeGGMLTextEmbedding(text);
|
|
|
+ case "onnx-clip":
|
|
|
+ return await computeONNXTextEmbedding(text);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+export async function computeGGMLTextEmbedding(
|
|
|
+ text: string,
|
|
|
+): Promise<Float32Array> {
|
|
|
+ const clipModelPath = await getClipTextModelPath("ggml");
|
|
|
+ const ggmlclipPath = getGGMLClipPath();
|
|
|
+ const cmd = TEXT_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
|
|
|
+ if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
|
|
|
+ return ggmlclipPath;
|
|
|
+ } else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
|
|
|
+ return clipModelPath;
|
|
|
+ } else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
|
|
|
+ return text;
|
|
|
+ } else {
|
|
|
+ return cmdPart;
|
|
|
+ }
|
|
|
+ });
|
|
|
+
|
|
|
+ const { stdout } = await execAsync(cmd);
|
|
|
+ // parse stdout and return embedding
|
|
|
+ // get the last line of stdout
|
|
|
+ const lines = stdout.split("\n");
|
|
|
+ const lastLine = lines[lines.length - 1];
|
|
|
+ const embedding = JSON.parse(lastLine);
|
|
|
+ const embeddingArray = new Float32Array(embedding);
|
|
|
+ return embeddingArray;
|
|
|
+}
|
|
|
+
|
|
|
+export async function computeONNXTextEmbedding(
|
|
|
+ text: string,
|
|
|
+): Promise<Float32Array> {
|
|
|
+ const imageSession = await getOnnxTextSession();
|
|
|
+ const t1 = Date.now();
|
|
|
+ const tokenizer = getTokenizer();
|
|
|
+ const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
|
|
|
+ const feeds = {
|
|
|
+ input: new ort.Tensor("int32", tokenizedText, [1, 77]),
|
|
|
+ };
|
|
|
+ const t2 = Date.now();
|
|
|
+ const results = await imageSession.run(feeds);
|
|
|
+ log.info(
|
|
|
+ `onnx text embedding time: ${Date.now() - t1} ms (prep:${
|
|
|
+ t2 - t1
|
|
|
+ } ms, extraction: ${Date.now() - t2} ms)`,
|
|
|
+ );
|
|
|
+ const textEmbedding = results["output"].data; // Float32Array
|
|
|
+ return normalizeEmbedding(textEmbedding);
|
|
|
+}
|