소스 검색

In band signalling

Manav Rathi 1 년 전
부모
커밋
d0b1ff5520

+ 3 - 3
desktop/src/main/ipc.ts

@@ -45,7 +45,7 @@ import {
     convertToJPEG,
     generateImageThumbnail,
 } from "./services/imageProcessor";
-import { clipImageEmbedding, clipTextEmbedding } from "./services/ml-clip";
+import { clipImageEmbedding, clipTextEmbeddingIfAvailable } from "./services/ml-clip";
 import { detectFaces, faceEmbedding } from "./services/ml-face";
 import {
     clearStores,
@@ -169,8 +169,8 @@ export const attachIPCHandlers = () => {
         clipImageEmbedding(jpegImageData),
     );
 
-    ipcMain.handle("clipTextEmbedding", (_, text: string) =>
-        clipTextEmbedding(text),
+    ipcMain.handle("clipTextEmbeddingIfAvailable", (_, text: string) =>
+        clipTextEmbeddingIfAvailable(text),
     );
 
     ipcMain.handle("detectFaces", (_, input: Float32Array) =>

+ 20 - 73
desktop/src/main/services/ml-clip.ts

@@ -5,86 +5,21 @@
  *
  * @see `web/apps/photos/src/services/clip-service.ts` for more details.
  */
-import { existsSync } from "fs";
 import jpeg from "jpeg-js";
 import fs from "node:fs/promises";
 import * as ort from "onnxruntime-node";
 import Tokenizer from "../../thirdparty/clip-bpe-ts/mod";
-import { CustomErrors } from "../../types/ipc";
 import log from "../log";
 import { writeStream } from "../stream";
 import { generateTempFilePath } from "../temp";
 import { deleteTempFile } from "./ffmpeg";
-import {
-    createInferenceSession,
-    downloadModel,
-    makeCachedInferenceSession,
-    modelSavePath,
-} from "./ml";
+import { makeCachedInferenceSession } from "./ml";
 
 const cachedCLIPImageSession = makeCachedInferenceSession(
     "clip-image-vit-32-float32.onnx",
     351468764 /* 335.2 MB */,
 );
 
-const cachedCLIPTextSession = makeCachedInferenceSession(
-    "clip-text-vit-32-uint8.onnx",
-    64173509 /* 61.2 MB */,
-);
-
-let textModelDownloadInProgress = false;
-
-/* TODO(MR): use the generic method. Then we can remove the exports for the
-   internal details functions that we use here */
-const textModelPathDownloadingIfNeeded = async () => {
-    if (textModelDownloadInProgress)
-        throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
-
-    const modelPath = modelSavePath(textModelName);
-    if (!existsSync(modelPath)) {
-        log.info("CLIP text model not found, downloading");
-        textModelDownloadInProgress = true;
-        downloadModel(modelPath, textModelName)
-            .catch((e) => {
-                // log but otherwise ignore
-                log.error("CLIP text model download failed", e);
-            })
-            .finally(() => {
-                textModelDownloadInProgress = false;
-            });
-        throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
-    } else {
-        const localFileSize = (await fs.stat(modelPath)).size;
-        if (localFileSize !== textModelByteSize) {
-            log.error(
-                `CLIP text model size ${localFileSize} does not match the expected size, downloading again`,
-            );
-            textModelDownloadInProgress = true;
-            downloadModel(modelPath, textModelName)
-                .catch((e) => {
-                    // log but otherwise ignore
-                    log.error("CLIP text model download failed", e);
-                })
-                .finally(() => {
-                    textModelDownloadInProgress = false;
-                });
-            throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
-        }
-    }
-
-    return modelPath;
-};
-
-let _textSession: any = null;
-
-const onnxTextSession = async () => {
-    if (!_textSession) {
-        const modelPath = await textModelPathDownloadingIfNeeded();
-        _textSession = await createInferenceSession(modelPath);
-    }
-    return _textSession;
-};
-
 export const clipImageEmbedding = async (jpegImageData: Uint8Array) => {
     const tempFilePath = await generateTempFilePath("");
     const imageStream = new Response(jpegImageData.buffer).body;
@@ -195,6 +130,11 @@ const normalizeEmbedding = (embedding: Float32Array) => {
     return embedding;
 };
 
+const cachedCLIPTextSession = makeCachedInferenceSession(
+    "clip-text-vit-32-uint8.onnx",
+    64173509 /* 61.2 MB */,
+);
+
 let _tokenizer: Tokenizer = null;
 const getTokenizer = () => {
     if (!_tokenizer) {
@@ -203,14 +143,21 @@ const getTokenizer = () => {
     return _tokenizer;
 };
 
-export const clipTextEmbedding = async (text: string) => {
-    const session = await Promise.race([
+export const clipTextEmbeddingIfAvailable = async (text: string) => {
+    const sessionOrStatus = await Promise.race([
         cachedCLIPTextSession(),
-        new Promise<"downloading-model">((resolve) =>
-            setTimeout(() => resolve("downloading-model"), 100),
-        ),
+        "downloading-model",
     ]);
-    await onnxTextSession();
+
+    // Don't wait for the download to complete
+    if (typeof sessionOrStatus == "string") {
+        console.log(
+            "Ignoring CLIP text embedding request because model download is pending",
+        );
+        return undefined;
+    }
+
+    const session = sessionOrStatus;
     const t1 = Date.now();
     const tokenizer = getTokenizer();
     const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
@@ -223,6 +170,6 @@ export const clipTextEmbedding = async (text: string) => {
         () =>
             `onnx/clip text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
     );
-    const textEmbedding = results["output"].data;
+    const textEmbedding = results["output"].data as Float32Array;
     return normalizeEmbedding(textEmbedding);
 };

+ 5 - 5
desktop/src/main/services/ml-face.ts

@@ -15,11 +15,6 @@ const cachedFaceDetectionSession = makeCachedInferenceSession(
     30762872 /* 29.3 MB */,
 );
 
-const cachedFaceEmbeddingSession = makeCachedInferenceSession(
-    "mobilefacenet_opset15.onnx",
-    5286998 /* 5 MB */,
-);
-
 export const detectFaces = async (input: Float32Array) => {
     const session = await cachedFaceDetectionSession();
     const t = Date.now();
@@ -31,6 +26,11 @@ export const detectFaces = async (input: Float32Array) => {
     return results["output"].data;
 };
 
+const cachedFaceEmbeddingSession = makeCachedInferenceSession(
+    "mobilefacenet_opset15.onnx",
+    5286998 /* 5 MB */,
+);
+
 export const faceEmbedding = async (input: Float32Array) => {
     // Dimension of each face (alias)
     const mobileFaceNetFaceSize = 112;

+ 4 - 4
desktop/src/main/services/ml.ts

@@ -1,5 +1,5 @@
 /**
- * @file AI/ML related functionality.
+ * @file AI/ML related functionality, generic layer.
  *
  * @see also `ml-clip.ts`, `ml-face.ts`.
  *
@@ -92,10 +92,10 @@ const modelPathDownloadingIfNeeded = async (
 };
 
 /** Return the path where the given {@link modelName} is meant to be saved */
-export const modelSavePath = (modelName: string) =>
+const modelSavePath = (modelName: string) =>
     path.join(app.getPath("userData"), "models", modelName);
 
-export const downloadModel = async (saveLocation: string, name: string) => {
+const downloadModel = async (saveLocation: string, name: string) => {
     // `mkdir -p` the directory where we want to save the model.
     const saveDir = path.dirname(saveLocation);
     await fs.mkdir(saveDir, { recursive: true });
@@ -112,7 +112,7 @@ export const downloadModel = async (saveLocation: string, name: string) => {
 /**
  * Crete an ONNX {@link InferenceSession} with some defaults.
  */
-export const createInferenceSession = async (modelPath: string) => {
+const createInferenceSession = async (modelPath: string) => {
     return await ort.InferenceSession.create(modelPath, {
         // Restrict the number of threads to 1
         intraOpNumThreads: 1,

+ 60 - 39
desktop/src/preload.ts

@@ -163,8 +163,10 @@ const runFFmpegCmd = (
 const clipImageEmbedding = (jpegImageData: Uint8Array): Promise<Float32Array> =>
     ipcRenderer.invoke("clipImageEmbedding", jpegImageData);
 
-const clipTextEmbedding = (text: string): Promise<Float32Array> =>
-    ipcRenderer.invoke("clipTextEmbedding", text);
+const clipTextEmbeddingIfAvailable = (
+    text: string,
+): Promise<Float32Array | undefined> =>
+    ipcRenderer.invoke("clipTextEmbeddingIfAvailable", text);
 
 const detectFaces = (input: Float32Array): Promise<Float32Array> =>
     ipcRenderer.invoke("detectFaces", input);
@@ -263,42 +265,61 @@ const getElectronFilesFromGoogleZip = (
 const getDirFiles = (dirPath: string): Promise<ElectronFile[]> =>
     ipcRenderer.invoke("getDirFiles", dirPath);
 
-//
-// These objects exposed here will become available to the JS code in our
-// renderer (the web/ code) as `window.ElectronAPIs.*`
-//
-// There are a few related concepts at play here, and it might be worthwhile to
-// read their (excellent) documentation to get an understanding;
-//`
-// - ContextIsolation:
-//   https://www.electronjs.org/docs/latest/tutorial/context-isolation
-//
-// - IPC https://www.electronjs.org/docs/latest/tutorial/ipc
-//
-// [Note: Transferring large amount of data over IPC]
-//
-// Electron's IPC implementation uses the HTML standard Structured Clone
-// Algorithm to serialize objects passed between processes.
-// https://www.electronjs.org/docs/latest/tutorial/ipc#object-serialization
-//
-// In particular, ArrayBuffer is eligible for structured cloning.
-// https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Structured_clone_algorithm
-//
-// Also, ArrayBuffer is "transferable", which means it is a zero-copy operation
-// operation when it happens across threads.
-// https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Transferable_objects
-//
-// In our case though, we're not dealing with threads but separate processes. So
-// the ArrayBuffer will be copied:
-// > "parameters, errors and return values are **copied** when they're sent over
-//   the bridge".
-//   https://www.electronjs.org/docs/latest/api/context-bridge#methods
-//
-// The copy itself is relatively fast, but the problem with transfering large
-// amounts of data is potentially running out of memory during the copy.
-//
-// For an alternative, see [Note: IPC streams].
-//
+/**
+ * These objects exposed here will become available to the JS code in our
+ * renderer (the web/ code) as `window.ElectronAPIs.*`
+ *
+ * There are a few related concepts at play here, and it might be worthwhile to
+ * read their (excellent) documentation to get an understanding;
+ *`
+ * - ContextIsolation:
+ *   https://www.electronjs.org/docs/latest/tutorial/context-isolation
+ *
+ * - IPC https://www.electronjs.org/docs/latest/tutorial/ipc
+ *
+ * ---
+ *
+ * [Note: Custom errors across Electron/Renderer boundary]
+ *
+ * If we need to identify errors thrown by the main process when invoked from
+ * the renderer process, we can only use the `message` field because:
+ *
+ * > Errors thrown throw `handle` in the main process are not transparent as
+ * > they are serialized and only the `message` property from the original error
+ * > is provided to the renderer process.
+ * >
+ * > - https://www.electronjs.org/docs/latest/tutorial/ipc
+ * >
+ * > Ref: https://github.com/electron/electron/issues/24427
+ *
+ * ---
+ *
+ * [Note: Transferring large amount of data over IPC]
+ *
+ * Electron's IPC implementation uses the HTML standard Structured Clone
+ * Algorithm to serialize objects passed between processes.
+ * https://www.electronjs.org/docs/latest/tutorial/ipc#object-serialization
+ *
+ * In particular, ArrayBuffer is eligible for structured cloning.
+ * https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Structured_clone_algorithm
+ *
+ * Also, ArrayBuffer is "transferable", which means it is a zero-copy operation
+ * operation when it happens across threads.
+ * https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Transferable_objects
+ *
+ * In our case though, we're not dealing with threads but separate processes. So
+ * the ArrayBuffer will be copied:
+ *
+ * > "parameters, errors and return values are **copied** when they're sent over
+ * > the bridge".
+ * >
+ * > https://www.electronjs.org/docs/latest/api/context-bridge#methods
+ *
+ * The copy itself is relatively fast, but the problem with transfering large
+ * amounts of data is potentially running out of memory during the copy.
+ *
+ * For an alternative, see [Note: IPC streams].
+ */
 contextBridge.exposeInMainWorld("electron", {
     // - General
 
@@ -340,7 +361,7 @@ contextBridge.exposeInMainWorld("electron", {
     // - ML
 
     clipImageEmbedding,
-    clipTextEmbedding,
+    clipTextEmbeddingIfAvailable,
     detectFaces,
     faceEmbedding,
 

+ 0 - 15
desktop/src/types/ipc.ts

@@ -33,25 +33,10 @@ export interface PendingUploads {
 
 /**
  * Errors that have special semantics on the web side.
- *
- * [Note: Custom errors across Electron/Renderer boundary]
- *
- * We need to use the `message` field to disambiguate between errors thrown by
- * the main process when invoked from the renderer process. This is because:
- *
- * > Errors thrown throw `handle` in the main process are not transparent as
- * > they are serialized and only the `message` property from the original error
- * > is provided to the renderer process.
- * >
- * > - https://www.electronjs.org/docs/latest/tutorial/ipc
- * >
- * > Ref: https://github.com/electron/electron/issues/24427
  */
 export const CustomErrors = {
     WINDOWS_NATIVE_IMAGE_PROCESSING_NOT_SUPPORTED:
         "Windows native image processing is not supported",
-    MODEL_DOWNLOAD_PENDING:
-        "Model download pending, skipping clip search request",
 };
 
 /**

+ 2 - 2
web/apps/photos/src/services/clip-service.ts

@@ -184,8 +184,8 @@ class CLIPService {
         }
     };
 
-    getTextEmbedding = async (text: string) => {
-        return ensureElectron().clipTextEmbedding(text);
+    getTextEmbeddingIfAvailable = async (text: string) => {
+        return ensureElectron().clipTextEmbeddingIfAvailable(text);
     };
 
     private runClipEmbeddingExtraction = async (canceller: AbortController) => {

+ 20 - 20
web/apps/photos/src/services/searchService.ts

@@ -1,5 +1,4 @@
 import log from "@/next/log";
-import { CustomError } from "@ente/shared/error";
 import * as chrono from "chrono-node";
 import { FILE_TYPE } from "constants/file";
 import { t } from "i18next";
@@ -287,24 +286,20 @@ async function getLocationSuggestions(searchPhrase: string) {
     return [...locationTagSuggestions, ...citySearchSuggestions];
 }
 
-async function getClipSuggestion(searchPhrase: string): Promise<Suggestion> {
-    try {
-        if (!clipService.isPlatformSupported()) {
-            return null;
-        }
-
-        const clipResults = await searchClip(searchPhrase);
-        return {
-            type: SuggestionType.CLIP,
-            value: clipResults,
-            label: searchPhrase,
-        };
-    } catch (e) {
-        if (!e.message?.includes(CustomError.MODEL_DOWNLOAD_PENDING)) {
-            log.error("getClipSuggestion failed", e);
-        }
+async function getClipSuggestion(
+    searchPhrase: string,
+): Promise<Suggestion | undefined> {
+    if (!clipService.isPlatformSupported()) {
         return null;
     }
+
+    const clipResults = await searchClip(searchPhrase);
+    if (!clipResults) return clipResults;
+    return {
+        type: SuggestionType.CLIP,
+        value: clipResults,
+        label: searchPhrase,
+    };
 }
 
 function searchCollection(
@@ -374,9 +369,14 @@ async function searchLocationTag(searchPhrase: string): Promise<LocationTag[]> {
     return matchedLocationTags;
 }
 
-async function searchClip(searchPhrase: string): Promise<ClipSearchScores> {
+const searchClip = async (
+    searchPhrase: string,
+): Promise<ClipSearchScores | undefined> => {
+    const textEmbedding =
+        await clipService.getTextEmbeddingIfAvailable(searchPhrase);
+    if (!textEmbedding) return undefined;
+
     const imageEmbeddings = await getLocalEmbeddings();
-    const textEmbedding = await clipService.getTextEmbedding(searchPhrase);
     const clipSearchResult = new Map<number, number>(
         (
             await Promise.all(
@@ -394,7 +394,7 @@ async function searchClip(searchPhrase: string): Promise<ClipSearchScores> {
     );
 
     return clipSearchResult;
-}
+};
 
 function convertSuggestionToSearchQuery(option: Suggestion): Search {
     switch (option.type) {

+ 15 - 2
web/packages/next/types/ipc.ts

@@ -240,7 +240,18 @@ export interface Electron {
     clipImageEmbedding: (jpegImageData: Uint8Array) => Promise<Float32Array>;
 
     /**
-     * Return a CLIP embedding of the given image.
+     * Return a CLIP embedding of the given image if we already have the model
+     * downloaded and prepped. If the model is not available return `undefined`.
+     *
+     * This differs from the other sibling ML functions in that it doesn't wait
+     * for the model download to finish. It does trigger a model download, but
+     * then immediately returns `undefined`. At some future point, when the
+     * model downloaded finishes, calls to this function will start returning
+     * the result we seek.
+     *
+     * The reason for doing it in this asymmetric way is because CLIP text
+     * embeddings are used as part of deducing user initiated search results,
+     * and we don't want to block that interaction on a large network request.
      *
      * See: [Note: CLIP based magic search]
      *
@@ -248,7 +259,9 @@ export interface Electron {
      *
      * @returns A CLIP embedding.
      */
-    clipTextEmbedding: (text: string) => Promise<Float32Array>;
+    clipTextEmbeddingIfAvailable: (
+        text: string,
+    ) => Promise<Float32Array | undefined>;
 
     /**
      * Detect faces in the given image using YOLO.

+ 0 - 2
web/packages/shared/error/index.ts

@@ -84,8 +84,6 @@ export const CustomError = {
     ServerError: "server error",
     FILE_NOT_FOUND: "file not found",
     UNSUPPORTED_PLATFORM: "Unsupported platform",
-    MODEL_DOWNLOAD_PENDING:
-        "Model download pending, skipping clip search request",
     UPDATE_URL_FILE_ID_MISMATCH: "update url file id mismatch",
     URL_ALREADY_SET: "url already set",
     FILE_CONVERSION_FAILED: "file conversion failed",