Merge pull request #304 from ente-io/onnx-clip

ONNX clip
This commit is contained in:
Abhinav Kumar 2024-01-16 11:47:18 +05:30 committed by GitHub
commit 7bd908c142
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 819 additions and 43 deletions

View file

@ -1,7 +1,7 @@
{
"name": "ente",
"productName": "ente",
"version": "1.6.60",
"version": "1.6.61-alpha.6",
"private": true,
"description": "Desktop client for ente.io",
"main": "app/main.js",
@ -134,9 +134,12 @@
"electron-updater": "^4.3.8",
"ffmpeg-static": "^5.1.0",
"get-folder-size": "^2.0.1",
"html-entities": "^2.4.0",
"jpeg-js": "^0.4.4",
"next-electron-server": "file:./thirdparty/next-electron-server",
"node-fetch": "^2.6.7",
"node-stream-zip": "^1.15.0",
"onnxruntime-node": "^1.16.3",
"promise-fs": "^2.1.1"
},
"standard": {

View file

@ -2,8 +2,10 @@ import { ipcRenderer } from 'electron';
import { writeStream } from '../services/fs';
import { isExecError } from '../utils/error';
import { parseExecError } from '../utils/error';
import { Model } from '../types';
export async function computeImageEmbedding(
model: Model,
imageData: Uint8Array
): Promise<Float32Array> {
let tempInputFilePath = null;
@ -13,6 +15,7 @@ export async function computeImageEmbedding(
await writeStream(tempInputFilePath, imageStream);
const embedding = await ipcRenderer.invoke(
'compute-image-embedding',
model,
tempInputFilePath
);
return embedding;
@ -31,11 +34,13 @@ export async function computeImageEmbedding(
}
export async function computeTextEmbedding(
model: Model,
text: string
): Promise<Float32Array> {
try {
const embedding = await ipcRenderer.invoke(
'compute-text-embedding',
model,
text
);
return embedding;

View file

@ -7,4 +7,6 @@ export const CustomErrors = {
`Unsupported platform - ${platform} ${arch}`,
MODEL_DOWNLOAD_PENDING:
'Model download pending, skipping clip search request',
INVALID_FILE_PATH: 'Invalid file path',
INVALID_CLIP_MODEL: (model: string) => `Invalid Clip model - ${model}`,
};

View file

@ -12,6 +12,7 @@ import fetch from 'node-fetch';
import { writeNodeStream } from './fs';
import { getPlatform } from '../utils/common/platform';
import { CustomErrors } from '../constants/errors';
const jpeg = require('jpeg-js');
const CLIP_MODEL_PATH_PLACEHOLDER = 'CLIP_MODEL';
const GGMLCLIP_PATH_PLACEHOLDER = 'GGML_PATH';
@ -32,17 +33,38 @@ const TEXT_EMBEDDING_EXTRACT_CMD: string[] = [
'--text',
INPUT_PATH_PLACEHOLDER,
];
const ort = require('onnxruntime-node');
import Tokenizer from '../utils/clip-bpe-ts/mod';
import { readFile } from 'promise-fs';
import { Model } from '../types';
const TEXT_MODEL_DOWNLOAD_URL =
'https://models.ente.io/clip-vit-base-patch32_ggml-text-model-f16.gguf';
const IMAGE_MODEL_DOWNLOAD_URL =
'https://models.ente.io/clip-vit-base-patch32_ggml-vision-model-f16.gguf';
const TEXT_MODEL_DOWNLOAD_URL = {
ggml: 'https://models.ente.io/clip-vit-base-patch32_ggml-text-model-f16.gguf',
onnx: 'https://models.ente.io/clip-text-vit-32-uint8.onnx',
};
const IMAGE_MODEL_DOWNLOAD_URL = {
ggml: 'https://models.ente.io/clip-vit-base-patch32_ggml-vision-model-f16.gguf',
onnx: 'https://models.ente.io/clip-image-vit-32-float32.onnx',
};
const TEXT_MODEL_NAME = 'clip-vit-base-patch32_ggml-text-model-f16.gguf';
const IMAGE_MODEL_NAME = 'clip-vit-base-patch32_ggml-vision-model-f16.gguf';
const TEXT_MODEL_NAME = {
ggml: 'clip-vit-base-patch32_ggml-text-model-f16.gguf',
onnx: 'clip-text-vit-32-uint8.onnx',
};
const IMAGE_MODEL_NAME = {
ggml: 'clip-vit-base-patch32_ggml-vision-model-f16.gguf',
onnx: 'clip-image-vit-32-float32.onnx',
};
const IMAGE_MODEL_SIZE_IN_BYTES = {
ggml: 175957504, // 167.8 MB
onnx: 351468764, // 335.2 MB
};
const TEXT_MODEL_SIZE_IN_BYTES = {
ggml: 127853440, // 121.9 MB,
onnx: 64173509, // 61.2 MB
};
const IMAGE_MODEL_SIZE_IN_BYTES = 175957504; // 167.8 MB
const TEXT_MODEL_SIZE_IN_BYTES = 127853440; // 121.9 MB
const MODEL_SAVE_FOLDER = 'models';
function getModelSavePath(modelName: string) {
@ -64,15 +86,15 @@ async function downloadModel(saveLocation: string, url: string) {
}
log.info('downloading clip model');
const resp = await fetch(url);
await writeNodeStream(saveLocation, resp.body, true);
await writeNodeStream(saveLocation, resp.body);
log.info('clip model downloaded');
}
let imageModelDownloadInProgress: Promise<void> = null;
export async function getClipImageModelPath() {
export async function getClipImageModelPath(type: 'ggml' | 'onnx') {
try {
const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME);
const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME[type]);
if (imageModelDownloadInProgress) {
log.info('waiting for image model download to finish');
await imageModelDownloadInProgress;
@ -81,16 +103,19 @@ export async function getClipImageModelPath() {
log.info('clip image model not found, downloading');
imageModelDownloadInProgress = downloadModel(
modelSavePath,
IMAGE_MODEL_DOWNLOAD_URL
IMAGE_MODEL_DOWNLOAD_URL[type]
);
await imageModelDownloadInProgress;
} else {
const localFileSize = (await fs.stat(modelSavePath)).size;
if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES) {
log.info('clip model size mismatch, downloading again');
if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES[type]) {
log.info(
'clip image model size mismatch, downloading again got:',
localFileSize
);
imageModelDownloadInProgress = downloadModel(
modelSavePath,
IMAGE_MODEL_DOWNLOAD_URL
IMAGE_MODEL_DOWNLOAD_URL[type]
);
await imageModelDownloadInProgress;
}
@ -104,15 +129,15 @@ export async function getClipImageModelPath() {
let textModelDownloadInProgress: boolean = false;
export async function getClipTextModelPath() {
const modelSavePath = getModelSavePath(TEXT_MODEL_NAME);
export async function getClipTextModelPath(type: 'ggml' | 'onnx') {
const modelSavePath = getModelSavePath(TEXT_MODEL_NAME[type]);
if (textModelDownloadInProgress) {
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
} else {
if (!existsSync(modelSavePath)) {
log.info('clip text model not found, downloading');
textModelDownloadInProgress = true;
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL)
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
.catch(() => {
// ignore
})
@ -122,10 +147,13 @@ export async function getClipTextModelPath() {
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
} else {
const localFileSize = (await fs.stat(modelSavePath)).size;
if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES) {
log.info('clip model size mismatch, downloading again');
if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES[type]) {
log.info(
'clip text model size mismatch, downloading again got:',
localFileSize
);
textModelDownloadInProgress = true;
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL)
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
.catch(() => {
// ignore
})
@ -145,11 +173,64 @@ function getGGMLClipPath() {
: path.join(process.resourcesPath, `ggmlclip-${getPlatform()}`);
}
async function createOnnxSession(modelPath: string) {
return await ort.InferenceSession.create(modelPath, {
intraOpNumThreads: 1,
enableCpuMemArena: false,
});
}
let onnxImageSessionPromise: Promise<any> = null;
async function getOnnxImageSession() {
if (!onnxImageSessionPromise) {
onnxImageSessionPromise = (async () => {
const clipModelPath = await getClipImageModelPath('onnx');
return createOnnxSession(clipModelPath);
})();
}
return onnxImageSessionPromise;
}
let onnxTextSession: any = null;
async function getOnnxTextSession() {
if (!onnxTextSession) {
const clipModelPath = await getClipTextModelPath('onnx');
onnxTextSession = await createOnnxSession(clipModelPath);
}
return onnxTextSession;
}
let tokenizer: Tokenizer = null;
function getTokenizer() {
if (!tokenizer) {
tokenizer = new Tokenizer();
}
return tokenizer;
}
export async function computeImageEmbedding(
model: Model,
inputFilePath: string
): Promise<Float32Array> {
if (!existsSync(inputFilePath)) {
throw Error(CustomErrors.INVALID_FILE_PATH);
}
if (model === Model.GGML_CLIP) {
return await computeGGMLImageEmbedding(inputFilePath);
} else if (model === Model.ONNX_CLIP) {
return await computeONNXImageEmbedding(inputFilePath);
} else {
throw Error(CustomErrors.INVALID_CLIP_MODEL(model));
}
}
export async function computeGGMLImageEmbedding(
inputFilePath: string
): Promise<Float32Array> {
try {
const clipModelPath = await getClipImageModelPath();
const clipModelPath = await getClipImageModelPath('ggml');
const ggmlclipPath = getGGMLClipPath();
const cmd = IMAGE_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
@ -176,16 +257,45 @@ export async function computeImageEmbedding(
const embeddingArray = new Float32Array(embedding);
return embeddingArray;
} catch (err) {
logErrorSentry(err, 'Error in computeImageEmbedding');
logErrorSentry(err, 'Error in computeGGMLImageEmbedding');
throw err;
}
}
export async function computeONNXImageEmbedding(
inputFilePath: string
): Promise<Float32Array> {
try {
const imageSession = await getOnnxImageSession();
const rgbData = await getRGBData(inputFilePath);
const feeds = {
input: new ort.Tensor('float32', rgbData, [1, 3, 224, 224]),
};
const results = await imageSession.run(feeds);
const imageEmbedding = results['output'].data; // Float32Array
return normalizeEmbedding(imageEmbedding);
} catch (err) {
logErrorSentry(err, 'Error in computeONNXImageEmbedding');
throw err;
}
}
export async function computeTextEmbedding(
model: Model,
text: string
): Promise<Float32Array> {
if (model === Model.GGML_CLIP) {
return await computeGGMLTextEmbedding(text);
} else {
return await computeONNXTextEmbedding(text);
}
}
export async function computeGGMLTextEmbedding(
text: string
): Promise<Float32Array> {
try {
const clipModelPath = await getClipTextModelPath();
const clipModelPath = await getClipTextModelPath('ggml');
const ggmlclipPath = getGGMLClipPath();
const cmd = TEXT_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
@ -215,8 +325,131 @@ export async function computeTextEmbedding(
if (err.message === CustomErrors.MODEL_DOWNLOAD_PENDING) {
log.info(CustomErrors.MODEL_DOWNLOAD_PENDING);
} else {
logErrorSentry(err, 'Error in computeTextEmbedding');
logErrorSentry(err, 'Error in computeGGMLTextEmbedding');
}
throw err;
}
}
export async function computeONNXTextEmbedding(
text: string
): Promise<Float32Array> {
try {
const imageSession = await getOnnxTextSession();
const tokenizer = getTokenizer();
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
const feeds = {
input: new ort.Tensor('int32', tokenizedText, [1, 77]),
};
const results = await imageSession.run(feeds);
const textEmbedding = results['output'].data; // Float32Array
return normalizeEmbedding(textEmbedding);
} catch (err) {
if (err.message === CustomErrors.MODEL_DOWNLOAD_PENDING) {
log.info(CustomErrors.MODEL_DOWNLOAD_PENDING);
} else {
logErrorSentry(err, 'Error in computeONNXTextEmbedding');
}
throw err;
}
}
async function getRGBData(inputFilePath: string) {
const jpegData = await readFile(inputFilePath);
let rawImageData;
try {
rawImageData = jpeg.decode(jpegData, {
useTArray: true,
formatAsRGBA: false,
});
} catch (err) {
logErrorSentry(err, 'JPEG decode error');
throw err;
}
const nx: number = rawImageData.width;
const ny: number = rawImageData.height;
const inputImage: Uint8Array = rawImageData.data;
const nx2: number = 224;
const ny2: number = 224;
const totalSize: number = 3 * nx2 * ny2;
const result: number[] = Array(totalSize).fill(0);
const scale: number = Math.max(nx, ny) / 224;
const nx3: number = Math.round(nx / scale);
const ny3: number = Math.round(ny / scale);
const mean: number[] = [0.48145466, 0.4578275, 0.40821073];
const std: number[] = [0.26862954, 0.26130258, 0.27577711];
for (let y = 0; y < ny3; y++) {
for (let x = 0; x < nx3; x++) {
for (let c = 0; c < 3; c++) {
// linear interpolation
const sx: number = (x + 0.5) * scale - 0.5;
const sy: number = (y + 0.5) * scale - 0.5;
const x0: number = Math.max(0, Math.floor(sx));
const y0: number = Math.max(0, Math.floor(sy));
const x1: number = Math.min(x0 + 1, nx - 1);
const y1: number = Math.min(y0 + 1, ny - 1);
const dx: number = sx - x0;
const dy: number = sy - y0;
const j00: number = 3 * (y0 * nx + x0) + c;
const j01: number = 3 * (y0 * nx + x1) + c;
const j10: number = 3 * (y1 * nx + x0) + c;
const j11: number = 3 * (y1 * nx + x1) + c;
const v00: number = inputImage[j00];
const v01: number = inputImage[j01];
const v10: number = inputImage[j10];
const v11: number = inputImage[j11];
const v0: number = v00 * (1 - dx) + v01 * dx;
const v1: number = v10 * (1 - dx) + v11 * dx;
const v: number = v0 * (1 - dy) + v1 * dy;
const v2: number = Math.min(Math.max(Math.round(v), 0), 255);
// createTensorWithDataList is dump compared to reshape and hence has to be given with one channel after another
const i: number = y * nx3 + x + (c % 3) * 224 * 224;
result[i] = (v2 / 255 - mean[c]) / std[c];
}
}
}
return result;
}
export const computeClipMatchScore = async (
imageEmbedding: Float32Array,
textEmbedding: Float32Array
) => {
if (imageEmbedding.length !== textEmbedding.length) {
throw Error('imageEmbedding and textEmbedding length mismatch');
}
let score = 0;
for (let index = 0; index < imageEmbedding.length; index++) {
score += imageEmbedding[index] * textEmbedding[index];
}
return score;
};
export const normalizeEmbedding = (embedding: Float32Array) => {
let normalization = 0;
for (let index = 0; index < embedding.length; index++) {
normalization += embedding[index] * embedding[index];
}
const sqrtNormalization = Math.sqrt(normalization);
for (let index = 0; index < embedding.length; index++) {
embedding[index] = embedding[index] / sqrtNormalization;
}
return embedding;
};

View file

@ -6,8 +6,6 @@ import StreamZip from 'node-stream-zip';
import { Readable } from 'stream';
import { logError } from './logging';
import { existsSync } from 'fs';
import { log } from 'electron-log';
import { convertBytesToHumanReadable } from '../utils/logging';
// https://stackoverflow.com/a/63111390
export const getDirFilePaths = async (dirPath: string) => {
@ -230,8 +228,7 @@ export const convertBrowserStreamToNode = (
export async function writeNodeStream(
filePath: string,
fileStream: NodeJS.ReadableStream,
enableLogging = false
fileStream: NodeJS.ReadableStream
) {
const writeable = fs.createWriteStream(filePath);
@ -241,14 +238,6 @@ export async function writeNodeStream(
fileStream.pipe(writeable);
let downloaded = 0;
if (enableLogging) {
fileStream.on('data', (chunk) => {
downloaded += chunk.length;
log(`Received ${convertBytesToHumanReadable(downloaded)} of data.`);
});
}
await new Promise((resolve, reject) => {
writeable.on('finish', resolve);
writeable.on('error', async (e) => {

View file

@ -70,3 +70,8 @@ export interface AppUpdateInfo {
export interface GetFeatureFlagResponse {
desktopCutoffVersion?: string;
}
export enum Model {
GGML_CLIP = 'ggml-clip',
ONNX_CLIP = 'onnx-clip',
}

View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 josephrocca
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,26 @@
# CLIP Byte Pair Encoding JavaScript Port
A JavaScript port of [OpenAI's CLIP byte-pair-encoding tokenizer](https://github.com/openai/CLIP/blob/3bee28119e6b28e75b82b811b87b56935314e6a5/clip/simple_tokenizer.py).
```js
import Tokenizer from "https://deno.land/x/clip_bpe@v0.0.6/mod.js";
let t = new Tokenizer();
t.encode("hello") // [3306]
t.encode("magnificent") // [10724]
t.encode("magnificently") // [9725, 2922]
t.decode(t.encode("HELLO")) // "hello "
t.decode(t.encode("abc123")) // "abc 1 2 3 "
t.decode(st.encode("let's see here")) // "let 's see here "
t.encode("hello world!") // [3306, 1002, 256]
// to encode for CLIP (trims to maximum of 77 tokens and adds start and end token, and pads with zeros if less than 77 tokens):
t.encodeForCLIP("hello world!") // [49406,3306,1002,256,49407,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
```
This encoder/decoder behaves differently to the the GPT-2/3 tokenizer (JavaScript version of that [here](https://github.com/latitudegames/GPT-3-Encoder)). For example, it doesn't preserve capital letters, as shown above.
The [Python version](https://github.com/openai/CLIP/blob/3bee28119e6b28e75b82b811b87b56935314e6a5/clip/simple_tokenizer.py) of this tokenizer uses the `ftfy` module to clean up the text before encoding it. I didn't include that module by default because currently the only version available in JavaScript is [this one](https://github.com/josephrocca/ftfy-pyodide), which requires importing a full Python runtime as a WebAssembly module. If you want the `ftfy` cleaning, just import it and clean your text with it before passing it to the `.encode()` method.
# License
To the extent that there is any original work in this repo, it is MIT Licensed, just like [openai/CLIP](https://github.com/openai/CLIP).

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,466 @@
import * as htmlEntities from 'html-entities';
import bpeVocabData from './bpe_simple_vocab_16e6';
// import ftfy from "https://deno.land/x/ftfy_pyodide@v0.1.1/mod.js";
function ord(c: string) {
return c.charCodeAt(0);
}
function range(start: number, stop?: number, step: number = 1) {
if (stop === undefined) {
stop = start;
start = 0;
}
if ((step > 0 && start >= stop) || (step < 0 && start <= stop)) {
return [];
}
const result: number[] = [];
for (let i = start; step > 0 ? i < stop : i > stop; i += step) {
result.push(i);
}
return result;
}
function bytesToUnicode() {
const bs = [
...range(ord('!'), ord('~') + 1),
...range(ord('¡'), ord('¬') + 1),
...range(ord('®'), ord('ÿ') + 1),
];
const cs = bs.slice(0);
let n = 0;
for (const b of range(2 ** 8)) {
if (!bs.includes(b)) {
bs.push(b);
cs.push(2 ** 8 + n);
n += 1;
}
}
const csString = cs.map((n) => String.fromCharCode(n));
return Object.fromEntries(bs.map((v, i) => [v, csString[i]]));
}
function getPairs(word: string | any[]) {
const pairs: [string, string][] = [];
let prevChar = word[0];
for (const char of word.slice(1)) {
pairs.push([prevChar, char]);
prevChar = char;
}
return pairs;
}
function basicClean(text: string) {
// text = ftfy.fix_text(text);
text = htmlEntities.decode(htmlEntities.decode(text));
return text.trim();
}
function whitespaceClean(text: string) {
return text.replace(/\s+/g, ' ').trim();
}
export default class {
byteEncoder;
byteDecoder: {
[k: string]: number;
};
encoder;
decoder: any;
bpeRanks: any;
cache: Record<string, string>;
pat: RegExp;
constructor() {
this.byteEncoder = bytesToUnicode();
this.byteDecoder = Object.fromEntries(
Object.entries(this.byteEncoder).map(([k, v]) => [v, Number(k)])
);
let merges = bpeVocabData.text.split('\n');
merges = merges.slice(1, 49152 - 256 - 2 + 1);
const mergedMerges = merges.map((merge) => merge.split(' '));
// There was a bug related to the ordering of Python's .values() output. I'm lazy do I've just copy-pasted the Python output:
let vocab = [
'!',
'"',
'#',
'$',
'%',
'&',
"'",
'(',
')',
'*',
'+',
',',
'-',
'.',
'/',
'0',
'1',
'2',
'3',
'4',
'5',
'6',
'7',
'8',
'9',
':',
';',
'<',
'=',
'>',
'?',
'@',
'A',
'B',
'C',
'D',
'E',
'F',
'G',
'H',
'I',
'J',
'K',
'L',
'M',
'N',
'O',
'P',
'Q',
'R',
'S',
'T',
'U',
'V',
'W',
'X',
'Y',
'Z',
'[',
'\\',
']',
'^',
'_',
'`',
'a',
'b',
'c',
'd',
'e',
'f',
'g',
'h',
'i',
'j',
'k',
'l',
'm',
'n',
'o',
'p',
'q',
'r',
's',
't',
'u',
'v',
'w',
'x',
'y',
'z',
'{',
'|',
'}',
'~',
'¡',
'¢',
'£',
'¤',
'¥',
'¦',
'§',
'¨',
'©',
'ª',
'«',
'¬',
'®',
'¯',
'°',
'±',
'²',
'³',
'´',
'µ',
'¶',
'·',
'¸',
'¹',
'º',
'»',
'¼',
'½',
'¾',
'¿',
'À',
'Á',
'Â',
'Ã',
'Ä',
'Å',
'Æ',
'Ç',
'È',
'É',
'Ê',
'Ë',
'Ì',
'Í',
'Î',
'Ï',
'Ð',
'Ñ',
'Ò',
'Ó',
'Ô',
'Õ',
'Ö',
'×',
'Ø',
'Ù',
'Ú',
'Û',
'Ü',
'Ý',
'Þ',
'ß',
'à',
'á',
'â',
'ã',
'ä',
'å',
'æ',
'ç',
'è',
'é',
'ê',
'ë',
'ì',
'í',
'î',
'ï',
'ð',
'ñ',
'ò',
'ó',
'ô',
'õ',
'ö',
'÷',
'ø',
'ù',
'ú',
'û',
'ü',
'ý',
'þ',
'ÿ',
'Ā',
'ā',
'Ă',
'ă',
'Ą',
'ą',
'Ć',
'ć',
'Ĉ',
'ĉ',
'Ċ',
'ċ',
'Č',
'č',
'Ď',
'ď',
'Đ',
'đ',
'Ē',
'ē',
'Ĕ',
'ĕ',
'Ė',
'ė',
'Ę',
'ę',
'Ě',
'ě',
'Ĝ',
'ĝ',
'Ğ',
'ğ',
'Ġ',
'ġ',
'Ģ',
'ģ',
'Ĥ',
'ĥ',
'Ħ',
'ħ',
'Ĩ',
'ĩ',
'Ī',
'ī',
'Ĭ',
'ĭ',
'Į',
'į',
'İ',
'ı',
'IJ',
'ij',
'Ĵ',
'ĵ',
'Ķ',
'ķ',
'ĸ',
'Ĺ',
'ĺ',
'Ļ',
'ļ',
'Ľ',
'ľ',
'Ŀ',
'ŀ',
'Ł',
'ł',
'Ń',
];
vocab = [...vocab, ...vocab.map((v) => v + '</w>')];
for (const merge of mergedMerges) {
vocab.push(merge.join(''));
}
vocab.push('<|startoftext|>', '<|endoftext|>');
this.encoder = Object.fromEntries(vocab.map((v, i) => [v, i]));
this.decoder = Object.fromEntries(
Object.entries(this.encoder).map(([k, v]) => [v, k])
);
this.bpeRanks = Object.fromEntries(
mergedMerges.map((v, i) => [v.join('·😎·'), i])
); // ·😎· because js doesn't yet have tuples
this.cache = {
'<|startoftext|>': '<|startoftext|>',
'<|endoftext|>': '<|endoftext|>',
};
this.pat =
/<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+/giu;
}
bpe(token: string) {
if (this.cache[token] !== undefined) {
return this.cache[token];
}
let word = [...token.slice(0, -1), token.slice(-1) + '</w>'];
let pairs = getPairs(word);
if (pairs.length === 0) {
return token + '</w>';
}
// eslint-disable-next-line no-constant-condition
while (1) {
let bigram: [string, string] | null = null;
let minRank = Infinity;
for (const p of pairs) {
const r = this.bpeRanks[p.join('·😎·')];
if (r === undefined) continue;
if (r < minRank) {
minRank = r;
bigram = p;
}
}
if (bigram === null) {
break;
}
const [first, second] = bigram;
const newWord: string[] = [];
let i = 0;
while (i < word.length) {
const j = word.indexOf(first, i);
if (j === -1) {
newWord.push(...word.slice(i));
break;
}
newWord.push(...word.slice(i, j));
i = j;
if (
word[i] === first &&
i < word.length - 1 &&
word[i + 1] === second
) {
newWord.push(first + second);
i += 2;
} else {
newWord.push(word[i]);
i += 1;
}
}
word = newWord;
if (word.length === 1) {
break;
} else {
pairs = getPairs(word);
}
}
const joinedWord = word.join(' ');
this.cache[token] = joinedWord;
return joinedWord;
}
encode(text: string) {
const bpeTokens: number[] = [];
text = whitespaceClean(basicClean(text)).toLowerCase();
for (let token of [...text.matchAll(this.pat)].map((m) => m[0])) {
token = [...token]
.map((b) => this.byteEncoder[b.charCodeAt(0) as number])
.join('');
bpeTokens.push(
...this.bpe(token)
.split(' ')
.map((bpeToken: string) => this.encoder[bpeToken])
);
}
return bpeTokens;
}
// adds start and end token, and adds padding 0's and ensures it's 77 tokens long
encodeForCLIP(text: string) {
let tokens = this.encode(text);
tokens.unshift(49406); // start token
tokens = tokens.slice(0, 76);
tokens.push(49407); // end token
while (tokens.length < 77) tokens.push(0);
return tokens;
}
decode(tokens: any[]) {
let text = tokens
.map((token: string | number) => this.decoder[token])
.join('');
text = [...text]
.map((c) => this.byteDecoder[c])
.map((v) => String.fromCharCode(v))
.join('')
.replace(/<\/w>/g, ' ');
return text;
}
}

View file

@ -175,11 +175,11 @@ export default function setupIpcComs(
setOptOutOfCrashReports(optOut);
updateOptOutOfCrashReports(optOut);
});
ipcMain.handle('compute-image-embedding', (_, inputFilePath) => {
return computeImageEmbedding(inputFilePath);
ipcMain.handle('compute-image-embedding', (_, model, inputFilePath) => {
return computeImageEmbedding(model, inputFilePath);
});
ipcMain.handle('compute-text-embedding', (_, text) => {
return computeTextEmbedding(text);
ipcMain.handle('compute-text-embedding', (_, model, text) => {
return computeTextEmbedding(model, text);
});
ipcMain.handle('get-platform', () => {
return getPlatform();

2
ui

@ -1 +1 @@
Subproject commit 018666cbe1e7a73119933785d9cc889541349bcf
Subproject commit 6837225281247174db0fdc463e2341cccdc081f6

View file

@ -2102,6 +2102,11 @@ hosted-git-info@^4.1.0:
dependencies:
lru-cache "^6.0.0"
html-entities@^2.4.0:
version "2.4.0"
resolved "https://registry.yarnpkg.com/html-entities/-/html-entities-2.4.0.tgz#edd0cee70402584c8c76cc2c0556db09d1f45061"
integrity sha512-igBTJcNNNhvZFRtm8uA6xMY6xYleeDwn3PeBCkDz7tHttv4F2hsDI2aPgNERWzvRcNYHNT3ymRaQzllmXj4YsQ==
http-cache-semantics@^4.0.0:
version "4.1.0"
resolved "https://registry.yarnpkg.com/http-cache-semantics/-/http-cache-semantics-4.1.0.tgz#49e91c5cbf36c9b94bcfcd71c23d5249ec74e390"
@ -2339,6 +2344,11 @@ jake@^10.8.5:
filelist "^1.0.1"
minimatch "^3.0.4"
jpeg-js@^0.4.4:
version "0.4.4"
resolved "https://registry.yarnpkg.com/jpeg-js/-/jpeg-js-0.4.4.tgz#a9f1c6f1f9f0fa80cdb3484ed9635054d28936aa"
integrity sha512-WZzeDOEtTOBK4Mdsar0IqEU5sMr3vSV2RqkAIzUEV2BHnUfKGyswWFPFwK5EeDo93K3FohSHbLAjj0s1Wzd+dg==
js-tokens@^4.0.0:
version "4.0.0"
resolved "https://registry.yarnpkg.com/js-tokens/-/js-tokens-4.0.0.tgz#19203fb59991df98e3a287050d4647cdeaf32499"
@ -2831,6 +2841,18 @@ onetime@^6.0.0:
dependencies:
mimic-fn "^4.0.0"
onnxruntime-common@~1.16.3:
version "1.16.3"
resolved "https://registry.yarnpkg.com/onnxruntime-common/-/onnxruntime-common-1.16.3.tgz#216bd1318d171496f1e92906a801c95bd2fb1aaa"
integrity sha512-ZZfFzEqBf6YIGwB9PtBLESHI53jMXA+/hn+ACVUbEfPuK2xI5vMGpLPn+idpwCmHsKJNRzRwqV12K+6TQj6tug==
onnxruntime-node@^1.16.3:
version "1.16.3"
resolved "https://registry.yarnpkg.com/onnxruntime-node/-/onnxruntime-node-1.16.3.tgz#8530439f4a513b17e4d3df0073f54c4614a46070"
integrity sha512-6T2pjwg5ik74VnI1IXFzxvPAm2UCo+vNNsDGbMP+A2q6GZPMYai2pMA17g3YMUvgOZLwsjWBUwNIlP4QaVRFlA==
dependencies:
onnxruntime-common "~1.16.3"
optionator@^0.9.1:
version "0.9.1"
resolved "https://registry.yarnpkg.com/optionator/-/optionator-0.9.1.tgz#4f236a6373dae0566a6d43e1326674f50c291499"