Use
This commit is contained in:
parent
46a53d5fdf
commit
926bc33c79
3 changed files with 65 additions and 151 deletions
|
@ -18,34 +18,17 @@ import { deleteTempFile } from "./ffmpeg";
|
|||
import {
|
||||
createInferenceSession,
|
||||
downloadModel,
|
||||
modelPathDownloadingIfNeeded,
|
||||
makeCachedInferenceSession,
|
||||
modelSavePath,
|
||||
} from "./ml";
|
||||
|
||||
const textModelName = "clip-text-vit-32-uint8.onnx";
|
||||
const textModelByteSize = 64173509; // 61.2 MB
|
||||
|
||||
const imageModelName = "clip-image-vit-32-float32.onnx";
|
||||
const imageModelByteSize = 351468764; // 335.2 MB
|
||||
|
||||
let activeImageModelDownload: Promise<string> | undefined;
|
||||
|
||||
const imageModelPathDownloadingIfNeeded = async () => {
|
||||
try {
|
||||
if (activeImageModelDownload) {
|
||||
log.info("Waiting for CLIP image model download to finish");
|
||||
await activeImageModelDownload;
|
||||
} else {
|
||||
activeImageModelDownload = modelPathDownloadingIfNeeded(
|
||||
imageModelName,
|
||||
imageModelByteSize,
|
||||
);
|
||||
return await activeImageModelDownload;
|
||||
}
|
||||
} finally {
|
||||
activeImageModelDownload = undefined;
|
||||
}
|
||||
};
|
||||
const cachedCLIPImageSession = makeCachedInferenceSession(
|
||||
"clip-image-vit-32-float32.onnx",
|
||||
351468764 /* 335.2 MB */,
|
||||
);
|
||||
|
||||
let textModelDownloadInProgress = false;
|
||||
|
||||
|
@ -90,18 +73,6 @@ const textModelPathDownloadingIfNeeded = async () => {
|
|||
return modelPath;
|
||||
};
|
||||
|
||||
let imageSessionPromise: Promise<any> | undefined;
|
||||
|
||||
const onnxImageSession = async () => {
|
||||
if (!imageSessionPromise) {
|
||||
imageSessionPromise = (async () => {
|
||||
const modelPath = await imageModelPathDownloadingIfNeeded();
|
||||
return createInferenceSession(modelPath);
|
||||
})();
|
||||
}
|
||||
return imageSessionPromise;
|
||||
};
|
||||
|
||||
let _textSession: any = null;
|
||||
|
||||
const onnxTextSession = async () => {
|
||||
|
@ -124,7 +95,7 @@ export const clipImageEmbedding = async (jpegImageData: Uint8Array) => {
|
|||
};
|
||||
|
||||
const clipImageEmbedding_ = async (jpegFilePath: string) => {
|
||||
const session = await onnxImageSession();
|
||||
const session = await cachedCLIPImageSession();
|
||||
const t1 = Date.now();
|
||||
const rgbData = await getRGBData(jpegFilePath);
|
||||
const feeds = {
|
||||
|
@ -136,7 +107,8 @@ const clipImageEmbedding_ = async (jpegFilePath: string) => {
|
|||
() =>
|
||||
`onnx/clip image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
|
||||
);
|
||||
const imageEmbedding = results["output"].data; // Float32Array
|
||||
/* Need these model specific casts to type the result */
|
||||
const imageEmbedding = results["output"].data as Float32Array;
|
||||
return normalizeEmbedding(imageEmbedding);
|
||||
};
|
||||
|
||||
|
|
|
@ -8,121 +8,20 @@
|
|||
*/
|
||||
import * as ort from "onnxruntime-node";
|
||||
import log from "../log";
|
||||
import { createInferenceSession, modelPathDownloadingIfNeeded } from "./ml";
|
||||
import { makeCachedInferenceSession } from "./ml";
|
||||
|
||||
const faceDetectionModelName = "yolov5s_face_640_640_dynamic.onnx";
|
||||
const faceDetectionModelByteSize = 30762872; // 29.3 MB
|
||||
const cachedFaceDetectionSession = makeCachedInferenceSession(
|
||||
"yolov5s_face_640_640_dynamic.onnx",
|
||||
30762872 /* 29.3 MB */,
|
||||
);
|
||||
|
||||
const faceEmbeddingModelName = "mobilefacenet_opset15.onnx";
|
||||
const faceEmbeddingModelByteSize = 5286998; // 5 MB
|
||||
|
||||
/**
|
||||
* Return a function that can be used to trigger a download of the specified
|
||||
* model, and the creating of an ONNX inference session initialized using it.
|
||||
*
|
||||
* Multiple parallel calls to the returned function are fine, it ensures that
|
||||
* the the model will be downloaded and the session created using it only once.
|
||||
* All pending calls to it meanwhile will just await on the same promise.
|
||||
*
|
||||
* And once the promise is resolved, the create ONNX inference session will be
|
||||
* cached, so subsequent calls to the returned function will just reuse the same
|
||||
* session.
|
||||
*
|
||||
* {@link makeCachedInferenceSession} can itself be called anytime, it doesn't
|
||||
* actively trigger a download until the returned function is called.
|
||||
*
|
||||
* @param modelName The name of the model to download.
|
||||
* @param modelByteSize The size in bytes that we expect the model to have. If
|
||||
* the size of the downloaded model does not match the expected size, then we
|
||||
* will redownload it.
|
||||
*
|
||||
* @returns a function. calling that function returns a promise to an ONNX
|
||||
* session.
|
||||
*/
|
||||
const makeCachedInferenceSession = (
|
||||
modelName: string,
|
||||
modelByteSize: number,
|
||||
) => {
|
||||
let session: Promise<ort.InferenceSession> | undefined;
|
||||
|
||||
const download = () =>
|
||||
modelPathDownloadingIfNeeded(modelName, modelByteSize);
|
||||
|
||||
const createSession = (modelPath: string) =>
|
||||
createInferenceSession(modelPath);
|
||||
|
||||
const cachedInferenceSession = () => {
|
||||
if (!session) session = download().then(createSession);
|
||||
return session;
|
||||
};
|
||||
|
||||
return cachedInferenceSession;
|
||||
};
|
||||
|
||||
let activeFaceDetectionModelDownload: Promise<string> | undefined;
|
||||
|
||||
const faceDetectionModelPathDownloadingIfNeeded = async () => {
|
||||
try {
|
||||
if (activeFaceDetectionModelDownload) {
|
||||
log.info("Waiting for face detection model download to finish");
|
||||
await activeFaceDetectionModelDownload;
|
||||
} else {
|
||||
activeFaceDetectionModelDownload = modelPathDownloadingIfNeeded(
|
||||
faceDetectionModelName,
|
||||
faceDetectionModelByteSize,
|
||||
);
|
||||
return await activeFaceDetectionModelDownload;
|
||||
}
|
||||
} finally {
|
||||
activeFaceDetectionModelDownload = undefined;
|
||||
}
|
||||
};
|
||||
|
||||
let _faceDetectionSession: Promise<ort.InferenceSession> | undefined;
|
||||
|
||||
const faceDetectionSession = async () => {
|
||||
if (!_faceDetectionSession) {
|
||||
_faceDetectionSession =
|
||||
faceDetectionModelPathDownloadingIfNeeded().then((modelPath) =>
|
||||
createInferenceSession(modelPath),
|
||||
);
|
||||
}
|
||||
return _faceDetectionSession;
|
||||
};
|
||||
|
||||
let activeFaceEmbeddingModelDownload: Promise<string> | undefined;
|
||||
|
||||
const faceEmbeddingModelPathDownloadingIfNeeded = async () => {
|
||||
try {
|
||||
if (activeFaceEmbeddingModelDownload) {
|
||||
log.info("Waiting for face embedding model download to finish");
|
||||
await activeFaceEmbeddingModelDownload;
|
||||
} else {
|
||||
activeFaceEmbeddingModelDownload = modelPathDownloadingIfNeeded(
|
||||
faceEmbeddingModelName,
|
||||
faceEmbeddingModelByteSize,
|
||||
);
|
||||
return await activeFaceEmbeddingModelDownload;
|
||||
}
|
||||
} finally {
|
||||
activeFaceEmbeddingModelDownload = undefined;
|
||||
}
|
||||
};
|
||||
|
||||
let _faceEmbeddingSession: Promise<ort.InferenceSession> | undefined;
|
||||
|
||||
const faceEmbeddingSession = async () => {
|
||||
if (!_faceEmbeddingSession) {
|
||||
_faceEmbeddingSession =
|
||||
faceEmbeddingModelPathDownloadingIfNeeded().then((modelPath) =>
|
||||
createInferenceSession(modelPath),
|
||||
);
|
||||
}
|
||||
return _faceEmbeddingSession;
|
||||
};
|
||||
const cachedFaceEmbeddingSession = makeCachedInferenceSession(
|
||||
"mobilefacenet_opset15.onnx",
|
||||
5286998 /* 5 MB */,
|
||||
);
|
||||
|
||||
export const detectFaces = async (input: Float32Array) => {
|
||||
const session = await faceDetectionSession();
|
||||
const session = await cachedFaceDetectionSession();
|
||||
const t = Date.now();
|
||||
const feeds = {
|
||||
input: new ort.Tensor("float32", input, [1, 3, 640, 640]),
|
||||
|
@ -141,11 +40,11 @@ export const faceEmbedding = async (input: Float32Array) => {
|
|||
const n = Math.round(input.length / (z * z * 3));
|
||||
const inputTensor = new ort.Tensor("float32", input, [n, z, z, 3]);
|
||||
|
||||
const session = await faceEmbeddingSession();
|
||||
const session = await cachedFaceEmbeddingSession();
|
||||
const t = Date.now();
|
||||
const feeds = { img_inputs: inputTensor };
|
||||
const results = await session.run(feeds);
|
||||
log.debug(() => `onnx/yolo face embedding took ${Date.now() - t} ms`);
|
||||
// TODO: What's with this type? It works in practice, but double check.
|
||||
return (results.embeddings as unknown as any)["cpuData"]; // as Float32Array;
|
||||
/* Need these model specific casts to extract and type the result */
|
||||
return (results.embeddings as unknown as any)["cpuData"] as Float32Array;
|
||||
};
|
||||
|
|
|
@ -18,6 +18,49 @@ import * as ort from "onnxruntime-node";
|
|||
import log from "../log";
|
||||
import { writeStream } from "../stream";
|
||||
|
||||
/**
|
||||
* Return a function that can be used to trigger a download of the specified
|
||||
* model, and the creating of an ONNX inference session initialized using it.
|
||||
*
|
||||
* Multiple parallel calls to the returned function are fine, it ensures that
|
||||
* the the model will be downloaded and the session created using it only once.
|
||||
* All pending calls to it meanwhile will just await on the same promise.
|
||||
*
|
||||
* And once the promise is resolved, the create ONNX inference session will be
|
||||
* cached, so subsequent calls to the returned function will just reuse the same
|
||||
* session.
|
||||
*
|
||||
* {@link makeCachedInferenceSession} can itself be called anytime, it doesn't
|
||||
* actively trigger a download until the returned function is called.
|
||||
*
|
||||
* @param modelName The name of the model to download.
|
||||
* @param modelByteSize The size in bytes that we expect the model to have. If
|
||||
* the size of the downloaded model does not match the expected size, then we
|
||||
* will redownload it.
|
||||
*
|
||||
* @returns a function. calling that function returns a promise to an ONNX
|
||||
* session.
|
||||
*/
|
||||
export const makeCachedInferenceSession = (
|
||||
modelName: string,
|
||||
modelByteSize: number,
|
||||
) => {
|
||||
let session: Promise<ort.InferenceSession> | undefined;
|
||||
|
||||
const download = () =>
|
||||
modelPathDownloadingIfNeeded(modelName, modelByteSize);
|
||||
|
||||
const createSession = (modelPath: string) =>
|
||||
createInferenceSession(modelPath);
|
||||
|
||||
const cachedInferenceSession = () => {
|
||||
if (!session) session = download().then(createSession);
|
||||
return session;
|
||||
};
|
||||
|
||||
return cachedInferenceSession;
|
||||
};
|
||||
|
||||
/**
|
||||
* Download the model named {@link modelName} if we don't already have it.
|
||||
*
|
||||
|
@ -26,7 +69,7 @@ import { writeStream } from "../stream";
|
|||
*
|
||||
* @returns the path to the model on the local machine.
|
||||
*/
|
||||
export const modelPathDownloadingIfNeeded = async (
|
||||
const modelPathDownloadingIfNeeded = async (
|
||||
modelName: string,
|
||||
expectedByteSize: number,
|
||||
) => {
|
||||
|
|
Loading…
Add table
Reference in a new issue