Handle ML related functions in updated preload

This commit is contained in:
Manav Rathi 2024-03-25 12:09:11 +05:30
parent ed4886a6a5
commit 106ba270fe
No known key found for this signature in database
5 changed files with 121 additions and 108 deletions

View file

@ -1,3 +1,17 @@
/**
* [Note: Custom errors across Electron/Renderer boundary]
*
* We need to use the `message` field to disambiguate between errors thrown by
* the main process when invoked from the renderer process. This is because:
*
* > Errors thrown throw `handle` in the main process are not transparent as
* > they are serialized and only the `message` property from the original error
* > is provided to the renderer process.
* >
* > - https://www.electronjs.org/docs/latest/tutorial/ipc
* >
* > Ref: https://github.com/electron/electron/issues/24427
*/
export const CustomErrors = { export const CustomErrors = {
WINDOWS_NATIVE_IMAGE_PROCESSING_NOT_SUPPORTED: WINDOWS_NATIVE_IMAGE_PROCESSING_NOT_SUPPORTED:
"Windows native image processing is not supported", "Windows native image processing is not supported",

View file

@ -7,6 +7,11 @@
*/ */
import { ipcMain } from "electron/main"; import { ipcMain } from "electron/main";
import {
computeImageEmbedding,
computeTextEmbedding,
} from "services/clipService";
import type { Model } from "types";
import { clearElectronStore } from "../api/electronStore"; import { clearElectronStore } from "../api/electronStore";
import { import {
appVersion, appVersion,
@ -59,6 +64,7 @@ export const attachIPCHandlers = () => {
ipcMain.on("update-and-restart", (_) => { ipcMain.on("update-and-restart", (_) => {
updateAndRestart(); updateAndRestart();
}); });
ipcMain.on("skip-app-update", (_, version) => { ipcMain.on("skip-app-update", (_, version) => {
skipAppUpdate(version); skipAppUpdate(version);
}); });
@ -66,4 +72,14 @@ export const attachIPCHandlers = () => {
ipcMain.on("mute-update-notification", (_, version) => { ipcMain.on("mute-update-notification", (_, version) => {
muteUpdateNotification(version); muteUpdateNotification(version);
}); });
ipcMain.handle(
"computeImageEmbedding",
(_, model: Model, imageData: Uint8Array) =>
computeImageEmbedding(model, imageData),
);
ipcMain.handle("computeTextEmbedding", (_, model: Model, text: string) =>
computeTextEmbedding(model, text),
);
}; };

View file

@ -136,6 +136,19 @@ const muteUpdateNotification = (version: string) => {
ipcRenderer.send("mute-update-notification", version); ipcRenderer.send("mute-update-notification", version);
}; };
// - ML
const computeImageEmbedding = (
model: Model,
imageData: Uint8Array,
): Promise<Float32Array> =>
ipcRenderer.invoke("computeImageEmbedding", model, imageData);
const computeTextEmbedding = (
model: Model,
text: string,
): Promise<Float32Array> =>
ipcRenderer.invoke("computeTextEmbedding", model, text);
// - FIXME below this // - FIXME below this
@ -301,104 +314,8 @@ export enum Model {
ONNX_CLIP = "onnx-clip", ONNX_CLIP = "onnx-clip",
} }
const computeImageEmbedding = async (
model: Model,
imageData: Uint8Array,
): Promise<Float32Array> => {
let tempInputFilePath = null;
try {
tempInputFilePath = await ipcRenderer.invoke("get-temp-file-path", "");
const imageStream = new Response(imageData.buffer).body;
await writeStream(tempInputFilePath, imageStream);
const embedding = await ipcRenderer.invoke(
"compute-image-embedding",
model,
tempInputFilePath,
);
return embedding;
} catch (err) {
if (isExecError(err)) {
const parsedExecError = parseExecError(err);
throw Error(parsedExecError);
} else {
throw err;
}
} finally {
if (tempInputFilePath) {
await ipcRenderer.invoke("remove-temp-file", tempInputFilePath);
}
}
};
export async function computeTextEmbedding(
model: Model,
text: string,
): Promise<Float32Array> {
try {
const embedding = await ipcRenderer.invoke(
"compute-text-embedding",
model,
text,
);
return embedding;
} catch (err) {
if (isExecError(err)) {
const parsedExecError = parseExecError(err);
throw Error(parsedExecError);
} else {
throw err;
}
}
}
// - // -
/**
* [Note: Custom errors across Electron/Renderer boundary]
*
* We need to use the `message` field to disambiguate between errors thrown by
* the main process when invoked from the renderer process. This is because:
*
* > Errors thrown throw `handle` in the main process are not transparent as
* > they are serialized and only the `message` property from the original error
* > is provided to the renderer process.
* >
* > - https://www.electronjs.org/docs/latest/tutorial/ipc
* >
* > Ref: https://github.com/electron/electron/issues/24427
*/
/* preload: duplicated CustomErrors */
const CustomErrorsP = {
WINDOWS_NATIVE_IMAGE_PROCESSING_NOT_SUPPORTED:
"Windows native image processing is not supported",
INVALID_OS: (os: string) => `Invalid OS - ${os}`,
WAIT_TIME_EXCEEDED: "Wait time exceeded",
UNSUPPORTED_PLATFORM: (platform: string, arch: string) =>
`Unsupported platform - ${platform} ${arch}`,
MODEL_DOWNLOAD_PENDING:
"Model download pending, skipping clip search request",
INVALID_FILE_PATH: "Invalid file path",
INVALID_CLIP_MODEL: (model: string) => `Invalid Clip model - ${model}`,
};
const isExecError = (err: any) => {
return err.message.includes("Command failed:");
};
const parseExecError = (err: any) => {
const errMessage = err.message;
if (errMessage.includes("Bad CPU type in executable")) {
return CustomErrorsP.UNSUPPORTED_PLATFORM(
process.platform,
process.arch,
);
} else {
return errMessage;
}
};
// - General
const selectDirectory = async (): Promise<string> => { const selectDirectory = async (): Promise<string> => {
try { try {
return await ipcRenderer.invoke("select-dir"); return await ipcRenderer.invoke("select-dir");
@ -458,6 +375,10 @@ contextBridge.exposeInMainWorld("ElectronAPIs", {
muteUpdateNotification, muteUpdateNotification,
registerUpdateEventListener, registerUpdateEventListener,
// - ML
computeImageEmbedding,
computeTextEmbedding,
// - FS // - FS
fs: { fs: {
exists: fsExists, exists: fsExists,
@ -498,8 +419,4 @@ contextBridge.exposeInMainWorld("ElectronAPIs", {
deleteFolder, deleteFolder,
rename, rename,
deleteFile, deleteFile,
// - ML
computeImageEmbedding,
computeTextEmbedding,
}); });

View file

@ -4,13 +4,15 @@ import { existsSync } from "fs";
import * as fs from "node:fs/promises"; import * as fs from "node:fs/promises";
import * as path from "node:path"; import * as path from "node:path";
import util from "util"; import util from "util";
import { generateTempFilePath } from "utils/temp";
import { CustomErrors } from "../constants/errors"; import { CustomErrors } from "../constants/errors";
import { isDev } from "../main/general";
import { logErrorSentry } from "../main/log";
import { Model } from "../types"; import { Model } from "../types";
import Tokenizer from "../utils/clip-bpe-ts/mod"; import Tokenizer from "../utils/clip-bpe-ts/mod";
import { isDev } from "../main/general";
import { getPlatform } from "../utils/common/platform"; import { getPlatform } from "../utils/common/platform";
import { deleteTempFile } from "./ffmpeg";
import { writeStream } from "./fs"; import { writeStream } from "./fs";
import { logErrorSentry } from "../main/log";
const shellescape = require("any-shell-escape"); const shellescape = require("any-shell-escape");
const execAsync = util.promisify(require("child_process").exec); const execAsync = util.promisify(require("child_process").exec);
const jpeg = require("jpeg-js"); const jpeg = require("jpeg-js");
@ -198,7 +200,51 @@ function getTokenizer() {
return tokenizer; return tokenizer;
} }
export async function computeImageEmbedding( export const computeImageEmbedding = async (
model: Model,
imageData: Uint8Array,
): Promise<Float32Array> => {
let tempInputFilePath = null;
try {
tempInputFilePath = await generateTempFilePath("");
const imageStream = new Response(imageData.buffer).body;
await writeStream(tempInputFilePath, imageStream);
const embedding = await computeImageEmbedding_(
model,
tempInputFilePath,
);
return embedding;
} catch (err) {
if (isExecError(err)) {
const parsedExecError = parseExecError(err);
throw Error(parsedExecError);
} else {
throw err;
}
} finally {
if (tempInputFilePath) {
await deleteTempFile(tempInputFilePath);
}
}
};
const isExecError = (err: any) => {
return err.message.includes("Command failed:");
};
const parseExecError = (err: any) => {
const errMessage = err.message;
if (errMessage.includes("Bad CPU type in executable")) {
return CustomErrors.UNSUPPORTED_PLATFORM(
process.platform,
process.arch,
);
} else {
return errMessage;
}
};
async function computeImageEmbedding_(
model: Model, model: Model,
inputFilePath: string, inputFilePath: string,
): Promise<Float32Array> { ): Promise<Float32Array> {
@ -278,6 +324,23 @@ export async function computeONNXImageEmbedding(
export async function computeTextEmbedding( export async function computeTextEmbedding(
model: Model, model: Model,
text: string, text: string,
): Promise<Float32Array> {
try {
const embedding = computeTextEmbedding_(model, text);
return embedding;
} catch (err) {
if (isExecError(err)) {
const parsedExecError = parseExecError(err);
throw Error(parsedExecError);
} else {
throw err;
}
}
}
async function computeTextEmbedding_(
model: Model,
text: string,
): Promise<Float32Array> { ): Promise<Float32Array> {
if (model === Model.GGML_CLIP) { if (model === Model.GGML_CLIP) {
return await computeGGMLTextEmbedding(text); return await computeGGMLTextEmbedding(text);

View file

@ -95,6 +95,14 @@ export interface ElectronAPIsType {
showUpdateDialog: (updateInfo: AppUpdateInfo) => void, showUpdateDialog: (updateInfo: AppUpdateInfo) => void,
) => void; ) => void;
// - ML
computeImageEmbedding: (
model: Model,
imageData: Uint8Array,
) => Promise<Float32Array>;
computeTextEmbedding: (model: Model, text: string) => Promise<Float32Array>;
/** TODO: FIXME or migrate below this */ /** TODO: FIXME or migrate below this */
saveStreamToDisk: ( saveStreamToDisk: (
path: string, path: string,
@ -163,9 +171,4 @@ export interface ElectronAPIsType {
deleteFolder: (path: string) => Promise<void>; deleteFolder: (path: string) => Promise<void>;
deleteFile: (path: string) => Promise<void>; deleteFile: (path: string) => Promise<void>;
rename: (oldPath: string, newPath: string) => Promise<void>; rename: (oldPath: string, newPath: string) => Promise<void>;
computeImageEmbedding: (
model: Model,
imageData: Uint8Array,
) => Promise<Float32Array>;
computeTextEmbedding: (model: Model, text: string) => Promise<Float32Array>;
} }