Merge remote-tracking branch 'origin/mobile_face' into mobile_face

This commit is contained in:
laurenspriem 2024-04-12 15:11:15 +05:30
commit 21451efa6b
90 changed files with 701 additions and 17309 deletions

View file

@ -1,8 +1,8 @@
# Dependencies
* [Electron](#electron)
* [Dev dependencies](#dev)
* [Functionality](#functionality)
- [Electron](#electron)
- [Dev dependencies](#dev)
- [Functionality](#functionality)
## Electron
@ -114,8 +114,8 @@ available on the host machine, and is not bundled with our app.
AI/ML runtime. It powers both natural language searches (using CLIP) and face
detection (using YOLO).
[jpeg-js](https://github.com/jpeg-js/jpeg-js#readme) is used for decoding
JPEG data into raw RGB bytes before passing it to ONNX.
[jpeg-js](https://github.com/jpeg-js/jpeg-js#readme) is used for decoding JPEG
data into raw RGB bytes before passing it to ONNX.
html-entities is used by the bundled clip-bpe-ts tokenizer for CLIP.

View file

@ -36,13 +36,14 @@ import {
updateAndRestart,
updateOnNextRestart,
} from "./services/app-update";
import { clipImageEmbedding, clipTextEmbedding } from "./services/clip";
import { runFFmpegCmd } from "./services/ffmpeg";
import { getDirFiles } from "./services/fs";
import {
convertToJPEG,
generateImageThumbnail,
} from "./services/imageProcessor";
import { clipImageEmbedding, clipTextEmbedding } from "./services/ml-clip";
import { detectFaces, faceEmbedding } from "./services/ml-face";
import {
clearStores,
encryptionKey,
@ -146,6 +147,14 @@ export const attachIPCHandlers = () => {
clipTextEmbedding(text),
);
ipcMain.handle("detectFaces", (_, input: Float32Array) =>
detectFaces(input),
);
ipcMain.handle("faceEmbedding", (_, input: Float32Array) =>
faceEmbedding(input),
);
// - File selection
ipcMain.handle("selectDirectory", () => selectDirectory());

View file

@ -1,19 +1,13 @@
/**
* @file Compute CLIP embeddings
* @file Compute CLIP embeddings for images and text.
*
* @see `web/apps/photos/src/services/clip-service.ts` for more details. This
* file implements the Node.js implementation of the actual embedding
* computation. By doing it in the Node.js layer, we can use the binary ONNX
* runtimes which are 10-20x faster than the WASM based web ones.
* The embeddings are computed using ONNX runtime, with CLIP as the model.
*
* The embeddings are computed using ONNX runtime. The model itself is not
* shipped with the app but is downloaded on demand.
* @see `web/apps/photos/src/services/clip-service.ts` for more details.
*/
import { app, net } from "electron/main";
import { existsSync } from "fs";
import jpeg from "jpeg-js";
import fs from "node:fs/promises";
import path from "node:path";
import * as ort from "onnxruntime-node";
import Tokenizer from "../../thirdparty/clip-bpe-ts/mod";
import { CustomErrors } from "../../types/ipc";
@ -21,6 +15,12 @@ import { writeStream } from "../fs";
import log from "../log";
import { generateTempFilePath } from "../temp";
import { deleteTempFile } from "./ffmpeg";
import {
createInferenceSession,
downloadModel,
modelPathDownloadingIfNeeded,
modelSavePath,
} from "./ml";
const textModelName = "clip-text-vit-32-uint8.onnx";
const textModelByteSize = 64173509; // 61.2 MB
@ -28,55 +28,20 @@ const textModelByteSize = 64173509; // 61.2 MB
const imageModelName = "clip-image-vit-32-float32.onnx";
const imageModelByteSize = 351468764; // 335.2 MB
/** Return the path where the given {@link modelName} is meant to be saved */
const modelSavePath = (modelName: string) =>
path.join(app.getPath("userData"), "models", modelName);
const downloadModel = async (saveLocation: string, name: string) => {
// `mkdir -p` the directory where we want to save the model.
const saveDir = path.dirname(saveLocation);
await fs.mkdir(saveDir, { recursive: true });
// Download
log.info(`Downloading CLIP model from ${name}`);
const url = `https://models.ente.io/${name}`;
const res = await net.fetch(url);
if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
// Save
await writeStream(saveLocation, res.body);
log.info(`Downloaded CLIP model ${name}`);
};
let activeImageModelDownload: Promise<void> | undefined;
let activeImageModelDownload: Promise<string> | undefined;
const imageModelPathDownloadingIfNeeded = async () => {
try {
const modelPath = modelSavePath(imageModelName);
if (activeImageModelDownload) {
log.info("Waiting for CLIP image model download to finish");
await activeImageModelDownload;
} else {
if (!existsSync(modelPath)) {
log.info("CLIP image model not found, downloading");
activeImageModelDownload = downloadModel(
modelPath,
imageModelName,
);
await activeImageModelDownload;
} else {
const localFileSize = (await fs.stat(modelPath)).size;
if (localFileSize !== imageModelByteSize) {
log.error(
`CLIP image model size ${localFileSize} does not match the expected size, downloading again`,
);
activeImageModelDownload = downloadModel(
modelPath,
imageModelName,
);
await activeImageModelDownload;
}
}
activeImageModelDownload = modelPathDownloadingIfNeeded(
imageModelName,
imageModelByteSize,
);
return await activeImageModelDownload;
}
return modelPath;
} finally {
activeImageModelDownload = undefined;
}
@ -84,6 +49,8 @@ const imageModelPathDownloadingIfNeeded = async () => {
let textModelDownloadInProgress = false;
/* TODO(MR): use the generic method. Then we can remove the exports for the
internal details functions that we use here */
const textModelPathDownloadingIfNeeded = async () => {
if (textModelDownloadInProgress)
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
@ -123,13 +90,6 @@ const textModelPathDownloadingIfNeeded = async () => {
return modelPath;
};
const createInferenceSession = async (modelPath: string) => {
return await ort.InferenceSession.create(modelPath, {
intraOpNumThreads: 1,
enableCpuMemArena: false,
});
};
let imageSessionPromise: Promise<any> | undefined;
const onnxImageSession = async () => {
@ -174,7 +134,7 @@ const clipImageEmbedding_ = async (jpegFilePath: string) => {
const results = await imageSession.run(feeds);
log.debug(
() =>
`CLIP image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
`onnx/clip image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
);
const imageEmbedding = results["output"].data; // Float32Array
return normalizeEmbedding(imageEmbedding);
@ -281,7 +241,7 @@ export const clipTextEmbedding = async (text: string) => {
const results = await imageSession.run(feeds);
log.debug(
() =>
`CLIP text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
`onnx/clip text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
);
const textEmbedding = results["output"].data;
return normalizeEmbedding(textEmbedding);

View file

@ -0,0 +1,108 @@
/**
* @file Various face recognition related tasks.
*
* - Face detection with the YOLO model.
* - Face embedding with the MobileFaceNet model.
*
* The runtime used is ONNX.
*/
import * as ort from "onnxruntime-node";
import log from "../log";
import { createInferenceSession, modelPathDownloadingIfNeeded } from "./ml";
const faceDetectionModelName = "yolov5s_face_640_640_dynamic.onnx";
const faceDetectionModelByteSize = 30762872; // 29.3 MB
const faceEmbeddingModelName = "mobilefacenet_opset15.onnx";
const faceEmbeddingModelByteSize = 5286998; // 5 MB
let activeFaceDetectionModelDownload: Promise<string> | undefined;
const faceDetectionModelPathDownloadingIfNeeded = async () => {
try {
if (activeFaceDetectionModelDownload) {
log.info("Waiting for face detection model download to finish");
await activeFaceDetectionModelDownload;
} else {
activeFaceDetectionModelDownload = modelPathDownloadingIfNeeded(
faceDetectionModelName,
faceDetectionModelByteSize,
);
return await activeFaceDetectionModelDownload;
}
} finally {
activeFaceDetectionModelDownload = undefined;
}
};
let _faceDetectionSession: Promise<ort.InferenceSession> | undefined;
const faceDetectionSession = async () => {
if (!_faceDetectionSession) {
_faceDetectionSession =
faceDetectionModelPathDownloadingIfNeeded().then((modelPath) =>
createInferenceSession(modelPath),
);
}
return _faceDetectionSession;
};
let activeFaceEmbeddingModelDownload: Promise<string> | undefined;
const faceEmbeddingModelPathDownloadingIfNeeded = async () => {
try {
if (activeFaceEmbeddingModelDownload) {
log.info("Waiting for face embedding model download to finish");
await activeFaceEmbeddingModelDownload;
} else {
activeFaceEmbeddingModelDownload = modelPathDownloadingIfNeeded(
faceEmbeddingModelName,
faceEmbeddingModelByteSize,
);
return await activeFaceEmbeddingModelDownload;
}
} finally {
activeFaceEmbeddingModelDownload = undefined;
}
};
let _faceEmbeddingSession: Promise<ort.InferenceSession> | undefined;
const faceEmbeddingSession = async () => {
if (!_faceEmbeddingSession) {
_faceEmbeddingSession =
faceEmbeddingModelPathDownloadingIfNeeded().then((modelPath) =>
createInferenceSession(modelPath),
);
}
return _faceEmbeddingSession;
};
export const detectFaces = async (input: Float32Array) => {
const session = await faceDetectionSession();
const t = Date.now();
const feeds = {
input: new ort.Tensor("float32", input, [1, 3, 640, 640]),
};
const results = await session.run(feeds);
log.debug(() => `onnx/yolo face detection took ${Date.now() - t} ms`);
return results["output"].data;
};
export const faceEmbedding = async (input: Float32Array) => {
// Dimension of each face (alias)
const mobileFaceNetFaceSize = 112;
// Smaller alias
const z = mobileFaceNetFaceSize;
// Size of each face's data in the batch
const n = Math.round(input.length / (z * z * 3));
const inputTensor = new ort.Tensor("float32", input, [n, z, z, 3]);
const session = await faceEmbeddingSession();
const t = Date.now();
const feeds = { img_inputs: inputTensor };
const results = await session.run(feeds);
log.debug(() => `onnx/yolo face embedding took ${Date.now() - t} ms`);
// TODO: What's with this type? It works in practice, but double check.
return (results.embeddings as unknown as any)["cpuData"]; // as Float32Array;
};

View file

@ -0,0 +1,79 @@
/**
* @file AI/ML related functionality.
*
* @see also `ml-clip.ts`, `ml-face.ts`.
*
* The ML runtime we use for inference is [ONNX](https://onnxruntime.ai). Models
* for various tasks are not shipped with the app but are downloaded on demand.
*
* The primary reason for doing these tasks in the Node.js layer is so that we
* can use the binary ONNX runtime which is 10-20x faster than the WASM based
* web one.
*/
import { app, net } from "electron/main";
import { existsSync } from "fs";
import fs from "node:fs/promises";
import path from "node:path";
import * as ort from "onnxruntime-node";
import { writeStream } from "../fs";
import log from "../log";
/**
* Download the model named {@link modelName} if we don't already have it.
*
* Also verify that the size of the model we get matches {@expectedByteSize} (if
* not, redownload it).
*
* @returns the path to the model on the local machine.
*/
export const modelPathDownloadingIfNeeded = async (
modelName: string,
expectedByteSize: number,
) => {
const modelPath = modelSavePath(modelName);
if (!existsSync(modelPath)) {
log.info("CLIP image model not found, downloading");
await downloadModel(modelPath, modelName);
} else {
const size = (await fs.stat(modelPath)).size;
if (size !== expectedByteSize) {
log.error(
`The size ${size} of model ${modelName} does not match the expected size, downloading again`,
);
await downloadModel(modelPath, modelName);
}
}
return modelPath;
};
/** Return the path where the given {@link modelName} is meant to be saved */
export const modelSavePath = (modelName: string) =>
path.join(app.getPath("userData"), "models", modelName);
export const downloadModel = async (saveLocation: string, name: string) => {
// `mkdir -p` the directory where we want to save the model.
const saveDir = path.dirname(saveLocation);
await fs.mkdir(saveDir, { recursive: true });
// Download
log.info(`Downloading ML model from ${name}`);
const url = `https://models.ente.io/${name}`;
const res = await net.fetch(url);
if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
// Save
await writeStream(saveLocation, res.body);
log.info(`Downloaded CLIP model ${name}`);
};
/**
* Crete an ONNX {@link InferenceSession} with some defaults.
*/
export const createInferenceSession = async (modelPath: string) => {
return await ort.InferenceSession.create(modelPath, {
// Restrict the number of threads to 1
intraOpNumThreads: 1,
// Be more conservative with RAM usage
enableCpuMemArena: false,
});
};

View file

@ -143,6 +143,12 @@ const clipImageEmbedding = (jpegImageData: Uint8Array): Promise<Float32Array> =>
const clipTextEmbedding = (text: string): Promise<Float32Array> =>
ipcRenderer.invoke("clipTextEmbedding", text);
const detectFaces = (input: Float32Array): Promise<Float32Array> =>
ipcRenderer.invoke("detectFaces", input);
const faceEmbedding = (input: Float32Array): Promise<Float32Array> =>
ipcRenderer.invoke("faceEmbedding", input);
// - File selection
// TODO: Deprecated - use dialogs on the renderer process itself
@ -322,6 +328,8 @@ contextBridge.exposeInMainWorld("electron", {
// - ML
clipImageEmbedding,
clipTextEmbedding,
detectFaces,
faceEmbedding,
// - File selection
selectDirectory,

View file

@ -9,6 +9,8 @@ PODS:
- connectivity_plus (0.0.1):
- Flutter
- ReachabilitySwift
- dart_ui_isolate (0.0.1):
- Flutter
- device_info_plus (0.0.1):
- Flutter
- file_saver (0.0.1):
@ -18,7 +20,7 @@ PODS:
- Firebase/Messaging (10.22.0):
- Firebase/CoreOnly
- FirebaseMessaging (~> 10.22.0)
- firebase_core (2.27.0):
- firebase_core (2.29.0):
- Firebase/CoreOnly (= 10.22.0)
- Flutter
- firebase_messaging (14.7.19):
@ -62,8 +64,6 @@ PODS:
- flutter_inappwebview/Core (0.0.1):
- Flutter
- OrderedSet (~> 5.0)
- flutter_isolate (0.0.1):
- Flutter
- flutter_local_notifications (0.0.1):
- Flutter
- flutter_native_splash (0.0.1):
@ -235,6 +235,7 @@ DEPENDENCIES:
- battery_info (from `.symlinks/plugins/battery_info/ios`)
- bonsoir_darwin (from `.symlinks/plugins/bonsoir_darwin/darwin`)
- connectivity_plus (from `.symlinks/plugins/connectivity_plus/ios`)
- dart_ui_isolate (from `.symlinks/plugins/dart_ui_isolate/ios`)
- device_info_plus (from `.symlinks/plugins/device_info_plus/ios`)
- file_saver (from `.symlinks/plugins/file_saver/ios`)
- firebase_core (from `.symlinks/plugins/firebase_core/ios`)
@ -244,7 +245,6 @@ DEPENDENCIES:
- flutter_email_sender (from `.symlinks/plugins/flutter_email_sender/ios`)
- flutter_image_compress (from `.symlinks/plugins/flutter_image_compress/ios`)
- flutter_inappwebview (from `.symlinks/plugins/flutter_inappwebview/ios`)
- flutter_isolate (from `.symlinks/plugins/flutter_isolate/ios`)
- flutter_local_notifications (from `.symlinks/plugins/flutter_local_notifications/ios`)
- flutter_native_splash (from `.symlinks/plugins/flutter_native_splash/ios`)
- flutter_secure_storage (from `.symlinks/plugins/flutter_secure_storage/ios`)
@ -276,7 +276,6 @@ DEPENDENCIES:
- shared_preferences_foundation (from `.symlinks/plugins/shared_preferences_foundation/darwin`)
- sqflite (from `.symlinks/plugins/sqflite/darwin`)
- sqlite3_flutter_libs (from `.symlinks/plugins/sqlite3_flutter_libs/ios`)
- tflite_flutter (from `.symlinks/plugins/tflite_flutter/ios`)
- uni_links (from `.symlinks/plugins/uni_links/ios`)
- url_launcher_ios (from `.symlinks/plugins/url_launcher_ios/ios`)
- video_player_avfoundation (from `.symlinks/plugins/video_player_avfoundation/darwin`)
@ -306,8 +305,6 @@ SPEC REPOS:
- Sentry
- SentryPrivate
- sqlite3
- TensorFlowLiteC
- TensorFlowLiteSwift
- Toast
EXTERNAL SOURCES:
@ -319,6 +316,8 @@ EXTERNAL SOURCES:
:path: ".symlinks/plugins/bonsoir_darwin/darwin"
connectivity_plus:
:path: ".symlinks/plugins/connectivity_plus/ios"
dart_ui_isolate:
:path: ".symlinks/plugins/dart_ui_isolate/ios"
device_info_plus:
:path: ".symlinks/plugins/device_info_plus/ios"
file_saver:
@ -337,8 +336,6 @@ EXTERNAL SOURCES:
:path: ".symlinks/plugins/flutter_image_compress/ios"
flutter_inappwebview:
:path: ".symlinks/plugins/flutter_inappwebview/ios"
flutter_isolate:
:path: ".symlinks/plugins/flutter_isolate/ios"
flutter_local_notifications:
:path: ".symlinks/plugins/flutter_local_notifications/ios"
flutter_native_splash:
@ -401,8 +398,6 @@ EXTERNAL SOURCES:
:path: ".symlinks/plugins/sqflite/darwin"
sqlite3_flutter_libs:
:path: ".symlinks/plugins/sqlite3_flutter_libs/ios"
tflite_flutter:
:path: ".symlinks/plugins/tflite_flutter/ios"
uni_links:
:path: ".symlinks/plugins/uni_links/ios"
url_launcher_ios:
@ -421,10 +416,11 @@ SPEC CHECKSUMS:
battery_info: 09f5c9ee65394f2291c8c6227bedff345b8a730c
bonsoir_darwin: 127bdc632fdc154ae2f277a4d5c86a6212bc75be
connectivity_plus: 07c49e96d7fc92bc9920617b83238c4d178b446a
dart_ui_isolate: d5bcda83ca4b04f129d70eb90110b7a567aece14
device_info_plus: c6fb39579d0f423935b0c9ce7ee2f44b71b9fce6
file_saver: 503e386464dbe118f630e17b4c2e1190fa0cf808
Firebase: 797fd7297b7e1be954432743a0b3f90038e45a71
firebase_core: 100945864b4aedce3cfef0c62ab864858bf013cf
firebase_core: aaadbddb3cb2ee3792b9804f9dbb63e5f6f7b55c
firebase_messaging: e65050bf9b187511d80ea3a4de7cf5573d2c7543
FirebaseCore: 0326ec9b05fbed8f8716cddbf0e36894a13837f7
FirebaseCoreInternal: bcb5acffd4ea05e12a783ecf835f2210ce3dc6af
@ -435,8 +431,7 @@ SPEC CHECKSUMS:
flutter_email_sender: 02d7443217d8c41483223627972bfdc09f74276b
flutter_image_compress: 5a5e9aee05b6553048b8df1c3bc456d0afaac433
flutter_inappwebview: 3d32228f1304635e7c028b0d4252937730bbc6cf
flutter_isolate: 0edf5081826d071adf21759d1eb10ff5c24503b5
flutter_local_notifications: 0c0b1ae97e741e1521e4c1629a459d04b9aec743
flutter_local_notifications: 4cde75091f6327eb8517fa068a0a5950212d2086
flutter_native_splash: 52501b97d1c0a5f898d687f1646226c1f93c56ef
flutter_secure_storage: 23fc622d89d073675f2eaa109381aefbcf5a49be
flutter_sodium: c84426b4de738514b5b66cfdeb8a06634e72fe0b
@ -477,7 +472,7 @@ SPEC CHECKSUMS:
Sentry: ebc12276bd17613a114ab359074096b6b3725203
sentry_flutter: 88ebea3f595b0bc16acc5bedacafe6d60c12dcd5
SentryPrivate: d651efb234cf385ec9a1cdd3eff94b5e78a0e0fe
share_plus: 056a1e8ac890df3e33cb503afffaf1e9b4fbae68
share_plus: 8875f4f2500512ea181eef553c3e27dba5135aad
shared_preferences_foundation: b4c3b4cddf1c21f02770737f147a3f5da9d39695
sqflite: 673a0e54cc04b7d6dba8d24fb8095b31c3a99eec
sqlite3: 73b7fc691fdc43277614250e04d183740cb15078
@ -492,4 +487,4 @@ SPEC CHECKSUMS:
PODFILE CHECKSUM: c1a8f198a245ed1f10e40b617efdb129b021b225
COCOAPODS: 1.15.2
COCOAPODS: 1.14.3

View file

@ -192,7 +192,7 @@
isa = PBXProject;
attributes = {
LastSwiftUpdateCheck = 1520;
LastUpgradeCheck = 1430;
LastUpgradeCheck = 1510;
ORGANIZATIONNAME = "The Chromium Authors";
TargetAttributes = {
97C146ED1CF9000F007C117D = {
@ -295,13 +295,13 @@
"${BUILT_PRODUCTS_DIR}/battery_info/battery_info.framework",
"${BUILT_PRODUCTS_DIR}/bonsoir_darwin/bonsoir_darwin.framework",
"${BUILT_PRODUCTS_DIR}/connectivity_plus/connectivity_plus.framework",
"${BUILT_PRODUCTS_DIR}/dart_ui_isolate/dart_ui_isolate.framework",
"${BUILT_PRODUCTS_DIR}/device_info_plus/device_info_plus.framework",
"${BUILT_PRODUCTS_DIR}/file_saver/file_saver.framework",
"${BUILT_PRODUCTS_DIR}/fk_user_agent/fk_user_agent.framework",
"${BUILT_PRODUCTS_DIR}/flutter_email_sender/flutter_email_sender.framework",
"${BUILT_PRODUCTS_DIR}/flutter_image_compress/flutter_image_compress.framework",
"${BUILT_PRODUCTS_DIR}/flutter_inappwebview/flutter_inappwebview.framework",
"${BUILT_PRODUCTS_DIR}/flutter_isolate/flutter_isolate.framework",
"${BUILT_PRODUCTS_DIR}/flutter_local_notifications/flutter_local_notifications.framework",
"${BUILT_PRODUCTS_DIR}/flutter_native_splash/flutter_native_splash.framework",
"${BUILT_PRODUCTS_DIR}/flutter_secure_storage/flutter_secure_storage.framework",
@ -380,13 +380,13 @@
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/battery_info.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/bonsoir_darwin.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/connectivity_plus.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/dart_ui_isolate.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/device_info_plus.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/file_saver.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/fk_user_agent.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_email_sender.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_image_compress.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_inappwebview.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_isolate.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_local_notifications.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_native_splash.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_secure_storage.framework",

View file

@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<Scheme
LastUpgradeVersion = "1430"
LastUpgradeVersion = "1510"
version = "1.3">
<BuildAction
parallelizeBuildables = "YES"

View file

@ -170,7 +170,13 @@ class PersonService {
faceCount += cluster.faces.length;
for (var faceId in cluster.faces) {
if (faceIdToClusterID.containsKey(faceId)) {
throw Exception("Face $faceId is already assigned to a cluster");
final otherPersonID = clusterToPersonID[faceIdToClusterID[faceId]!];
if (otherPersonID != e.id) {
final otherPerson = await getPerson(otherPersonID!);
throw Exception(
"Face $faceId is already assigned to person $otherPersonID (${otherPerson!.data.name}) and person ${e.id} (${personData.name})",
);
}
}
faceIdToClusterID[faceId] = cluster.id;
}

View file

@ -27,25 +27,33 @@ class EmbeddingStore {
late SharedPreferences _preferences;
Completer<void>? _syncStatus;
Completer<bool>? _remoteSyncStatus;
Future<void> init() async {
_preferences = await SharedPreferences.getInstance();
}
Future<void> pullEmbeddings(Model model) async {
if (_syncStatus != null) {
return _syncStatus!.future;
Future<bool> pullEmbeddings(Model model) async {
if (_remoteSyncStatus != null) {
return _remoteSyncStatus!.future;
}
_syncStatus = Completer();
var remoteEmbeddings = await _getRemoteEmbeddings(model);
await _storeRemoteEmbeddings(remoteEmbeddings.embeddings);
while (remoteEmbeddings.hasMore) {
remoteEmbeddings = await _getRemoteEmbeddings(model);
_remoteSyncStatus = Completer();
try {
var remoteEmbeddings = await _getRemoteEmbeddings(model);
await _storeRemoteEmbeddings(remoteEmbeddings.embeddings);
while (remoteEmbeddings.hasMore) {
remoteEmbeddings = await _getRemoteEmbeddings(model);
await _storeRemoteEmbeddings(remoteEmbeddings.embeddings);
}
_remoteSyncStatus!.complete(true);
_remoteSyncStatus = null;
return true;
} catch (e, s) {
_logger.severe("failed to fetch & store remote embeddings", e, s);
_remoteSyncStatus!.complete(false);
_remoteSyncStatus = null;
return false;
}
_syncStatus!.complete();
_syncStatus = null;
}
Future<void> pushEmbeddings() async {
@ -132,7 +140,8 @@ class EmbeddingStore {
remoteEmbeddings.add(embedding);
}
} catch (e, s) {
_logger.severe(e, s);
_logger.warning("Fetching embeddings failed", e, s);
rethrow;
}
_logger.info("${remoteEmbeddings.length} embeddings fetched");

View file

@ -49,9 +49,10 @@ class SemanticSearchService {
bool _hasInitialized = false;
bool _isComputingEmbeddings = false;
bool _isSyncing = false;
Future<List<EnteFile>>? _ongoingRequest;
List<Embedding> _cachedEmbeddings = <Embedding>[];
PendingQuery? _nextQuery;
Future<(String, List<EnteFile>)>? _searchScreenRequest;
String? _latestPendingQuery;
Completer<void> _mlController = Completer<void>();
get hasInitialized => _hasInitialized;
@ -125,37 +126,40 @@ class SemanticSearchService {
return;
}
_isSyncing = true;
await EmbeddingStore.instance.pullEmbeddings(_currentModel);
await _backFill();
final fetchCompleted =
await EmbeddingStore.instance.pullEmbeddings(_currentModel);
if (fetchCompleted) {
await _backFill();
}
_isSyncing = false;
}
Future<List<EnteFile>> search(String query) async {
// searchScreenQuery should only be used for the user initiate query on the search screen.
// If there are multiple call tho this method, then for all the calls, the result will be the same as the last query.
Future<(String, List<EnteFile>)> searchScreenQuery(String query) async {
if (!LocalSettings.instance.hasEnabledMagicSearch() ||
!_frameworkInitialization.isCompleted) {
return [];
return (query, <EnteFile>[]);
}
if (_ongoingRequest == null) {
_ongoingRequest = _getMatchingFiles(query).then((result) {
_ongoingRequest = null;
if (_nextQuery != null) {
final next = _nextQuery;
_nextQuery = null;
search(next!.query).then((nextResult) {
next.completer.complete(nextResult);
});
}
return result;
});
return _ongoingRequest!;
// If there's an ongoing request, just update the last query and return its future.
if (_searchScreenRequest != null) {
_latestPendingQuery = query;
return _searchScreenRequest!;
} else {
// If there's an ongoing request, create or replace the nextCompleter.
_logger.info("Queuing query $query");
await _nextQuery?.completer.future
.timeout(const Duration(seconds: 0)); // Cancels the previous future.
_nextQuery = PendingQuery(query, Completer<List<EnteFile>>());
return _nextQuery!.completer.future;
// No ongoing request, start a new search.
_searchScreenRequest = _getMatchingFiles(query).then((result) {
// Search completed, reset the ongoing request.
_searchScreenRequest = null;
// If there was a new query during the last search, start a new search with the last query.
if (_latestPendingQuery != null) {
final String newQuery = _latestPendingQuery!;
_latestPendingQuery = null; // Reset last query.
// Recursively call search with the latest query.
return searchScreenQuery(newQuery);
}
return (query, result);
});
return _searchScreenRequest!;
}
}
@ -431,13 +435,6 @@ class QueryResult {
QueryResult(this.id, this.score);
}
class PendingQuery {
final String query;
final Completer<List<EnteFile>> completer;
PendingQuery(this.query, this.completer);
}
class IndexStatus {
final int indexedItems, pendingItems;

View file

@ -23,9 +23,8 @@ class NotificationService {
Future<void> init(
void Function(
NotificationResponse notificationResponse,
)
onNotificationTapped,
SharedPreferences preferences,
) onNotificationTapped,
SharedPreferences preferences,
) async {
_preferences = preferences;
const androidSettings = AndroidInitializationSettings('notification_icon');
@ -72,7 +71,7 @@ class NotificationService {
result = await _notificationsPlugin
.resolvePlatformSpecificImplementation<
AndroidFlutterLocalNotificationsPlugin>()
?.requestPermission();
?.requestNotificationsPermission();
}
if (result != null) {
await _preferences.setBool(keyGrantedNotificationPermission, result);

View file

@ -986,8 +986,16 @@ class SearchService {
String query,
) async {
final List<GenericSearchResult> searchResults = [];
final files = await SemanticSearchService.instance.search(query);
if (files.isNotEmpty) {
late List<EnteFile> files;
late String resultForQuery;
try {
(resultForQuery, files) =
await SemanticSearchService.instance.searchScreenQuery(query);
} catch (e, s) {
_logger.severe("Error occurred during magic search", e, s);
return searchResults;
}
if (files.isNotEmpty && resultForQuery == query) {
searchResults.add(GenericSearchResult(ResultType.magic, query, files));
}
return searchResults;

View file

@ -13,10 +13,10 @@ packages:
dependency: transitive
description:
name: _flutterfire_internals
sha256: "4eec93681221723a686ad580c2e7d960e1017cf1a4e0a263c2573c2c6b0bf5cd"
sha256: "0cb43f83f36ba8cb20502dee0c205e3f3aafb751732d724aeac3f2e044212cc2"
url: "https://pub.dev"
source: hosted
version: "1.3.25"
version: "1.3.29"
adaptive_theme:
dependency: "direct main"
description:
@ -350,18 +350,18 @@ packages:
dependency: transitive
description:
name: coverage
sha256: "595a29b55ce82d53398e1bcc2cba525d7bd7c59faeb2d2540e9d42c390cfeeeb"
sha256: "8acabb8306b57a409bf4c83522065672ee13179297a6bb0cb9ead73948df7c76"
url: "https://pub.dev"
source: hosted
version: "1.6.4"
version: "1.7.2"
cross_file:
dependency: "direct main"
description:
name: cross_file
sha256: fedaadfa3a6996f75211d835aaeb8fede285dae94262485698afd832371b9a5e
sha256: "55d7b444feb71301ef6b8838dbc1ae02e63dd48c8773f3810ff53bb1e2945b32"
url: "https://pub.dev"
source: hosted
version: "0.3.3+8"
version: "0.3.4+1"
crypto:
dependency: "direct main"
description:
@ -534,10 +534,10 @@ packages:
dependency: transitive
description:
name: extended_image_library
sha256: "9b55fc5ebc65fad984de66b8f177a1bef2a84d79203c9c213f75ff83c2c29edd"
sha256: c9caee8fe9b6547bd41c960c4f2d1ef8e34321804de6a1777f1d614a24247ad6
url: "https://pub.dev"
source: hosted
version: "4.0.1"
version: "4.0.4"
fade_indexed_stack:
dependency: "direct main"
description:
@ -582,10 +582,10 @@ packages:
dependency: transitive
description:
name: file
sha256: "1b92bec4fc2a72f59a8e15af5f52cd441e4a7860b49499d69dfa817af20e925d"
sha256: "5fc22d7c25582e38ad9a8515372cd9a93834027aacf1801cf01164dac0ffa08c"
url: "https://pub.dev"
source: hosted
version: "6.1.4"
version: "7.0.0"
file_saver:
dependency: "direct main"
description:
@ -599,10 +599,10 @@ packages:
dependency: "direct main"
description:
name: firebase_core
sha256: "53316975310c8af75a96e365f9fccb67d1c544ef0acdbf0d88bbe30eedd1c4f9"
sha256: a864d1b6afd25497a3b57b016886d1763df52baaa69758a46723164de8d187fe
url: "https://pub.dev"
source: hosted
version: "2.27.0"
version: "2.29.0"
firebase_core_platform_interface:
dependency: transitive
description:
@ -615,10 +615,10 @@ packages:
dependency: transitive
description:
name: firebase_core_web
sha256: c8e1d59385eee98de63c92f961d2a7062c5d9a65e7f45bdc7f1b0b205aab2492
sha256: c8b02226e548f35aace298e2bb2e6c24e34e8a203d614e742bb1146e5a4ad3c8
url: "https://pub.dev"
source: hosted
version: "2.11.5"
version: "2.15.0"
firebase_messaging:
dependency: "direct main"
description:
@ -631,18 +631,18 @@ packages:
dependency: transitive
description:
name: firebase_messaging_platform_interface
sha256: f7a9d74ff7fc588a924f6b2eaeaa148b0db521b13a9db55f6ad45864fa98c06e
sha256: "80b4ccf20066b0579ebc88d4678230a5f53ab282fe040e31671af745db1588f9"
url: "https://pub.dev"
source: hosted
version: "4.5.27"
version: "4.5.31"
firebase_messaging_web:
dependency: transitive
description:
name: firebase_messaging_web
sha256: fc21e771166860c55b103701c5ac7cdb2eec28897b97c42e6e5703cbedf9e02e
sha256: "9224aa4db1ce6f08d96a82978453d37e9980204a20e410a11d9b774b24c6841c"
url: "https://pub.dev"
source: hosted
version: "3.6.8"
version: "3.8.1"
fixnum:
dependency: transitive
description:
@ -753,26 +753,26 @@ packages:
dependency: "direct main"
description:
name: flutter_local_notifications
sha256: f222919a34545931e47b06000836b5101baeffb0e6eb5a4691d2d42851740dd9
sha256: f9a05409385b77b06c18f200a41c7c2711ebf7415669350bb0f8474c07bd40d1
url: "https://pub.dev"
source: hosted
version: "12.0.4"
version: "17.0.0"
flutter_local_notifications_linux:
dependency: transitive
description:
name: flutter_local_notifications_linux
sha256: "3c6d6db334f609a92be0c0915f40871ec56f5d2adf01e77ae364162c587c0ca8"
sha256: "33f741ef47b5f63cc7f78fe75eeeac7e19f171ff3c3df054d84c1e38bedb6a03"
url: "https://pub.dev"
source: hosted
version: "2.0.0"
version: "4.0.0+1"
flutter_local_notifications_platform_interface:
dependency: transitive
description:
name: flutter_local_notifications_platform_interface
sha256: "5ec1feac5f7f7d9266759488bc5f76416152baba9aa1b26fe572246caa00d1ab"
sha256: "7cf643d6d5022f3baed0be777b0662cce5919c0a7b86e700299f22dc4ae660ef"
url: "https://pub.dev"
source: hosted
version: "6.0.0"
version: "7.0.0+1"
flutter_localizations:
dependency: "direct main"
description: flutter
@ -1013,10 +1013,10 @@ packages:
dependency: "direct main"
description:
name: http
sha256: a2bbf9d017fcced29139daa8ed2bba4ece450ab222871df93ca9eec6f80c34ba
sha256: "761a297c042deedc1ffbb156d6e2af13886bb305c2a343a4d972504cd67dd938"
url: "https://pub.dev"
source: hosted
version: "1.2.0"
version: "1.2.1"
http_client_helper:
dependency: transitive
description:
@ -1190,6 +1190,30 @@ packages:
url: "https://pub.dev"
source: hosted
version: "0.9.1"
leak_tracker:
dependency: transitive
description:
name: leak_tracker
sha256: "78eb209deea09858f5269f5a5b02be4049535f568c07b275096836f01ea323fa"
url: "https://pub.dev"
source: hosted
version: "10.0.0"
leak_tracker_flutter_testing:
dependency: transitive
description:
name: leak_tracker_flutter_testing
sha256: b46c5e37c19120a8a01918cfaf293547f47269f7cb4b0058f21531c2465d6ef0
url: "https://pub.dev"
source: hosted
version: "2.0.1"
leak_tracker_testing:
dependency: transitive
description:
name: leak_tracker_testing
sha256: a597f72a664dbd293f3bfc51f9ba69816f84dcd403cdac7066cb3f6003f3ab47
url: "https://pub.dev"
source: hosted
version: "2.0.1"
like_button:
dependency: "direct main"
description:
@ -1290,18 +1314,18 @@ packages:
dependency: transitive
description:
name: matcher
sha256: "1803e76e6653768d64ed8ff2e1e67bea3ad4b923eb5c56a295c3e634bad5960e"
sha256: d2323aa2060500f906aa31a895b4030b6da3ebdcc5619d14ce1aada65cd161cb
url: "https://pub.dev"
source: hosted
version: "0.12.16"
version: "0.12.16+1"
material_color_utilities:
dependency: transitive
description:
name: material_color_utilities
sha256: "9528f2f296073ff54cb9fee677df673ace1218163c3bc7628093e7eed5203d41"
sha256: "0e0a020085b65b6083975e499759762399b4475f766c21668c4ecca34ea74e5a"
url: "https://pub.dev"
source: hosted
version: "0.5.0"
version: "0.8.0"
media_extension:
dependency: "direct main"
description:
@ -1386,10 +1410,10 @@ packages:
dependency: transitive
description:
name: meta
sha256: a6e590c838b18133bb482a2745ad77c5bb7715fb0451209e1a7567d416678b8e
sha256: d584fa6707a52763a52446f02cc621b077888fb63b93bbcb1143a7be5a0c0c04
url: "https://pub.dev"
source: hosted
version: "1.10.0"
version: "1.11.0"
mgrs_dart:
dependency: transitive
description:
@ -1541,10 +1565,10 @@ packages:
dependency: "direct main"
description:
name: path
sha256: "8829d8a55c13fc0e37127c29fedf290c102f4e40ae94ada574091fe0ff96c917"
sha256: "087ce49c3f0dc39180befefc60fdb4acd8f8620e5682fe2476afd0b3688bb4af"
url: "https://pub.dev"
source: hosted
version: "1.8.3"
version: "1.9.0"
path_drawing:
dependency: transitive
description:
@ -1709,10 +1733,10 @@ packages:
dependency: transitive
description:
name: platform
sha256: ae68c7bfcd7383af3629daafb32fb4e8681c7154428da4febcff06200585f102
sha256: "12220bb4b65720483f8fa9450b4332347737cf8213dd2840d8b2c823e47243ec"
url: "https://pub.dev"
source: hosted
version: "3.1.2"
version: "3.1.4"
plugin_platform_interface:
dependency: transitive
description:
@ -1749,10 +1773,10 @@ packages:
dependency: transitive
description:
name: process
sha256: "53fd8db9cec1d37b0574e12f07520d582019cb6c44abf5479a01505099a34a09"
sha256: "21e54fd2faf1b5bdd5102afd25012184a6793927648ea81eea80552ac9405b32"
url: "https://pub.dev"
source: hosted
version: "4.2.4"
version: "5.0.2"
proj4dart:
dependency: transitive
description:
@ -1901,26 +1925,10 @@ packages:
dependency: "direct main"
description:
name: share_plus
sha256: f582d5741930f3ad1bf0211d358eddc0508cc346e5b4b248bd1e569c995ebb7a
sha256: fb5319f3aab4c5dda5ebb92dca978179ba21f8c783ee4380910ef4c1c6824f51
url: "https://pub.dev"
source: hosted
version: "4.5.3"
share_plus_linux:
dependency: transitive
description:
name: share_plus_linux
sha256: dc32bf9f1151b9864bb86a997c61a487967a08f2e0b4feaa9a10538712224da4
url: "https://pub.dev"
source: hosted
version: "3.0.1"
share_plus_macos:
dependency: transitive
description:
name: share_plus_macos
sha256: "44daa946f2845045ecd7abb3569b61cd9a55ae9cc4cbec9895b2067b270697ae"
url: "https://pub.dev"
source: hosted
version: "3.0.1"
version: "8.0.3"
share_plus_platform_interface:
dependency: transitive
description:
@ -1929,22 +1937,6 @@ packages:
url: "https://pub.dev"
source: hosted
version: "3.4.0"
share_plus_web:
dependency: transitive
description:
name: share_plus_web
sha256: eaef05fa8548b372253e772837dd1fbe4ce3aca30ea330765c945d7d4f7c9935
url: "https://pub.dev"
source: hosted
version: "3.1.0"
share_plus_windows:
dependency: transitive
description:
name: share_plus_windows
sha256: "3a21515ae7d46988d42130cd53294849e280a5de6ace24bae6912a1bffd757d4"
url: "https://pub.dev"
source: hosted
version: "3.0.1"
shared_preferences:
dependency: "direct main"
description:
@ -1989,10 +1981,10 @@ packages:
dependency: transitive
description:
name: shared_preferences_web
sha256: "7b15ffb9387ea3e237bb7a66b8a23d2147663d391cafc5c8f37b2e7b4bde5d21"
sha256: "9aee1089b36bd2aafe06582b7d7817fd317ef05fc30e6ba14bff247d0933042a"
url: "https://pub.dev"
source: hosted
version: "2.2.2"
version: "2.3.0"
shared_preferences_windows:
dependency: transitive
description:
@ -2402,10 +2394,10 @@ packages:
dependency: transitive
description:
name: url_launcher_web
sha256: fff0932192afeedf63cdd50ecbb1bc825d31aed259f02bb8dba0f3b729a5e88b
sha256: "3692a459204a33e04bc94f5fb91158faf4f2c8903281ddd82915adecdb1a901d"
url: "https://pub.dev"
source: hosted
version: "2.2.3"
version: "2.3.0"
url_launcher_windows:
dependency: transitive
description:
@ -2491,10 +2483,10 @@ packages:
dependency: transitive
description:
name: vm_service
sha256: c538be99af830f478718b51630ec1b6bee5e74e52c8a802d328d9e71d35d2583
sha256: b3d56ff4341b8f182b96aceb2fa20e3dcb336b9f867bc0eafc0de10f1048e957
url: "https://pub.dev"
source: hosted
version: "11.10.0"
version: "13.0.0"
volume_controller:
dependency: transitive
description:
@ -2539,10 +2531,10 @@ packages:
dependency: transitive
description:
name: web
sha256: afe077240a270dcfd2aafe77602b4113645af95d0ad31128cc02bce5ac5d5152
sha256: "97da13628db363c635202ad97068d47c5b8aa555808e7a9411963c533b449b27"
url: "https://pub.dev"
source: hosted
version: "0.3.0"
version: "0.5.1"
web_socket_channel:
dependency: transitive
description:
@ -2555,10 +2547,10 @@ packages:
dependency: transitive
description:
name: webdriver
sha256: "3c923e918918feeb90c4c9fdf1fe39220fa4c0e8e2c0fffaded174498ef86c49"
sha256: "003d7da9519e1e5f329422b36c4dcdf18d7d2978d1ba099ea4e45ba490ed845e"
url: "https://pub.dev"
source: hosted
version: "3.0.2"
version: "3.0.3"
webkit_inspection_protocol:
dependency: transitive
description:
@ -2611,10 +2603,10 @@ packages:
dependency: transitive
description:
name: xdg_directories
sha256: bd512f03919aac5f1313eb8249f223bacf4927031bf60b02601f81f687689e86
sha256: faea9dee56b520b55a566385b84f2e8de55e7496104adada9962e0bd11bcff1d
url: "https://pub.dev"
source: hosted
version: "0.2.0+3"
version: "1.0.4"
xml:
dependency: transitive
description:
@ -2648,5 +2640,5 @@ packages:
source: hosted
version: "3.1.2"
sdks:
dart: ">=3.2.5 <4.0.0"
flutter: ">=3.16.6"
dart: ">=3.3.0 <4.0.0"
flutter: ">=3.19.0"

View file

@ -79,7 +79,7 @@ dependencies:
flutter_image_compress: ^1.1.0
flutter_inappwebview: ^5.8.0
flutter_launcher_icons: ^0.13.1
flutter_local_notifications: ^12.0.4
flutter_local_notifications: ^17.0.0
flutter_localizations:
sdk: flutter
flutter_map: ^5.0.0
@ -145,7 +145,7 @@ dependencies:
scrollable_positioned_list: ^0.3.5
sentry: ^7.9.0
sentry_flutter: ^7.9.0
share_plus: ^4.0.10
share_plus: ^8.0.3
shared_preferences: ^2.0.5
simple_cluster: ^0.3.0
sqflite: ^2.3.0

View file

@ -10,14 +10,7 @@
"@ente/shared": "*",
"@mui/x-date-pickers": "^5.0.0-alpha.6",
"@stripe/stripe-js": "^1.13.2",
"@tensorflow-models/coco-ssd": "^2.2.2",
"@tensorflow/tfjs-backend-cpu": "^4.10.0",
"@tensorflow/tfjs-backend-webgl": "^4.9.0",
"@tensorflow/tfjs-converter": "^4.10.0",
"@tensorflow/tfjs-core": "^4.10.0",
"@tensorflow/tfjs-tflite": "0.0.1-alpha.7",
"bip39": "^3.0.4",
"blazeface-back": "^0.0.9",
"bs58": "^5.0.0",
"chrono-node": "^2.2.6",
"date-fns": "^2",

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1 +0,0 @@
"use strict";var Module={};var initializedJS=false;function threadPrintErr(){var text=Array.prototype.slice.call(arguments).join(" ");console.error(text)}function threadAlert(){var text=Array.prototype.slice.call(arguments).join(" ");postMessage({cmd:"alert",text:text,threadId:Module["_pthread_self"]()})}var err=threadPrintErr;self.alert=threadAlert;Module["instantiateWasm"]=function(info,receiveInstance){var instance=new WebAssembly.Instance(Module["wasmModule"],info);receiveInstance(instance);Module["wasmModule"]=null;return instance.exports};function moduleLoaded(){}self.onmessage=function(e){try{if(e.data.cmd==="load"){Module["wasmModule"]=e.data.wasmModule;Module["wasmMemory"]=e.data.wasmMemory;Module["buffer"]=Module["wasmMemory"].buffer;Module["ENVIRONMENT_IS_PTHREAD"]=true;if(typeof e.data.urlOrBlob==="string"){importScripts(e.data.urlOrBlob)}else{var objectUrl=URL.createObjectURL(e.data.urlOrBlob);importScripts(objectUrl);URL.revokeObjectURL(objectUrl)}tflite_web_api_ModuleFactory(Module).then(function(instance){Module=instance;moduleLoaded()})}else if(e.data.cmd==="objectTransfer"){Module["PThread"].receiveObjectTransfer(e.data)}else if(e.data.cmd==="run"){Module["__performance_now_clock_drift"]=performance.now()-e.data.time;Module["__emscripten_thread_init"](e.data.threadInfoStruct,0,0);var max=e.data.stackBase;var top=e.data.stackBase+e.data.stackSize;Module["establishStackSpace"](top,max);Module["PThread"].receiveObjectTransfer(e.data);Module["PThread"].threadInit();if(!initializedJS){Module["___embind_register_native_and_builtin_types"]();initializedJS=true}try{var result=Module["invokeEntryPoint"](e.data.start_routine,e.data.arg);if(Module["keepRuntimeAlive"]()){Module["PThread"].setExitStatus(result)}else{Module["PThread"].threadExit(result)}}catch(ex){if(ex==="Canceled!"){Module["PThread"].threadCancel()}else if(ex!="unwind"){if(ex instanceof Module["ExitStatus"]){if(Module["keepRuntimeAlive"]()){}else{Module["PThread"].threadExit(ex.status)}}else{Module["PThread"].threadExit(-2);throw ex}}}}else if(e.data.cmd==="cancel"){if(Module["_pthread_self"]()){Module["PThread"].threadCancel()}}else if(e.data.target==="setimmediate"){}else if(e.data.cmd==="processThreadQueue"){if(Module["_pthread_self"]()){Module["_emscripten_current_thread_process_queued_calls"]()}}else{err("worker.js received unknown command "+e.data.cmd);err(e.data)}}catch(ex){err("worker.js onmessage() captured an uncaught exception: "+ex);if(ex&&ex.stack)err(ex.stack);throw ex}};

File diff suppressed because one or more lines are too long

View file

@ -1 +0,0 @@
"use strict";var Module={};var initializedJS=false;function threadPrintErr(){var text=Array.prototype.slice.call(arguments).join(" ");console.error(text)}function threadAlert(){var text=Array.prototype.slice.call(arguments).join(" ");postMessage({cmd:"alert",text:text,threadId:Module["_pthread_self"]()})}var err=threadPrintErr;self.alert=threadAlert;Module["instantiateWasm"]=function(info,receiveInstance){var instance=new WebAssembly.Instance(Module["wasmModule"],info);receiveInstance(instance);Module["wasmModule"]=null;return instance.exports};function moduleLoaded(){}self.onmessage=function(e){try{if(e.data.cmd==="load"){Module["wasmModule"]=e.data.wasmModule;Module["wasmMemory"]=e.data.wasmMemory;Module["buffer"]=Module["wasmMemory"].buffer;Module["ENVIRONMENT_IS_PTHREAD"]=true;if(typeof e.data.urlOrBlob==="string"){importScripts(e.data.urlOrBlob)}else{var objectUrl=URL.createObjectURL(e.data.urlOrBlob);importScripts(objectUrl);URL.revokeObjectURL(objectUrl)}tflite_web_api_ModuleFactory(Module).then(function(instance){Module=instance;moduleLoaded()})}else if(e.data.cmd==="objectTransfer"){Module["PThread"].receiveObjectTransfer(e.data)}else if(e.data.cmd==="run"){Module["__performance_now_clock_drift"]=performance.now()-e.data.time;Module["__emscripten_thread_init"](e.data.threadInfoStruct,0,0);var max=e.data.stackBase;var top=e.data.stackBase+e.data.stackSize;Module["establishStackSpace"](top,max);Module["PThread"].receiveObjectTransfer(e.data);Module["PThread"].threadInit();if(!initializedJS){Module["___embind_register_native_and_builtin_types"]();initializedJS=true}try{var result=Module["invokeEntryPoint"](e.data.start_routine,e.data.arg);if(Module["keepRuntimeAlive"]()){Module["PThread"].setExitStatus(result)}else{Module["PThread"].threadExit(result)}}catch(ex){if(ex==="Canceled!"){Module["PThread"].threadCancel()}else if(ex!="unwind"){if(ex instanceof Module["ExitStatus"]){if(Module["keepRuntimeAlive"]()){}else{Module["PThread"].threadExit(ex.status)}}else{Module["PThread"].threadExit(-2);throw ex}}}}else if(e.data.cmd==="cancel"){if(Module["_pthread_self"]()){Module["PThread"].threadCancel()}}else if(e.data.target==="setimmediate"){}else if(e.data.cmd==="processThreadQueue"){if(Module["_pthread_self"]()){Module["_emscripten_current_thread_process_queued_calls"]()}}else{err("worker.js received unknown command "+e.data.cmd);err(e.data)}}catch(ex){err("worker.js onmessage() captured an uncaught exception: "+ex);if(ex&&ex.stack)err(ex.stack);throw ex}};

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1,32 +0,0 @@
{
"0": "waterfall",
"1": "snow",
"2": "landscape",
"3": "underwater",
"4": "architecture",
"5": "sunset / sunrise",
"6": "blue sky",
"7": "cloudy sky",
"8": "greenery",
"9": "autumn leaves",
"10": "portrait",
"11": "flower",
"12": "night shot",
"13": "stage concert",
"14": "fireworks",
"15": "candle light",
"16": "neon lights",
"17": "indoor",
"18": "backlight",
"19": "text documents",
"20": "qr images",
"21": "group portrait",
"22": "computer screens",
"23": "kids",
"24": "dog",
"25": "cat",
"26": "macro",
"27": "food",
"28": "beach",
"29": "mountain"
}

File diff suppressed because it is too large Load diff

File diff suppressed because one or more lines are too long

View file

@ -1,51 +0,0 @@
import Box from "@mui/material/Box";
import { Chip } from "components/Chip";
import { Legend } from "components/PhotoViewer/styledComponents/Legend";
import { t } from "i18next";
import { useEffect, useState } from "react";
import { EnteFile } from "types/file";
import mlIDbStorage from "utils/storage/mlIDbStorage";
export function ObjectLabelList(props: {
file: EnteFile;
updateMLDataIndex: number;
}) {
const [objects, setObjects] = useState<Array<string>>([]);
useEffect(() => {
let didCancel = false;
const main = async () => {
const objects = await mlIDbStorage.getAllObjectsMap();
const uniqueObjectNames = [
...new Set(
(objects.get(props.file.id) ?? []).map(
(object) => object.detection.class,
),
),
];
!didCancel && setObjects(uniqueObjectNames);
};
main();
return () => {
didCancel = true;
};
}, [props.file, props.updateMLDataIndex]);
if (objects.length === 0) return <></>;
return (
<div>
<Legend sx={{ pb: 1, display: "block" }}>{t("OBJECTS")}</Legend>
<Box
display={"flex"}
gap={1}
flexWrap="wrap"
justifyContent={"flex-start"}
alignItems={"flex-start"}
>
{objects.map((object) => (
<Chip key={object}>{object}</Chip>
))}
</Box>
</div>
);
}

View file

@ -1,39 +0,0 @@
import * as tf from "@tensorflow/tfjs-core";
import { useEffect, useRef } from "react";
import { FaceImage } from "types/machineLearning";
interface FaceImageProps {
faceImage: FaceImage;
width?: number;
height?: number;
}
export default function TFJSImage(props: FaceImageProps) {
const canvasRef = useRef(null);
useEffect(() => {
if (!props || !props.faceImage) {
return;
}
const canvas = canvasRef.current;
const faceTensor = tf.tensor3d(props.faceImage);
const resized =
props.width && props.height
? tf.image.resizeBilinear(faceTensor, [
props.width,
props.height,
])
: faceTensor;
const normFaceImage = tf.div(tf.add(resized, 1.0), 2);
tf.browser.toPixels(normFaceImage as tf.Tensor3D, canvas);
}, [props]);
return (
<canvas
ref={canvasRef}
width={112}
height={112}
style={{ display: "inline" }}
/>
);
}

View file

@ -10,7 +10,6 @@ import TextSnippetOutlined from "@mui/icons-material/TextSnippetOutlined";
import { Box, DialogProps, Link, Stack, styled } from "@mui/material";
import { Chip } from "components/Chip";
import { EnteDrawer } from "components/EnteDrawer";
import { ObjectLabelList } from "components/MachineLearning/ObjectList";
import {
PhotoPeopleList,
UnidentifiedFaces,
@ -344,10 +343,6 @@ export function FileInfo({
file={file}
updateMLDataIndex={updateMLDataIndex}
/>
<ObjectLabelList
file={file}
updateMLDataIndex={updateMLDataIndex}
/>
</>
)}
</Stack>

View file

@ -17,7 +17,7 @@ import {
import { Collection } from "types/collection";
import { LocationTagData } from "types/entity";
import { EnteFile } from "types/file";
import { Person, Thing, WordGroup } from "types/machineLearning";
import { Person } from "types/machineLearning";
import {
ClipSearchScores,
DateValue,
@ -146,12 +146,6 @@ export default function SearchInput(props: Iprops) {
case SuggestionType.PERSON:
search = { person: selectedOption.value as Person };
break;
case SuggestionType.THING:
search = { thing: selectedOption.value as Thing };
break;
case SuggestionType.TEXT:
search = { text: selectedOption.value as WordGroup };
break;
case SuggestionType.FILE_TYPE:
search = { fileType: selectedOption.value as FILE_TYPE };
break;

View file

@ -46,22 +46,6 @@ export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = {
// maxDistanceInsideCluster: 0.4,
generateDebugInfo: true,
},
objectDetection: {
method: "SSDMobileNetV2",
maxNumBoxes: 20,
minScore: 0.2,
},
sceneDetection: {
method: "ImageScene",
minScore: 0.1,
},
// tsne: {
// samples: 200,
// dim: 2,
// perplexity: 10.0,
// learningRate: 10.0,
// metric: 'euclidean',
// },
mlVersion: 3,
};
@ -69,29 +53,4 @@ export const DEFAULT_ML_SEARCH_CONFIG: MLSearchConfig = {
enabled: false,
};
export const ML_SYNC_DOWNLOAD_TIMEOUT_MS = 300000;
export const MAX_FACE_DISTANCE_PERCENT = Math.sqrt(2) / 100;
export const MAX_ML_SYNC_ERROR_COUNT = 1;
export const TEXT_DETECTION_TIMEOUT_MS = [10000, 30000, 60000, 120000, 240000];
export const BLAZEFACE_MAX_FACES = 50;
export const BLAZEFACE_INPUT_SIZE = 256;
export const BLAZEFACE_IOU_THRESHOLD = 0.3;
export const BLAZEFACE_SCORE_THRESHOLD = 0.75;
export const BLAZEFACE_PASS1_SCORE_THRESHOLD = 0.4;
export const BLAZEFACE_FACE_SIZE = 112;
export const MOBILEFACENET_FACE_SIZE = 112;
export const MOBILEFACENET_EMBEDDING_SIZE = 192;
// scene detection model takes fixed-shaped (224x224) inputs
// https://tfhub.dev/sayannath/lite-model/image-scene/1
export const SCENE_DETECTION_IMAGE_SIZE = 224;
// SSD with Mobilenet v2 initialized from Imagenet classification checkpoint. Trained on COCO 2017 dataset (images scaled to 320x320 resolution).
// https://tfhub.dev/tensorflow/ssd_mobilenet_v2/2
export const OBJECT_DETECTION_IMAGE_SIZE = 320;
export const BATCHES_BEFORE_SYNCING_INDEX = 5;

View file

@ -1,4 +1,6 @@
import { inWorker } from "@/next/env";
import log from "@/next/log";
import { workerBridge } from "@/next/worker/worker-bridge";
import ComlinkCryptoWorker from "@ente/shared/crypto";
import { CustomError } from "@ente/shared/error";
import HTTPService from "@ente/shared/network/HTTPService";
@ -262,7 +264,9 @@ export const putEmbedding = async (
putEmbeddingReq: PutEmbeddingRequest,
): Promise<EncryptedEmbedding> => {
try {
const token = getToken();
const token = inWorker()
? await workerBridge.getAuthToken()
: getToken();
if (!token) {
log.info("putEmbedding failed: token not found");
throw Error(CustomError.TOKEN_MISSING);

View file

@ -1,257 +0,0 @@
import log from "@/next/log";
import { GraphModel } from "@tensorflow/tfjs-converter";
import * as tf from "@tensorflow/tfjs-core";
import {
load as blazeFaceLoad,
BlazeFaceModel,
NormalizedFace,
} from "blazeface-back";
import {
BLAZEFACE_FACE_SIZE,
BLAZEFACE_INPUT_SIZE,
BLAZEFACE_IOU_THRESHOLD,
BLAZEFACE_MAX_FACES,
BLAZEFACE_PASS1_SCORE_THRESHOLD,
BLAZEFACE_SCORE_THRESHOLD,
MAX_FACE_DISTANCE_PERCENT,
} from "constants/mlConfig";
import {
FaceDetection,
FaceDetectionMethod,
FaceDetectionService,
Versioned,
} from "types/machineLearning";
import { addPadding, crop, resizeToSquare } from "utils/image";
import { enlargeBox, newBox, normFaceBox } from "utils/machineLearning";
import {
getNearestDetection,
removeDuplicateDetections,
transformPaddedToImage,
} from "utils/machineLearning/faceDetection";
import {
computeTransformToBox,
transformBox,
transformPoints,
} from "utils/machineLearning/transform";
import { Box, Point } from "../../../thirdparty/face-api/classes";
class BlazeFaceDetectionService implements FaceDetectionService {
private blazeFaceModel: Promise<BlazeFaceModel>;
private blazeFaceBackModel: GraphModel;
public method: Versioned<FaceDetectionMethod>;
private desiredLeftEye = [0.36, 0.45];
private desiredFaceSize;
public constructor(desiredFaceSize: number = BLAZEFACE_FACE_SIZE) {
this.method = {
value: "BlazeFace",
version: 1,
};
this.desiredFaceSize = desiredFaceSize;
}
public getRelativeDetection(): FaceDetection {
// TODO(MR): onnx-yolo
throw new Error();
}
private async init() {
this.blazeFaceModel = blazeFaceLoad({
maxFaces: BLAZEFACE_MAX_FACES,
scoreThreshold: BLAZEFACE_PASS1_SCORE_THRESHOLD,
iouThreshold: BLAZEFACE_IOU_THRESHOLD,
modelUrl: "/models/blazeface/back/model.json",
inputHeight: BLAZEFACE_INPUT_SIZE,
inputWidth: BLAZEFACE_INPUT_SIZE,
});
log.info(
"loaded blazeFaceModel: ",
// await this.blazeFaceModel,
// eslint-disable-next-line @typescript-eslint/await-thenable
await tf.getBackend(),
);
}
private getDlibAlignedFace(normFace: NormalizedFace): Box {
const relX = 0.5;
const relY = 0.43;
const relScale = 0.45;
const leftEyeCenter = normFace.landmarks[0];
const rightEyeCenter = normFace.landmarks[1];
const mountCenter = normFace.landmarks[3];
const distToMouth = (pt) => {
const dy = mountCenter[1] - pt[1];
const dx = mountCenter[0] - pt[0];
return Math.sqrt(dx * dx + dy * dy);
};
const eyeToMouthDist =
(distToMouth(leftEyeCenter) + distToMouth(rightEyeCenter)) / 2;
const size = Math.floor(eyeToMouthDist / relScale);
const center = [
(leftEyeCenter[0] + rightEyeCenter[0] + mountCenter[0]) / 3,
(leftEyeCenter[1] + rightEyeCenter[1] + mountCenter[1]) / 3,
];
const left = center[0] - relX * size;
const top = center[1] - relY * size;
const right = center[0] + relX * size;
const bottom = center[1] + relY * size;
return new Box({
left: left,
top: top,
right: right,
bottom: bottom,
});
}
private getAlignedFace(normFace: NormalizedFace): Box {
const leftEye = normFace.landmarks[0];
const rightEye = normFace.landmarks[1];
// const noseTip = normFace.landmarks[2];
const dy = rightEye[1] - leftEye[1];
const dx = rightEye[0] - leftEye[0];
const desiredRightEyeX = 1.0 - this.desiredLeftEye[0];
// const eyesCenterX = (leftEye[0] + rightEye[0]) / 2;
// const yaw = Math.abs(noseTip[0] - eyesCenterX)
const dist = Math.sqrt(dx * dx + dy * dy);
let desiredDist = desiredRightEyeX - this.desiredLeftEye[0];
desiredDist *= this.desiredFaceSize;
const scale = desiredDist / dist;
// log.info("scale: ", scale);
const eyesCenter = [];
eyesCenter[0] = Math.floor((leftEye[0] + rightEye[0]) / 2);
eyesCenter[1] = Math.floor((leftEye[1] + rightEye[1]) / 2);
// log.info("eyesCenter: ", eyesCenter);
const faceWidth = this.desiredFaceSize / scale;
const faceHeight = this.desiredFaceSize / scale;
// log.info("faceWidth: ", faceWidth, "faceHeight: ", faceHeight)
const tx = eyesCenter[0] - faceWidth * 0.5;
const ty = eyesCenter[1] - faceHeight * this.desiredLeftEye[1];
// log.info("tx: ", tx, "ty: ", ty);
return new Box({
left: tx,
top: ty,
right: tx + faceWidth,
bottom: ty + faceHeight,
});
}
public async detectFacesUsingModel(image: tf.Tensor3D) {
const resizedImage = tf.image.resizeBilinear(image, [256, 256]);
const reshapedImage = tf.reshape(resizedImage, [
1,
resizedImage.shape[0],
resizedImage.shape[1],
3,
]);
const normalizedImage = tf.sub(tf.div(reshapedImage, 127.5), 1.0);
// eslint-disable-next-line @typescript-eslint/await-thenable
const results = await this.blazeFaceBackModel.predict(normalizedImage);
// log.info('onFacesDetected: ', results);
return results;
}
private async getBlazefaceModel() {
if (!this.blazeFaceModel) {
await this.init();
}
return this.blazeFaceModel;
}
private async estimateFaces(
imageBitmap: ImageBitmap,
): Promise<Array<FaceDetection>> {
const resized = resizeToSquare(imageBitmap, BLAZEFACE_INPUT_SIZE);
const tfImage = tf.browser.fromPixels(resized.image);
const blazeFaceModel = await this.getBlazefaceModel();
// TODO: check if this works concurrently, else use serialqueue
const faces = await blazeFaceModel.estimateFaces(tfImage);
tf.dispose(tfImage);
const inBox = newBox(0, 0, resized.width, resized.height);
const toBox = newBox(0, 0, imageBitmap.width, imageBitmap.height);
const transform = computeTransformToBox(inBox, toBox);
// log.info("1st pass: ", { transform });
const faceDetections: Array<FaceDetection> = faces?.map((f) => {
const box = transformBox(normFaceBox(f), transform);
const normLandmarks = (f.landmarks as number[][])?.map(
(l) => new Point(l[0], l[1]),
);
const landmarks = transformPoints(normLandmarks, transform);
return {
box,
landmarks,
probability: f.probability as number,
// detectionMethod: this.method,
} as FaceDetection;
});
return faceDetections;
}
public async detectFaces(
imageBitmap: ImageBitmap,
): Promise<Array<FaceDetection>> {
const maxFaceDistance = imageBitmap.width * MAX_FACE_DISTANCE_PERCENT;
const pass1Detections = await this.estimateFaces(imageBitmap);
// run 2nd pass for accuracy
const detections: Array<FaceDetection> = [];
for (const pass1Detection of pass1Detections) {
const imageBox = enlargeBox(pass1Detection.box, 2);
const faceImage = crop(
imageBitmap,
imageBox,
BLAZEFACE_INPUT_SIZE / 2,
);
const paddedImage = addPadding(faceImage, 0.5);
const paddedBox = enlargeBox(imageBox, 2);
const pass2Detections = await this.estimateFaces(paddedImage);
pass2Detections?.forEach((d) =>
transformPaddedToImage(d, faceImage, imageBox, paddedBox),
);
let selected = pass2Detections?.[0];
if (pass2Detections?.length > 1) {
// log.info('2nd pass >1 face', pass2Detections.length);
selected = getNearestDetection(
pass1Detection,
pass2Detections,
// maxFaceDistance
);
}
// we might miss 1st pass face actually having score within threshold
// it is ok as results will be consistent with 2nd pass only detections
if (selected && selected.probability >= BLAZEFACE_SCORE_THRESHOLD) {
// log.info("pass2: ", { imageBox, paddedBox, transform, selected });
detections.push(selected);
}
}
return removeDuplicateDetections(detections, maxFaceDistance);
}
public async dispose() {
const blazeFaceModel = await this.getBlazefaceModel();
blazeFaceModel?.dispose();
this.blazeFaceModel = undefined;
}
}
export default new BlazeFaceDetectionService();

View file

@ -55,7 +55,7 @@ class FaceService {
await syncContext.faceDetectionService.detectFaces(imageBitmap);
console.timeEnd(timerId);
console.log("faceDetections: ", faceDetections?.length);
// log.info('3 TF Memory stats: ',JSON.stringify(tf.memory()));
// TODO: reenable faces filtering based on width
const detectedFaces = faceDetections?.map((detection) => {
return {
@ -150,7 +150,7 @@ class FaceService {
imageBitmap.close();
log.info("[MLService] alignedFaces: ", newMlFile.faces?.length);
// log.info('4 TF Memory stats: ',JSON.stringify(tf.memory()));
return faceImages;
}
@ -187,7 +187,6 @@ class FaceService {
newMlFile.faces.forEach((f, i) => (f.embedding = embeddings[i]));
log.info("[MLService] facesWithEmbeddings: ", newMlFile.faces.length);
// log.info('5 TF Memory stats: ',JSON.stringify(tf.memory()));
}
async syncFileFaceMakeRelativeDetections(
@ -226,11 +225,21 @@ class FaceService {
face.detection,
syncContext.config.faceCrop,
);
face.crop = await storeFaceCrop(
face.id,
faceCrop,
syncContext.config.faceCrop.blobOptions,
);
try {
face.crop = await storeFaceCrop(
face.id,
faceCrop,
syncContext.config.faceCrop.blobOptions,
);
} catch (e) {
// TODO(MR): Temporarily ignoring errors about failing cache puts
// when using a custom scheme in Electron. Needs an alternative
// approach, perhaps OPFS.
console.error(
"Ignoring error when caching face crop, the face crop will not be available",
e,
);
}
const blob = await imageBitmapToBlob(faceCrop.image);
faceCrop.image.close();
return blob;

View file

@ -1,108 +0,0 @@
import log from "@/next/log";
import * as tfjsConverter from "@tensorflow/tfjs-converter";
import * as tf from "@tensorflow/tfjs-core";
import { SCENE_DETECTION_IMAGE_SIZE } from "constants/mlConfig";
import {
ObjectDetection,
SceneDetectionMethod,
SceneDetectionService,
Versioned,
} from "types/machineLearning";
import { resizeToSquare } from "utils/image";
class ImageScene implements SceneDetectionService {
method: Versioned<SceneDetectionMethod>;
private model: tfjsConverter.GraphModel;
private sceneMap: { [key: string]: string };
private ready: Promise<void>;
private workerID: number;
public constructor() {
this.method = {
value: "ImageScene",
version: 1,
};
this.workerID = Math.round(Math.random() * 1000);
}
private async init() {
log.info(`[${this.workerID}]`, "ImageScene init called");
if (this.model) {
return;
}
this.sceneMap = await (
await fetch("/models/imagescene/sceneMap.json")
).json();
this.model = await tfjsConverter.loadGraphModel(
"/models/imagescene/model.json",
);
log.info(
`[${this.workerID}]`,
"loaded ImageScene model",
tf.getBackend(),
);
tf.tidy(() => {
const zeroTensor = tf.zeros([1, 224, 224, 3]);
// warmup the model
this.model.predict(zeroTensor) as tf.Tensor;
});
}
private async getImageSceneModel() {
log.info(`[${this.workerID}]`, "ImageScene getImageSceneModel called");
if (!this.ready) {
this.ready = this.init();
}
await this.ready;
return this.model;
}
async detectScenes(image: ImageBitmap, minScore: number) {
const resized = resizeToSquare(image, SCENE_DETECTION_IMAGE_SIZE);
const model = await this.getImageSceneModel();
const output = tf.tidy(() => {
const tfImage = tf.browser.fromPixels(resized.image);
const input = tf.expandDims(tf.cast(tfImage, "float32"));
const output = model.predict(input) as tf.Tensor;
return output;
});
const data = (await output.data()) as Float32Array;
output.dispose();
const scenes = this.parseSceneDetectionResult(
data,
minScore,
image.width,
image.height,
);
return scenes;
}
private parseSceneDetectionResult(
outputData: Float32Array,
minScore: number,
width: number,
height: number,
): ObjectDetection[] {
const scenes = [];
for (let i = 0; i < outputData.length; i++) {
if (outputData[i] >= minScore) {
scenes.push({
class: this.sceneMap[i.toString()],
score: outputData[i],
bbox: [0, 0, width, height],
});
}
}
return scenes;
}
}
export default new ImageScene();

View file

@ -1,10 +1,10 @@
import { MOBILEFACENET_FACE_SIZE } from "constants/mlConfig";
import {
BlurDetectionMethod,
BlurDetectionService,
Versioned,
} from "types/machineLearning";
import { createGrayscaleIntMatrixFromNormalized2List } from "utils/image";
import { mobileFaceNetFaceSize } from "./mobileFaceNetEmbeddingService";
class LaplacianBlurDetectionService implements BlurDetectionService {
public method: Versioned<BlurDetectionMethod>;
@ -19,7 +19,7 @@ class LaplacianBlurDetectionService implements BlurDetectionService {
public detectBlur(alignedFaces: Float32Array): number[] {
const numFaces = Math.round(
alignedFaces.length /
(MOBILEFACENET_FACE_SIZE * MOBILEFACENET_FACE_SIZE * 3),
(mobileFaceNetFaceSize * mobileFaceNetFaceSize * 3),
);
const blurValues: number[] = [];
for (let i = 0; i < numFaces; i++) {

View file

@ -22,20 +22,14 @@ import {
MLLibraryData,
MLSyncConfig,
MLSyncContext,
ObjectDetectionMethod,
ObjectDetectionService,
SceneDetectionMethod,
SceneDetectionService,
} from "types/machineLearning";
import { logQueueStats } from "utils/machineLearning";
import arcfaceAlignmentService from "./arcfaceAlignmentService";
import arcfaceCropService from "./arcfaceCropService";
import dbscanClusteringService from "./dbscanClusteringService";
import hdbscanClusteringService from "./hdbscanClusteringService";
import imageSceneService from "./imageSceneService";
import laplacianBlurDetectionService from "./laplacianBlurDetectionService";
import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService";
import ssdMobileNetV2Service from "./ssdMobileNetV2Service";
import yoloFaceDetectionService from "./yoloFaceDetectionService";
export class MLFactory {
@ -49,26 +43,6 @@ export class MLFactory {
throw Error("Unknon face detection method: " + method);
}
public static getObjectDetectionService(
method: ObjectDetectionMethod,
): ObjectDetectionService {
if (method === "SSDMobileNetV2") {
return ssdMobileNetV2Service;
}
throw Error("Unknown object detection method: " + method);
}
public static getSceneDetectionService(
method: SceneDetectionMethod,
): SceneDetectionService {
if (method === "ImageScene") {
return imageSceneService;
}
throw Error("Unknown scene detection method: " + method);
}
public static getFaceCropService(method: FaceCropMethod) {
if (method === "ArcFace") {
return arcfaceCropService;
@ -147,15 +121,12 @@ export class LocalMLSyncContext implements MLSyncContext {
public blurDetectionService: BlurDetectionService;
public faceEmbeddingService: FaceEmbeddingService;
public faceClusteringService: ClusteringService;
public objectDetectionService: ObjectDetectionService;
public sceneDetectionService: SceneDetectionService;
public localFilesMap: Map<number, EnteFile>;
public outOfSyncFiles: EnteFile[];
public nSyncedFiles: number;
public nSyncedFaces: number;
public allSyncedFacesMap?: Map<number, Array<Face>>;
public tsne?: any;
public error?: Error;
@ -202,13 +173,6 @@ export class LocalMLSyncContext implements MLSyncContext {
this.config.faceClustering.method,
);
this.objectDetectionService = MLFactory.getObjectDetectionService(
this.config.objectDetection.method,
);
this.sceneDetectionService = MLFactory.getSceneDetectionService(
this.config.sceneDetection.method,
);
this.outOfSyncFiles = [];
this.nSyncedFiles = 0;
this.nSyncedFaces = 0;
@ -239,9 +203,6 @@ export class LocalMLSyncContext implements MLSyncContext {
}
public async dispose() {
// await this.faceDetectionService.dispose();
// await this.faceEmbeddingService.dispose();
this.localFilesMap = undefined;
await this.syncQueue.onIdle();
this.syncQueue.removeAllListeners();

View file

@ -2,9 +2,6 @@ import log from "@/next/log";
import { APPS } from "@ente/shared/apps/constants";
import ComlinkCryptoWorker from "@ente/shared/crypto";
import { CustomError, parseUploadErrorCodes } from "@ente/shared/error";
import "@tensorflow/tfjs-backend-cpu";
import "@tensorflow/tfjs-backend-webgl";
import * as tf from "@tensorflow/tfjs-core";
import { MAX_ML_SYNC_ERROR_COUNT } from "constants/mlConfig";
import downloadManager from "services/download";
import { putEmbedding } from "services/embeddingService";
@ -21,13 +18,10 @@ import { LocalFileMlDataToServerFileMl } from "utils/machineLearning/mldataMappe
import mlIDbStorage from "utils/storage/mlIDbStorage";
import FaceService from "./faceService";
import { MLFactory } from "./machineLearningFactory";
import ObjectService from "./objectService";
import PeopleService from "./peopleService";
import ReaderService from "./readerService";
class MachineLearningService {
private initialized = false;
private localSyncContext: Promise<MLSyncContext>;
private syncContext: Promise<MLSyncContext>;
@ -58,12 +52,6 @@ class MachineLearningService {
await this.syncIndex(syncContext);
}
// tf.engine().endScope();
// if (syncContext.config.tsne) {
// await this.runTSNE(syncContext);
// }
const mlSyncResult: MLSyncResult = {
nOutOfSyncFiles: syncContext.outOfSyncFiles.length,
nSyncedFiles: syncContext.nSyncedFiles,
@ -73,14 +61,10 @@ class MachineLearningService {
.length,
nFaceNoise:
syncContext.mlLibraryData?.faceClusteringResults?.noise.length,
tsne: syncContext.tsne,
error: syncContext.error,
};
// log.info('[MLService] sync results: ', mlSyncResult);
// await syncContext.dispose();
log.info("Final TF Memory stats: ", JSON.stringify(tf.memory()));
return mlSyncResult;
}
@ -183,50 +167,6 @@ class MachineLearningService {
log.info("getOutOfSyncFiles", Date.now() - startTime, "ms");
}
// TODO: optimize, use indexdb indexes, move facecrops to cache to reduce io
// remove, already done
private async getUniqueOutOfSyncFilesNoIdx(
syncContext: MLSyncContext,
files: EnteFile[],
) {
const limit = syncContext.config.batchSize;
const mlVersion = syncContext.config.mlVersion;
const uniqueFiles: Map<number, EnteFile> = new Map<number, EnteFile>();
for (let i = 0; uniqueFiles.size < limit && i < files.length; i++) {
const mlFileData = await this.getMLFileData(files[i].id);
const mlFileVersion = mlFileData?.mlVersion || 0;
if (
!uniqueFiles.has(files[i].id) &&
(!mlFileData?.errorCount || mlFileData.errorCount < 2) &&
(mlFileVersion < mlVersion ||
syncContext.config.imageSource !== mlFileData.imageSource)
) {
uniqueFiles.set(files[i].id, files[i]);
}
}
return [...uniqueFiles.values()];
}
private async getOutOfSyncFilesNoIdx(syncContext: MLSyncContext) {
const existingFilesMap = await this.getLocalFilesMap(syncContext);
// existingFiles.sort(
// (a, b) => b.metadata.creationTime - a.metadata.creationTime
// );
console.time("getUniqueOutOfSyncFiles");
syncContext.outOfSyncFiles = await this.getUniqueOutOfSyncFilesNoIdx(
syncContext,
[...existingFilesMap.values()],
);
log.info("getUniqueOutOfSyncFiles");
log.info(
"Got unique outOfSyncFiles: ",
syncContext.outOfSyncFiles.length,
"for batchSize: ",
syncContext.config.batchSize,
);
}
private async syncFiles(syncContext: MLSyncContext) {
try {
const functions = syncContext.outOfSyncFiles.map(
@ -295,7 +235,6 @@ class MachineLearningService {
userID: number,
enteFile: EnteFile,
localFile?: globalThis.File,
textDetectionTimeoutIndex?: number,
): Promise<MlFileData | Error> {
const syncContext = await this.getLocalSyncContext(token, userID);
@ -304,7 +243,6 @@ class MachineLearningService {
syncContext,
enteFile,
localFile,
textDetectionTimeoutIndex,
);
if (syncContext.nSyncedFiles >= syncContext.config.batchSize) {
@ -322,19 +260,15 @@ class MachineLearningService {
syncContext: MLSyncContext,
enteFile: EnteFile,
localFile?: globalThis.File,
textDetectionTimeoutIndex?: number,
): Promise<MlFileData> {
try {
console.log(
"Start index for ",
enteFile.title ?? "no title",
enteFile.id,
`Indexing ${enteFile.title ?? "<untitled>"} ${enteFile.id}`,
);
const mlFileData = await this.syncFile(
syncContext,
enteFile,
localFile,
textDetectionTimeoutIndex,
);
syncContext.nSyncedFaces += mlFileData.faces?.length || 0;
syncContext.nSyncedFiles += 1;
@ -363,16 +297,8 @@ class MachineLearningService {
throw error;
}
await this.persistMLFileSyncError(syncContext, enteFile, error);
await this.persistMLFileSyncError(enteFile, error);
syncContext.nSyncedFiles += 1;
} finally {
console.log(
"done index for ",
enteFile.title ?? "no title",
enteFile.id,
);
// addLogLine('TF Memory stats: ', JSON.stringify(tf.memory()));
log.info("TF Memory stats: ", JSON.stringify(tf.memory()));
}
}
@ -380,8 +306,6 @@ class MachineLearningService {
syncContext: MLSyncContext,
enteFile: EnteFile,
localFile?: globalThis.File,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
textDetectionTimeoutIndex?: number,
) {
console.log("Syncing for file" + enteFile.title);
const fileContext: MLSyncFileContext = { enteFile, localFile };
@ -406,34 +330,17 @@ class MachineLearningService {
await ReaderService.getImageBitmap(syncContext, fileContext);
await Promise.all([
this.syncFileAnalyzeFaces(syncContext, fileContext),
// ObjectService.syncFileObjectDetections(
// syncContext,
// fileContext
// ),
// TextService.syncFileTextDetections(
// syncContext,
// fileContext,
// textDetectionTimeoutIndex
// ),
]);
newMlFile.errorCount = 0;
newMlFile.lastErrorMessage = undefined;
await this.persistOnServer(newMlFile, enteFile);
await this.persistMLFileData(syncContext, newMlFile);
await mlIDbStorage.putFile(newMlFile);
} catch (e) {
log.error("ml detection failed", e);
newMlFile.mlVersion = oldMlFile.mlVersion;
throw e;
} finally {
fileContext.tfImage && fileContext.tfImage.dispose();
fileContext.imageBitmap && fileContext.imageBitmap.close();
// log.info('8 TF Memory stats: ',JSON.stringify(tf.memory()));
// TODO: enable once faceId changes go in
// await removeOldFaceCrops(
// fileContext.oldMlFile,
// fileContext.newMlFile
// );
}
return newMlFile;
@ -458,38 +365,11 @@ class MachineLearningService {
log.info("putEmbedding response: ", res);
}
public async init() {
if (this.initialized) {
return;
}
await tf.ready();
log.info("01 TF Memory stats: ", JSON.stringify(tf.memory()));
this.initialized = true;
}
public async dispose() {
this.initialized = false;
}
private async getMLFileData(fileId: number) {
return mlIDbStorage.getFile(fileId);
}
private async persistMLFileData(
syncContext: MLSyncContext,
mlFileData: MlFileData,
) {
mlIDbStorage.putFile(mlFileData);
}
private async persistMLFileSyncError(
syncContext: MLSyncContext,
enteFile: EnteFile,
e: Error,
) {
private async persistMLFileSyncError(enteFile: EnteFile, e: Error) {
try {
await mlIDbStorage.upsertFileInTx(enteFile.id, (mlFileData) => {
if (!mlFileData) {
@ -522,8 +402,6 @@ class MachineLearningService {
await PeopleService.syncPeopleIndex(syncContext);
await ObjectService.syncThingsIndex(syncContext);
await this.persistMLLibraryData(syncContext);
}

View file

@ -1,9 +0,0 @@
import { JobResult } from "types/common/job";
import { MLSyncResult } from "types/machineLearning";
import { SimpleJob } from "utils/common/job";
export interface MLSyncJobResult extends JobResult {
mlSyncResult: MLSyncResult;
}
export class MLSyncJob extends SimpleJob<MLSyncJobResult> {}

View file

@ -5,18 +5,26 @@ import { getToken, getUserID } from "@ente/shared/storage/localStorage/helpers";
import { FILE_TYPE } from "constants/file";
import debounce from "debounce";
import PQueue from "p-queue";
import { JobResult } from "types/common/job";
import { EnteFile } from "types/file";
import { MLSyncResult } from "types/machineLearning";
import { getDedicatedMLWorker } from "utils/comlink/ComlinkMLWorker";
import { SimpleJob } from "utils/common/job";
import { logQueueStats } from "utils/machineLearning";
import { getMLSyncJobConfig } from "utils/machineLearning/config";
import mlIDbStorage from "utils/storage/mlIDbStorage";
import { DedicatedMLWorker } from "worker/ml.worker";
import { MLSyncJob, MLSyncJobResult } from "./mlSyncJob";
const LIVE_SYNC_IDLE_DEBOUNCE_SEC = 30;
const LIVE_SYNC_QUEUE_TIMEOUT_SEC = 300;
const LOCAL_FILES_UPDATED_DEBOUNCE_SEC = 30;
export interface MLSyncJobResult extends JobResult {
mlSyncResult: MLSyncResult;
}
export class MLSyncJob extends SimpleJob<MLSyncJobResult> {}
class MLWorkManager {
private mlSyncJob: MLSyncJob;
private syncJobWorker: ComlinkWorker<typeof DedicatedMLWorker>;
@ -178,8 +186,7 @@ class MLWorkManager {
return mlWorker.syncLocalFile(token, userID, enteFile, localFile);
});
// @ts-expect-error "TODO: Fix ML related type errors"
if ("message" in result) {
if (result instanceof Error) {
// TODO: redirect/refresh to gallery in case of session_expired
// may not be required as uploader should anyways take care of this
console.error("Error while syncing local file: ", result);

View file

@ -1,11 +1,4 @@
import log from "@/next/log";
import * as tf from "@tensorflow/tfjs-core";
import {
MOBILEFACENET_EMBEDDING_SIZE,
MOBILEFACENET_FACE_SIZE,
} from "constants/mlConfig";
// import { TFLiteModel } from "@tensorflow/tfjs-tflite";
// import PQueue from "p-queue";
import { workerBridge } from "@/next/worker/worker-bridge";
import {
FaceEmbedding,
FaceEmbeddingMethod,
@ -13,184 +6,36 @@ import {
Versioned,
} from "types/machineLearning";
// TODO(MR): onnx-yolo
// import * as ort from "onnxruntime-web";
// import { env } from "onnxruntime-web";
const ort: any = {};
export const mobileFaceNetFaceSize = 112;
import {
clamp,
getPixelBilinear,
normalizePixelBetweenMinus1And1,
} from "utils/image";
// TODO(MR): onnx-yolo
// env.wasm.wasmPaths = "/js/onnx/";
class MobileFaceNetEmbeddingService implements FaceEmbeddingService {
// TODO(MR): onnx-yolo
// private onnxInferenceSession?: ort.InferenceSession;
private onnxInferenceSession?: any;
public method: Versioned<FaceEmbeddingMethod>;
public faceSize: number;
public constructor(faceSize: number = MOBILEFACENET_FACE_SIZE) {
public constructor() {
this.method = {
value: "MobileFaceNet",
version: 2,
};
this.faceSize = faceSize;
// TODO: set timeout
}
private async initOnnx() {
console.log("start ort mobilefacenet");
this.onnxInferenceSession = await ort.InferenceSession.create(
"/models/mobilefacenet/mobilefacenet_opset15.onnx",
);
const faceBatchSize = 1;
const data = new Float32Array(
faceBatchSize * 3 * this.faceSize * this.faceSize,
);
const inputTensor = new ort.Tensor("float32", data, [
faceBatchSize,
this.faceSize,
this.faceSize,
3,
]);
// TODO(MR): onnx-yolo
// const feeds: Record<string, ort.Tensor> = {};
const feeds: Record<string, any> = {};
const name = this.onnxInferenceSession.inputNames[0];
feeds[name] = inputTensor;
await this.onnxInferenceSession.run(feeds);
console.log("start end mobilefacenet");
}
private async getOnnxInferenceSession() {
if (!this.onnxInferenceSession) {
await this.initOnnx();
}
return this.onnxInferenceSession;
}
private preprocessImageBitmapToFloat32(
imageBitmap: ImageBitmap,
requiredWidth: number = this.faceSize,
requiredHeight: number = this.faceSize,
maintainAspectRatio: boolean = true,
normFunction: (
pixelValue: number,
) => number = normalizePixelBetweenMinus1And1,
) {
// Create an OffscreenCanvas and set its size
const offscreenCanvas = new OffscreenCanvas(
imageBitmap.width,
imageBitmap.height,
);
const ctx = offscreenCanvas.getContext("2d");
ctx.drawImage(imageBitmap, 0, 0, imageBitmap.width, imageBitmap.height);
const imageData = ctx.getImageData(
0,
0,
imageBitmap.width,
imageBitmap.height,
);
const pixelData = imageData.data;
let scaleW = requiredWidth / imageBitmap.width;
let scaleH = requiredHeight / imageBitmap.height;
if (maintainAspectRatio) {
const scale = Math.min(
requiredWidth / imageBitmap.width,
requiredHeight / imageBitmap.height,
);
scaleW = scale;
scaleH = scale;
}
const scaledWidth = clamp(
Math.round(imageBitmap.width * scaleW),
0,
requiredWidth,
);
const scaledHeight = clamp(
Math.round(imageBitmap.height * scaleH),
0,
requiredHeight,
);
const processedImage = new Float32Array(
1 * requiredWidth * requiredHeight * 3,
);
log.info("loaded mobileFaceNetModel: ", tf.getBackend());
// Populate the Float32Array with normalized pixel values
for (let h = 0; h < requiredHeight; h++) {
for (let w = 0; w < requiredWidth; w++) {
let pixel: {
r: number;
g: number;
b: number;
};
if (w >= scaledWidth || h >= scaledHeight) {
pixel = { r: 114, g: 114, b: 114 };
} else {
pixel = getPixelBilinear(
w / scaleW,
h / scaleH,
pixelData,
imageBitmap.width,
imageBitmap.height,
);
}
const pixelIndex = 3 * (h * requiredWidth + w);
processedImage[pixelIndex] = normFunction(pixel.r);
processedImage[pixelIndex + 1] = normFunction(pixel.g);
processedImage[pixelIndex + 2] = normFunction(pixel.b);
}
}
return processedImage;
this.faceSize = mobileFaceNetFaceSize;
}
public async getFaceEmbeddings(
faceData: Float32Array,
): Promise<Array<FaceEmbedding>> {
const inputTensor = new ort.Tensor("float32", faceData, [
Math.round(faceData.length / (this.faceSize * this.faceSize * 3)),
this.faceSize,
this.faceSize,
3,
]);
// TODO(MR): onnx-yolo
// const feeds: Record<string, ort.Tensor> = {};
const feeds: Record<string, any> = {};
feeds["img_inputs"] = inputTensor;
const inferenceSession = await this.getOnnxInferenceSession();
// TODO(MR): onnx-yolo
// const runout: ort.InferenceSession.OnnxValueMapType =
const runout: any = await inferenceSession.run(feeds);
// const test = runout.embeddings;
// const test2 = test.cpuData;
const outputData = runout.embeddings["cpuData"] as Float32Array;
const outputData = await workerBridge.faceEmbedding(faceData);
const embeddingSize = 192;
const embeddings = new Array<FaceEmbedding>(
outputData.length / MOBILEFACENET_EMBEDDING_SIZE,
outputData.length / embeddingSize,
);
for (let i = 0; i < embeddings.length; i++) {
embeddings[i] = new Float32Array(
outputData.slice(
i * MOBILEFACENET_EMBEDDING_SIZE,
(i + 1) * MOBILEFACENET_EMBEDDING_SIZE,
),
outputData.slice(i * embeddingSize, (i + 1) * embeddingSize),
);
}
return embeddings;
}
public async dispose() {
const inferenceSession = await this.getOnnxInferenceSession();
inferenceSession?.release();
this.onnxInferenceSession = undefined;
}
}
export default new MobileFaceNetEmbeddingService();

View file

@ -1,146 +0,0 @@
import log from "@/next/log";
import {
DetectedObject,
MLSyncContext,
MLSyncFileContext,
Thing,
} from "types/machineLearning";
import {
getAllObjectsFromMap,
getObjectId,
isDifferentOrOld,
} from "utils/machineLearning";
import mlIDbStorage from "utils/storage/mlIDbStorage";
import ReaderService from "./readerService";
class ObjectService {
async syncFileObjectDetections(
syncContext: MLSyncContext,
fileContext: MLSyncFileContext,
) {
const startTime = Date.now();
const { oldMlFile, newMlFile } = fileContext;
if (
!isDifferentOrOld(
oldMlFile?.objectDetectionMethod,
syncContext.objectDetectionService.method,
) &&
!isDifferentOrOld(
oldMlFile?.sceneDetectionMethod,
syncContext.sceneDetectionService.method,
) &&
oldMlFile?.imageSource === syncContext.config.imageSource
) {
newMlFile.objects = oldMlFile?.objects;
newMlFile.imageSource = oldMlFile.imageSource;
newMlFile.imageDimensions = oldMlFile.imageDimensions;
newMlFile.objectDetectionMethod = oldMlFile.objectDetectionMethod;
newMlFile.sceneDetectionMethod = oldMlFile.sceneDetectionMethod;
return;
}
newMlFile.objectDetectionMethod =
syncContext.objectDetectionService.method;
newMlFile.sceneDetectionMethod =
syncContext.sceneDetectionService.method;
fileContext.newDetection = true;
const imageBitmap = await ReaderService.getImageBitmap(
syncContext,
fileContext,
);
const objectDetections =
await syncContext.objectDetectionService.detectObjects(
imageBitmap,
syncContext.config.objectDetection.maxNumBoxes,
syncContext.config.objectDetection.minScore,
);
objectDetections.push(
...(await syncContext.sceneDetectionService.detectScenes(
imageBitmap,
syncContext.config.sceneDetection.minScore,
)),
);
// log.info('3 TF Memory stats: ',JSON.stringify(tf.memory()));
// TODO: reenable faces filtering based on width
const detectedObjects = objectDetections?.map((detection) => {
return {
fileID: fileContext.enteFile.id,
detection,
} as DetectedObject;
});
newMlFile.objects = detectedObjects?.map((detectedObject) => ({
...detectedObject,
id: getObjectId(detectedObject, newMlFile.imageDimensions),
className: detectedObject.detection.class,
}));
// ?.filter((f) =>
// f.box.width > syncContext.config.faceDetection.minFaceSize
// );
log.info(
`object detection time taken ${fileContext.enteFile.id}`,
Date.now() - startTime,
"ms",
);
log.info("[MLService] Detected Objects: ", newMlFile.objects?.length);
}
async getAllSyncedObjectsMap(syncContext: MLSyncContext) {
if (syncContext.allSyncedObjectsMap) {
return syncContext.allSyncedObjectsMap;
}
syncContext.allSyncedObjectsMap = await mlIDbStorage.getAllObjectsMap();
return syncContext.allSyncedObjectsMap;
}
public async clusterThings(syncContext: MLSyncContext): Promise<Thing[]> {
const allObjectsMap = await this.getAllSyncedObjectsMap(syncContext);
const allObjects = getAllObjectsFromMap(allObjectsMap);
const objectClusters = new Map<string, number[]>();
allObjects.map((object) => {
if (!objectClusters.has(object.detection.class)) {
objectClusters.set(object.detection.class, []);
}
const objectsInCluster = objectClusters.get(object.detection.class);
objectsInCluster.push(object.fileID);
});
return [...objectClusters.entries()].map(([className, files], id) => ({
id,
name: className,
files,
}));
}
async syncThingsIndex(syncContext: MLSyncContext) {
const filesVersion = await mlIDbStorage.getIndexVersion("files");
log.info("things", await mlIDbStorage.getIndexVersion("things"));
if (filesVersion <= (await mlIDbStorage.getIndexVersion("things"))) {
log.info(
"[MLService] Skipping people index as already synced to latest version",
);
return;
}
const things = await this.clusterThings(syncContext);
if (!things || things.length < 1) {
return;
}
await mlIDbStorage.clearAllThings();
for (const thing of things) {
await mlIDbStorage.putThing(thing);
}
await mlIDbStorage.setIndexVersion("things", filesVersion);
}
async getAllThings() {
return await mlIDbStorage.getAllThings();
}
}
export default new ObjectService();

View file

@ -16,7 +16,6 @@ class ReaderService {
if (fileContext.imageBitmap) {
return fileContext.imageBitmap;
}
// log.info('1 TF Memory stats: ',JSON.stringify(tf.memory()));
if (fileContext.localFile) {
if (
fileContext.enteFile.metadata.fileType !== FILE_TYPE.IMAGE
@ -47,7 +46,6 @@ class ReaderService {
fileContext.newMlFile.imageSource = syncContext.config.imageSource;
const { width, height } = fileContext.imageBitmap;
fileContext.newMlFile.imageDimensions = { width, height };
// log.info('2 TF Memory stats: ',JSON.stringify(tf.memory()));
return fileContext.imageBitmap;
} catch (e) {

View file

@ -1,66 +0,0 @@
import log from "@/next/log";
import * as tf from "@tensorflow/tfjs-core";
import {
ObjectDetection,
ObjectDetectionMethod,
ObjectDetectionService,
Versioned,
} from "types/machineLearning";
import * as SSDMobileNet from "@tensorflow-models/coco-ssd";
import { OBJECT_DETECTION_IMAGE_SIZE } from "constants/mlConfig";
import { resizeToSquare } from "utils/image";
class SSDMobileNetV2 implements ObjectDetectionService {
private ssdMobileNetV2Model: SSDMobileNet.ObjectDetection;
public method: Versioned<ObjectDetectionMethod>;
private ready: Promise<void>;
public constructor() {
this.method = {
value: "SSDMobileNetV2",
version: 1,
};
}
private async init() {
this.ssdMobileNetV2Model = await SSDMobileNet.load({
base: "mobilenet_v2",
modelUrl: "/models/ssdmobilenet/model.json",
});
log.info("loaded ssdMobileNetV2Model", tf.getBackend());
}
private async getSSDMobileNetV2Model() {
if (!this.ready) {
this.ready = this.init();
}
await this.ready;
return this.ssdMobileNetV2Model;
}
public async detectObjects(
image: ImageBitmap,
maxNumberBoxes: number,
minScore: number,
): Promise<ObjectDetection[]> {
const ssdMobileNetV2Model = await this.getSSDMobileNetV2Model();
const resized = resizeToSquare(image, OBJECT_DETECTION_IMAGE_SIZE);
const tfImage = tf.browser.fromPixels(resized.image);
const detections = await ssdMobileNetV2Model.detect(
tfImage,
maxNumberBoxes,
minScore,
);
tfImage.dispose();
return detections;
}
public async dispose() {
const ssdMobileNetV2Model = await this.getSSDMobileNetV2Model();
ssdMobileNetV2Model?.dispose();
this.ssdMobileNetV2Model = null;
}
}
export default new SSDMobileNetV2();

View file

@ -1,7 +1,12 @@
import { workerBridge } from "@/next/worker/worker-bridge";
import { euclidean } from "hdbscan";
import {
BLAZEFACE_FACE_SIZE,
MAX_FACE_DISTANCE_PERCENT,
} from "constants/mlConfig";
Matrix,
applyToPoint,
compose,
scale,
translate,
} from "transformation-matrix";
import { Dimensions } from "types/image";
import {
FaceDetection,
@ -15,57 +20,50 @@ import {
normalizePixelBetween0And1,
} from "utils/image";
import { newBox } from "utils/machineLearning";
import { removeDuplicateDetections } from "utils/machineLearning/faceDetection";
import {
computeTransformToBox,
transformBox,
transformPoints,
} from "utils/machineLearning/transform";
import { Box, Point } from "../../../thirdparty/face-api/classes";
// TODO(MR): onnx-yolo
// import * as ort from "onnxruntime-web";
// import { env } from "onnxruntime-web";
const ort: any = {};
// TODO(MR): onnx-yolo
// env.wasm.wasmPaths = "/js/onnx/";
class YoloFaceDetectionService implements FaceDetectionService {
// TODO(MR): onnx-yolo
// private onnxInferenceSession?: ort.InferenceSession;
private onnxInferenceSession?: any;
public method: Versioned<FaceDetectionMethod>;
private desiredFaceSize;
public constructor(desiredFaceSize: number = BLAZEFACE_FACE_SIZE) {
public constructor() {
this.method = {
value: "YoloFace",
version: 1,
};
this.desiredFaceSize = desiredFaceSize;
}
private async initOnnx() {
console.log("start ort");
this.onnxInferenceSession = await ort.InferenceSession.create(
"/models/yoloface/yolov5s_face_640_640_dynamic.onnx",
public async detectFaces(
imageBitmap: ImageBitmap,
): Promise<Array<FaceDetection>> {
const maxFaceDistancePercent = Math.sqrt(2) / 100;
const maxFaceDistance = imageBitmap.width * maxFaceDistancePercent;
const preprocessResult =
this.preprocessImageBitmapToFloat32ChannelsFirst(
imageBitmap,
640,
640,
);
const data = preprocessResult.data;
const resized = preprocessResult.newSize;
const outputData = await workerBridge.detectFaces(data);
const faces = this.getFacesFromYoloOutput(
outputData as Float32Array,
0.7,
);
const data = new Float32Array(1 * 3 * 640 * 640);
const inputTensor = new ort.Tensor("float32", data, [1, 3, 640, 640]);
// TODO(MR): onnx-yolo
// const feeds: Record<string, ort.Tensor> = {};
const feeds: Record<string, any> = {};
const name = this.onnxInferenceSession.inputNames[0];
feeds[name] = inputTensor;
await this.onnxInferenceSession.run(feeds);
console.log("start end");
}
private async getOnnxInferenceSession() {
if (!this.onnxInferenceSession) {
await this.initOnnx();
}
return this.onnxInferenceSession;
const inBox = newBox(0, 0, resized.width, resized.height);
const toBox = newBox(0, 0, imageBitmap.width, imageBitmap.height);
const transform = computeTransformToBox(inBox, toBox);
const faceDetections: Array<FaceDetection> = faces?.map((f) => {
const box = transformBox(f.box, transform);
const normLandmarks = f.landmarks;
const landmarks = transformPoints(normLandmarks, transform);
return {
box,
landmarks,
probability: f.probability as number,
} as FaceDetection;
});
return removeDuplicateDetections(faceDetections, maxFaceDistance);
}
private preprocessImageBitmapToFloat32ChannelsFirst(
@ -160,43 +158,6 @@ class YoloFaceDetectionService implements FaceDetectionService {
};
}
/**
* @deprecated The method should not be used
*/
private imageBitmapToTensorData(imageBitmap) {
// Create an OffscreenCanvas and set its size
const offscreenCanvas = new OffscreenCanvas(
imageBitmap.width,
imageBitmap.height,
);
const ctx = offscreenCanvas.getContext("2d");
ctx.drawImage(imageBitmap, 0, 0, imageBitmap.width, imageBitmap.height);
const imageData = ctx.getImageData(
0,
0,
imageBitmap.width,
imageBitmap.height,
);
const pixelData = imageData.data;
const data = new Float32Array(
1 * 3 * imageBitmap.width * imageBitmap.height,
);
// Populate the Float32Array with normalized pixel values
for (let i = 0; i < pixelData.length; i += 4) {
// Normalize pixel values to the range [0, 1]
data[i / 4] = pixelData[i] / 255.0; // Red channel
data[i / 4 + imageBitmap.width * imageBitmap.height] =
pixelData[i + 1] / 255.0; // Green channel
data[i / 4 + 2 * imageBitmap.width * imageBitmap.height] =
pixelData[i + 2] / 255.0; // Blue channel
}
return {
data: data,
shape: [1, 3, imageBitmap.width, imageBitmap.height],
};
}
// The rowOutput is a Float32Array of shape [25200, 16], where each row represents a bounding box.
private getFacesFromYoloOutput(
rowOutput: Float32Array,
@ -274,58 +235,98 @@ class YoloFaceDetectionService implements FaceDetectionService {
probability: faceDetection.probability,
};
}
private async estimateOnnx(imageBitmap: ImageBitmap) {
const maxFaceDistance = imageBitmap.width * MAX_FACE_DISTANCE_PERCENT;
const preprocessResult =
this.preprocessImageBitmapToFloat32ChannelsFirst(
imageBitmap,
640,
640,
);
const data = preprocessResult.data;
const resized = preprocessResult.newSize;
const inputTensor = new ort.Tensor("float32", data, [1, 3, 640, 640]);
// TODO(MR): onnx-yolo
// const feeds: Record<string, ort.Tensor> = {};
const feeds: Record<string, any> = {};
feeds["input"] = inputTensor;
const inferenceSession = await this.getOnnxInferenceSession();
const runout = await inferenceSession.run(feeds);
const outputData = runout.output.data;
const faces = this.getFacesFromYoloOutput(
outputData as Float32Array,
0.7,
);
const inBox = newBox(0, 0, resized.width, resized.height);
const toBox = newBox(0, 0, imageBitmap.width, imageBitmap.height);
const transform = computeTransformToBox(inBox, toBox);
const faceDetections: Array<FaceDetection> = faces?.map((f) => {
const box = transformBox(f.box, transform);
const normLandmarks = f.landmarks;
const landmarks = transformPoints(normLandmarks, transform);
return {
box,
landmarks,
probability: f.probability as number,
} as FaceDetection;
});
return removeDuplicateDetections(faceDetections, maxFaceDistance);
}
public async detectFaces(
imageBitmap: ImageBitmap,
): Promise<Array<FaceDetection>> {
// measure time taken
const facesFromOnnx = await this.estimateOnnx(imageBitmap);
return facesFromOnnx;
}
public async dispose() {
const inferenceSession = await this.getOnnxInferenceSession();
inferenceSession?.release();
this.onnxInferenceSession = undefined;
}
}
export default new YoloFaceDetectionService();
/**
* Removes duplicate face detections from an array of detections.
*
* This function sorts the detections by their probability in descending order, then iterates over them.
* For each detection, it calculates the Euclidean distance to all other detections.
* If the distance is less than or equal to the specified threshold (`withinDistance`), the other detection is considered a duplicate and is removed.
*
* @param detections - An array of face detections to remove duplicates from.
* @param withinDistance - The maximum Euclidean distance between two detections for them to be considered duplicates.
*
* @returns An array of face detections with duplicates removed.
*/
function removeDuplicateDetections(
detections: Array<FaceDetection>,
withinDistance: number,
) {
// console.time('removeDuplicates');
detections.sort((a, b) => b.probability - a.probability);
const isSelected = new Map<number, boolean>();
for (let i = 0; i < detections.length; i++) {
if (isSelected.get(i) === false) {
continue;
}
isSelected.set(i, true);
for (let j = i + 1; j < detections.length; j++) {
if (isSelected.get(j) === false) {
continue;
}
const centeri = getDetectionCenter(detections[i]);
const centerj = getDetectionCenter(detections[j]);
const dist = euclidean(
[centeri.x, centeri.y],
[centerj.x, centerj.y],
);
if (dist <= withinDistance) {
isSelected.set(j, false);
}
}
}
const uniques: Array<FaceDetection> = [];
for (let i = 0; i < detections.length; i++) {
isSelected.get(i) && uniques.push(detections[i]);
}
// console.timeEnd('removeDuplicates');
return uniques;
}
function getDetectionCenter(detection: FaceDetection) {
const center = new Point(0, 0);
// TODO: first 4 landmarks is applicable to blazeface only
// this needs to consider eyes, nose and mouth landmarks to take center
detection.landmarks?.slice(0, 4).forEach((p) => {
center.x += p.x;
center.y += p.y;
});
return center.div({ x: 4, y: 4 });
}
function computeTransformToBox(inBox: Box, toBox: Box): Matrix {
return compose(
translate(toBox.x, toBox.y),
scale(toBox.width / inBox.width, toBox.height / inBox.height),
);
}
function transformPoint(point: Point, transform: Matrix) {
const txdPoint = applyToPoint(transform, point);
return new Point(txdPoint.x, txdPoint.y);
}
function transformPoints(points: Point[], transform: Matrix) {
return points?.map((p) => transformPoint(p, transform));
}
function transformBox(box: Box, transform: Matrix) {
const topLeft = transformPoint(box.topLeft, transform);
const bottomRight = transformPoint(box.bottomRight, transform);
return newBoxFromPoints(topLeft.x, topLeft.y, bottomRight.x, bottomRight.y);
}
function newBoxFromPoints(
left: number,
top: number,
right: number,
bottom: number,
) {
return new Box({ left, top, right, bottom });
}

View file

@ -6,7 +6,7 @@ import { t } from "i18next";
import { Collection } from "types/collection";
import { EntityType, LocationTag, LocationTagData } from "types/entity";
import { EnteFile } from "types/file";
import { Person, Thing } from "types/machineLearning";
import { Person } from "types/machineLearning";
import {
ClipSearchScores,
DateValue,
@ -25,7 +25,6 @@ import { clipService, computeClipMatchScore } from "./clip-service";
import { getLocalEmbeddings } from "./embeddingService";
import { getLatestEntities } from "./entityService";
import locationSearchService, { City } from "./locationSearchService";
import ObjectService from "./machineLearning/objectService";
const DIGITS = new Set(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]);
@ -56,7 +55,6 @@ export const getAutoCompleteSuggestions =
getFileNameSuggestion(searchPhrase, files),
getFileCaptionSuggestion(searchPhrase, files),
...(await getLocationSuggestions(searchPhrase)),
...(await getThingSuggestion(searchPhrase)),
].filter((suggestion) => !!suggestion);
return convertSuggestionsToOptions(suggestions);
@ -289,19 +287,6 @@ async function getLocationSuggestions(searchPhrase: string) {
return [...locationTagSuggestions, ...citySearchSuggestions];
}
async function getThingSuggestion(searchPhrase: string): Promise<Suggestion[]> {
const thingResults = await searchThing(searchPhrase);
return thingResults.map(
(searchResult) =>
({
type: SuggestionType.THING,
value: searchResult,
label: searchResult.name,
}) as Suggestion,
);
}
async function getClipSuggestion(searchPhrase: string): Promise<Suggestion> {
try {
if (!clipService.isPlatformSupported()) {
@ -389,13 +374,6 @@ async function searchLocationTag(searchPhrase: string): Promise<LocationTag[]> {
return matchedLocationTags;
}
async function searchThing(searchPhrase: string) {
const things = await ObjectService.getAllThings();
return things.filter((thing) =>
thing.name.toLocaleLowerCase().includes(searchPhrase),
);
}
async function searchClip(searchPhrase: string): Promise<ClipSearchScores> {
const imageEmbeddings = await getLocalEmbeddings();
const textEmbedding = await clipService.getTextEmbedding(searchPhrase);
@ -445,10 +423,9 @@ function convertSuggestionToSearchQuery(option: Suggestion): Search {
case SuggestionType.PERSON:
return { person: option.value as Person };
case SuggestionType.THING:
return { thing: option.value as Thing };
case SuggestionType.FILE_TYPE:
return { fileType: option.value as FILE_TYPE };
case SuggestionType.CLIP:
return { clip: option.value as ClipSearchScores };
}

View file

@ -1,16 +0,0 @@
export const ARCFACE_LANDMARKS = [
[38.2946, 51.6963],
[73.5318, 51.5014],
[56.0252, 71.7366],
[56.1396, 92.2848],
] as Array<[number, number]>;
export const ARCFACE_LANDMARKS_FACE_SIZE = 112;
export const ARC_FACE_5_LANDMARKS = [
[38.2946, 51.6963],
[73.5318, 51.5014],
[56.0252, 71.7366],
[41.5493, 92.3655],
[70.7299, 92.2041],
] as Array<[number, number]>;

View file

@ -1,5 +1,3 @@
import * as tf from "@tensorflow/tfjs-core";
import { DebugInfo } from "hdbscan";
import PQueue from "p-queue";
import { EnteFile } from "types/file";
@ -12,21 +10,9 @@ export interface MLSyncResult {
nSyncedFaces: number;
nFaceClusters: number;
nFaceNoise: number;
tsne?: any;
error?: Error;
}
export interface DebugFace {
fileId: string;
// face: FaceApiResult;
face: AlignedFace;
embedding: FaceEmbedding;
faceImage: FaceImage;
}
export declare type FaceImage = Array<Array<Array<number>>>;
export declare type FaceImageBlob = Blob;
export declare type FaceDescriptor = Float32Array;
export declare type Cluster = Array<number>;
@ -59,23 +45,13 @@ export declare type Landmark = Point;
export declare type ImageType = "Original" | "Preview";
export declare type FaceDetectionMethod =
| "BlazeFace"
| "FaceApiSSD"
| "YoloFace";
export declare type ObjectDetectionMethod = "SSDMobileNetV2";
export declare type SceneDetectionMethod = "ImageScene";
export declare type FaceDetectionMethod = "YoloFace";
export declare type FaceCropMethod = "ArcFace";
export declare type FaceAlignmentMethod =
| "ArcFace"
| "FaceApiDlib"
| "RotatedFaceApiDlib";
export declare type FaceAlignmentMethod = "ArcFace";
export declare type FaceEmbeddingMethod = "MobileFaceNet" | "FaceApiDlib";
export declare type FaceEmbeddingMethod = "MobileFaceNet";
export declare type BlurDetectionMethod = "Laplacian";
@ -155,45 +131,15 @@ export interface Person {
displayImageUrl?: string;
}
export interface ObjectDetection {
bbox: [number, number, number, number];
class: string;
score: number;
}
export interface DetectedObject {
fileID: number;
detection: ObjectDetection;
}
export interface RealWorldObject extends DetectedObject {
id: string;
className: string;
}
export interface Thing {
id: number;
name: string;
files: Array<number>;
}
export interface WordGroup {
word: string;
files: Array<number>;
}
export interface MlFileData {
fileId: number;
faces?: Face[];
objects?: RealWorldObject[];
imageSource?: ImageType;
imageDimensions?: Dimensions;
faceDetectionMethod?: Versioned<FaceDetectionMethod>;
faceCropMethod?: Versioned<FaceCropMethod>;
faceAlignmentMethod?: Versioned<FaceAlignmentMethod>;
faceEmbeddingMethod?: Versioned<FaceEmbeddingMethod>;
objectDetectionMethod?: Versioned<ObjectDetectionMethod>;
sceneDetectionMethod?: Versioned<SceneDetectionMethod>;
mlVersion: number;
errorCount: number;
lastErrorMessage?: string;
@ -203,17 +149,6 @@ export interface FaceDetectionConfig {
method: FaceDetectionMethod;
}
export interface ObjectDetectionConfig {
method: ObjectDetectionMethod;
maxNumBoxes: number;
minScore: number;
}
export interface SceneDetectionConfig {
method: SceneDetectionMethod;
minScore: number;
}
export interface FaceCropConfig {
enabled: boolean;
method: FaceCropMethod;
@ -263,9 +198,6 @@ export interface MLSyncConfig {
blurDetection: BlurDetectionConfig;
faceEmbedding: FaceEmbeddingConfig;
faceClustering: FaceClusteringConfig;
objectDetection: ObjectDetectionConfig;
sceneDetection: SceneDetectionConfig;
tsne?: TSNEConfig;
mlVersion: number;
}
@ -285,16 +217,12 @@ export interface MLSyncContext {
faceEmbeddingService: FaceEmbeddingService;
blurDetectionService: BlurDetectionService;
faceClusteringService: ClusteringService;
objectDetectionService: ObjectDetectionService;
sceneDetectionService: SceneDetectionService;
localFilesMap: Map<number, EnteFile>;
outOfSyncFiles: EnteFile[];
nSyncedFiles: number;
nSyncedFaces: number;
allSyncedFacesMap?: Map<number, Array<Face>>;
allSyncedObjectsMap?: Map<number, Array<RealWorldObject>>;
tsne?: any;
error?: Error;
@ -314,7 +242,6 @@ export interface MLSyncFileContext {
oldMlFile?: MlFileData;
newMlFile?: MlFileData;
tfImage?: tf.Tensor3D;
imageBitmap?: ImageBitmap;
newDetection?: boolean;
@ -331,33 +258,12 @@ export declare type MLIndex = "files" | "people";
export interface FaceDetectionService {
method: Versioned<FaceDetectionMethod>;
// init(): Promise<void>;
detectFaces(image: ImageBitmap): Promise<Array<FaceDetection>>;
getRelativeDetection(
faceDetection: FaceDetection,
imageDimensions: Dimensions,
): FaceDetection;
dispose(): Promise<void>;
}
export interface ObjectDetectionService {
method: Versioned<ObjectDetectionMethod>;
// init(): Promise<void>;
detectObjects(
image: ImageBitmap,
maxNumBoxes: number,
minScore: number,
): Promise<ObjectDetection[]>;
dispose(): Promise<void>;
}
export interface SceneDetectionService {
method: Versioned<SceneDetectionMethod>;
// init(): Promise<void>;
detectScenes(
image: ImageBitmap,
minScore: number,
): Promise<ObjectDetection[]>;
}
export interface FaceCropService {
@ -378,9 +284,8 @@ export interface FaceAlignmentService {
export interface FaceEmbeddingService {
method: Versioned<FaceEmbeddingMethod>;
faceSize: number;
// init(): Promise<void>;
getFaceEmbeddings(faceImages: Float32Array): Promise<Array<FaceEmbedding>>;
dispose(): Promise<void>;
}
export interface BlurDetectionService {

View file

@ -2,7 +2,7 @@ import { FILE_TYPE } from "constants/file";
import { City } from "services/locationSearchService";
import { LocationTagData } from "types/entity";
import { EnteFile } from "types/file";
import { Person, Thing, WordGroup } from "types/machineLearning";
import { Person } from "types/machineLearning";
import { IndexStatus } from "types/machineLearning/ui";
export enum SuggestionType {
@ -12,8 +12,6 @@ export enum SuggestionType {
FILE_NAME = "FILE_NAME",
PERSON = "PERSON",
INDEX_STATUS = "INDEX_STATUS",
THING = "THING",
TEXT = "TEXT",
FILE_CAPTION = "FILE_CAPTION",
FILE_TYPE = "FILE_TYPE",
CLIP = "CLIP",
@ -34,8 +32,6 @@ export interface Suggestion {
| number[]
| Person
| IndexStatus
| Thing
| WordGroup
| LocationTagData
| City
| FILE_TYPE
@ -50,8 +46,6 @@ export type Search = {
collection?: number;
files?: number[];
person?: Person;
thing?: Thing;
text?: WordGroup;
fileType?: FILE_TYPE;
clip?: ClipSearchScores;
};

View file

@ -50,7 +50,7 @@ export class SimpleJob<R extends JobResult> {
try {
const jobResult = await this.runCallback();
if (jobResult.shouldBackoff) {
if (jobResult && jobResult.shouldBackoff) {
this.intervalSec = Math.min(
this.config.maxItervalSec,
this.intervalSec * this.config.backoffMultiplier,

View file

@ -1,34 +1,39 @@
import * as tf from "@tensorflow/tfjs-core";
import { Matrix, inverse } from "ml-matrix";
import { Matrix } from "ml-matrix";
import { getSimilarityTransformation } from "similarity-transformation";
import { Dimensions } from "types/image";
import { FaceAlignment, FaceDetection } from "types/machineLearning";
import {
ARCFACE_LANDMARKS,
ARCFACE_LANDMARKS_FACE_SIZE,
ARC_FACE_5_LANDMARKS,
} from "types/machineLearning/archface";
import { cropWithRotation, transform } from "utils/image";
import {
computeRotation,
enlargeBox,
extractFaces,
getBoxCenter,
getBoxCenterPt,
toTensor4D,
} from ".";
import { Box, Point } from "../../../thirdparty/face-api/classes";
import { Point } from "../../../thirdparty/face-api/classes";
export function normalizeLandmarks(
landmarks: Array<[number, number]>,
faceSize: number,
): Array<[number, number]> {
return landmarks.map((landmark) =>
landmark.map((p) => p / faceSize),
) as Array<[number, number]>;
const ARCFACE_LANDMARKS = [
[38.2946, 51.6963],
[73.5318, 51.5014],
[56.0252, 71.7366],
[56.1396, 92.2848],
] as Array<[number, number]>;
const ARCFACE_LANDMARKS_FACE_SIZE = 112;
const ARC_FACE_5_LANDMARKS = [
[38.2946, 51.6963],
[73.5318, 51.5014],
[56.0252, 71.7366],
[41.5493, 92.3655],
[70.7299, 92.2041],
] as Array<[number, number]>;
export function getArcfaceAlignment(
faceDetection: FaceDetection,
): FaceAlignment {
const landmarkCount = faceDetection.landmarks.length;
return getFaceAlignmentUsingSimilarityTransform(
faceDetection,
normalizeLandmarks(
landmarkCount === 5 ? ARC_FACE_5_LANDMARKS : ARCFACE_LANDMARKS,
ARCFACE_LANDMARKS_FACE_SIZE,
),
);
}
export function getFaceAlignmentUsingSimilarityTransform(
function getFaceAlignmentUsingSimilarityTransform(
faceDetection: FaceDetection,
alignedLandmarks: Array<[number, number]>,
// alignmentMethod: Versioned<FaceAlignmentMethod>
@ -72,175 +77,11 @@ export function getFaceAlignmentUsingSimilarityTransform(
};
}
export function getArcfaceAlignment(
faceDetection: FaceDetection,
): FaceAlignment {
const landmarkCount = faceDetection.landmarks.length;
return getFaceAlignmentUsingSimilarityTransform(
faceDetection,
normalizeLandmarks(
landmarkCount === 5 ? ARC_FACE_5_LANDMARKS : ARCFACE_LANDMARKS,
ARCFACE_LANDMARKS_FACE_SIZE,
),
);
}
export function extractFaceImage(
image: tf.Tensor4D,
alignment: FaceAlignment,
function normalizeLandmarks(
landmarks: Array<[number, number]>,
faceSize: number,
) {
const affineMat = new Matrix(alignment.affineMatrix);
const I = inverse(affineMat);
return tf.tidy(() => {
const projection = tf.tensor2d([
[
I.get(0, 0),
I.get(0, 1),
I.get(0, 2),
I.get(1, 0),
I.get(1, 1),
I.get(1, 2),
0,
0,
],
]);
const faceImage = tf.image.transform(
image,
projection,
"bilinear",
"constant",
0,
[faceSize, faceSize],
);
return faceImage;
});
}
export function tfExtractFaceImages(
image: tf.Tensor3D | tf.Tensor4D,
alignments: Array<FaceAlignment>,
faceSize: number,
): tf.Tensor4D {
return tf.tidy(() => {
const tf4dFloat32Image = toTensor4D(image, "float32");
const faceImages = new Array<tf.Tensor3D>(alignments.length);
for (let i = 0; i < alignments.length; i++) {
faceImages[i] = tf.squeeze(
extractFaceImage(tf4dFloat32Image, alignments[i], faceSize),
[0],
);
}
return tf.stack(faceImages) as tf.Tensor4D;
});
}
export function getAlignedFaceBox(alignment: FaceAlignment) {
return new Box({
x: alignment.center.x - alignment.size / 2,
y: alignment.center.y - alignment.size / 2,
width: alignment.size,
height: alignment.size,
}).round();
}
export function ibExtractFaceImage(
image: ImageBitmap,
alignment: FaceAlignment,
faceSize: number,
): ImageBitmap {
const box = getAlignedFaceBox(alignment);
const faceSizeDimentions: Dimensions = {
width: faceSize,
height: faceSize,
};
return cropWithRotation(
image,
box,
alignment.rotation,
faceSizeDimentions,
faceSizeDimentions,
);
}
// Used in MLDebugViewOnly
export function ibExtractFaceImageUsingTransform(
image: ImageBitmap,
alignment: FaceAlignment,
faceSize: number,
): ImageBitmap {
const scaledMatrix = new Matrix(alignment.affineMatrix)
.mul(faceSize)
.to2DArray();
// log.info("scaledMatrix: ", scaledMatrix);
return transform(image, scaledMatrix, faceSize, faceSize);
}
export function ibExtractFaceImages(
image: ImageBitmap,
alignments: Array<FaceAlignment>,
faceSize: number,
): Array<ImageBitmap> {
return alignments.map((alignment) =>
ibExtractFaceImage(image, alignment, faceSize),
);
}
const BLAZEFACE_LEFT_EYE_INDEX = 0;
const BLAZEFACE_RIGHT_EYE_INDEX = 1;
// const BLAZEFACE_NOSE_INDEX = 2;
const BLAZEFACE_MOUTH_INDEX = 3;
export function getRotatedFaceImage(
image: tf.Tensor3D | tf.Tensor4D,
faceDetection: FaceDetection,
padding: number = 1.5,
): tf.Tensor4D {
const paddedBox = enlargeBox(faceDetection.box, padding);
// log.info("paddedBox", paddedBox);
const landmarkPoints = faceDetection.landmarks;
return tf.tidy(() => {
const tf4dFloat32Image = toTensor4D(image, "float32");
let angle = 0;
const leftEye = landmarkPoints[BLAZEFACE_LEFT_EYE_INDEX];
const rightEye = landmarkPoints[BLAZEFACE_RIGHT_EYE_INDEX];
const foreheadCenter = getBoxCenterPt(leftEye, rightEye);
angle = computeRotation(
landmarkPoints[BLAZEFACE_MOUTH_INDEX],
foreheadCenter,
); // landmarkPoints[BLAZEFACE_NOSE_INDEX]
// angle = computeRotation(leftEye, rightEye);
// log.info('angle: ', angle);
const faceCenter = getBoxCenter(faceDetection.box);
// log.info('faceCenter: ', faceCenter);
const faceCenterNormalized: [number, number] = [
faceCenter.x / tf4dFloat32Image.shape[2],
faceCenter.y / tf4dFloat32Image.shape[1],
];
// log.info('faceCenterNormalized: ', faceCenterNormalized);
let rotatedImage = tf4dFloat32Image;
if (angle !== 0) {
rotatedImage = tf.image.rotateWithOffset(
tf4dFloat32Image,
angle,
0,
faceCenterNormalized,
);
}
const faceImageTensor = extractFaces(
rotatedImage,
[paddedBox],
paddedBox.width > 224 ? 448 : 224,
);
return faceImageTensor;
// return tf.gather(faceImageTensor, 0);
});
): Array<[number, number]> {
return landmarks.map((landmark) =>
landmark.map((p) => p / faceSize),
) as Array<[number, number]>;
}

View file

@ -1,23 +1,15 @@
import log from "@/next/log";
import { CacheStorageService } from "@ente/shared/storage/cacheStorage";
import { CACHES } from "@ente/shared/storage/cacheStorage/constants";
import { getBlobFromCache } from "@ente/shared/storage/cacheStorage/helpers";
import { compose, Matrix, scale, translate } from "transformation-matrix";
import { BlobOptions, Dimensions } from "types/image";
import { BlobOptions } from "types/image";
import {
AlignedFace,
FaceAlignment,
FaceCrop,
FaceCropConfig,
FaceDetection,
MlFileData,
StoredFaceCrop,
} from "types/machineLearning";
import { cropWithRotation, imageBitmapToBlob } from "utils/image";
import { enlargeBox } from ".";
import { Box } from "../../../thirdparty/face-api/classes";
import { getAlignedFaceBox } from "./faceAlign";
import { transformBox, transformPoints } from "./transform";
export function getFaceCrop(
imageBitmap: ImageBitmap,
@ -38,7 +30,25 @@ export function getFaceCrop(
};
}
export async function storeFaceCropForBlob(
function getAlignedFaceBox(alignment: FaceAlignment) {
return new Box({
x: alignment.center.x - alignment.size / 2,
y: alignment.center.y - alignment.size / 2,
width: alignment.size,
height: alignment.size,
}).round();
}
export async function storeFaceCrop(
faceId: string,
faceCrop: FaceCrop,
blobOptions: BlobOptions,
): Promise<StoredFaceCrop> {
const faceCropBlob = await imageBitmapToBlob(faceCrop.image, blobOptions);
return storeFaceCropForBlob(faceId, faceCrop.imageBox, faceCropBlob);
}
async function storeFaceCropForBlob(
faceId: string,
imageBox: Box,
faceCropBlob: Blob,
@ -52,166 +62,3 @@ export async function storeFaceCropForBlob(
imageBox: imageBox,
};
}
export async function storeFaceCrop(
faceId: string,
faceCrop: FaceCrop,
blobOptions: BlobOptions,
): Promise<StoredFaceCrop> {
const faceCropBlob = await imageBitmapToBlob(faceCrop.image, blobOptions);
return storeFaceCropForBlob(faceId, faceCrop.imageBox, faceCropBlob);
}
export async function getFaceCropBlobFromStorage(
storedFaceCrop: StoredFaceCrop,
): Promise<Blob> {
return getBlobFromCache(CACHES.FACE_CROPS, storedFaceCrop.imageUrl);
}
export async function getFaceCropFromStorage(
storedFaceCrop: StoredFaceCrop,
): Promise<FaceCrop> {
const faceCropBlob = await getFaceCropBlobFromStorage(storedFaceCrop);
const faceCropImage = await createImageBitmap(faceCropBlob);
return {
image: faceCropImage,
imageBox: storedFaceCrop.imageBox,
};
}
export async function removeOldFaceCrops(
oldMLFileData: MlFileData,
newMLFileData: MlFileData,
) {
const newFaceCropUrls =
newMLFileData?.faces
?.map((f) => f.crop?.imageUrl)
?.filter((fc) => fc !== null && fc !== undefined) || [];
const oldFaceCropUrls =
oldMLFileData?.faces
?.map((f) => f.crop?.imageUrl)
?.filter((fc) => fc !== null && fc !== undefined) || [];
const unusedFaceCropUrls = oldFaceCropUrls.filter(
(oldUrl) => !newFaceCropUrls.includes(oldUrl),
);
if (!unusedFaceCropUrls || unusedFaceCropUrls.length < 1) {
return;
}
return removeFaceCropUrls(unusedFaceCropUrls);
}
export async function removeFaceCropUrls(faceCropUrls: Array<string>) {
log.info("Removing face crop urls: ", JSON.stringify(faceCropUrls));
const faceCropCache = await CacheStorageService.open(CACHES.FACE_CROPS);
const urlRemovalPromises = faceCropUrls?.map((url) =>
faceCropCache.delete(url),
);
return urlRemovalPromises && Promise.all(urlRemovalPromises);
}
export function extractFaceImageFromCrop(
faceCrop: FaceCrop,
box: Box,
rotation: number,
faceSize: number,
): ImageBitmap {
const faceCropImage = faceCrop?.image;
let imageBox = faceCrop?.imageBox;
if (!faceCropImage || !imageBox) {
throw Error("Face crop not present");
}
// TODO: Have better serialization to avoid creating new object manually when calling class methods
imageBox = new Box(imageBox);
const scale = faceCropImage.width / imageBox.width;
const transformedBox = box
.shift(-imageBox.x, -imageBox.y)
.rescale(scale)
.round();
// log.info({ box, imageBox, faceCropImage, scale, scaledBox, scaledImageBox, shiftedBox });
const faceSizeDimentions: Dimensions = {
width: faceSize,
height: faceSize,
};
const faceImage = cropWithRotation(
faceCropImage,
transformedBox,
rotation,
faceSizeDimentions,
faceSizeDimentions,
);
return faceImage;
}
export async function ibExtractFaceImageFromCrop(
faceCrop: FaceCrop,
alignment: FaceAlignment,
faceSize: number,
): Promise<ImageBitmap> {
const box = getAlignedFaceBox(alignment);
return extractFaceImageFromCrop(
faceCrop,
box,
alignment.rotation,
faceSize,
);
}
export async function ibExtractFaceImagesFromCrops(
faces: Array<AlignedFace>,
faceSize: number,
): Promise<Array<ImageBitmap>> {
const faceImagePromises = faces.map(async (alignedFace) => {
const faceCrop = await getFaceCropFromStorage(alignedFace.crop);
return ibExtractFaceImageFromCrop(
faceCrop,
alignedFace.alignment,
faceSize,
);
});
return Promise.all(faceImagePromises);
}
export function transformFace(faceDetection: FaceDetection, transform: Matrix) {
return {
...faceDetection,
box: transformBox(faceDetection.box, transform),
landmarks: transformPoints(faceDetection.landmarks, transform),
};
}
export function transformToFaceCropDims(
faceCrop: FaceCrop,
faceDetection: FaceDetection,
) {
const imageBox = new Box(faceCrop.imageBox);
const transform = compose(
scale(faceCrop.image.width / imageBox.width),
translate(-imageBox.x, -imageBox.y),
);
return transformFace(faceDetection, transform);
}
export function transformToImageDims(
faceCrop: FaceCrop,
faceDetection: FaceDetection,
) {
const imageBox = new Box(faceCrop.imageBox);
const transform = compose(
translate(imageBox.x, imageBox.y),
scale(imageBox.width / faceCrop.image.width),
);
return transformFace(faceDetection, transform);
}

View file

@ -1,108 +0,0 @@
import { euclidean } from "hdbscan";
import { FaceDetection } from "types/machineLearning";
import { getNearestPointIndex, newBox } from ".";
import { Box, Point } from "../../../thirdparty/face-api/classes";
import {
computeTransformToBox,
transformBox,
transformPoints,
} from "./transform";
export function transformPaddedToImage(
detection: FaceDetection,
faceImage: ImageBitmap,
imageBox: Box,
paddedBox: Box,
) {
const inBox = newBox(0, 0, faceImage.width, faceImage.height);
imageBox.x = paddedBox.x;
imageBox.y = paddedBox.y;
const transform = computeTransformToBox(inBox, imageBox);
detection.box = transformBox(detection.box, transform);
detection.landmarks = transformPoints(detection.landmarks, transform);
}
export function getDetectionCenter(detection: FaceDetection) {
const center = new Point(0, 0);
// TODO: first 4 landmarks is applicable to blazeface only
// this needs to consider eyes, nose and mouth landmarks to take center
detection.landmarks?.slice(0, 4).forEach((p) => {
center.x += p.x;
center.y += p.y;
});
return center.div({ x: 4, y: 4 });
}
/**
* Finds the nearest face detection from a list of detections to a specified detection.
*
* This function calculates the center of each detection and then finds the detection whose center is nearest to the center of the specified detection.
* If a maximum distance is specified, only detections within that distance are considered.
*
* @param toDetection - The face detection to find the nearest detection to.
* @param fromDetections - An array of face detections to search in.
* @param maxDistance - The maximum distance between the centers of the two detections for a detection to be considered. If not specified, all detections are considered.
*
* @returns The nearest face detection from the list, or `undefined` if no detection is within the maximum distance.
*/
export function getNearestDetection(
toDetection: FaceDetection,
fromDetections: Array<FaceDetection>,
maxDistance?: number,
) {
const toCenter = getDetectionCenter(toDetection);
const centers = fromDetections.map((d) => getDetectionCenter(d));
const nearestIndex = getNearestPointIndex(toCenter, centers, maxDistance);
return nearestIndex >= 0 && fromDetections[nearestIndex];
}
/**
* Removes duplicate face detections from an array of detections.
*
* This function sorts the detections by their probability in descending order, then iterates over them.
* For each detection, it calculates the Euclidean distance to all other detections.
* If the distance is less than or equal to the specified threshold (`withinDistance`), the other detection is considered a duplicate and is removed.
*
* @param detections - An array of face detections to remove duplicates from.
* @param withinDistance - The maximum Euclidean distance between two detections for them to be considered duplicates.
*
* @returns An array of face detections with duplicates removed.
*/
export function removeDuplicateDetections(
detections: Array<FaceDetection>,
withinDistance: number,
) {
// console.time('removeDuplicates');
detections.sort((a, b) => b.probability - a.probability);
const isSelected = new Map<number, boolean>();
for (let i = 0; i < detections.length; i++) {
if (isSelected.get(i) === false) {
continue;
}
isSelected.set(i, true);
for (let j = i + 1; j < detections.length; j++) {
if (isSelected.get(j) === false) {
continue;
}
const centeri = getDetectionCenter(detections[i]);
const centerj = getDetectionCenter(detections[j]);
const dist = euclidean(
[centeri.x, centeri.y],
[centerj.x, centerj.y],
);
if (dist <= withinDistance) {
isSelected.set(j, false);
}
}
}
const uniques: Array<FaceDetection> = [];
for (let i = 0; i < detections.length; i++) {
isSelected.get(i) && uniques.push(detections[i]);
}
// console.timeEnd('removeDuplicates');
return uniques;
}

View file

@ -1,11 +1,7 @@
import log from "@/next/log";
import { CACHES } from "@ente/shared/storage/cacheStorage/constants";
import { cached } from "@ente/shared/storage/cacheStorage/helpers";
import * as tf from "@tensorflow/tfjs-core";
import { NormalizedFace } from "blazeface-back";
import { FILE_TYPE } from "constants/file";
import { BLAZEFACE_FACE_SIZE } from "constants/mlConfig";
import { euclidean } from "hdbscan";
import PQueue from "p-queue";
import DownloadManager from "services/download";
import { getLocalFiles } from "services/fileService";
@ -13,152 +9,22 @@ import { decodeLivePhoto } from "services/livePhotoService";
import { EnteFile } from "types/file";
import { Dimensions } from "types/image";
import {
AlignedFace,
DetectedFace,
DetectedObject,
Face,
FaceAlignment,
FaceImageBlob,
MlFileData,
Person,
RealWorldObject,
Versioned,
} from "types/machineLearning";
import { getRenderableImage } from "utils/file";
import { clamp, imageBitmapToBlob, warpAffineFloat32List } from "utils/image";
import { clamp, warpAffineFloat32List } from "utils/image";
import mlIDbStorage from "utils/storage/mlIDbStorage";
import { Box, Point } from "../../../thirdparty/face-api/classes";
import { ibExtractFaceImage, ibExtractFaceImages } from "./faceAlign";
import { getFaceCropBlobFromStorage } from "./faceCrop";
export function f32Average(descriptors: Float32Array[]) {
if (descriptors.length < 1) {
throw Error("f32Average: input size 0");
}
if (descriptors.length === 1) {
return descriptors[0];
}
const f32Size = descriptors[0].length;
const avg = new Float32Array(f32Size);
for (let index = 0; index < f32Size; index++) {
avg[index] = descriptors[0][index];
for (let desc = 1; desc < descriptors.length; desc++) {
avg[index] = avg[index] + descriptors[desc][index];
}
avg[index] = avg[index] / descriptors.length;
}
return avg;
}
export function isTensor(tensor: any, dim: number) {
return tensor instanceof tf.Tensor && tensor.shape.length === dim;
}
export function isTensor1D(tensor: any): tensor is tf.Tensor1D {
return isTensor(tensor, 1);
}
export function isTensor2D(tensor: any): tensor is tf.Tensor2D {
return isTensor(tensor, 2);
}
export function isTensor3D(tensor: any): tensor is tf.Tensor3D {
return isTensor(tensor, 3);
}
export function isTensor4D(tensor: any): tensor is tf.Tensor4D {
return isTensor(tensor, 4);
}
export function toTensor4D(
image: tf.Tensor3D | tf.Tensor4D,
dtype?: tf.DataType,
) {
return tf.tidy(() => {
let reshapedImage: tf.Tensor4D;
if (isTensor3D(image)) {
reshapedImage = tf.expandDims(image, 0);
} else if (isTensor4D(image)) {
reshapedImage = image;
} else {
throw Error("toTensor4D only supports Tensor3D and Tensor4D input");
}
if (dtype) {
reshapedImage = tf.cast(reshapedImage, dtype);
}
return reshapedImage;
});
}
export function imageBitmapsToTensor4D(imageBitmaps: Array<ImageBitmap>) {
return tf.tidy(() => {
const tfImages = imageBitmaps.map((ib) => tf.browser.fromPixels(ib));
return tf.stack(tfImages) as tf.Tensor4D;
});
}
export function extractFaces(
image: tf.Tensor3D | tf.Tensor4D,
facebBoxes: Array<Box>,
faceSize: number,
) {
return tf.tidy(() => {
const reshapedImage = toTensor4D(image, "float32");
const boxes = facebBoxes.map((box) => {
const normalized = box.rescale({
width: 1 / reshapedImage.shape[2],
height: 1 / reshapedImage.shape[1],
});
return [
normalized.top,
normalized.left,
normalized.bottom,
normalized.right,
];
});
// log.info('boxes: ', boxes[0]);
const faceImagesTensor = tf.image.cropAndResize(
reshapedImage,
boxes,
tf.fill([boxes.length], 0, "int32"),
[faceSize, faceSize],
);
return faceImagesTensor;
});
}
export function newBox(x: number, y: number, width: number, height: number) {
return new Box({ x, y, width, height });
}
export function newBoxFromPoints(
left: number,
top: number,
right: number,
bottom: number,
) {
return new Box({ left, top, right, bottom });
}
export function normFaceBox(face: NormalizedFace) {
return newBoxFromPoints(
face.topLeft[0],
face.topLeft[1],
face.bottomRight[0],
face.bottomRight[1],
);
}
export function getBoxCenterPt(topLeft: Point, bottomRight: Point): Point {
return topLeft.add(bottomRight.sub(topLeft).div(new Point(2, 2)));
}
@ -180,74 +46,17 @@ export function enlargeBox(box: Box, factor: number = 1.5) {
});
}
export function normalizeRadians(angle: number) {
return angle - 2 * Math.PI * Math.floor((angle + Math.PI) / (2 * Math.PI));
}
export function computeRotation(point1: Point, point2: Point) {
const radians =
Math.PI / 2 - Math.atan2(-(point2.y - point1.y), point2.x - point1.x);
return normalizeRadians(radians);
}
export function getAllFacesFromMap(allFacesMap: Map<number, Array<Face>>) {
const allFaces = [...allFacesMap.values()].flat();
return allFaces;
}
export function getAllObjectsFromMap(
allObjectsMap: Map<number, Array<RealWorldObject>>,
) {
return [...allObjectsMap.values()].flat();
}
export async function getLocalFile(fileId: number) {
const localFiles = await getLocalFiles();
return localFiles.find((f) => f.id === fileId);
}
export async function getFaceImage(
face: AlignedFace,
token: string,
faceSize: number = BLAZEFACE_FACE_SIZE,
file?: EnteFile,
): Promise<FaceImageBlob> {
if (!file) {
file = await getLocalFile(face.fileId);
}
const imageBitmap = await getOriginalImageBitmap(file);
const faceImageBitmap = ibExtractFaceImage(
imageBitmap,
face.alignment,
faceSize,
);
const faceImage = imageBitmapToBlob(faceImageBitmap);
faceImageBitmap.close();
imageBitmap.close();
return faceImage;
}
export async function extractFaceImages(
faces: Array<AlignedFace>,
faceSize: number,
image?: ImageBitmap,
) {
// if (faces.length === faces.filter((f) => f.crop).length) {
// return ibExtractFaceImagesFromCrops(faces, faceSize);
// } else
if (image) {
const faceAlignments = faces.map((f) => f.alignment);
return ibExtractFaceImages(image, faceAlignments, faceSize);
} else {
throw Error(
"Either face crops or image is required to extract face images",
);
}
}
export async function extractFaceImagesToFloat32(
faceAlignments: Array<FaceAlignment>,
faceSize: number,
@ -270,10 +79,6 @@ export async function extractFaceImagesToFloat32(
return faceData;
}
export function leftFillNum(num: number, length: number, padding: number) {
return num.toString().padStart(length, padding.toString());
}
export function getFaceId(detectedFace: DetectedFace, imageDims: Dimensions) {
const xMin = clamp(
detectedFace.detection.box.x / imageDims.width,
@ -312,45 +117,10 @@ export function getFaceId(detectedFace: DetectedFace, imageDims: Dimensions) {
return faceID;
}
export function getObjectId(
detectedObject: DetectedObject,
imageDims: Dimensions,
) {
const imgDimPoint = new Point(imageDims.width, imageDims.height);
const objectCenterPoint = new Point(
detectedObject.detection.bbox[2] / 2,
detectedObject.detection.bbox[3] / 2,
);
const gridPt = objectCenterPoint
.mul(new Point(100, 100))
.div(imgDimPoint)
.floor()
.bound(0, 99);
const gridPaddedX = leftFillNum(gridPt.x, 2, 0);
const gridPaddedY = leftFillNum(gridPt.y, 2, 0);
return `${detectedObject.fileID}-${gridPaddedX}-${gridPaddedY}`;
}
export async function getTFImage(blob): Promise<tf.Tensor3D> {
const imageBitmap = await createImageBitmap(blob);
const tfImage = tf.browser.fromPixels(imageBitmap);
imageBitmap.close();
return tfImage;
}
export async function getImageBlobBitmap(blob: Blob): Promise<ImageBitmap> {
return await createImageBitmap(blob);
}
// export async function getTFImageUsingJpegJS(blob: Blob): Promise<TFImageBitmap> {
// const imageData = jpegjs.decode(await blob.arrayBuffer());
// const tfImage = tf.browser.fromPixels(imageData);
// return new TFImageBitmap(undefined, tfImage);
// }
async function getOriginalFile(file: EnteFile, queue?: PQueue) {
let fileStream;
if (queue) {
@ -453,21 +223,6 @@ export async function getUnidentifiedFaces(
);
}
export async function getFaceCropBlobs(
faces: Array<Face>,
): Promise<Array<FaceImageBlob>> {
const faceCrops = faces
.map((f) => f.crop)
.filter((faceCrop) => faceCrop !== null && faceCrop !== undefined);
return (
faceCrops &&
Promise.all(
faceCrops.map((faceCrop) => getFaceCropBlobFromStorage(faceCrop)),
)
);
}
export async function getAllPeople(limit: number = undefined) {
let people: Array<Person> = await mlIDbStorage.getAllPeople();
// await mlPeopleStore.iterate<Person, void>((person) => {
@ -531,27 +286,6 @@ export function areFaceIdsSame(ofFaces: Array<Face>, toFaces: Array<Face>) {
);
}
export function getNearestPointIndex(
toPoint: Point,
fromPoints: Array<Point>,
maxDistance?: number,
) {
const dists = fromPoints.map((point, i) => ({
index: i,
point: point,
distance: euclidean([point.x, point.y], [toPoint.x, toPoint.y]),
}));
const nearest = findFirstIfSorted(
dists,
(a, b) => Math.abs(a.distance) - Math.abs(b.distance),
);
// log.info('Nearest dist: ', nearest.distance, maxDistance);
if (!maxDistance || nearest.distance <= maxDistance) {
return nearest.index;
}
}
export function logQueueStats(queue: PQueue, name: string) {
queue.on("active", () =>
log.info(

View file

@ -1,33 +0,0 @@
import { newBoxFromPoints } from ".";
import { Box, Point } from "../../../thirdparty/face-api/classes";
import {
Matrix,
applyToPoint,
compose,
scale,
translate,
} from "transformation-matrix";
export function computeTransformToBox(inBox: Box, toBox: Box): Matrix {
return compose(
translate(toBox.x, toBox.y),
scale(toBox.width / inBox.width, toBox.height / inBox.height),
);
}
export function transformPoint(point: Point, transform: Matrix) {
const txdPoint = applyToPoint(transform, point);
return new Point(txdPoint.x, txdPoint.y);
}
export function transformPoints(points: Point[], transform: Matrix) {
return points?.map((p) => transformPoint(p, transform));
}
export function transformBox(box: Box, transform: Matrix) {
const topLeft = transformPoint(box.topLeft, transform);
const bottomRight = transformPoint(box.bottomRight, transform);
return newBoxFromPoints(topLeft.x, topLeft.y, bottomRight.x, bottomRight.y);
}

View file

@ -15,14 +15,7 @@ import {
openDB,
} from "idb";
import isElectron from "is-electron";
import {
Face,
MLLibraryData,
MlFileData,
Person,
RealWorldObject,
Thing,
} from "types/machineLearning";
import { Face, MLLibraryData, MlFileData, Person } from "types/machineLearning";
import { IndexStatus } from "types/machineLearning/ui";
interface Config {}
@ -42,9 +35,11 @@ interface MLDb extends DBSchema {
key: number;
value: Person;
};
// Unused, we only retain this is the schema so that we can delete it during
// migration.
things: {
key: number;
value: Thing;
value: unknown;
};
versions: {
key: string;
@ -72,7 +67,7 @@ class MLIDbStorage {
}
private openDB(): Promise<IDBPDatabase<MLDb>> {
return openDB<MLDb>(MLDATA_DB_NAME, 3, {
return openDB<MLDb>(MLDATA_DB_NAME, 4, {
terminated: async () => {
log.error("ML Indexed DB terminated");
this._db = undefined;
@ -128,6 +123,14 @@ class MLIDbStorage {
.objectStore("configs")
.add(DEFAULT_ML_SEARCH_CONFIG, ML_SEARCH_CONFIG_NAME);
}
if (oldVersion < 4) {
// TODO(MR): This loses the user's settings.
db.deleteObjectStore("configs");
db.createObjectStore("configs");
db.deleteObjectStore("things");
}
log.info(
`Ml DB upgraded to version: ${newVersion} from version: ${oldVersion}`,
);
@ -299,21 +302,6 @@ class MLIDbStorage {
log.info("updateFaces", Date.now() - startTime, "ms");
}
public async getAllObjectsMap() {
const startTime = Date.now();
const db = await this.db;
const allFiles = await db.getAll("files");
const allObjectsMap = new Map<number, Array<RealWorldObject>>();
allFiles.forEach(
(mlFileData) =>
mlFileData.objects &&
allObjectsMap.set(mlFileData.fileId, mlFileData.objects),
);
log.info("allObjectsMap", Date.now() - startTime, "ms");
return allObjectsMap;
}
public async getPerson(id: number) {
const db = await this.db;
return db.get("people", id);
@ -334,20 +322,6 @@ class MLIDbStorage {
return db.clear("people");
}
public async getAllThings() {
const db = await this.db;
return db.getAll("things");
}
public async putThing(thing: Thing) {
const db = await this.db;
return db.put("things", thing);
}
public async clearAllThings() {
const db = await this.db;
return db.clear("things");
}
public async getIndexVersion(index: string) {
const db = await this.db;
return db.get("versions", index);

View file

@ -57,14 +57,6 @@ function isSearchedFile(file: EnteFile, search: Search) {
if (search?.person) {
return search.person.files.indexOf(file.id) !== -1;
}
if (search?.thing) {
return search.thing.files.indexOf(file.id) !== -1;
}
if (search?.text) {
return search.text.files.indexOf(file.id) !== -1;
}
if (typeof search?.fileType !== "undefined") {
return search.fileType === file.metadata.fileType;
}

View file

@ -1,4 +1,3 @@
import { isDimensions, isValidNumber } from '../utils';
import { IBoundingBox } from './BoundingBox';
import { IDimensions } from './Dimensions';
import { Point } from './Point';
@ -172,4 +171,12 @@ export class Box<BoxType = any> implements IBoundingBox, IRect {
bottom: this.bottom + (region.bottom * this.height)
}).toSquare().round()
}
}
}
export function isValidNumber(num: any) {
return !!num && num !== Infinity && num !== -Infinity && !isNaN(num) || num === 0
}
export function isDimensions(obj: any): boolean {
return obj && obj.width && obj.height
}

View file

@ -1,4 +1,4 @@
import { isValidNumber } from '../utils';
import { isValidNumber } from './Box';
export interface IDimensions {
width: number
@ -25,4 +25,4 @@ export class Dimensions implements IDimensions {
public reverse(): Dimensions {
return new Dimensions(1 / this.width, 1 / this.height)
}
}
}

View file

@ -1,63 +0,0 @@
import * as tf from '@tensorflow/tfjs-core';
import { Point } from '../classes';
import { Dimensions, IDimensions } from '../classes/Dimensions';
export function isTensor(tensor: any, dim: number) {
return tensor instanceof tf.Tensor && tensor.shape.length === dim
}
export function isTensor1D(tensor: any): tensor is tf.Tensor1D {
return isTensor(tensor, 1)
}
export function isTensor2D(tensor: any): tensor is tf.Tensor2D {
return isTensor(tensor, 2)
}
export function isTensor3D(tensor: any): tensor is tf.Tensor3D {
return isTensor(tensor, 3)
}
export function isTensor4D(tensor: any): tensor is tf.Tensor4D {
return isTensor(tensor, 4)
}
export function isFloat(num: number) {
return num % 1 !== 0
}
export function isEven(num: number) {
return num % 2 === 0
}
export function round(num: number, prec: number = 2) {
const f = Math.pow(10, prec)
return Math.floor(num * f) / f
}
export function isDimensions(obj: any): boolean {
return obj && obj.width && obj.height
}
export function computeReshapedDimensions({ width, height }: IDimensions, inputSize: number) {
const scale = inputSize / Math.max(height, width)
return new Dimensions(Math.round(width * scale), Math.round(height * scale))
}
export function getCenterPoint(pts: Point[]): Point {
return pts.reduce((sum, pt) => sum.add(pt), new Point(0, 0))
.div(new Point(pts.length, pts.length))
}
export function range(num: number, start: number, step: number): number[] {
return Array(num).fill(0).map((_, i) => start + (i * step))
}
export function isValidNumber(num: any) {
return !!num && num !== Infinity && num !== -Infinity && !isNaN(num) || num === 0
}
export function isValidProbablitiy(num: any) {
return isValidNumber(num) && 0 <= num && num <= 1.0
}

View file

@ -196,7 +196,7 @@ export interface Electron {
// - ML
/**
* Compute and return a CLIP embedding of the given image.
* Return a CLIP embedding of the given image.
*
* See: [Note: CLIP based magic search]
*
@ -207,7 +207,7 @@ export interface Electron {
clipImageEmbedding: (jpegImageData: Uint8Array) => Promise<Float32Array>;
/**
* Compute and return a CLIP embedding of the given image.
* Return a CLIP embedding of the given image.
*
* See: [Note: CLIP based magic search]
*
@ -217,6 +217,22 @@ export interface Electron {
*/
clipTextEmbedding: (text: string) => Promise<Float32Array>;
/**
* Detect faces in the given image using YOLO.
*
* Both the input and output are opaque binary data whose internal structure
* is specific to our implementation and the model (YOLO) we use.
*/
detectFaces: (input: Float32Array) => Promise<Float32Array>;
/**
* Return a MobileFaceNet embedding for the given face data.
*
* Both the input and output are opaque binary data whose internal structure
* is specific to our implementation and the model (MobileFaceNet) we use.
*/
faceEmbedding: (input: Float32Array) => Promise<Float32Array>;
// - File selection
// TODO: Deprecated - use dialogs on the renderer process itself

View file

@ -35,6 +35,19 @@ export class ComlinkWorker<T extends new () => InstanceType<T>> {
}
}
// TODO(MR): Temporary method to forward auth tokens to workers
const getAuthToken = () => {
// LS_KEYS.USER
const userJSONString = localStorage.getItem("user");
if (!userJSONString) return undefined;
const json: unknown = JSON.parse(userJSONString);
if (!json || typeof json != "object" || !("token" in json))
return undefined;
const token = json.token;
if (typeof token != "string") return undefined;
return token;
};
/**
* A minimal set of utility functions that we expose to all workers that we
* create.
@ -44,8 +57,12 @@ export class ComlinkWorker<T extends new () => InstanceType<T>> {
*/
const workerBridge = {
logToDisk,
getAuthToken,
convertToJPEG: (inputFileData: Uint8Array, filename: string) =>
ensureElectron().convertToJPEG(inputFileData, filename),
detectFaces: (input: Float32Array) => ensureElectron().detectFaces(input),
faceEmbedding: (input: Float32Array) =>
ensureElectron().faceEmbedding(input),
};
export type WorkerBridge = typeof workerBridge;

View file

@ -954,52 +954,6 @@
dependencies:
tslib "^2.4.0"
"@tensorflow-models/coco-ssd@^2.2.2":
version "2.2.3"
resolved "https://registry.yarnpkg.com/@tensorflow-models/coco-ssd/-/coco-ssd-2.2.3.tgz#3825286569076d6788199c9cb89fb2fa31f7d2f2"
integrity sha512-iCLGktG/XhHbP6h2FWxqCKMp/Px0lCp6MZU1fjNhjDHeaWEC9G7S7cZrnPXsfH+NewCM53YShlrHnknxU3SQig==
"@tensorflow/tfjs-backend-cpu@4.17.0", "@tensorflow/tfjs-backend-cpu@^4.10.0":
version "4.17.0"
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-backend-cpu/-/tfjs-backend-cpu-4.17.0.tgz#b0c495de686cf700f2ae1f6d8bc2eb6f1964d250"
integrity sha512-2VSCHnX9qhYTjw9HiVwTBSnRVlntKXeBlK7aSVsmZfHGwWE2faErTtO7bWmqNqw0U7gyznJbVAjlow/p+0RNGw==
dependencies:
"@types/seedrandom" "^2.4.28"
seedrandom "^3.0.5"
"@tensorflow/tfjs-backend-webgl@^4.9.0":
version "4.17.0"
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-backend-webgl/-/tfjs-backend-webgl-4.17.0.tgz#7d540a92343582d37d2cdf9509060598a19cd17a"
integrity sha512-CC5GsGECCd7eYAUaKq0XJ48FjEZdgXZWPxgUYx4djvfUx5fQPp35hCSP9w/k463jllBMbjl2tKRg8u7Ia/LYzg==
dependencies:
"@tensorflow/tfjs-backend-cpu" "4.17.0"
"@types/offscreencanvas" "~2019.3.0"
"@types/seedrandom" "^2.4.28"
seedrandom "^3.0.5"
"@tensorflow/tfjs-converter@^4.10.0":
version "4.17.0"
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-converter/-/tfjs-converter-4.17.0.tgz#f4407bd53d5e300b05ed0b0f068506bc50c956b0"
integrity sha512-qFxIjPfomCuTrYxsFjtKbi3QfdmTTCWo+RvqD64oCMS0sjp7sUDNhJyKDoLx6LZhXlwXpHIVDJctLMRMwet0Zw==
"@tensorflow/tfjs-core@^4.10.0":
version "4.17.0"
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-4.17.0.tgz#1ea128555a4d197aed417d70461fcbc7eaec635f"
integrity sha512-v9Q5430EnRpyhWNd9LVgXadciKvxLiq+sTrLKRowh26BHyAsams4tZIgX3lFKjB7b90p+FYifVMcqLTTHgjGpQ==
dependencies:
"@types/long" "^4.0.1"
"@types/offscreencanvas" "~2019.7.0"
"@types/seedrandom" "^2.4.28"
"@webgpu/types" "0.1.38"
long "4.0.0"
node-fetch "~2.6.1"
seedrandom "^3.0.5"
"@tensorflow/tfjs-tflite@0.0.1-alpha.7":
version "0.0.1-alpha.7"
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-tflite/-/tfjs-tflite-0.0.1-alpha.7.tgz#647c088689131fee424b7ae0bb9b7fdc74a61475"
integrity sha512-aOmmEC/AHzfc/u1Q6ccY6Kr7CfNwjonqyTGVU1OqlQGDrH2IopcCjNSZdatJIB6J2RxlBs979JilCOUpK1LXng==
"@tokenizer/token@^0.3.0":
version "0.3.0"
resolved "https://registry.yarnpkg.com/@tokenizer/token/-/token-0.3.0.tgz#fe98a93fe789247e998c75e74e9c7c63217aa276"
@ -1098,11 +1052,6 @@
resolved "https://registry.yarnpkg.com/@types/lodash/-/lodash-4.14.202.tgz#f09dbd2fb082d507178b2f2a5c7e74bd72ff98f8"
integrity sha512-OvlIYQK9tNneDlS0VN54LLd5uiPCBOp7gS5Z0f1mjoJYBrtStzgmJBxONW3U6OZqdtNzZPmn9BS/7WI7BFFcFQ==
"@types/long@^4.0.1":
version "4.0.2"
resolved "https://registry.yarnpkg.com/@types/long/-/long-4.0.2.tgz#b74129719fc8d11c01868010082d483b7545591a"
integrity sha512-MqTGEo5bj5t157U6fA/BiDynNkn0YknVdh48CMPkTSpFTVmvao5UQmm7uEF6xBEo7qIMAlY/JSleYaE6VOdpaA==
"@types/node@*":
version "20.11.20"
resolved "https://registry.yarnpkg.com/@types/node/-/node-20.11.20.tgz#f0a2aee575215149a62784210ad88b3a34843659"
@ -1110,16 +1059,6 @@
dependencies:
undici-types "~5.26.4"
"@types/offscreencanvas@~2019.3.0":
version "2019.3.0"
resolved "https://registry.yarnpkg.com/@types/offscreencanvas/-/offscreencanvas-2019.3.0.tgz#3336428ec7e9180cf4566dfea5da04eb586a6553"
integrity sha512-esIJx9bQg+QYF0ra8GnvfianIY8qWB0GBx54PK5Eps6m+xTj86KLavHv6qDhzKcu5UUOgNfJ2pWaIIV7TRUd9Q==
"@types/offscreencanvas@~2019.7.0":
version "2019.7.3"
resolved "https://registry.yarnpkg.com/@types/offscreencanvas/-/offscreencanvas-2019.7.3.tgz#90267db13f64d6e9ccb5ae3eac92786a7c77a516"
integrity sha512-ieXiYmgSRXUDeOntE1InxjWyvEelZGP63M+cGuquuRLuIKKT1osnkXjxev9B7d1nXSug5vpunx+gNlbVxMlC9A==
"@types/parse-json@^4.0.0":
version "4.0.2"
resolved "https://registry.yarnpkg.com/@types/parse-json/-/parse-json-4.0.2.tgz#5950e50960793055845e956c427fc2b0d70c5239"
@ -1195,11 +1134,6 @@
resolved "https://registry.yarnpkg.com/@types/scheduler/-/scheduler-0.16.8.tgz#ce5ace04cfeabe7ef87c0091e50752e36707deff"
integrity sha512-WZLiwShhwLRmeV6zH+GkbOFT6Z6VklCItrDioxUnv+u4Ll+8vKeFySoFyK/0ctcRpOmwAicELfmys1sDc/Rw+A==
"@types/seedrandom@^2.4.28":
version "2.4.34"
resolved "https://registry.yarnpkg.com/@types/seedrandom/-/seedrandom-2.4.34.tgz#c725cd0fc0442e2d3d0e5913af005686ffb7eb99"
integrity sha512-ytDiArvrn/3Xk6/vtylys5tlY6eo7Ane0hvcx++TKo6RxQXuVfW0AF/oeWqAj9dN29SyhtawuXstgmPlwNcv/A==
"@types/semver@^7.5.0":
version "7.5.7"
resolved "https://registry.yarnpkg.com/@types/semver/-/semver-7.5.7.tgz#326f5fdda70d13580777bcaa1bc6fa772a5aef0e"
@ -1363,11 +1297,6 @@
"@types/babel__core" "^7.20.5"
react-refresh "^0.14.0"
"@webgpu/types@0.1.38":
version "0.1.38"
resolved "https://registry.yarnpkg.com/@webgpu/types/-/types-0.1.38.tgz#6fda4b410edc753d3213c648320ebcf319669020"
integrity sha512-7LrhVKz2PRh+DD7+S+PVaFd5HxaWQvoMqBbsV9fNJO1pjUs1P8bM2vQVNfk+3URTqbuTI7gkXi0rfsN0IadoBA==
acorn-jsx@^5.3.2:
version "5.3.2"
resolved "https://registry.yarnpkg.com/acorn-jsx/-/acorn-jsx-5.3.2.tgz#7ed5bb55908b3b2f1bc55c6af1653bada7f07937"
@ -1630,11 +1559,6 @@ bip39@^3.0.4:
dependencies:
"@noble/hashes" "^1.2.0"
blazeface-back@^0.0.9:
version "0.0.9"
resolved "https://registry.yarnpkg.com/blazeface-back/-/blazeface-back-0.0.9.tgz#a8a26a0022950eb21136693f2fca3c52315ad2a4"
integrity sha512-t0i5V117j074d7d7mlLaRq9n/bYchXcSEgpWVbGGloV68A6Jn22t4SNoEC3t+MOsU8H+eXoDv2/6+JsqActM1g==
brace-expansion@^1.1.7:
version "1.1.11"
resolved "https://registry.yarnpkg.com/brace-expansion/-/brace-expansion-1.1.11.tgz#3c7fcbf529d87226f3d2f52b966ff5271eb441dd"
@ -3445,11 +3369,6 @@ lodash@^4.17.21:
resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.21.tgz#679591c564c3bffaae8454cf0b3df370c3d6911c"
integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==
long@4.0.0:
version "4.0.0"
resolved "https://registry.yarnpkg.com/long/-/long-4.0.0.tgz#9a7b71cfb7d361a194ea555241c92f7468d5bf28"
integrity sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA==
loose-envify@^1.1.0, loose-envify@^1.4.0:
version "1.4.0"
resolved "https://registry.yarnpkg.com/loose-envify/-/loose-envify-1.4.0.tgz#71ee51fa7be4caec1a63839f7e682d8132d30caf"
@ -3621,13 +3540,6 @@ node-fetch@^2.6.1:
dependencies:
whatwg-url "^5.0.0"
node-fetch@~2.6.1:
version "2.6.13"
resolved "https://registry.yarnpkg.com/node-fetch/-/node-fetch-2.6.13.tgz#a20acbbec73c2e09f9007de5cda17104122e0010"
integrity sha512-StxNAxh15zr77QvvkmveSQ8uCQ4+v5FkvNTj0OESmiHu+VRi/gXArXtkWMElOsOUNLtUEvI4yS+rdtOHZTwlQA==
dependencies:
whatwg-url "^5.0.0"
node-releases@^2.0.14:
version "2.0.14"
resolved "https://registry.yarnpkg.com/node-releases/-/node-releases-2.0.14.tgz#2ffb053bceb8b2be8495ece1ab6ce600c4461b0b"
@ -4251,11 +4163,6 @@ scheduler@^0.23.0:
dependencies:
loose-envify "^1.1.0"
seedrandom@^3.0.5:
version "3.0.5"
resolved "https://registry.yarnpkg.com/seedrandom/-/seedrandom-3.0.5.tgz#54edc85c95222525b0c7a6f6b3543d8e0b3aa0a7"
integrity sha512-8OwmbklUNzwezjGInmZ+2clQmExQPvomqjL7LFqOYqtmuxRgQYqOD3mHaU+MvZn5FLUeVxVfQjwLZW/n/JFuqg==
semver@^6.3.1:
version "6.3.1"
resolved "https://registry.yarnpkg.com/semver/-/semver-6.3.1.tgz#556d2ef8689146e46dcea4bfdd095f3434dffcb4"