|
@@ -1,27 +1,68 @@
|
|
|
import * as log from 'electron-log';
|
|
|
+import util from 'util';
|
|
|
import { logErrorSentry } from './sentry';
|
|
|
import { isDev } from '../utils/common';
|
|
|
import { app } from 'electron';
|
|
|
import path from 'path';
|
|
|
import { existsSync } from 'fs';
|
|
|
import fs from 'fs/promises';
|
|
|
+const shellescape = require('any-shell-escape');
|
|
|
+const execAsync = util.promisify(require('child_process').exec);
|
|
|
import fetch from 'node-fetch';
|
|
|
import { writeNodeStream } from './fs';
|
|
|
+import { getPlatform } from '../utils/common/platform';
|
|
|
import { CustomErrors } from '../constants/errors';
|
|
|
+
|
|
|
+const CLIP_MODEL_PATH_PLACEHOLDER = 'CLIP_MODEL';
|
|
|
+const GGMLCLIP_PATH_PLACEHOLDER = 'GGML_PATH';
|
|
|
+const INPUT_PATH_PLACEHOLDER = 'INPUT';
|
|
|
+
|
|
|
+const IMAGE_EMBEDDING_EXTRACT_CMD: string[] = [
|
|
|
+ GGMLCLIP_PATH_PLACEHOLDER,
|
|
|
+ '-mv',
|
|
|
+ CLIP_MODEL_PATH_PLACEHOLDER,
|
|
|
+ '--image',
|
|
|
+ INPUT_PATH_PLACEHOLDER,
|
|
|
+];
|
|
|
+
|
|
|
+const TEXT_EMBEDDING_EXTRACT_CMD: string[] = [
|
|
|
+ GGMLCLIP_PATH_PLACEHOLDER,
|
|
|
+ '-mt',
|
|
|
+ CLIP_MODEL_PATH_PLACEHOLDER,
|
|
|
+ '--text',
|
|
|
+ INPUT_PATH_PLACEHOLDER,
|
|
|
+];
|
|
|
const ort = require('onnxruntime-node');
|
|
|
const { encode } = require('gpt-3-encoder');
|
|
|
const { createCanvas, Image } = require('canvas');
|
|
|
|
|
|
-const TEXT_MODEL_DOWNLOAD_URL =
|
|
|
- 'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-text-vit-32-float32-int32.onnx';
|
|
|
-const IMAGE_MODEL_DOWNLOAD_URL =
|
|
|
- 'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-image-vit-32-float32.onnx';
|
|
|
+const TEXT_MODEL_DOWNLOAD_URL = {
|
|
|
+ ggml: 'https://models.ente.io/clip-vit-base-patch32_ggml-text-model-f16.gguf',
|
|
|
+ onnx: 'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-text-vit-32-float32-int32.onnx',
|
|
|
+};
|
|
|
+const IMAGE_MODEL_DOWNLOAD_URL = {
|
|
|
+ ggml: 'https://models.ente.io/clip-vit-base-patch32_ggml-vision-model-f16.gguf',
|
|
|
+ onnx: 'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-image-vit-32-float32.onnx',
|
|
|
+};
|
|
|
+
|
|
|
+const TEXT_MODEL_NAME = {
|
|
|
+ ggml: 'clip-vit-base-patch32_ggml-text-model-f16.gguf',
|
|
|
+ onnx: 'clip-text-vit-32-float32-int32.onnx',
|
|
|
+};
|
|
|
+const IMAGE_MODEL_NAME = {
|
|
|
+ ggml: 'clip-vit-base-patch32_ggml-vision-model-f16.gguf',
|
|
|
+ onnx: 'clip-image-vit-32-float32.onnx',
|
|
|
+};
|
|
|
|
|
|
-const TEXT_MODEL_NAME = 'clip-text-vit-32-float32-int32.onnx';
|
|
|
-const IMAGE_MODEL_NAME = 'clip-image-vit-32-float32.onnx';
|
|
|
+const IMAGE_MODEL_SIZE_IN_BYTES = {
|
|
|
+ ggml: 175957504, // 167.8 MB
|
|
|
+ onnx: 351468764, // 335.2 MB
|
|
|
+};
|
|
|
+const TEXT_MODEL_SIZE_IN_BYTES = {
|
|
|
+ ggml: 127853440, // 121.9 MB,
|
|
|
+ onnx: 254069585, // 242.3 MB
|
|
|
+};
|
|
|
|
|
|
-const IMAGE_MODEL_SIZE_IN_BYTES = 351468764; // 335.2 MB
|
|
|
-const TEXT_MODEL_SIZE_IN_BYTES = 254069585; // 242.3 MB
|
|
|
const MODEL_SAVE_FOLDER = 'models';
|
|
|
|
|
|
function getModelSavePath(modelName: string) {
|
|
@@ -49,9 +90,9 @@ async function downloadModel(saveLocation: string, url: string) {
|
|
|
|
|
|
let imageModelDownloadInProgress: Promise<void> = null;
|
|
|
|
|
|
-export async function getClipImageModelPath() {
|
|
|
+export async function getClipImageModelPath(type: 'ggml' | 'onnx') {
|
|
|
try {
|
|
|
- const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME);
|
|
|
+ const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME[type]);
|
|
|
if (imageModelDownloadInProgress) {
|
|
|
log.info('waiting for image model download to finish');
|
|
|
await imageModelDownloadInProgress;
|
|
@@ -60,19 +101,19 @@ export async function getClipImageModelPath() {
|
|
|
log.info('clip image model not found, downloading');
|
|
|
imageModelDownloadInProgress = downloadModel(
|
|
|
modelSavePath,
|
|
|
- IMAGE_MODEL_DOWNLOAD_URL
|
|
|
+ IMAGE_MODEL_DOWNLOAD_URL[type]
|
|
|
);
|
|
|
await imageModelDownloadInProgress;
|
|
|
} else {
|
|
|
const localFileSize = (await fs.stat(modelSavePath)).size;
|
|
|
- if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES) {
|
|
|
+ if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES[type]) {
|
|
|
log.info(
|
|
|
'clip image model size mismatch, downloading again got:',
|
|
|
localFileSize
|
|
|
);
|
|
|
imageModelDownloadInProgress = downloadModel(
|
|
|
modelSavePath,
|
|
|
- IMAGE_MODEL_DOWNLOAD_URL
|
|
|
+ IMAGE_MODEL_DOWNLOAD_URL[type]
|
|
|
);
|
|
|
await imageModelDownloadInProgress;
|
|
|
}
|
|
@@ -86,15 +127,15 @@ export async function getClipImageModelPath() {
|
|
|
|
|
|
let textModelDownloadInProgress: boolean = false;
|
|
|
|
|
|
-export async function getClipTextModelPath() {
|
|
|
- const modelSavePath = getModelSavePath(TEXT_MODEL_NAME);
|
|
|
+export async function getClipTextModelPath(type: 'ggml' | 'onnx') {
|
|
|
+ const modelSavePath = getModelSavePath(TEXT_MODEL_NAME[type]);
|
|
|
if (textModelDownloadInProgress) {
|
|
|
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
|
|
} else {
|
|
|
if (!existsSync(modelSavePath)) {
|
|
|
log.info('clip text model not found, downloading');
|
|
|
textModelDownloadInProgress = true;
|
|
|
- downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL)
|
|
|
+ downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
|
|
|
.catch(() => {
|
|
|
// ignore
|
|
|
})
|
|
@@ -104,13 +145,13 @@ export async function getClipTextModelPath() {
|
|
|
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
|
|
} else {
|
|
|
const localFileSize = (await fs.stat(modelSavePath)).size;
|
|
|
- if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES) {
|
|
|
+ if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES[type]) {
|
|
|
log.info(
|
|
|
'clip text model size mismatch, downloading again',
|
|
|
localFileSize
|
|
|
);
|
|
|
textModelDownloadInProgress = true;
|
|
|
- downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL)
|
|
|
+ downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
|
|
|
.catch(() => {
|
|
|
// ignore
|
|
|
})
|
|
@@ -124,6 +165,12 @@ export async function getClipTextModelPath() {
|
|
|
return modelSavePath;
|
|
|
}
|
|
|
|
|
|
+function getGGMLClipPath() {
|
|
|
+ return isDev
|
|
|
+ ? path.join('./build', `ggmlclip-${getPlatform()}`)
|
|
|
+ : path.join(process.resourcesPath, `ggmlclip-${getPlatform()}`);
|
|
|
+}
|
|
|
+
|
|
|
async function createOnnxSession(modelPath: string) {
|
|
|
return await ort.InferenceSession.create(modelPath, {
|
|
|
intraOpNumThreads: 1,
|
|
@@ -135,7 +182,7 @@ let onnxImageSession: any = null;
|
|
|
|
|
|
async function getOnnxImageSession() {
|
|
|
if (!onnxImageSession) {
|
|
|
- const clipModelPath = await getClipImageModelPath();
|
|
|
+ const clipModelPath = await getClipImageModelPath('onnx');
|
|
|
onnxImageSession = createOnnxSession(clipModelPath);
|
|
|
}
|
|
|
return onnxImageSession;
|
|
@@ -145,7 +192,7 @@ let onnxTextSession: any = null;
|
|
|
|
|
|
async function getOnnxTextSession() {
|
|
|
if (!onnxTextSession) {
|
|
|
- const clipModelPath = await getClipTextModelPath();
|
|
|
+ const clipModelPath = await getClipTextModelPath('onnx');
|
|
|
onnxTextSession = createOnnxSession(clipModelPath);
|
|
|
}
|
|
|
return onnxTextSession;
|
|
@@ -153,6 +200,55 @@ async function getOnnxTextSession() {
|
|
|
|
|
|
export async function computeImageEmbedding(
|
|
|
inputFilePath: string
|
|
|
+): Promise<Float32Array> {
|
|
|
+ const ggmlImageEmbedding = await computeGGMLImageEmbedding(inputFilePath);
|
|
|
+ const onnxImageEmbedding = await computeONNXImageEmbedding(inputFilePath);
|
|
|
+ const score = await computeClipMatchScore(
|
|
|
+ ggmlImageEmbedding,
|
|
|
+ onnxImageEmbedding
|
|
|
+ );
|
|
|
+ console.log('imageEmbeddingScore', score);
|
|
|
+ return onnxImageEmbedding;
|
|
|
+}
|
|
|
+
|
|
|
+export async function computeGGMLImageEmbedding(
|
|
|
+ inputFilePath: string
|
|
|
+): Promise<Float32Array> {
|
|
|
+ try {
|
|
|
+ const clipModelPath = await getClipImageModelPath('ggml');
|
|
|
+ const ggmlclipPath = getGGMLClipPath();
|
|
|
+ const cmd = IMAGE_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
|
|
|
+ if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
|
|
|
+ return ggmlclipPath;
|
|
|
+ } else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
|
|
|
+ return clipModelPath;
|
|
|
+ } else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
|
|
|
+ return inputFilePath;
|
|
|
+ } else {
|
|
|
+ return cmdPart;
|
|
|
+ }
|
|
|
+ });
|
|
|
+
|
|
|
+ const escapedCmd = shellescape(cmd);
|
|
|
+ log.info('running clip command', escapedCmd);
|
|
|
+ const startTime = Date.now();
|
|
|
+ const { stdout } = await execAsync(escapedCmd);
|
|
|
+ log.info('clip command execution time ', Date.now() - startTime);
|
|
|
+ // parse stdout and return embedding
|
|
|
+ // get the last line of stdout
|
|
|
+ const lines = stdout.split('\n');
|
|
|
+ const lastLine = lines[lines.length - 1];
|
|
|
+ const embedding = JSON.parse(lastLine);
|
|
|
+ const embeddingArray = new Float32Array(embedding);
|
|
|
+ return embeddingArray;
|
|
|
+ } catch (err) {
|
|
|
+ logErrorSentry(err, 'Error in computeImageEmbedding');
|
|
|
+ throw err;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+export async function computeONNXImageEmbedding(
|
|
|
+ inputFilePath: string
|
|
|
): Promise<Float32Array> {
|
|
|
try {
|
|
|
const imageSession = await getOnnxImageSession();
|
|
@@ -170,6 +266,59 @@ export async function computeImageEmbedding(
|
|
|
}
|
|
|
|
|
|
export async function computeTextEmbedding(
|
|
|
+ inputFilePath: string
|
|
|
+): Promise<Float32Array> {
|
|
|
+ const ggmlImageEmbedding = await computeGGMLTextEmbedding(inputFilePath);
|
|
|
+ const onnxImageEmbedding = await computeONNXTextEmbedding(inputFilePath);
|
|
|
+ const score = await computeClipMatchScore(
|
|
|
+ ggmlImageEmbedding,
|
|
|
+ onnxImageEmbedding
|
|
|
+ );
|
|
|
+ console.log('textEmbeddingScore', score);
|
|
|
+ return onnxImageEmbedding;
|
|
|
+}
|
|
|
+
|
|
|
+export async function computeGGMLTextEmbedding(
|
|
|
+ text: string
|
|
|
+): Promise<Float32Array> {
|
|
|
+ try {
|
|
|
+ const clipModelPath = await getClipTextModelPath('ggml');
|
|
|
+ const ggmlclipPath = getGGMLClipPath();
|
|
|
+ const cmd = TEXT_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
|
|
|
+ if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
|
|
|
+ return ggmlclipPath;
|
|
|
+ } else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
|
|
|
+ return clipModelPath;
|
|
|
+ } else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
|
|
|
+ return text;
|
|
|
+ } else {
|
|
|
+ return cmdPart;
|
|
|
+ }
|
|
|
+ });
|
|
|
+
|
|
|
+ const escapedCmd = shellescape(cmd);
|
|
|
+ log.info('running clip command', escapedCmd);
|
|
|
+ const startTime = Date.now();
|
|
|
+ const { stdout } = await execAsync(escapedCmd);
|
|
|
+ log.info('clip command execution time ', Date.now() - startTime);
|
|
|
+ // parse stdout and return embedding
|
|
|
+ // get the last line of stdout
|
|
|
+ const lines = stdout.split('\n');
|
|
|
+ const lastLine = lines[lines.length - 1];
|
|
|
+ const embedding = JSON.parse(lastLine);
|
|
|
+ const embeddingArray = new Float32Array(embedding);
|
|
|
+ return embeddingArray;
|
|
|
+ } catch (err) {
|
|
|
+ if (err.message === CustomErrors.MODEL_DOWNLOAD_PENDING) {
|
|
|
+ log.info(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
|
|
+ } else {
|
|
|
+ logErrorSentry(err, 'Error in computeTextEmbedding');
|
|
|
+ }
|
|
|
+ throw err;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+export async function computeONNXTextEmbedding(
|
|
|
text: string
|
|
|
): Promise<Float32Array> {
|
|
|
try {
|
|
@@ -239,3 +388,30 @@ async function getRgbData(inputFilePath: string) {
|
|
|
}
|
|
|
return Float32Array.from(rgbData.flat().flat());
|
|
|
}
|
|
|
+
|
|
|
+export const computeClipMatchScore = async (
|
|
|
+ imageEmbedding: Float32Array,
|
|
|
+ textEmbedding: Float32Array
|
|
|
+) => {
|
|
|
+ if (imageEmbedding.length !== textEmbedding.length) {
|
|
|
+ throw Error('imageEmbedding and textEmbedding length mismatch');
|
|
|
+ }
|
|
|
+ let score = 0;
|
|
|
+ let imageNormalization = 0;
|
|
|
+ let textNormalization = 0;
|
|
|
+
|
|
|
+ for (let index = 0; index < imageEmbedding.length; index++) {
|
|
|
+ imageNormalization += imageEmbedding[index] * imageEmbedding[index];
|
|
|
+ textNormalization += textEmbedding[index] * textEmbedding[index];
|
|
|
+ }
|
|
|
+ for (let index = 0; index < imageEmbedding.length; index++) {
|
|
|
+ imageEmbedding[index] =
|
|
|
+ imageEmbedding[index] / Math.sqrt(imageNormalization);
|
|
|
+ textEmbedding[index] =
|
|
|
+ textEmbedding[index] / Math.sqrt(textNormalization);
|
|
|
+ }
|
|
|
+ for (let index = 0; index < imageEmbedding.length; index++) {
|
|
|
+ score += imageEmbedding[index] * textEmbedding[index];
|
|
|
+ }
|
|
|
+ return score;
|
|
|
+};
|