[web] ML purge - Part 2/x (#1725)
This commit is contained in:
commit
bbfd2ae640
34 changed files with 1148 additions and 1242 deletions
|
@ -5,8 +5,8 @@ import { t } from "i18next";
|
|||
import { AppContext } from "pages/_app";
|
||||
import { useContext } from "react";
|
||||
import { components } from "react-select";
|
||||
import { IndexStatus } from "services/ml/db";
|
||||
import { Suggestion, SuggestionType } from "types/search";
|
||||
import { IndexStatus } from "utils/storage/mlIDbStorage";
|
||||
|
||||
const { Menu } = components;
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import { t } from "i18next";
|
|||
import { useRouter } from "next/router";
|
||||
import { AppContext } from "pages/_app";
|
||||
import { useContext, useState } from "react";
|
||||
// import mlIDbStorage from 'utils/storage/mlIDbStorage';
|
||||
// import mlIDbStorage from 'services/ml/db';
|
||||
import {
|
||||
configurePasskeyRecovery,
|
||||
isPasskeyRecoveryEnabled,
|
||||
|
|
|
@ -3,9 +3,9 @@ import { Skeleton, styled } from "@mui/material";
|
|||
import { Legend } from "components/PhotoViewer/styledComponents/Legend";
|
||||
import { t } from "i18next";
|
||||
import React, { useEffect, useState } from "react";
|
||||
import { Face, Person } from "services/ml/types";
|
||||
import mlIDbStorage from "services/ml/db";
|
||||
import { Face, Person, type MlFileData } from "services/ml/types";
|
||||
import { EnteFile } from "types/file";
|
||||
import { getPeopleList, getUnidentifiedFaces } from "utils/machineLearning";
|
||||
|
||||
const FaceChipContainer = styled("div")`
|
||||
display: flex;
|
||||
|
@ -194,3 +194,45 @@ const FaceCropImageView: React.FC<FaceCropImageViewProps> = ({
|
|||
<Skeleton variant="circular" height={120} width={120} />
|
||||
);
|
||||
};
|
||||
|
||||
async function getPeopleList(file: EnteFile): Promise<Array<Person>> {
|
||||
let startTime = Date.now();
|
||||
const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id);
|
||||
log.info(
|
||||
"getPeopleList:mlFilesStore:getItem",
|
||||
Date.now() - startTime,
|
||||
"ms",
|
||||
);
|
||||
if (!mlFileData?.faces || mlFileData.faces.length < 1) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const peopleIds = mlFileData.faces
|
||||
.filter((f) => f.personId !== null && f.personId !== undefined)
|
||||
.map((f) => f.personId);
|
||||
if (!peopleIds || peopleIds.length < 1) {
|
||||
return [];
|
||||
}
|
||||
// log.info("peopleIds: ", peopleIds);
|
||||
startTime = Date.now();
|
||||
const peoplePromises = peopleIds.map(
|
||||
(p) => mlIDbStorage.getPerson(p) as Promise<Person>,
|
||||
);
|
||||
const peopleList = await Promise.all(peoplePromises);
|
||||
log.info(
|
||||
"getPeopleList:mlPeopleStore:getItems",
|
||||
Date.now() - startTime,
|
||||
"ms",
|
||||
);
|
||||
// log.info("peopleList: ", peopleList);
|
||||
|
||||
return peopleList;
|
||||
}
|
||||
|
||||
async function getUnidentifiedFaces(file: EnteFile): Promise<Array<Face>> {
|
||||
const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id);
|
||||
|
||||
return mlFileData?.faces?.filter(
|
||||
(f) => f.personId === null || f.personId === undefined,
|
||||
);
|
||||
}
|
||||
|
|
|
@ -1,56 +0,0 @@
|
|||
import { MLSearchConfig, MLSyncConfig } from "services/ml/types";
|
||||
import { JobConfig } from "types/common/job";
|
||||
|
||||
export const DEFAULT_ML_SYNC_JOB_CONFIG: JobConfig = {
|
||||
intervalSec: 5,
|
||||
// TODO: finalize this after seeing effects on and from machine sleep
|
||||
maxItervalSec: 960,
|
||||
backoffMultiplier: 2,
|
||||
};
|
||||
|
||||
export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = {
|
||||
batchSize: 200,
|
||||
imageSource: "Original",
|
||||
faceDetection: {
|
||||
method: "YoloFace",
|
||||
},
|
||||
faceCrop: {
|
||||
enabled: true,
|
||||
method: "ArcFace",
|
||||
padding: 0.25,
|
||||
maxSize: 256,
|
||||
blobOptions: {
|
||||
type: "image/jpeg",
|
||||
quality: 0.8,
|
||||
},
|
||||
},
|
||||
faceAlignment: {
|
||||
method: "ArcFace",
|
||||
},
|
||||
blurDetection: {
|
||||
method: "Laplacian",
|
||||
threshold: 15,
|
||||
},
|
||||
faceEmbedding: {
|
||||
method: "MobileFaceNet",
|
||||
faceSize: 112,
|
||||
generateTsne: true,
|
||||
},
|
||||
faceClustering: {
|
||||
method: "Hdbscan",
|
||||
minClusterSize: 3,
|
||||
minSamples: 5,
|
||||
clusterSelectionEpsilon: 0.6,
|
||||
clusterSelectionMethod: "leaf",
|
||||
minInputSize: 50,
|
||||
// maxDistanceInsideCluster: 0.4,
|
||||
generateDebugInfo: true,
|
||||
},
|
||||
mlVersion: 3,
|
||||
};
|
||||
|
||||
export const DEFAULT_ML_SEARCH_CONFIG: MLSearchConfig = {
|
||||
enabled: false,
|
||||
};
|
||||
|
||||
export const MAX_ML_SYNC_ERROR_COUNT = 1;
|
|
@ -53,6 +53,10 @@ import { createContext, useEffect, useRef, useState } from "react";
|
|||
import LoadingBar from "react-top-loading-bar";
|
||||
import DownloadManager from "services/download";
|
||||
import exportService, { resumeExportsIfNeeded } from "services/export";
|
||||
import {
|
||||
getMLSearchConfig,
|
||||
updateMLSearchConfig,
|
||||
} from "services/machineLearning/machineLearningService";
|
||||
import mlWorkManager from "services/machineLearning/mlWorkManager";
|
||||
import {
|
||||
getFamilyPortalRedirectURL,
|
||||
|
@ -64,10 +68,6 @@ import {
|
|||
NotificationAttributes,
|
||||
SetNotificationAttributes,
|
||||
} from "types/Notification";
|
||||
import {
|
||||
getMLSearchConfig,
|
||||
updateMLSearchConfig,
|
||||
} from "utils/machineLearning/config";
|
||||
import {
|
||||
getUpdateAvailableForDownloadMessage,
|
||||
getUpdateReadyToInstallMessage,
|
||||
|
|
|
@ -84,7 +84,10 @@ import {
|
|||
getSectionSummaries,
|
||||
} from "services/collectionService";
|
||||
import downloadManager from "services/download";
|
||||
import { syncEmbeddings, syncFileEmbeddings } from "services/embeddingService";
|
||||
import {
|
||||
syncCLIPEmbeddings,
|
||||
syncFaceEmbeddings,
|
||||
} from "services/embeddingService";
|
||||
import { syncEntities } from "services/entityService";
|
||||
import { getLocalFiles, syncFiles } from "services/fileService";
|
||||
import locationSearchService from "services/locationSearchService";
|
||||
|
@ -130,6 +133,7 @@ import {
|
|||
} from "utils/file";
|
||||
import { isArchivedFile } from "utils/magicMetadata";
|
||||
import { getSessionExpiredMessage } from "utils/ui";
|
||||
import { isInternalUserForML } from "utils/user";
|
||||
import { getLocalFamilyData } from "utils/user/family";
|
||||
|
||||
export const DeadCenter = styled("div")`
|
||||
|
@ -698,10 +702,10 @@ export default function Gallery() {
|
|||
await syncTrash(collections, setTrashedFiles);
|
||||
await syncEntities();
|
||||
await syncMapEnabled();
|
||||
await syncEmbeddings();
|
||||
await syncCLIPEmbeddings();
|
||||
const electron = globalThis.electron;
|
||||
if (electron) {
|
||||
await syncFileEmbeddings();
|
||||
if (isInternalUserForML() && electron) {
|
||||
await syncFaceEmbeddings();
|
||||
}
|
||||
if (clipService.isPlatformSupported()) {
|
||||
void clipService.scheduleImageEmbeddingExtraction();
|
||||
|
|
|
@ -11,7 +11,7 @@ import { Embedding } from "types/embedding";
|
|||
import { EnteFile } from "types/file";
|
||||
import { getPersonalFiles } from "utils/file";
|
||||
import downloadManager from "./download";
|
||||
import { getLocalEmbeddings, putEmbedding } from "./embeddingService";
|
||||
import { localCLIPEmbeddings, putEmbedding } from "./embeddingService";
|
||||
import { getAllLocalFiles, getLocalFiles } from "./fileService";
|
||||
|
||||
/** Status of CLIP indexing on the images in the user's local library. */
|
||||
|
@ -195,7 +195,7 @@ class CLIPService {
|
|||
return;
|
||||
}
|
||||
const localFiles = getPersonalFiles(await getAllLocalFiles(), user);
|
||||
const existingEmbeddings = await getLocalEmbeddings();
|
||||
const existingEmbeddings = await localCLIPEmbeddings();
|
||||
const pendingFiles = await getNonClipEmbeddingExtractedFiles(
|
||||
localFiles,
|
||||
existingEmbeddings,
|
||||
|
@ -394,7 +394,7 @@ export const computeClipMatchScore = async (
|
|||
const initialIndexingStatus = async (): Promise<CLIPIndexingStatus> => {
|
||||
const user = getData(LS_KEYS.USER);
|
||||
if (!user) throw new Error("Orphan CLIP indexing without a login");
|
||||
const allEmbeddings = await getLocalEmbeddings();
|
||||
const allEmbeddings = await localCLIPEmbeddings();
|
||||
const localFiles = getPersonalFiles(await getLocalFiles(), user);
|
||||
const pendingFiles = await getNonClipEmbeddingExtractedFiles(
|
||||
localFiles,
|
||||
|
|
|
@ -7,6 +7,7 @@ import HTTPService from "@ente/shared/network/HTTPService";
|
|||
import { getEndpoint } from "@ente/shared/network/api";
|
||||
import localForage from "@ente/shared/storage/localForage";
|
||||
import { getToken } from "@ente/shared/storage/localStorage/helpers";
|
||||
import { FileML } from "services/machineLearning/machineLearningService";
|
||||
import type {
|
||||
Embedding,
|
||||
EmbeddingModel,
|
||||
|
@ -15,31 +16,30 @@ import type {
|
|||
PutEmbeddingRequest,
|
||||
} from "types/embedding";
|
||||
import { EnteFile } from "types/file";
|
||||
import {
|
||||
getLatestVersionEmbeddings,
|
||||
getLatestVersionFileEmbeddings,
|
||||
} from "utils/embedding";
|
||||
import { FileML } from "utils/machineLearning/mldataMappers";
|
||||
import { getLocalCollections } from "./collectionService";
|
||||
import { getAllLocalFiles } from "./fileService";
|
||||
import { getLocalTrashedFiles } from "./trashService";
|
||||
|
||||
const ENDPOINT = getEndpoint();
|
||||
|
||||
const DIFF_LIMIT = 500;
|
||||
|
||||
const EMBEDDINGS_TABLE_V1 = "embeddings";
|
||||
const EMBEDDINGS_TABLE = "embeddings_v2";
|
||||
/** Local storage key suffix for embedding sync times */
|
||||
const embeddingSyncTimeLSKeySuffix = "embedding_sync_time";
|
||||
/** Local storage key for CLIP embeddings. */
|
||||
const clipEmbeddingsLSKey = "embeddings_v2";
|
||||
const FILE_EMBEDING_TABLE = "file_embeddings";
|
||||
const EMBEDDING_SYNC_TIME_TABLE = "embedding_sync_time";
|
||||
|
||||
export const getAllLocalEmbeddings = async () => {
|
||||
/** Return all CLIP embeddings that we have available locally. */
|
||||
export const localCLIPEmbeddings = async () =>
|
||||
(await storedCLIPEmbeddings()).filter(({ model }) => model === "onnx-clip");
|
||||
|
||||
const storedCLIPEmbeddings = async () => {
|
||||
const embeddings: Array<Embedding> =
|
||||
await localForage.getItem<Embedding[]>(EMBEDDINGS_TABLE);
|
||||
await localForage.getItem<Embedding[]>(clipEmbeddingsLSKey);
|
||||
if (!embeddings) {
|
||||
await localForage.removeItem(EMBEDDINGS_TABLE_V1);
|
||||
await localForage.removeItem(EMBEDDING_SYNC_TIME_TABLE);
|
||||
await localForage.setItem(EMBEDDINGS_TABLE, []);
|
||||
// Migrate
|
||||
await localForage.removeItem("embeddings");
|
||||
await localForage.removeItem("embedding_sync_time");
|
||||
await localForage.setItem(clipEmbeddingsLSKey, []);
|
||||
return [];
|
||||
}
|
||||
return embeddings;
|
||||
|
@ -54,15 +54,10 @@ export const getFileMLEmbeddings = async (): Promise<FileML[]> => {
|
|||
return embeddings;
|
||||
};
|
||||
|
||||
export const getLocalEmbeddings = async () => {
|
||||
const embeddings = await getAllLocalEmbeddings();
|
||||
return embeddings.filter((embedding) => embedding.model === "onnx-clip");
|
||||
};
|
||||
|
||||
const getModelEmbeddingSyncTime = async (model: EmbeddingModel) => {
|
||||
return (
|
||||
(await localForage.getItem<number>(
|
||||
`${model}-${EMBEDDING_SYNC_TIME_TABLE}`,
|
||||
`${model}-${embeddingSyncTimeLSKeySuffix}`,
|
||||
)) ?? 0
|
||||
);
|
||||
};
|
||||
|
@ -71,13 +66,17 @@ const setModelEmbeddingSyncTime = async (
|
|||
model: EmbeddingModel,
|
||||
time: number,
|
||||
) => {
|
||||
await localForage.setItem(`${model}-${EMBEDDING_SYNC_TIME_TABLE}`, time);
|
||||
await localForage.setItem(`${model}-${embeddingSyncTimeLSKeySuffix}`, time);
|
||||
};
|
||||
|
||||
export const syncEmbeddings = async () => {
|
||||
const models: EmbeddingModel[] = ["onnx-clip"];
|
||||
/**
|
||||
* Fetch new CLIP embeddings with the server and save them locally. Also prune
|
||||
* local embeddings for any files no longer exist locally.
|
||||
*/
|
||||
export const syncCLIPEmbeddings = async () => {
|
||||
const model: EmbeddingModel = "onnx-clip";
|
||||
try {
|
||||
let allEmbeddings = await getAllLocalEmbeddings();
|
||||
let allEmbeddings = await storedCLIPEmbeddings();
|
||||
const localFiles = await getAllLocalFiles();
|
||||
const hiddenAlbums = await getLocalCollections("hidden");
|
||||
const localTrashFiles = await getLocalTrashedFiles();
|
||||
|
@ -89,79 +88,75 @@ export const syncEmbeddings = async () => {
|
|||
await cleanupDeletedEmbeddings(
|
||||
allLocalFiles,
|
||||
allEmbeddings,
|
||||
EMBEDDINGS_TABLE,
|
||||
clipEmbeddingsLSKey,
|
||||
);
|
||||
log.info(`Syncing embeddings localCount: ${allEmbeddings.length}`);
|
||||
for (const model of models) {
|
||||
let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
|
||||
log.info(
|
||||
`Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`,
|
||||
);
|
||||
let response: GetEmbeddingDiffResponse;
|
||||
do {
|
||||
response = await getEmbeddingsDiff(modelLastSinceTime, model);
|
||||
if (!response.diff?.length) {
|
||||
return;
|
||||
}
|
||||
const newEmbeddings = await Promise.all(
|
||||
response.diff.map(async (embedding) => {
|
||||
try {
|
||||
const {
|
||||
encryptedEmbedding,
|
||||
decryptionHeader,
|
||||
...rest
|
||||
} = embedding;
|
||||
const worker =
|
||||
await ComlinkCryptoWorker.getInstance();
|
||||
const fileKey = fileIdToKeyMap.get(
|
||||
embedding.fileID,
|
||||
);
|
||||
if (!fileKey) {
|
||||
throw Error(CustomError.FILE_NOT_FOUND);
|
||||
}
|
||||
const decryptedData = await worker.decryptEmbedding(
|
||||
encryptedEmbedding,
|
||||
decryptionHeader,
|
||||
fileIdToKeyMap.get(embedding.fileID),
|
||||
);
|
||||
|
||||
return {
|
||||
...rest,
|
||||
embedding: decryptedData,
|
||||
} as Embedding;
|
||||
} catch (e) {
|
||||
let hasHiddenAlbums = false;
|
||||
if (e.message === CustomError.FILE_NOT_FOUND) {
|
||||
hasHiddenAlbums = hiddenAlbums?.length > 0;
|
||||
}
|
||||
log.error(
|
||||
`decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`,
|
||||
e,
|
||||
);
|
||||
let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
|
||||
log.info(
|
||||
`Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`,
|
||||
);
|
||||
let response: GetEmbeddingDiffResponse;
|
||||
do {
|
||||
response = await getEmbeddingsDiff(modelLastSinceTime, model);
|
||||
if (!response.diff?.length) {
|
||||
return;
|
||||
}
|
||||
const newEmbeddings = await Promise.all(
|
||||
response.diff.map(async (embedding) => {
|
||||
try {
|
||||
const {
|
||||
encryptedEmbedding,
|
||||
decryptionHeader,
|
||||
...rest
|
||||
} = embedding;
|
||||
const worker = await ComlinkCryptoWorker.getInstance();
|
||||
const fileKey = fileIdToKeyMap.get(embedding.fileID);
|
||||
if (!fileKey) {
|
||||
throw Error(CustomError.FILE_NOT_FOUND);
|
||||
}
|
||||
}),
|
||||
);
|
||||
allEmbeddings = getLatestVersionEmbeddings([
|
||||
...allEmbeddings,
|
||||
...newEmbeddings,
|
||||
]);
|
||||
if (response.diff.length) {
|
||||
modelLastSinceTime = response.diff.slice(-1)[0].updatedAt;
|
||||
}
|
||||
await localForage.setItem(EMBEDDINGS_TABLE, allEmbeddings);
|
||||
await setModelEmbeddingSyncTime(model, modelLastSinceTime);
|
||||
log.info(
|
||||
`Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`,
|
||||
);
|
||||
} while (response.diff.length === DIFF_LIMIT);
|
||||
}
|
||||
const decryptedData = await worker.decryptEmbedding(
|
||||
encryptedEmbedding,
|
||||
decryptionHeader,
|
||||
fileIdToKeyMap.get(embedding.fileID),
|
||||
);
|
||||
|
||||
return {
|
||||
...rest,
|
||||
embedding: decryptedData,
|
||||
} as Embedding;
|
||||
} catch (e) {
|
||||
let hasHiddenAlbums = false;
|
||||
if (e.message === CustomError.FILE_NOT_FOUND) {
|
||||
hasHiddenAlbums = hiddenAlbums?.length > 0;
|
||||
}
|
||||
log.error(
|
||||
`decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`,
|
||||
e,
|
||||
);
|
||||
}
|
||||
}),
|
||||
);
|
||||
allEmbeddings = getLatestVersionEmbeddings([
|
||||
...allEmbeddings,
|
||||
...newEmbeddings,
|
||||
]);
|
||||
if (response.diff.length) {
|
||||
modelLastSinceTime = response.diff.slice(-1)[0].updatedAt;
|
||||
}
|
||||
await localForage.setItem(clipEmbeddingsLSKey, allEmbeddings);
|
||||
await setModelEmbeddingSyncTime(model, modelLastSinceTime);
|
||||
log.info(
|
||||
`Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`,
|
||||
);
|
||||
} while (response.diff.length === DIFF_LIMIT);
|
||||
} catch (e) {
|
||||
log.error("Sync embeddings failed", e);
|
||||
}
|
||||
};
|
||||
|
||||
export const syncFileEmbeddings = async () => {
|
||||
const models: EmbeddingModel[] = ["file-ml-clip-face"];
|
||||
export const syncFaceEmbeddings = async () => {
|
||||
const model: EmbeddingModel = "file-ml-clip-face";
|
||||
try {
|
||||
let allEmbeddings: FileML[] = await getFileMLEmbeddings();
|
||||
const localFiles = await getAllLocalFiles();
|
||||
|
@ -178,69 +173,99 @@ export const syncFileEmbeddings = async () => {
|
|||
FILE_EMBEDING_TABLE,
|
||||
);
|
||||
log.info(`Syncing embeddings localCount: ${allEmbeddings.length}`);
|
||||
for (const model of models) {
|
||||
let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
|
||||
log.info(
|
||||
`Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`,
|
||||
);
|
||||
let response: GetEmbeddingDiffResponse;
|
||||
do {
|
||||
response = await getEmbeddingsDiff(modelLastSinceTime, model);
|
||||
if (!response.diff?.length) {
|
||||
return;
|
||||
}
|
||||
const newEmbeddings = await Promise.all(
|
||||
response.diff.map(async (embedding) => {
|
||||
try {
|
||||
const worker =
|
||||
await ComlinkCryptoWorker.getInstance();
|
||||
const fileKey = fileIdToKeyMap.get(
|
||||
embedding.fileID,
|
||||
);
|
||||
if (!fileKey) {
|
||||
throw Error(CustomError.FILE_NOT_FOUND);
|
||||
}
|
||||
const decryptedData = await worker.decryptMetadata(
|
||||
embedding.encryptedEmbedding,
|
||||
embedding.decryptionHeader,
|
||||
fileIdToKeyMap.get(embedding.fileID),
|
||||
);
|
||||
|
||||
return {
|
||||
...decryptedData,
|
||||
updatedAt: embedding.updatedAt,
|
||||
} as unknown as FileML;
|
||||
} catch (e) {
|
||||
let hasHiddenAlbums = false;
|
||||
if (e.message === CustomError.FILE_NOT_FOUND) {
|
||||
hasHiddenAlbums = hiddenAlbums?.length > 0;
|
||||
}
|
||||
log.error(
|
||||
`decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`,
|
||||
e,
|
||||
);
|
||||
let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
|
||||
log.info(
|
||||
`Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`,
|
||||
);
|
||||
let response: GetEmbeddingDiffResponse;
|
||||
do {
|
||||
response = await getEmbeddingsDiff(modelLastSinceTime, model);
|
||||
if (!response.diff?.length) {
|
||||
return;
|
||||
}
|
||||
const newEmbeddings = await Promise.all(
|
||||
response.diff.map(async (embedding) => {
|
||||
try {
|
||||
const worker = await ComlinkCryptoWorker.getInstance();
|
||||
const fileKey = fileIdToKeyMap.get(embedding.fileID);
|
||||
if (!fileKey) {
|
||||
throw Error(CustomError.FILE_NOT_FOUND);
|
||||
}
|
||||
}),
|
||||
);
|
||||
allEmbeddings = getLatestVersionFileEmbeddings([
|
||||
...allEmbeddings,
|
||||
...newEmbeddings,
|
||||
]);
|
||||
if (response.diff.length) {
|
||||
modelLastSinceTime = response.diff.slice(-1)[0].updatedAt;
|
||||
}
|
||||
await localForage.setItem(FILE_EMBEDING_TABLE, allEmbeddings);
|
||||
await setModelEmbeddingSyncTime(model, modelLastSinceTime);
|
||||
log.info(
|
||||
`Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`,
|
||||
);
|
||||
} while (response.diff.length === DIFF_LIMIT);
|
||||
}
|
||||
const decryptedData = await worker.decryptMetadata(
|
||||
embedding.encryptedEmbedding,
|
||||
embedding.decryptionHeader,
|
||||
fileIdToKeyMap.get(embedding.fileID),
|
||||
);
|
||||
|
||||
return {
|
||||
...decryptedData,
|
||||
updatedAt: embedding.updatedAt,
|
||||
} as unknown as FileML;
|
||||
} catch (e) {
|
||||
let hasHiddenAlbums = false;
|
||||
if (e.message === CustomError.FILE_NOT_FOUND) {
|
||||
hasHiddenAlbums = hiddenAlbums?.length > 0;
|
||||
}
|
||||
log.error(
|
||||
`decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`,
|
||||
e,
|
||||
);
|
||||
}
|
||||
}),
|
||||
);
|
||||
allEmbeddings = getLatestVersionFileEmbeddings([
|
||||
...allEmbeddings,
|
||||
...newEmbeddings,
|
||||
]);
|
||||
if (response.diff.length) {
|
||||
modelLastSinceTime = response.diff.slice(-1)[0].updatedAt;
|
||||
}
|
||||
await localForage.setItem(FILE_EMBEDING_TABLE, allEmbeddings);
|
||||
await setModelEmbeddingSyncTime(model, modelLastSinceTime);
|
||||
log.info(
|
||||
`Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`,
|
||||
);
|
||||
} while (response.diff.length === DIFF_LIMIT);
|
||||
} catch (e) {
|
||||
log.error("Sync embeddings failed", e);
|
||||
}
|
||||
};
|
||||
|
||||
const getLatestVersionEmbeddings = (embeddings: Embedding[]) => {
|
||||
const latestVersionEntities = new Map<number, Embedding>();
|
||||
embeddings.forEach((embedding) => {
|
||||
if (!embedding?.fileID) {
|
||||
return;
|
||||
}
|
||||
const existingEmbeddings = latestVersionEntities.get(embedding.fileID);
|
||||
if (
|
||||
!existingEmbeddings ||
|
||||
existingEmbeddings.updatedAt < embedding.updatedAt
|
||||
) {
|
||||
latestVersionEntities.set(embedding.fileID, embedding);
|
||||
}
|
||||
});
|
||||
return Array.from(latestVersionEntities.values());
|
||||
};
|
||||
|
||||
const getLatestVersionFileEmbeddings = (embeddings: FileML[]) => {
|
||||
const latestVersionEntities = new Map<number, FileML>();
|
||||
embeddings.forEach((embedding) => {
|
||||
if (!embedding?.fileID) {
|
||||
return;
|
||||
}
|
||||
const existingEmbeddings = latestVersionEntities.get(embedding.fileID);
|
||||
if (
|
||||
!existingEmbeddings ||
|
||||
existingEmbeddings.updatedAt < embedding.updatedAt
|
||||
) {
|
||||
latestVersionEntities.set(embedding.fileID, embedding);
|
||||
}
|
||||
});
|
||||
return Array.from(latestVersionEntities.values());
|
||||
};
|
||||
|
||||
export const getEmbeddingsDiff = async (
|
||||
sinceTime: number,
|
||||
model: EmbeddingModel,
|
||||
|
@ -251,7 +276,7 @@ export const getEmbeddingsDiff = async (
|
|||
return;
|
||||
}
|
||||
const response = await HTTPService.get(
|
||||
`${ENDPOINT}/embeddings/diff`,
|
||||
`${getEndpoint()}/embeddings/diff`,
|
||||
{
|
||||
sinceTime,
|
||||
limit: DIFF_LIMIT,
|
||||
|
@ -280,7 +305,7 @@ export const putEmbedding = async (
|
|||
throw Error(CustomError.TOKEN_MISSING);
|
||||
}
|
||||
const resp = await HTTPService.put(
|
||||
`${ENDPOINT}/embeddings`,
|
||||
`${getEndpoint()}/embeddings`,
|
||||
putEmbeddingReq,
|
||||
null,
|
||||
{
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import { Matrix } from "ml-matrix";
|
||||
import { Point } from "services/ml/geom";
|
||||
import {
|
||||
FaceAlignment,
|
||||
FaceAlignmentMethod,
|
||||
|
@ -5,7 +7,7 @@ import {
|
|||
FaceDetection,
|
||||
Versioned,
|
||||
} from "services/ml/types";
|
||||
import { getArcfaceAlignment } from "utils/machineLearning/faceAlign";
|
||||
import { getSimilarityTransformation } from "similarity-transformation";
|
||||
|
||||
class ArcfaceAlignmentService implements FaceAlignmentService {
|
||||
public method: Versioned<FaceAlignmentMethod>;
|
||||
|
@ -23,3 +25,86 @@ class ArcfaceAlignmentService implements FaceAlignmentService {
|
|||
}
|
||||
|
||||
export default new ArcfaceAlignmentService();
|
||||
|
||||
const ARCFACE_LANDMARKS = [
|
||||
[38.2946, 51.6963],
|
||||
[73.5318, 51.5014],
|
||||
[56.0252, 71.7366],
|
||||
[56.1396, 92.2848],
|
||||
] as Array<[number, number]>;
|
||||
|
||||
const ARCFACE_LANDMARKS_FACE_SIZE = 112;
|
||||
|
||||
const ARC_FACE_5_LANDMARKS = [
|
||||
[38.2946, 51.6963],
|
||||
[73.5318, 51.5014],
|
||||
[56.0252, 71.7366],
|
||||
[41.5493, 92.3655],
|
||||
[70.7299, 92.2041],
|
||||
] as Array<[number, number]>;
|
||||
|
||||
export function getArcfaceAlignment(
|
||||
faceDetection: FaceDetection,
|
||||
): FaceAlignment {
|
||||
const landmarkCount = faceDetection.landmarks.length;
|
||||
return getFaceAlignmentUsingSimilarityTransform(
|
||||
faceDetection,
|
||||
normalizeLandmarks(
|
||||
landmarkCount === 5 ? ARC_FACE_5_LANDMARKS : ARCFACE_LANDMARKS,
|
||||
ARCFACE_LANDMARKS_FACE_SIZE,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
function getFaceAlignmentUsingSimilarityTransform(
|
||||
faceDetection: FaceDetection,
|
||||
alignedLandmarks: Array<[number, number]>,
|
||||
// alignmentMethod: Versioned<FaceAlignmentMethod>
|
||||
): FaceAlignment {
|
||||
const landmarksMat = new Matrix(
|
||||
faceDetection.landmarks
|
||||
.map((p) => [p.x, p.y])
|
||||
.slice(0, alignedLandmarks.length),
|
||||
).transpose();
|
||||
const alignedLandmarksMat = new Matrix(alignedLandmarks).transpose();
|
||||
|
||||
const simTransform = getSimilarityTransformation(
|
||||
landmarksMat,
|
||||
alignedLandmarksMat,
|
||||
);
|
||||
|
||||
const RS = Matrix.mul(simTransform.rotation, simTransform.scale);
|
||||
const TR = simTransform.translation;
|
||||
|
||||
const affineMatrix = [
|
||||
[RS.get(0, 0), RS.get(0, 1), TR.get(0, 0)],
|
||||
[RS.get(1, 0), RS.get(1, 1), TR.get(1, 0)],
|
||||
[0, 0, 1],
|
||||
];
|
||||
|
||||
const size = 1 / simTransform.scale;
|
||||
const meanTranslation = simTransform.toMean.sub(0.5).mul(size);
|
||||
const centerMat = simTransform.fromMean.sub(meanTranslation);
|
||||
const center = new Point(centerMat.get(0, 0), centerMat.get(1, 0));
|
||||
const rotation = -Math.atan2(
|
||||
simTransform.rotation.get(0, 1),
|
||||
simTransform.rotation.get(0, 0),
|
||||
);
|
||||
// log.info({ affineMatrix, meanTranslation, centerMat, center, toMean: simTransform.toMean, fromMean: simTransform.fromMean, size });
|
||||
|
||||
return {
|
||||
affineMatrix,
|
||||
center,
|
||||
size,
|
||||
rotation,
|
||||
};
|
||||
}
|
||||
|
||||
function normalizeLandmarks(
|
||||
landmarks: Array<[number, number]>,
|
||||
faceSize: number,
|
||||
): Array<[number, number]> {
|
||||
return landmarks.map((landmark) =>
|
||||
landmark.map((p) => p / faceSize),
|
||||
) as Array<[number, number]>;
|
||||
}
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import { Box, enlargeBox } from "services/ml/geom";
|
||||
import {
|
||||
FaceAlignment,
|
||||
FaceCrop,
|
||||
FaceCropConfig,
|
||||
FaceCropMethod,
|
||||
|
@ -6,8 +8,8 @@ import {
|
|||
FaceDetection,
|
||||
Versioned,
|
||||
} from "services/ml/types";
|
||||
import { getArcfaceAlignment } from "utils/machineLearning/faceAlign";
|
||||
import { getFaceCrop } from "utils/machineLearning/faceCrop";
|
||||
import { cropWithRotation } from "utils/image";
|
||||
import { getArcfaceAlignment } from "./arcfaceAlignmentService";
|
||||
|
||||
class ArcFaceCropService implements FaceCropService {
|
||||
public method: Versioned<FaceCropMethod>;
|
||||
|
@ -32,3 +34,27 @@ class ArcFaceCropService implements FaceCropService {
|
|||
}
|
||||
|
||||
export default new ArcFaceCropService();
|
||||
|
||||
export function getFaceCrop(
|
||||
imageBitmap: ImageBitmap,
|
||||
alignment: FaceAlignment,
|
||||
config: FaceCropConfig,
|
||||
): FaceCrop {
|
||||
const alignmentBox = new Box({
|
||||
x: alignment.center.x - alignment.size / 2,
|
||||
y: alignment.center.y - alignment.size / 2,
|
||||
width: alignment.size,
|
||||
height: alignment.size,
|
||||
}).round();
|
||||
const scaleForPadding = 1 + config.padding * 2;
|
||||
const paddedBox = enlargeBox(alignmentBox, scaleForPadding).round();
|
||||
const faceImageBitmap = cropWithRotation(imageBitmap, paddedBox, 0, {
|
||||
width: config.maxSize,
|
||||
height: config.maxSize,
|
||||
});
|
||||
|
||||
return {
|
||||
image: faceImageBitmap,
|
||||
imageBox: paddedBox,
|
||||
};
|
||||
}
|
||||
|
|
|
@ -1,22 +1,20 @@
|
|||
import { openCache } from "@/next/blob-cache";
|
||||
import log from "@/next/log";
|
||||
import mlIDbStorage from "services/ml/db";
|
||||
import {
|
||||
DetectedFace,
|
||||
Face,
|
||||
MLSyncContext,
|
||||
MLSyncFileContext,
|
||||
type FaceAlignment,
|
||||
type Versioned,
|
||||
} from "services/ml/types";
|
||||
import { imageBitmapToBlob } from "utils/image";
|
||||
import {
|
||||
areFaceIdsSame,
|
||||
extractFaceImagesToFloat32,
|
||||
import { imageBitmapToBlob, warpAffineFloat32List } from "utils/image";
|
||||
import ReaderService, {
|
||||
getFaceId,
|
||||
getLocalFile,
|
||||
getOriginalImageBitmap,
|
||||
isDifferentOrOld,
|
||||
} from "utils/machineLearning";
|
||||
import mlIDbStorage from "utils/storage/mlIDbStorage";
|
||||
import ReaderService from "./readerService";
|
||||
} from "./readerService";
|
||||
|
||||
class FaceService {
|
||||
async syncFileFaceDetections(
|
||||
|
@ -304,3 +302,58 @@ class FaceService {
|
|||
}
|
||||
|
||||
export default new FaceService();
|
||||
|
||||
export function areFaceIdsSame(ofFaces: Array<Face>, toFaces: Array<Face>) {
|
||||
if (
|
||||
(ofFaces === null || ofFaces === undefined) &&
|
||||
(toFaces === null || toFaces === undefined)
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
return primitiveArrayEquals(
|
||||
ofFaces?.map((f) => f.id),
|
||||
toFaces?.map((f) => f.id),
|
||||
);
|
||||
}
|
||||
|
||||
function primitiveArrayEquals(a, b) {
|
||||
return (
|
||||
Array.isArray(a) &&
|
||||
Array.isArray(b) &&
|
||||
a.length === b.length &&
|
||||
a.every((val, index) => val === b[index])
|
||||
);
|
||||
}
|
||||
|
||||
export function isDifferentOrOld(
|
||||
method: Versioned<string>,
|
||||
thanMethod: Versioned<string>,
|
||||
) {
|
||||
return (
|
||||
!method ||
|
||||
method.value !== thanMethod.value ||
|
||||
method.version < thanMethod.version
|
||||
);
|
||||
}
|
||||
|
||||
async function extractFaceImagesToFloat32(
|
||||
faceAlignments: Array<FaceAlignment>,
|
||||
faceSize: number,
|
||||
image: ImageBitmap,
|
||||
): Promise<Float32Array> {
|
||||
const faceData = new Float32Array(
|
||||
faceAlignments.length * faceSize * faceSize * 3,
|
||||
);
|
||||
for (let i = 0; i < faceAlignments.length; i++) {
|
||||
const alignedFace = faceAlignments[i];
|
||||
const faceDataOffset = i * faceSize * faceSize * 3;
|
||||
warpAffineFloat32List(
|
||||
image,
|
||||
alignedFace,
|
||||
faceSize,
|
||||
faceData,
|
||||
faceDataOffset,
|
||||
);
|
||||
}
|
||||
return faceData;
|
||||
}
|
||||
|
|
|
@ -1,216 +0,0 @@
|
|||
import { haveWindow } from "@/next/env";
|
||||
import log from "@/next/log";
|
||||
import { ComlinkWorker } from "@/next/worker/comlink-worker";
|
||||
import { getDedicatedCryptoWorker } from "@ente/shared/crypto";
|
||||
import { DedicatedCryptoWorker } from "@ente/shared/crypto/internal/crypto.worker";
|
||||
import PQueue from "p-queue";
|
||||
import {
|
||||
BlurDetectionMethod,
|
||||
BlurDetectionService,
|
||||
ClusteringMethod,
|
||||
ClusteringService,
|
||||
Face,
|
||||
FaceAlignmentMethod,
|
||||
FaceAlignmentService,
|
||||
FaceCropMethod,
|
||||
FaceCropService,
|
||||
FaceDetectionMethod,
|
||||
FaceDetectionService,
|
||||
FaceEmbeddingMethod,
|
||||
FaceEmbeddingService,
|
||||
MLLibraryData,
|
||||
MLSyncConfig,
|
||||
MLSyncContext,
|
||||
} from "services/ml/types";
|
||||
import { EnteFile } from "types/file";
|
||||
import { logQueueStats } from "utils/machineLearning";
|
||||
import arcfaceAlignmentService from "./arcfaceAlignmentService";
|
||||
import arcfaceCropService from "./arcfaceCropService";
|
||||
import dbscanClusteringService from "./dbscanClusteringService";
|
||||
import hdbscanClusteringService from "./hdbscanClusteringService";
|
||||
import laplacianBlurDetectionService from "./laplacianBlurDetectionService";
|
||||
import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService";
|
||||
import yoloFaceDetectionService from "./yoloFaceDetectionService";
|
||||
|
||||
export class MLFactory {
|
||||
public static getFaceDetectionService(
|
||||
method: FaceDetectionMethod,
|
||||
): FaceDetectionService {
|
||||
if (method === "YoloFace") {
|
||||
return yoloFaceDetectionService;
|
||||
}
|
||||
|
||||
throw Error("Unknon face detection method: " + method);
|
||||
}
|
||||
|
||||
public static getFaceCropService(method: FaceCropMethod) {
|
||||
if (method === "ArcFace") {
|
||||
return arcfaceCropService;
|
||||
}
|
||||
|
||||
throw Error("Unknon face crop method: " + method);
|
||||
}
|
||||
|
||||
public static getFaceAlignmentService(
|
||||
method: FaceAlignmentMethod,
|
||||
): FaceAlignmentService {
|
||||
if (method === "ArcFace") {
|
||||
return arcfaceAlignmentService;
|
||||
}
|
||||
|
||||
throw Error("Unknon face alignment method: " + method);
|
||||
}
|
||||
|
||||
public static getBlurDetectionService(
|
||||
method: BlurDetectionMethod,
|
||||
): BlurDetectionService {
|
||||
if (method === "Laplacian") {
|
||||
return laplacianBlurDetectionService;
|
||||
}
|
||||
|
||||
throw Error("Unknon blur detection method: " + method);
|
||||
}
|
||||
|
||||
public static getFaceEmbeddingService(
|
||||
method: FaceEmbeddingMethod,
|
||||
): FaceEmbeddingService {
|
||||
if (method === "MobileFaceNet") {
|
||||
return mobileFaceNetEmbeddingService;
|
||||
}
|
||||
|
||||
throw Error("Unknon face embedding method: " + method);
|
||||
}
|
||||
|
||||
public static getClusteringService(
|
||||
method: ClusteringMethod,
|
||||
): ClusteringService {
|
||||
if (method === "Hdbscan") {
|
||||
return hdbscanClusteringService;
|
||||
}
|
||||
if (method === "Dbscan") {
|
||||
return dbscanClusteringService;
|
||||
}
|
||||
|
||||
throw Error("Unknon clustering method: " + method);
|
||||
}
|
||||
|
||||
public static getMLSyncContext(
|
||||
token: string,
|
||||
userID: number,
|
||||
config: MLSyncConfig,
|
||||
shouldUpdateMLVersion: boolean = true,
|
||||
) {
|
||||
return new LocalMLSyncContext(
|
||||
token,
|
||||
userID,
|
||||
config,
|
||||
shouldUpdateMLVersion,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export class LocalMLSyncContext implements MLSyncContext {
|
||||
public token: string;
|
||||
public userID: number;
|
||||
public config: MLSyncConfig;
|
||||
public shouldUpdateMLVersion: boolean;
|
||||
|
||||
public faceDetectionService: FaceDetectionService;
|
||||
public faceCropService: FaceCropService;
|
||||
public faceAlignmentService: FaceAlignmentService;
|
||||
public blurDetectionService: BlurDetectionService;
|
||||
public faceEmbeddingService: FaceEmbeddingService;
|
||||
public faceClusteringService: ClusteringService;
|
||||
|
||||
public localFilesMap: Map<number, EnteFile>;
|
||||
public outOfSyncFiles: EnteFile[];
|
||||
public nSyncedFiles: number;
|
||||
public nSyncedFaces: number;
|
||||
public allSyncedFacesMap?: Map<number, Array<Face>>;
|
||||
|
||||
public error?: Error;
|
||||
|
||||
public mlLibraryData: MLLibraryData;
|
||||
|
||||
public syncQueue: PQueue;
|
||||
// TODO: wheather to limit concurrent downloads
|
||||
// private downloadQueue: PQueue;
|
||||
|
||||
private concurrency: number;
|
||||
private comlinkCryptoWorker: Array<
|
||||
ComlinkWorker<typeof DedicatedCryptoWorker>
|
||||
>;
|
||||
private enteWorkers: Array<any>;
|
||||
|
||||
constructor(
|
||||
token: string,
|
||||
userID: number,
|
||||
config: MLSyncConfig,
|
||||
shouldUpdateMLVersion: boolean = true,
|
||||
concurrency?: number,
|
||||
) {
|
||||
this.token = token;
|
||||
this.userID = userID;
|
||||
this.config = config;
|
||||
this.shouldUpdateMLVersion = shouldUpdateMLVersion;
|
||||
|
||||
this.faceDetectionService = MLFactory.getFaceDetectionService(
|
||||
this.config.faceDetection.method,
|
||||
);
|
||||
this.faceCropService = MLFactory.getFaceCropService(
|
||||
this.config.faceCrop.method,
|
||||
);
|
||||
this.faceAlignmentService = MLFactory.getFaceAlignmentService(
|
||||
this.config.faceAlignment.method,
|
||||
);
|
||||
this.blurDetectionService = MLFactory.getBlurDetectionService(
|
||||
this.config.blurDetection.method,
|
||||
);
|
||||
this.faceEmbeddingService = MLFactory.getFaceEmbeddingService(
|
||||
this.config.faceEmbedding.method,
|
||||
);
|
||||
this.faceClusteringService = MLFactory.getClusteringService(
|
||||
this.config.faceClustering.method,
|
||||
);
|
||||
|
||||
this.outOfSyncFiles = [];
|
||||
this.nSyncedFiles = 0;
|
||||
this.nSyncedFaces = 0;
|
||||
|
||||
this.concurrency = concurrency ?? getConcurrency();
|
||||
|
||||
log.info("Using concurrency: ", this.concurrency);
|
||||
// timeout is added on downloads
|
||||
// timeout on queue will keep the operation open till worker is terminated
|
||||
this.syncQueue = new PQueue({ concurrency: this.concurrency });
|
||||
logQueueStats(this.syncQueue, "sync");
|
||||
// this.downloadQueue = new PQueue({ concurrency: 1 });
|
||||
// logQueueStats(this.downloadQueue, 'download');
|
||||
|
||||
this.comlinkCryptoWorker = new Array(this.concurrency);
|
||||
this.enteWorkers = new Array(this.concurrency);
|
||||
}
|
||||
|
||||
public async getEnteWorker(id: number): Promise<any> {
|
||||
const wid = id % this.enteWorkers.length;
|
||||
console.log("getEnteWorker: ", id, wid);
|
||||
if (!this.enteWorkers[wid]) {
|
||||
this.comlinkCryptoWorker[wid] = getDedicatedCryptoWorker();
|
||||
this.enteWorkers[wid] = await this.comlinkCryptoWorker[wid].remote;
|
||||
}
|
||||
|
||||
return this.enteWorkers[wid];
|
||||
}
|
||||
|
||||
public async dispose() {
|
||||
this.localFilesMap = undefined;
|
||||
await this.syncQueue.onIdle();
|
||||
this.syncQueue.removeAllListeners();
|
||||
for (const enteComlinkWorker of this.comlinkCryptoWorker) {
|
||||
enteComlinkWorker?.terminate();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const getConcurrency = () =>
|
||||
haveWindow() && Math.max(2, Math.ceil(navigator.hardwareConcurrency / 2));
|
|
@ -1,25 +1,330 @@
|
|||
import { haveWindow } from "@/next/env";
|
||||
import log from "@/next/log";
|
||||
import { ComlinkWorker } from "@/next/worker/comlink-worker";
|
||||
import { APPS } from "@ente/shared/apps/constants";
|
||||
import ComlinkCryptoWorker from "@ente/shared/crypto";
|
||||
import ComlinkCryptoWorker, {
|
||||
getDedicatedCryptoWorker,
|
||||
} from "@ente/shared/crypto";
|
||||
import { DedicatedCryptoWorker } from "@ente/shared/crypto/internal/crypto.worker";
|
||||
import { CustomError, parseUploadErrorCodes } from "@ente/shared/error";
|
||||
import { MAX_ML_SYNC_ERROR_COUNT } from "constants/mlConfig";
|
||||
import PQueue from "p-queue";
|
||||
import downloadManager from "services/download";
|
||||
import { putEmbedding } from "services/embeddingService";
|
||||
import { getLocalFiles } from "services/fileService";
|
||||
import mlIDbStorage, {
|
||||
ML_SEARCH_CONFIG_NAME,
|
||||
ML_SYNC_CONFIG_NAME,
|
||||
ML_SYNC_JOB_CONFIG_NAME,
|
||||
} from "services/ml/db";
|
||||
import {
|
||||
BlurDetectionMethod,
|
||||
BlurDetectionService,
|
||||
ClusteringMethod,
|
||||
ClusteringService,
|
||||
Face,
|
||||
FaceAlignmentMethod,
|
||||
FaceAlignmentService,
|
||||
FaceCropMethod,
|
||||
FaceCropService,
|
||||
FaceDetection,
|
||||
FaceDetectionMethod,
|
||||
FaceDetectionService,
|
||||
FaceEmbeddingMethod,
|
||||
FaceEmbeddingService,
|
||||
Landmark,
|
||||
MLLibraryData,
|
||||
MLSearchConfig,
|
||||
MLSyncConfig,
|
||||
MLSyncContext,
|
||||
MLSyncFileContext,
|
||||
MLSyncResult,
|
||||
MlFileData,
|
||||
} from "services/ml/types";
|
||||
import { EnteFile } from "types/file";
|
||||
import { getMLSyncConfig } from "utils/machineLearning/config";
|
||||
import { LocalFileMlDataToServerFileMl } from "utils/machineLearning/mldataMappers";
|
||||
import mlIDbStorage from "utils/storage/mlIDbStorage";
|
||||
import { isInternalUserForML } from "utils/user";
|
||||
import arcfaceAlignmentService from "./arcfaceAlignmentService";
|
||||
import arcfaceCropService from "./arcfaceCropService";
|
||||
import dbscanClusteringService from "./dbscanClusteringService";
|
||||
import FaceService from "./faceService";
|
||||
import { MLFactory } from "./machineLearningFactory";
|
||||
import hdbscanClusteringService from "./hdbscanClusteringService";
|
||||
import laplacianBlurDetectionService from "./laplacianBlurDetectionService";
|
||||
import type { JobConfig } from "./mlWorkManager";
|
||||
import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService";
|
||||
import PeopleService from "./peopleService";
|
||||
import ReaderService from "./readerService";
|
||||
import yoloFaceDetectionService from "./yoloFaceDetectionService";
|
||||
|
||||
export const DEFAULT_ML_SYNC_JOB_CONFIG: JobConfig = {
|
||||
intervalSec: 5,
|
||||
// TODO: finalize this after seeing effects on and from machine sleep
|
||||
maxItervalSec: 960,
|
||||
backoffMultiplier: 2,
|
||||
};
|
||||
|
||||
export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = {
|
||||
batchSize: 200,
|
||||
imageSource: "Original",
|
||||
faceDetection: {
|
||||
method: "YoloFace",
|
||||
},
|
||||
faceCrop: {
|
||||
enabled: true,
|
||||
method: "ArcFace",
|
||||
padding: 0.25,
|
||||
maxSize: 256,
|
||||
blobOptions: {
|
||||
type: "image/jpeg",
|
||||
quality: 0.8,
|
||||
},
|
||||
},
|
||||
faceAlignment: {
|
||||
method: "ArcFace",
|
||||
},
|
||||
blurDetection: {
|
||||
method: "Laplacian",
|
||||
threshold: 15,
|
||||
},
|
||||
faceEmbedding: {
|
||||
method: "MobileFaceNet",
|
||||
faceSize: 112,
|
||||
generateTsne: true,
|
||||
},
|
||||
faceClustering: {
|
||||
method: "Hdbscan",
|
||||
minClusterSize: 3,
|
||||
minSamples: 5,
|
||||
clusterSelectionEpsilon: 0.6,
|
||||
clusterSelectionMethod: "leaf",
|
||||
minInputSize: 50,
|
||||
// maxDistanceInsideCluster: 0.4,
|
||||
generateDebugInfo: true,
|
||||
},
|
||||
mlVersion: 3,
|
||||
};
|
||||
|
||||
export const DEFAULT_ML_SEARCH_CONFIG: MLSearchConfig = {
|
||||
enabled: false,
|
||||
};
|
||||
|
||||
export const MAX_ML_SYNC_ERROR_COUNT = 1;
|
||||
|
||||
export async function getMLSyncJobConfig() {
|
||||
return mlIDbStorage.getConfig(
|
||||
ML_SYNC_JOB_CONFIG_NAME,
|
||||
DEFAULT_ML_SYNC_JOB_CONFIG,
|
||||
);
|
||||
}
|
||||
|
||||
export async function getMLSyncConfig() {
|
||||
return mlIDbStorage.getConfig(ML_SYNC_CONFIG_NAME, DEFAULT_ML_SYNC_CONFIG);
|
||||
}
|
||||
|
||||
export async function getMLSearchConfig() {
|
||||
if (isInternalUserForML()) {
|
||||
return mlIDbStorage.getConfig(
|
||||
ML_SEARCH_CONFIG_NAME,
|
||||
DEFAULT_ML_SEARCH_CONFIG,
|
||||
);
|
||||
}
|
||||
// Force disabled for everyone else while we finalize it to avoid redundant
|
||||
// reindexing for users.
|
||||
return DEFAULT_ML_SEARCH_CONFIG;
|
||||
}
|
||||
|
||||
export async function updateMLSyncJobConfig(newConfig: JobConfig) {
|
||||
return mlIDbStorage.putConfig(ML_SYNC_JOB_CONFIG_NAME, newConfig);
|
||||
}
|
||||
|
||||
export async function updateMLSyncConfig(newConfig: MLSyncConfig) {
|
||||
return mlIDbStorage.putConfig(ML_SYNC_CONFIG_NAME, newConfig);
|
||||
}
|
||||
|
||||
export async function updateMLSearchConfig(newConfig: MLSearchConfig) {
|
||||
return mlIDbStorage.putConfig(ML_SEARCH_CONFIG_NAME, newConfig);
|
||||
}
|
||||
|
||||
export class MLFactory {
|
||||
public static getFaceDetectionService(
|
||||
method: FaceDetectionMethod,
|
||||
): FaceDetectionService {
|
||||
if (method === "YoloFace") {
|
||||
return yoloFaceDetectionService;
|
||||
}
|
||||
|
||||
throw Error("Unknon face detection method: " + method);
|
||||
}
|
||||
|
||||
public static getFaceCropService(method: FaceCropMethod) {
|
||||
if (method === "ArcFace") {
|
||||
return arcfaceCropService;
|
||||
}
|
||||
|
||||
throw Error("Unknon face crop method: " + method);
|
||||
}
|
||||
|
||||
public static getFaceAlignmentService(
|
||||
method: FaceAlignmentMethod,
|
||||
): FaceAlignmentService {
|
||||
if (method === "ArcFace") {
|
||||
return arcfaceAlignmentService;
|
||||
}
|
||||
|
||||
throw Error("Unknon face alignment method: " + method);
|
||||
}
|
||||
|
||||
public static getBlurDetectionService(
|
||||
method: BlurDetectionMethod,
|
||||
): BlurDetectionService {
|
||||
if (method === "Laplacian") {
|
||||
return laplacianBlurDetectionService;
|
||||
}
|
||||
|
||||
throw Error("Unknon blur detection method: " + method);
|
||||
}
|
||||
|
||||
public static getFaceEmbeddingService(
|
||||
method: FaceEmbeddingMethod,
|
||||
): FaceEmbeddingService {
|
||||
if (method === "MobileFaceNet") {
|
||||
return mobileFaceNetEmbeddingService;
|
||||
}
|
||||
|
||||
throw Error("Unknon face embedding method: " + method);
|
||||
}
|
||||
|
||||
public static getClusteringService(
|
||||
method: ClusteringMethod,
|
||||
): ClusteringService {
|
||||
if (method === "Hdbscan") {
|
||||
return hdbscanClusteringService;
|
||||
}
|
||||
if (method === "Dbscan") {
|
||||
return dbscanClusteringService;
|
||||
}
|
||||
|
||||
throw Error("Unknon clustering method: " + method);
|
||||
}
|
||||
|
||||
public static getMLSyncContext(
|
||||
token: string,
|
||||
userID: number,
|
||||
config: MLSyncConfig,
|
||||
shouldUpdateMLVersion: boolean = true,
|
||||
) {
|
||||
return new LocalMLSyncContext(
|
||||
token,
|
||||
userID,
|
||||
config,
|
||||
shouldUpdateMLVersion,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export class LocalMLSyncContext implements MLSyncContext {
|
||||
public token: string;
|
||||
public userID: number;
|
||||
public config: MLSyncConfig;
|
||||
public shouldUpdateMLVersion: boolean;
|
||||
|
||||
public faceDetectionService: FaceDetectionService;
|
||||
public faceCropService: FaceCropService;
|
||||
public faceAlignmentService: FaceAlignmentService;
|
||||
public blurDetectionService: BlurDetectionService;
|
||||
public faceEmbeddingService: FaceEmbeddingService;
|
||||
public faceClusteringService: ClusteringService;
|
||||
|
||||
public localFilesMap: Map<number, EnteFile>;
|
||||
public outOfSyncFiles: EnteFile[];
|
||||
public nSyncedFiles: number;
|
||||
public nSyncedFaces: number;
|
||||
public allSyncedFacesMap?: Map<number, Array<Face>>;
|
||||
|
||||
public error?: Error;
|
||||
|
||||
public mlLibraryData: MLLibraryData;
|
||||
|
||||
public syncQueue: PQueue;
|
||||
// TODO: wheather to limit concurrent downloads
|
||||
// private downloadQueue: PQueue;
|
||||
|
||||
private concurrency: number;
|
||||
private comlinkCryptoWorker: Array<
|
||||
ComlinkWorker<typeof DedicatedCryptoWorker>
|
||||
>;
|
||||
private enteWorkers: Array<any>;
|
||||
|
||||
constructor(
|
||||
token: string,
|
||||
userID: number,
|
||||
config: MLSyncConfig,
|
||||
shouldUpdateMLVersion: boolean = true,
|
||||
concurrency?: number,
|
||||
) {
|
||||
this.token = token;
|
||||
this.userID = userID;
|
||||
this.config = config;
|
||||
this.shouldUpdateMLVersion = shouldUpdateMLVersion;
|
||||
|
||||
this.faceDetectionService = MLFactory.getFaceDetectionService(
|
||||
this.config.faceDetection.method,
|
||||
);
|
||||
this.faceCropService = MLFactory.getFaceCropService(
|
||||
this.config.faceCrop.method,
|
||||
);
|
||||
this.faceAlignmentService = MLFactory.getFaceAlignmentService(
|
||||
this.config.faceAlignment.method,
|
||||
);
|
||||
this.blurDetectionService = MLFactory.getBlurDetectionService(
|
||||
this.config.blurDetection.method,
|
||||
);
|
||||
this.faceEmbeddingService = MLFactory.getFaceEmbeddingService(
|
||||
this.config.faceEmbedding.method,
|
||||
);
|
||||
this.faceClusteringService = MLFactory.getClusteringService(
|
||||
this.config.faceClustering.method,
|
||||
);
|
||||
|
||||
this.outOfSyncFiles = [];
|
||||
this.nSyncedFiles = 0;
|
||||
this.nSyncedFaces = 0;
|
||||
|
||||
this.concurrency = concurrency ?? getConcurrency();
|
||||
|
||||
log.info("Using concurrency: ", this.concurrency);
|
||||
// timeout is added on downloads
|
||||
// timeout on queue will keep the operation open till worker is terminated
|
||||
this.syncQueue = new PQueue({ concurrency: this.concurrency });
|
||||
logQueueStats(this.syncQueue, "sync");
|
||||
// this.downloadQueue = new PQueue({ concurrency: 1 });
|
||||
// logQueueStats(this.downloadQueue, 'download');
|
||||
|
||||
this.comlinkCryptoWorker = new Array(this.concurrency);
|
||||
this.enteWorkers = new Array(this.concurrency);
|
||||
}
|
||||
|
||||
public async getEnteWorker(id: number): Promise<any> {
|
||||
const wid = id % this.enteWorkers.length;
|
||||
console.log("getEnteWorker: ", id, wid);
|
||||
if (!this.enteWorkers[wid]) {
|
||||
this.comlinkCryptoWorker[wid] = getDedicatedCryptoWorker();
|
||||
this.enteWorkers[wid] = await this.comlinkCryptoWorker[wid].remote;
|
||||
}
|
||||
|
||||
return this.enteWorkers[wid];
|
||||
}
|
||||
|
||||
public async dispose() {
|
||||
this.localFilesMap = undefined;
|
||||
await this.syncQueue.onIdle();
|
||||
this.syncQueue.removeAllListeners();
|
||||
for (const enteComlinkWorker of this.comlinkCryptoWorker) {
|
||||
enteComlinkWorker?.terminate();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const getConcurrency = () =>
|
||||
haveWindow() && Math.max(2, Math.ceil(navigator.hardwareConcurrency / 2));
|
||||
|
||||
class MachineLearningService {
|
||||
private localSyncContext: Promise<MLSyncContext>;
|
||||
|
@ -445,3 +750,160 @@ class MachineLearningService {
|
|||
}
|
||||
|
||||
export default new MachineLearningService();
|
||||
|
||||
export interface FileML extends ServerFileMl {
|
||||
updatedAt: number;
|
||||
}
|
||||
|
||||
class ServerFileMl {
|
||||
public fileID: number;
|
||||
public height?: number;
|
||||
public width?: number;
|
||||
public faceEmbedding: ServerFaceEmbeddings;
|
||||
|
||||
public constructor(
|
||||
fileID: number,
|
||||
faceEmbedding: ServerFaceEmbeddings,
|
||||
height?: number,
|
||||
width?: number,
|
||||
) {
|
||||
this.fileID = fileID;
|
||||
this.height = height;
|
||||
this.width = width;
|
||||
this.faceEmbedding = faceEmbedding;
|
||||
}
|
||||
}
|
||||
|
||||
class ServerFaceEmbeddings {
|
||||
public faces: ServerFace[];
|
||||
public version: number;
|
||||
public client?: string;
|
||||
public error?: boolean;
|
||||
|
||||
public constructor(
|
||||
faces: ServerFace[],
|
||||
version: number,
|
||||
client?: string,
|
||||
error?: boolean,
|
||||
) {
|
||||
this.faces = faces;
|
||||
this.version = version;
|
||||
this.client = client;
|
||||
this.error = error;
|
||||
}
|
||||
}
|
||||
|
||||
class ServerFace {
|
||||
public faceID: string;
|
||||
public embeddings: number[];
|
||||
public detection: ServerDetection;
|
||||
public score: number;
|
||||
public blur: number;
|
||||
|
||||
public constructor(
|
||||
faceID: string,
|
||||
embeddings: number[],
|
||||
detection: ServerDetection,
|
||||
score: number,
|
||||
blur: number,
|
||||
) {
|
||||
this.faceID = faceID;
|
||||
this.embeddings = embeddings;
|
||||
this.detection = detection;
|
||||
this.score = score;
|
||||
this.blur = blur;
|
||||
}
|
||||
}
|
||||
|
||||
class ServerDetection {
|
||||
public box: ServerFaceBox;
|
||||
public landmarks: Landmark[];
|
||||
|
||||
public constructor(box: ServerFaceBox, landmarks: Landmark[]) {
|
||||
this.box = box;
|
||||
this.landmarks = landmarks;
|
||||
}
|
||||
}
|
||||
|
||||
class ServerFaceBox {
|
||||
public xMin: number;
|
||||
public yMin: number;
|
||||
public width: number;
|
||||
public height: number;
|
||||
|
||||
public constructor(
|
||||
xMin: number,
|
||||
yMin: number,
|
||||
width: number,
|
||||
height: number,
|
||||
) {
|
||||
this.xMin = xMin;
|
||||
this.yMin = yMin;
|
||||
this.width = width;
|
||||
this.height = height;
|
||||
}
|
||||
}
|
||||
|
||||
function LocalFileMlDataToServerFileMl(
|
||||
localFileMlData: MlFileData,
|
||||
): ServerFileMl {
|
||||
if (
|
||||
localFileMlData.errorCount > 0 &&
|
||||
localFileMlData.lastErrorMessage !== undefined
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
const imageDimensions = localFileMlData.imageDimensions;
|
||||
|
||||
const faces: ServerFace[] = [];
|
||||
for (let i = 0; i < localFileMlData.faces.length; i++) {
|
||||
const face: Face = localFileMlData.faces[i];
|
||||
const faceID = face.id;
|
||||
const embedding = face.embedding;
|
||||
const score = face.detection.probability;
|
||||
const blur = face.blurValue;
|
||||
const detection: FaceDetection = face.detection;
|
||||
const box = detection.box;
|
||||
const landmarks = detection.landmarks;
|
||||
const newBox = new ServerFaceBox(box.x, box.y, box.width, box.height);
|
||||
const newLandmarks: Landmark[] = [];
|
||||
for (let j = 0; j < landmarks.length; j++) {
|
||||
newLandmarks.push({
|
||||
x: landmarks[j].x,
|
||||
y: landmarks[j].y,
|
||||
} as Landmark);
|
||||
}
|
||||
|
||||
const newFaceObject = new ServerFace(
|
||||
faceID,
|
||||
Array.from(embedding),
|
||||
new ServerDetection(newBox, newLandmarks),
|
||||
score,
|
||||
blur,
|
||||
);
|
||||
faces.push(newFaceObject);
|
||||
}
|
||||
const faceEmbeddings = new ServerFaceEmbeddings(
|
||||
faces,
|
||||
1,
|
||||
localFileMlData.lastErrorMessage,
|
||||
);
|
||||
return new ServerFileMl(
|
||||
localFileMlData.fileId,
|
||||
faceEmbeddings,
|
||||
imageDimensions.height,
|
||||
imageDimensions.width,
|
||||
);
|
||||
}
|
||||
|
||||
export function logQueueStats(queue: PQueue, name: string) {
|
||||
queue.on("active", () =>
|
||||
log.info(
|
||||
`queuestats: ${name}: Active, Size: ${queue.size} Pending: ${queue.pending}`,
|
||||
),
|
||||
);
|
||||
queue.on("idle", () => log.info(`queuestats: ${name}: Idle`));
|
||||
queue.on("error", (error) =>
|
||||
console.error(`queuestats: ${name}: Error, `, error),
|
||||
);
|
||||
}
|
||||
|
|
|
@ -5,20 +5,110 @@ import { eventBus, Events } from "@ente/shared/events";
|
|||
import { getToken, getUserID } from "@ente/shared/storage/localStorage/helpers";
|
||||
import debounce from "debounce";
|
||||
import PQueue from "p-queue";
|
||||
import { getMLSyncJobConfig } from "services/machineLearning/machineLearningService";
|
||||
import mlIDbStorage from "services/ml/db";
|
||||
import { MLSyncResult } from "services/ml/types";
|
||||
import { JobResult } from "types/common/job";
|
||||
import { EnteFile } from "types/file";
|
||||
import { getDedicatedMLWorker } from "utils/comlink/ComlinkMLWorker";
|
||||
import { SimpleJob } from "utils/common/job";
|
||||
import { logQueueStats } from "utils/machineLearning";
|
||||
import { getMLSyncJobConfig } from "utils/machineLearning/config";
|
||||
import mlIDbStorage from "utils/storage/mlIDbStorage";
|
||||
import { DedicatedMLWorker } from "worker/ml.worker";
|
||||
import { logQueueStats } from "./machineLearningService";
|
||||
|
||||
const LIVE_SYNC_IDLE_DEBOUNCE_SEC = 30;
|
||||
const LIVE_SYNC_QUEUE_TIMEOUT_SEC = 300;
|
||||
const LOCAL_FILES_UPDATED_DEBOUNCE_SEC = 30;
|
||||
|
||||
export type JobState = "Scheduled" | "Running" | "NotScheduled";
|
||||
|
||||
export interface JobConfig {
|
||||
intervalSec: number;
|
||||
maxItervalSec: number;
|
||||
backoffMultiplier: number;
|
||||
}
|
||||
|
||||
export interface JobResult {
|
||||
shouldBackoff: boolean;
|
||||
}
|
||||
|
||||
export class SimpleJob<R extends JobResult> {
|
||||
private config: JobConfig;
|
||||
private runCallback: () => Promise<R>;
|
||||
private state: JobState;
|
||||
private stopped: boolean;
|
||||
private intervalSec: number;
|
||||
private nextTimeoutId: ReturnType<typeof setTimeout>;
|
||||
|
||||
constructor(config: JobConfig, runCallback: () => Promise<R>) {
|
||||
this.config = config;
|
||||
this.runCallback = runCallback;
|
||||
this.state = "NotScheduled";
|
||||
this.stopped = true;
|
||||
this.intervalSec = this.config.intervalSec;
|
||||
}
|
||||
|
||||
public resetInterval() {
|
||||
this.intervalSec = this.config.intervalSec;
|
||||
}
|
||||
|
||||
public start() {
|
||||
this.stopped = false;
|
||||
this.resetInterval();
|
||||
if (this.state !== "Running") {
|
||||
this.scheduleNext();
|
||||
} else {
|
||||
log.info("Job already running, not scheduling");
|
||||
}
|
||||
}
|
||||
|
||||
private scheduleNext() {
|
||||
if (this.state === "Scheduled" || this.nextTimeoutId) {
|
||||
this.clearScheduled();
|
||||
}
|
||||
|
||||
this.nextTimeoutId = setTimeout(
|
||||
() => this.run(),
|
||||
this.intervalSec * 1000,
|
||||
);
|
||||
this.state = "Scheduled";
|
||||
log.info("Scheduled next job after: ", this.intervalSec);
|
||||
}
|
||||
|
||||
async run() {
|
||||
this.nextTimeoutId = undefined;
|
||||
this.state = "Running";
|
||||
|
||||
try {
|
||||
const jobResult = await this.runCallback();
|
||||
if (jobResult && jobResult.shouldBackoff) {
|
||||
this.intervalSec = Math.min(
|
||||
this.config.maxItervalSec,
|
||||
this.intervalSec * this.config.backoffMultiplier,
|
||||
);
|
||||
} else {
|
||||
this.resetInterval();
|
||||
}
|
||||
log.info("Job completed");
|
||||
} catch (e) {
|
||||
console.error("Error while running Job: ", e);
|
||||
} finally {
|
||||
this.state = "NotScheduled";
|
||||
!this.stopped && this.scheduleNext();
|
||||
}
|
||||
}
|
||||
|
||||
// currently client is responsible to terminate running job
|
||||
public stop() {
|
||||
this.stopped = true;
|
||||
this.clearScheduled();
|
||||
}
|
||||
|
||||
private clearScheduled() {
|
||||
clearTimeout(this.nextTimeoutId);
|
||||
this.nextTimeoutId = undefined;
|
||||
this.state = "NotScheduled";
|
||||
log.info("Cleared next job");
|
||||
}
|
||||
}
|
||||
|
||||
export interface MLSyncJobResult extends JobResult {
|
||||
mlSyncResult: MLSyncResult;
|
||||
}
|
||||
|
|
|
@ -1,14 +1,8 @@
|
|||
import log from "@/next/log";
|
||||
import mlIDbStorage from "services/ml/db";
|
||||
import { Face, MLSyncContext, Person } from "services/ml/types";
|
||||
import {
|
||||
findFirstIfSorted,
|
||||
getAllFacesFromMap,
|
||||
getLocalFile,
|
||||
getOriginalImageBitmap,
|
||||
isDifferentOrOld,
|
||||
} from "utils/machineLearning";
|
||||
import mlIDbStorage from "utils/storage/mlIDbStorage";
|
||||
import FaceService from "./faceService";
|
||||
import FaceService, { isDifferentOrOld } from "./faceService";
|
||||
import { getLocalFile, getOriginalImageBitmap } from "./readerService";
|
||||
|
||||
class PeopleService {
|
||||
async syncPeopleIndex(syncContext: MLSyncContext) {
|
||||
|
@ -92,3 +86,28 @@ class PeopleService {
|
|||
}
|
||||
|
||||
export default new PeopleService();
|
||||
|
||||
function findFirstIfSorted<T>(
|
||||
elements: Array<T>,
|
||||
comparator: (a: T, b: T) => number,
|
||||
) {
|
||||
if (!elements || elements.length < 1) {
|
||||
return;
|
||||
}
|
||||
let first = elements[0];
|
||||
|
||||
for (let i = 1; i < elements.length; i++) {
|
||||
const comp = comparator(elements[i], first);
|
||||
if (comp < 0) {
|
||||
first = elements[i];
|
||||
}
|
||||
}
|
||||
|
||||
return first;
|
||||
}
|
||||
|
||||
function getAllFacesFromMap(allFacesMap: Map<number, Array<Face>>) {
|
||||
const allFaces = [...allFacesMap.values()].flat();
|
||||
|
||||
return allFaces;
|
||||
}
|
||||
|
|
|
@ -1,11 +1,18 @@
|
|||
import { FILE_TYPE } from "@/media/file-type";
|
||||
import { decodeLivePhoto } from "@/media/live-photo";
|
||||
import log from "@/next/log";
|
||||
import { MLSyncContext, MLSyncFileContext } from "services/ml/types";
|
||||
import PQueue from "p-queue";
|
||||
import DownloadManager from "services/download";
|
||||
import { getLocalFiles } from "services/fileService";
|
||||
import { Dimensions } from "services/ml/geom";
|
||||
import {
|
||||
getLocalFileImageBitmap,
|
||||
getOriginalImageBitmap,
|
||||
getThumbnailImageBitmap,
|
||||
} from "utils/machineLearning";
|
||||
DetectedFace,
|
||||
MLSyncContext,
|
||||
MLSyncFileContext,
|
||||
} from "services/ml/types";
|
||||
import { EnteFile } from "types/file";
|
||||
import { getRenderableImage } from "utils/file";
|
||||
import { clamp } from "utils/image";
|
||||
|
||||
class ReaderService {
|
||||
async getImageBitmap(
|
||||
|
@ -55,3 +62,95 @@ class ReaderService {
|
|||
}
|
||||
}
|
||||
export default new ReaderService();
|
||||
|
||||
export async function getLocalFile(fileId: number) {
|
||||
const localFiles = await getLocalFiles();
|
||||
return localFiles.find((f) => f.id === fileId);
|
||||
}
|
||||
|
||||
export function getFaceId(detectedFace: DetectedFace, imageDims: Dimensions) {
|
||||
const xMin = clamp(
|
||||
detectedFace.detection.box.x / imageDims.width,
|
||||
0.0,
|
||||
0.999999,
|
||||
)
|
||||
.toFixed(5)
|
||||
.substring(2);
|
||||
const yMin = clamp(
|
||||
detectedFace.detection.box.y / imageDims.height,
|
||||
0.0,
|
||||
0.999999,
|
||||
)
|
||||
.toFixed(5)
|
||||
.substring(2);
|
||||
const xMax = clamp(
|
||||
(detectedFace.detection.box.x + detectedFace.detection.box.width) /
|
||||
imageDims.width,
|
||||
0.0,
|
||||
0.999999,
|
||||
)
|
||||
.toFixed(5)
|
||||
.substring(2);
|
||||
const yMax = clamp(
|
||||
(detectedFace.detection.box.y + detectedFace.detection.box.height) /
|
||||
imageDims.height,
|
||||
0.0,
|
||||
0.999999,
|
||||
)
|
||||
.toFixed(5)
|
||||
.substring(2);
|
||||
|
||||
const rawFaceID = `${xMin}_${yMin}_${xMax}_${yMax}`;
|
||||
const faceID = `${detectedFace.fileId}_${rawFaceID}`;
|
||||
|
||||
return faceID;
|
||||
}
|
||||
|
||||
async function getImageBlobBitmap(blob: Blob): Promise<ImageBitmap> {
|
||||
return await createImageBitmap(blob);
|
||||
}
|
||||
|
||||
async function getOriginalFile(file: EnteFile, queue?: PQueue) {
|
||||
let fileStream;
|
||||
if (queue) {
|
||||
fileStream = await queue.add(() => DownloadManager.getFile(file));
|
||||
} else {
|
||||
fileStream = await DownloadManager.getFile(file);
|
||||
}
|
||||
return new Response(fileStream).blob();
|
||||
}
|
||||
|
||||
async function getOriginalConvertedFile(file: EnteFile, queue?: PQueue) {
|
||||
const fileBlob = await getOriginalFile(file, queue);
|
||||
if (file.metadata.fileType === FILE_TYPE.IMAGE) {
|
||||
return await getRenderableImage(file.metadata.title, fileBlob);
|
||||
} else {
|
||||
const { imageFileName, imageData } = await decodeLivePhoto(
|
||||
file.metadata.title,
|
||||
fileBlob,
|
||||
);
|
||||
return await getRenderableImage(imageFileName, new Blob([imageData]));
|
||||
}
|
||||
}
|
||||
|
||||
export async function getOriginalImageBitmap(file: EnteFile, queue?: PQueue) {
|
||||
const fileBlob = await getOriginalConvertedFile(file, queue);
|
||||
log.info("[MLService] Got file: ", file.id.toString());
|
||||
return getImageBlobBitmap(fileBlob);
|
||||
}
|
||||
|
||||
export async function getThumbnailImageBitmap(file: EnteFile) {
|
||||
const thumb = await DownloadManager.getThumbnail(file);
|
||||
log.info("[MLService] Got thumbnail: ", file.id.toString());
|
||||
|
||||
return getImageBlobBitmap(new Blob([thumb]));
|
||||
}
|
||||
|
||||
export async function getLocalFileImageBitmap(
|
||||
enteFile: EnteFile,
|
||||
localFile: globalThis.File,
|
||||
) {
|
||||
let fileBlob = localFile as Blob;
|
||||
fileBlob = await getRenderableImage(enteFile.metadata.title, fileBlob);
|
||||
return getImageBlobBitmap(fileBlob);
|
||||
}
|
||||
|
|
|
@ -323,14 +323,10 @@ function transformBox(box: Box, transform: Matrix) {
|
|||
const topLeft = transformPoint(box.topLeft, transform);
|
||||
const bottomRight = transformPoint(box.bottomRight, transform);
|
||||
|
||||
return newBoxFromPoints(topLeft.x, topLeft.y, bottomRight.x, bottomRight.y);
|
||||
}
|
||||
|
||||
function newBoxFromPoints(
|
||||
left: number,
|
||||
top: number,
|
||||
right: number,
|
||||
bottom: number,
|
||||
) {
|
||||
return boxFromBoundingBox({ left, top, right, bottom });
|
||||
return boxFromBoundingBox({
|
||||
left: topLeft.x,
|
||||
top: topLeft.y,
|
||||
right: bottomRight.x,
|
||||
bottom: bottomRight.y,
|
||||
});
|
||||
}
|
||||
|
|
|
@ -1,11 +1,5 @@
|
|||
import { haveWindow } from "@/next/env";
|
||||
import log from "@/next/log";
|
||||
import {
|
||||
DEFAULT_ML_SEARCH_CONFIG,
|
||||
DEFAULT_ML_SYNC_CONFIG,
|
||||
DEFAULT_ML_SYNC_JOB_CONFIG,
|
||||
MAX_ML_SYNC_ERROR_COUNT,
|
||||
} from "constants/mlConfig";
|
||||
import {
|
||||
DBSchema,
|
||||
IDBPDatabase,
|
||||
|
@ -15,6 +9,12 @@ import {
|
|||
openDB,
|
||||
} from "idb";
|
||||
import isElectron from "is-electron";
|
||||
import {
|
||||
DEFAULT_ML_SEARCH_CONFIG,
|
||||
DEFAULT_ML_SYNC_CONFIG,
|
||||
DEFAULT_ML_SYNC_JOB_CONFIG,
|
||||
MAX_ML_SYNC_ERROR_COUNT,
|
||||
} from "services/machineLearning/machineLearningService";
|
||||
import { Face, MLLibraryData, MlFileData, Person } from "services/ml/types";
|
||||
|
||||
export interface IndexStatus {
|
|
@ -76,3 +76,17 @@ export class Box implements IRect {
|
|||
return new Box({ x, y, width, height });
|
||||
}
|
||||
}
|
||||
|
||||
export function enlargeBox(box: Box, factor: number = 1.5) {
|
||||
const center = new Point(box.x + box.width / 2, box.y + box.height / 2);
|
||||
|
||||
const size = new Point(box.width, box.height);
|
||||
const newHalfSize = new Point((factor * size.x) / 2, (factor * size.y) / 2);
|
||||
|
||||
return boxFromBoundingBox({
|
||||
left: center.x - newHalfSize.x,
|
||||
top: center.y - newHalfSize.y,
|
||||
right: center.x + newHalfSize.x,
|
||||
bottom: center.y + newHalfSize.y,
|
||||
});
|
||||
}
|
||||
|
|
|
@ -329,8 +329,3 @@ export interface MachineLearningWorker {
|
|||
|
||||
close(): void;
|
||||
}
|
||||
|
||||
export interface ClipEmbedding {
|
||||
embedding: Float32Array;
|
||||
model: "ggml-clip" | "onnx-clip";
|
||||
}
|
||||
|
|
|
@ -2,6 +2,8 @@ import { FILE_TYPE } from "@/media/file-type";
|
|||
import log from "@/next/log";
|
||||
import * as chrono from "chrono-node";
|
||||
import { t } from "i18next";
|
||||
import { getMLSyncConfig } from "services/machineLearning/machineLearningService";
|
||||
import mlIDbStorage from "services/ml/db";
|
||||
import { Person } from "services/ml/types";
|
||||
import { Collection } from "types/collection";
|
||||
import { EntityType, LocationTag, LocationTagData } from "types/entity";
|
||||
|
@ -16,12 +18,9 @@ import {
|
|||
} from "types/search";
|
||||
import ComlinkSearchWorker from "utils/comlink/ComlinkSearchWorker";
|
||||
import { getUniqueFiles } from "utils/file";
|
||||
import { getAllPeople } from "utils/machineLearning";
|
||||
import { getMLSyncConfig } from "utils/machineLearning/config";
|
||||
import { getFormattedDate } from "utils/search";
|
||||
import mlIDbStorage from "utils/storage/mlIDbStorage";
|
||||
import { clipService, computeClipMatchScore } from "./clip-service";
|
||||
import { getLocalEmbeddings } from "./embeddingService";
|
||||
import { localCLIPEmbeddings } from "./embeddingService";
|
||||
import { getLatestEntities } from "./entityService";
|
||||
import locationSearchService, { City } from "./locationSearchService";
|
||||
|
||||
|
@ -376,7 +375,7 @@ const searchClip = async (
|
|||
await clipService.getTextEmbeddingIfAvailable(searchPhrase);
|
||||
if (!textEmbedding) return undefined;
|
||||
|
||||
const imageEmbeddings = await getLocalEmbeddings();
|
||||
const imageEmbeddings = await localCLIPEmbeddings();
|
||||
const clipSearchResult = new Map<number, number>(
|
||||
(
|
||||
await Promise.all(
|
||||
|
@ -430,3 +429,14 @@ function convertSuggestionToSearchQuery(option: Suggestion): Search {
|
|||
return { clip: option.value as ClipSearchScores };
|
||||
}
|
||||
}
|
||||
|
||||
async function getAllPeople(limit: number = undefined) {
|
||||
let people: Array<Person> = await mlIDbStorage.getAllPeople();
|
||||
// await mlPeopleStore.iterate<Person, void>((person) => {
|
||||
// people.push(person);
|
||||
// });
|
||||
people = people ?? [];
|
||||
return people
|
||||
.sort((p1, p2) => p2.files.length - p1.files.length)
|
||||
.slice(0, limit);
|
||||
}
|
||||
|
|
|
@ -1,11 +0,0 @@
|
|||
export type JobState = "Scheduled" | "Running" | "NotScheduled";
|
||||
|
||||
export interface JobConfig {
|
||||
intervalSec: number;
|
||||
maxItervalSec: number;
|
||||
backoffMultiplier: number;
|
||||
}
|
||||
|
||||
export interface JobResult {
|
||||
shouldBackoff: boolean;
|
||||
}
|
|
@ -1,9 +1,9 @@
|
|||
/**
|
||||
* The embeddings models that we support.
|
||||
* The embeddings that we (the current client) knows how to handle.
|
||||
*
|
||||
* This is an exhaustive set of values we pass when PUT-ting encrypted
|
||||
* embeddings on the server. However, we should be prepared to receive an
|
||||
* {@link EncryptedEmbedding} with a model value distinct from one of these.
|
||||
* {@link EncryptedEmbedding} with a model value different from these.
|
||||
*/
|
||||
export type EmbeddingModel = "onnx-clip" | "file-ml-clip-face";
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import { FILE_TYPE } from "@/media/file-type";
|
||||
import { City } from "services/locationSearchService";
|
||||
import { IndexStatus } from "services/ml/db";
|
||||
import { Person } from "services/ml/types";
|
||||
import { LocationTagData } from "types/entity";
|
||||
import { EnteFile } from "types/file";
|
||||
import { IndexStatus } from "utils/storage/mlIDbStorage";
|
||||
|
||||
export enum SuggestionType {
|
||||
DATE = "DATE",
|
||||
|
|
|
@ -1,82 +0,0 @@
|
|||
import log from "@/next/log";
|
||||
import { JobConfig, JobResult, JobState } from "types/common/job";
|
||||
|
||||
export class SimpleJob<R extends JobResult> {
|
||||
private config: JobConfig;
|
||||
private runCallback: () => Promise<R>;
|
||||
private state: JobState;
|
||||
private stopped: boolean;
|
||||
private intervalSec: number;
|
||||
private nextTimeoutId: ReturnType<typeof setTimeout>;
|
||||
|
||||
constructor(config: JobConfig, runCallback: () => Promise<R>) {
|
||||
this.config = config;
|
||||
this.runCallback = runCallback;
|
||||
this.state = "NotScheduled";
|
||||
this.stopped = true;
|
||||
this.intervalSec = this.config.intervalSec;
|
||||
}
|
||||
|
||||
public resetInterval() {
|
||||
this.intervalSec = this.config.intervalSec;
|
||||
}
|
||||
|
||||
public start() {
|
||||
this.stopped = false;
|
||||
this.resetInterval();
|
||||
if (this.state !== "Running") {
|
||||
this.scheduleNext();
|
||||
} else {
|
||||
log.info("Job already running, not scheduling");
|
||||
}
|
||||
}
|
||||
|
||||
private scheduleNext() {
|
||||
if (this.state === "Scheduled" || this.nextTimeoutId) {
|
||||
this.clearScheduled();
|
||||
}
|
||||
|
||||
this.nextTimeoutId = setTimeout(
|
||||
() => this.run(),
|
||||
this.intervalSec * 1000,
|
||||
);
|
||||
this.state = "Scheduled";
|
||||
log.info("Scheduled next job after: ", this.intervalSec);
|
||||
}
|
||||
|
||||
async run() {
|
||||
this.nextTimeoutId = undefined;
|
||||
this.state = "Running";
|
||||
|
||||
try {
|
||||
const jobResult = await this.runCallback();
|
||||
if (jobResult && jobResult.shouldBackoff) {
|
||||
this.intervalSec = Math.min(
|
||||
this.config.maxItervalSec,
|
||||
this.intervalSec * this.config.backoffMultiplier,
|
||||
);
|
||||
} else {
|
||||
this.resetInterval();
|
||||
}
|
||||
log.info("Job completed");
|
||||
} catch (e) {
|
||||
console.error("Error while running Job: ", e);
|
||||
} finally {
|
||||
this.state = "NotScheduled";
|
||||
!this.stopped && this.scheduleNext();
|
||||
}
|
||||
}
|
||||
|
||||
// currently client is responsible to terminate running job
|
||||
public stop() {
|
||||
this.stopped = true;
|
||||
this.clearScheduled();
|
||||
}
|
||||
|
||||
private clearScheduled() {
|
||||
clearTimeout(this.nextTimeoutId);
|
||||
this.nextTimeoutId = undefined;
|
||||
this.state = "NotScheduled";
|
||||
log.info("Cleared next job");
|
||||
}
|
||||
}
|
|
@ -1,36 +0,0 @@
|
|||
import { Embedding } from "types/embedding";
|
||||
import { FileML } from "./machineLearning/mldataMappers";
|
||||
|
||||
export const getLatestVersionEmbeddings = (embeddings: Embedding[]) => {
|
||||
const latestVersionEntities = new Map<number, Embedding>();
|
||||
embeddings.forEach((embedding) => {
|
||||
if (!embedding?.fileID) {
|
||||
return;
|
||||
}
|
||||
const existingEmbeddings = latestVersionEntities.get(embedding.fileID);
|
||||
if (
|
||||
!existingEmbeddings ||
|
||||
existingEmbeddings.updatedAt < embedding.updatedAt
|
||||
) {
|
||||
latestVersionEntities.set(embedding.fileID, embedding);
|
||||
}
|
||||
});
|
||||
return Array.from(latestVersionEntities.values());
|
||||
};
|
||||
|
||||
export const getLatestVersionFileEmbeddings = (embeddings: FileML[]) => {
|
||||
const latestVersionEntities = new Map<number, FileML>();
|
||||
embeddings.forEach((embedding) => {
|
||||
if (!embedding?.fileID) {
|
||||
return;
|
||||
}
|
||||
const existingEmbeddings = latestVersionEntities.get(embedding.fileID);
|
||||
if (
|
||||
!existingEmbeddings ||
|
||||
existingEmbeddings.updatedAt < embedding.updatedAt
|
||||
) {
|
||||
latestVersionEntities.set(embedding.fileID, embedding);
|
||||
}
|
||||
});
|
||||
return Array.from(latestVersionEntities.values());
|
||||
};
|
|
@ -1,9 +1,8 @@
|
|||
// these utils only work in env where OffscreenCanvas is available
|
||||
|
||||
import { Matrix, inverse } from "ml-matrix";
|
||||
import { Box, Dimensions } from "services/ml/geom";
|
||||
import { Box, Dimensions, enlargeBox } from "services/ml/geom";
|
||||
import { FaceAlignment } from "services/ml/types";
|
||||
import { enlargeBox } from "utils/machineLearning";
|
||||
|
||||
export function normalizePixelBetween0And1(pixelValue: number) {
|
||||
return pixelValue / 255.0;
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
import {
|
||||
DEFAULT_ML_SEARCH_CONFIG,
|
||||
DEFAULT_ML_SYNC_CONFIG,
|
||||
DEFAULT_ML_SYNC_JOB_CONFIG,
|
||||
} from "constants/mlConfig";
|
||||
import { MLSearchConfig, MLSyncConfig } from "services/ml/types";
|
||||
import { JobConfig } from "types/common/job";
|
||||
import mlIDbStorage, {
|
||||
ML_SEARCH_CONFIG_NAME,
|
||||
ML_SYNC_CONFIG_NAME,
|
||||
ML_SYNC_JOB_CONFIG_NAME,
|
||||
} from "utils/storage/mlIDbStorage";
|
||||
import { isInternalUserForML } from "utils/user";
|
||||
|
||||
export async function getMLSyncJobConfig() {
|
||||
return mlIDbStorage.getConfig(
|
||||
ML_SYNC_JOB_CONFIG_NAME,
|
||||
DEFAULT_ML_SYNC_JOB_CONFIG,
|
||||
);
|
||||
}
|
||||
|
||||
export async function getMLSyncConfig() {
|
||||
return mlIDbStorage.getConfig(ML_SYNC_CONFIG_NAME, DEFAULT_ML_SYNC_CONFIG);
|
||||
}
|
||||
|
||||
export async function getMLSearchConfig() {
|
||||
if (isInternalUserForML()) {
|
||||
return mlIDbStorage.getConfig(
|
||||
ML_SEARCH_CONFIG_NAME,
|
||||
DEFAULT_ML_SEARCH_CONFIG,
|
||||
);
|
||||
}
|
||||
// Force disabled for everyone else while we finalize it to avoid redundant
|
||||
// reindexing for users.
|
||||
return DEFAULT_ML_SEARCH_CONFIG;
|
||||
}
|
||||
|
||||
export async function updateMLSyncJobConfig(newConfig: JobConfig) {
|
||||
return mlIDbStorage.putConfig(ML_SYNC_JOB_CONFIG_NAME, newConfig);
|
||||
}
|
||||
|
||||
export async function updateMLSyncConfig(newConfig: MLSyncConfig) {
|
||||
return mlIDbStorage.putConfig(ML_SYNC_CONFIG_NAME, newConfig);
|
||||
}
|
||||
|
||||
export async function updateMLSearchConfig(newConfig: MLSearchConfig) {
|
||||
return mlIDbStorage.putConfig(ML_SEARCH_CONFIG_NAME, newConfig);
|
||||
}
|
|
@ -1,87 +0,0 @@
|
|||
import { Matrix } from "ml-matrix";
|
||||
import { Point } from "services/ml/geom";
|
||||
import { FaceAlignment, FaceDetection } from "services/ml/types";
|
||||
import { getSimilarityTransformation } from "similarity-transformation";
|
||||
|
||||
const ARCFACE_LANDMARKS = [
|
||||
[38.2946, 51.6963],
|
||||
[73.5318, 51.5014],
|
||||
[56.0252, 71.7366],
|
||||
[56.1396, 92.2848],
|
||||
] as Array<[number, number]>;
|
||||
|
||||
const ARCFACE_LANDMARKS_FACE_SIZE = 112;
|
||||
|
||||
const ARC_FACE_5_LANDMARKS = [
|
||||
[38.2946, 51.6963],
|
||||
[73.5318, 51.5014],
|
||||
[56.0252, 71.7366],
|
||||
[41.5493, 92.3655],
|
||||
[70.7299, 92.2041],
|
||||
] as Array<[number, number]>;
|
||||
|
||||
export function getArcfaceAlignment(
|
||||
faceDetection: FaceDetection,
|
||||
): FaceAlignment {
|
||||
const landmarkCount = faceDetection.landmarks.length;
|
||||
return getFaceAlignmentUsingSimilarityTransform(
|
||||
faceDetection,
|
||||
normalizeLandmarks(
|
||||
landmarkCount === 5 ? ARC_FACE_5_LANDMARKS : ARCFACE_LANDMARKS,
|
||||
ARCFACE_LANDMARKS_FACE_SIZE,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
function getFaceAlignmentUsingSimilarityTransform(
|
||||
faceDetection: FaceDetection,
|
||||
alignedLandmarks: Array<[number, number]>,
|
||||
// alignmentMethod: Versioned<FaceAlignmentMethod>
|
||||
): FaceAlignment {
|
||||
const landmarksMat = new Matrix(
|
||||
faceDetection.landmarks
|
||||
.map((p) => [p.x, p.y])
|
||||
.slice(0, alignedLandmarks.length),
|
||||
).transpose();
|
||||
const alignedLandmarksMat = new Matrix(alignedLandmarks).transpose();
|
||||
|
||||
const simTransform = getSimilarityTransformation(
|
||||
landmarksMat,
|
||||
alignedLandmarksMat,
|
||||
);
|
||||
|
||||
const RS = Matrix.mul(simTransform.rotation, simTransform.scale);
|
||||
const TR = simTransform.translation;
|
||||
|
||||
const affineMatrix = [
|
||||
[RS.get(0, 0), RS.get(0, 1), TR.get(0, 0)],
|
||||
[RS.get(1, 0), RS.get(1, 1), TR.get(1, 0)],
|
||||
[0, 0, 1],
|
||||
];
|
||||
|
||||
const size = 1 / simTransform.scale;
|
||||
const meanTranslation = simTransform.toMean.sub(0.5).mul(size);
|
||||
const centerMat = simTransform.fromMean.sub(meanTranslation);
|
||||
const center = new Point(centerMat.get(0, 0), centerMat.get(1, 0));
|
||||
const rotation = -Math.atan2(
|
||||
simTransform.rotation.get(0, 1),
|
||||
simTransform.rotation.get(0, 0),
|
||||
);
|
||||
// log.info({ affineMatrix, meanTranslation, centerMat, center, toMean: simTransform.toMean, fromMean: simTransform.fromMean, size });
|
||||
|
||||
return {
|
||||
affineMatrix,
|
||||
center,
|
||||
size,
|
||||
rotation,
|
||||
};
|
||||
}
|
||||
|
||||
function normalizeLandmarks(
|
||||
landmarks: Array<[number, number]>,
|
||||
faceSize: number,
|
||||
): Array<[number, number]> {
|
||||
return landmarks.map((landmark) =>
|
||||
landmark.map((p) => p / faceSize),
|
||||
) as Array<[number, number]>;
|
||||
}
|
|
@ -1,28 +0,0 @@
|
|||
import { Box } from "services/ml/geom";
|
||||
import { FaceAlignment, FaceCrop, FaceCropConfig } from "services/ml/types";
|
||||
import { cropWithRotation } from "utils/image";
|
||||
import { enlargeBox } from ".";
|
||||
|
||||
export function getFaceCrop(
|
||||
imageBitmap: ImageBitmap,
|
||||
alignment: FaceAlignment,
|
||||
config: FaceCropConfig,
|
||||
): FaceCrop {
|
||||
const alignmentBox = new Box({
|
||||
x: alignment.center.x - alignment.size / 2,
|
||||
y: alignment.center.y - alignment.size / 2,
|
||||
width: alignment.size,
|
||||
height: alignment.size,
|
||||
}).round();
|
||||
const scaleForPadding = 1 + config.padding * 2;
|
||||
const paddedBox = enlargeBox(alignmentBox, scaleForPadding).round();
|
||||
const faceImageBitmap = cropWithRotation(imageBitmap, paddedBox, 0, {
|
||||
width: config.maxSize,
|
||||
height: config.maxSize,
|
||||
});
|
||||
|
||||
return {
|
||||
image: faceImageBitmap,
|
||||
imageBox: paddedBox,
|
||||
};
|
||||
}
|
|
@ -1,272 +0,0 @@
|
|||
import { FILE_TYPE } from "@/media/file-type";
|
||||
import { decodeLivePhoto } from "@/media/live-photo";
|
||||
import log from "@/next/log";
|
||||
import PQueue from "p-queue";
|
||||
import DownloadManager from "services/download";
|
||||
import { getLocalFiles } from "services/fileService";
|
||||
import { Box, Dimensions, Point, boxFromBoundingBox } from "services/ml/geom";
|
||||
import {
|
||||
DetectedFace,
|
||||
Face,
|
||||
FaceAlignment,
|
||||
MlFileData,
|
||||
Person,
|
||||
Versioned,
|
||||
} from "services/ml/types";
|
||||
import { EnteFile } from "types/file";
|
||||
import { getRenderableImage } from "utils/file";
|
||||
import { clamp, warpAffineFloat32List } from "utils/image";
|
||||
import mlIDbStorage from "utils/storage/mlIDbStorage";
|
||||
|
||||
export function enlargeBox(box: Box, factor: number = 1.5) {
|
||||
const center = new Point(box.x + box.width / 2, box.y + box.height / 2);
|
||||
|
||||
const size = new Point(box.width, box.height);
|
||||
const newHalfSize = new Point((factor * size.x) / 2, (factor * size.y) / 2);
|
||||
|
||||
return boxFromBoundingBox({
|
||||
left: center.x - newHalfSize.x,
|
||||
top: center.y - newHalfSize.y,
|
||||
right: center.x + newHalfSize.x,
|
||||
bottom: center.y + newHalfSize.y,
|
||||
});
|
||||
}
|
||||
|
||||
export function getAllFacesFromMap(allFacesMap: Map<number, Array<Face>>) {
|
||||
const allFaces = [...allFacesMap.values()].flat();
|
||||
|
||||
return allFaces;
|
||||
}
|
||||
|
||||
export async function getLocalFile(fileId: number) {
|
||||
const localFiles = await getLocalFiles();
|
||||
return localFiles.find((f) => f.id === fileId);
|
||||
}
|
||||
|
||||
export async function extractFaceImagesToFloat32(
|
||||
faceAlignments: Array<FaceAlignment>,
|
||||
faceSize: number,
|
||||
image: ImageBitmap,
|
||||
): Promise<Float32Array> {
|
||||
const faceData = new Float32Array(
|
||||
faceAlignments.length * faceSize * faceSize * 3,
|
||||
);
|
||||
for (let i = 0; i < faceAlignments.length; i++) {
|
||||
const alignedFace = faceAlignments[i];
|
||||
const faceDataOffset = i * faceSize * faceSize * 3;
|
||||
warpAffineFloat32List(
|
||||
image,
|
||||
alignedFace,
|
||||
faceSize,
|
||||
faceData,
|
||||
faceDataOffset,
|
||||
);
|
||||
}
|
||||
return faceData;
|
||||
}
|
||||
|
||||
export function getFaceId(detectedFace: DetectedFace, imageDims: Dimensions) {
|
||||
const xMin = clamp(
|
||||
detectedFace.detection.box.x / imageDims.width,
|
||||
0.0,
|
||||
0.999999,
|
||||
)
|
||||
.toFixed(5)
|
||||
.substring(2);
|
||||
const yMin = clamp(
|
||||
detectedFace.detection.box.y / imageDims.height,
|
||||
0.0,
|
||||
0.999999,
|
||||
)
|
||||
.toFixed(5)
|
||||
.substring(2);
|
||||
const xMax = clamp(
|
||||
(detectedFace.detection.box.x + detectedFace.detection.box.width) /
|
||||
imageDims.width,
|
||||
0.0,
|
||||
0.999999,
|
||||
)
|
||||
.toFixed(5)
|
||||
.substring(2);
|
||||
const yMax = clamp(
|
||||
(detectedFace.detection.box.y + detectedFace.detection.box.height) /
|
||||
imageDims.height,
|
||||
0.0,
|
||||
0.999999,
|
||||
)
|
||||
.toFixed(5)
|
||||
.substring(2);
|
||||
|
||||
const rawFaceID = `${xMin}_${yMin}_${xMax}_${yMax}`;
|
||||
const faceID = `${detectedFace.fileId}_${rawFaceID}`;
|
||||
|
||||
return faceID;
|
||||
}
|
||||
|
||||
export async function getImageBlobBitmap(blob: Blob): Promise<ImageBitmap> {
|
||||
return await createImageBitmap(blob);
|
||||
}
|
||||
|
||||
async function getOriginalFile(file: EnteFile, queue?: PQueue) {
|
||||
let fileStream;
|
||||
if (queue) {
|
||||
fileStream = await queue.add(() => DownloadManager.getFile(file));
|
||||
} else {
|
||||
fileStream = await DownloadManager.getFile(file);
|
||||
}
|
||||
return new Response(fileStream).blob();
|
||||
}
|
||||
|
||||
async function getOriginalConvertedFile(file: EnteFile, queue?: PQueue) {
|
||||
const fileBlob = await getOriginalFile(file, queue);
|
||||
if (file.metadata.fileType === FILE_TYPE.IMAGE) {
|
||||
return await getRenderableImage(file.metadata.title, fileBlob);
|
||||
} else {
|
||||
const { imageFileName, imageData } = await decodeLivePhoto(
|
||||
file.metadata.title,
|
||||
fileBlob,
|
||||
);
|
||||
return await getRenderableImage(imageFileName, new Blob([imageData]));
|
||||
}
|
||||
}
|
||||
|
||||
export async function getOriginalImageBitmap(file: EnteFile, queue?: PQueue) {
|
||||
const fileBlob = await getOriginalConvertedFile(file, queue);
|
||||
log.info("[MLService] Got file: ", file.id.toString());
|
||||
return getImageBlobBitmap(fileBlob);
|
||||
}
|
||||
|
||||
export async function getThumbnailImageBitmap(file: EnteFile) {
|
||||
const thumb = await DownloadManager.getThumbnail(file);
|
||||
log.info("[MLService] Got thumbnail: ", file.id.toString());
|
||||
|
||||
return getImageBlobBitmap(new Blob([thumb]));
|
||||
}
|
||||
|
||||
export async function getLocalFileImageBitmap(
|
||||
enteFile: EnteFile,
|
||||
localFile: globalThis.File,
|
||||
) {
|
||||
let fileBlob = localFile as Blob;
|
||||
fileBlob = await getRenderableImage(enteFile.metadata.title, fileBlob);
|
||||
return getImageBlobBitmap(fileBlob);
|
||||
}
|
||||
|
||||
export async function getPeopleList(file: EnteFile): Promise<Array<Person>> {
|
||||
let startTime = Date.now();
|
||||
const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id);
|
||||
log.info(
|
||||
"getPeopleList:mlFilesStore:getItem",
|
||||
Date.now() - startTime,
|
||||
"ms",
|
||||
);
|
||||
if (!mlFileData?.faces || mlFileData.faces.length < 1) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const peopleIds = mlFileData.faces
|
||||
.filter((f) => f.personId !== null && f.personId !== undefined)
|
||||
.map((f) => f.personId);
|
||||
if (!peopleIds || peopleIds.length < 1) {
|
||||
return [];
|
||||
}
|
||||
// log.info("peopleIds: ", peopleIds);
|
||||
startTime = Date.now();
|
||||
const peoplePromises = peopleIds.map(
|
||||
(p) => mlIDbStorage.getPerson(p) as Promise<Person>,
|
||||
);
|
||||
const peopleList = await Promise.all(peoplePromises);
|
||||
log.info(
|
||||
"getPeopleList:mlPeopleStore:getItems",
|
||||
Date.now() - startTime,
|
||||
"ms",
|
||||
);
|
||||
// log.info("peopleList: ", peopleList);
|
||||
|
||||
return peopleList;
|
||||
}
|
||||
|
||||
export async function getUnidentifiedFaces(
|
||||
file: EnteFile,
|
||||
): Promise<Array<Face>> {
|
||||
const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id);
|
||||
|
||||
return mlFileData?.faces?.filter(
|
||||
(f) => f.personId === null || f.personId === undefined,
|
||||
);
|
||||
}
|
||||
|
||||
export async function getAllPeople(limit: number = undefined) {
|
||||
let people: Array<Person> = await mlIDbStorage.getAllPeople();
|
||||
// await mlPeopleStore.iterate<Person, void>((person) => {
|
||||
// people.push(person);
|
||||
// });
|
||||
people = people ?? [];
|
||||
return people
|
||||
.sort((p1, p2) => p2.files.length - p1.files.length)
|
||||
.slice(0, limit);
|
||||
}
|
||||
|
||||
export function findFirstIfSorted<T>(
|
||||
elements: Array<T>,
|
||||
comparator: (a: T, b: T) => number,
|
||||
) {
|
||||
if (!elements || elements.length < 1) {
|
||||
return;
|
||||
}
|
||||
let first = elements[0];
|
||||
|
||||
for (let i = 1; i < elements.length; i++) {
|
||||
const comp = comparator(elements[i], first);
|
||||
if (comp < 0) {
|
||||
first = elements[i];
|
||||
}
|
||||
}
|
||||
|
||||
return first;
|
||||
}
|
||||
|
||||
export function isDifferentOrOld(
|
||||
method: Versioned<string>,
|
||||
thanMethod: Versioned<string>,
|
||||
) {
|
||||
return (
|
||||
!method ||
|
||||
method.value !== thanMethod.value ||
|
||||
method.version < thanMethod.version
|
||||
);
|
||||
}
|
||||
|
||||
function primitiveArrayEquals(a, b) {
|
||||
return (
|
||||
Array.isArray(a) &&
|
||||
Array.isArray(b) &&
|
||||
a.length === b.length &&
|
||||
a.every((val, index) => val === b[index])
|
||||
);
|
||||
}
|
||||
|
||||
export function areFaceIdsSame(ofFaces: Array<Face>, toFaces: Array<Face>) {
|
||||
if (
|
||||
(ofFaces === null || ofFaces === undefined) &&
|
||||
(toFaces === null || toFaces === undefined)
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
return primitiveArrayEquals(
|
||||
ofFaces?.map((f) => f.id),
|
||||
toFaces?.map((f) => f.id),
|
||||
);
|
||||
}
|
||||
|
||||
export function logQueueStats(queue: PQueue, name: string) {
|
||||
queue.on("active", () =>
|
||||
log.info(
|
||||
`queuestats: ${name}: Active, Size: ${queue.size} Pending: ${queue.pending}`,
|
||||
),
|
||||
);
|
||||
queue.on("idle", () => log.info(`queuestats: ${name}: Idle`));
|
||||
queue.on("error", (error) =>
|
||||
console.error(`queuestats: ${name}: Error, `, error),
|
||||
);
|
||||
}
|
|
@ -1,177 +0,0 @@
|
|||
import {
|
||||
ClipEmbedding,
|
||||
Face,
|
||||
FaceDetection,
|
||||
Landmark,
|
||||
MlFileData,
|
||||
} from "services/ml/types";
|
||||
|
||||
export interface FileML extends ServerFileMl {
|
||||
updatedAt: number;
|
||||
}
|
||||
|
||||
class ServerFileMl {
|
||||
public fileID: number;
|
||||
public height?: number;
|
||||
public width?: number;
|
||||
public faceEmbedding: ServerFaceEmbeddings;
|
||||
public clipEmbedding?: ClipEmbedding;
|
||||
|
||||
public constructor(
|
||||
fileID: number,
|
||||
faceEmbedding: ServerFaceEmbeddings,
|
||||
clipEmbedding?: ClipEmbedding,
|
||||
height?: number,
|
||||
width?: number,
|
||||
) {
|
||||
this.fileID = fileID;
|
||||
this.height = height;
|
||||
this.width = width;
|
||||
this.faceEmbedding = faceEmbedding;
|
||||
this.clipEmbedding = clipEmbedding;
|
||||
}
|
||||
}
|
||||
|
||||
class ServerFaceEmbeddings {
|
||||
public faces: ServerFace[];
|
||||
public version: number;
|
||||
public client?: string;
|
||||
public error?: boolean;
|
||||
|
||||
public constructor(
|
||||
faces: ServerFace[],
|
||||
version: number,
|
||||
client?: string,
|
||||
error?: boolean,
|
||||
) {
|
||||
this.faces = faces;
|
||||
this.version = version;
|
||||
this.client = client;
|
||||
this.error = error;
|
||||
}
|
||||
}
|
||||
|
||||
class ServerFace {
|
||||
public fileID: number;
|
||||
public faceID: string;
|
||||
public embeddings: number[];
|
||||
public detection: ServerDetection;
|
||||
public score: number;
|
||||
public blur: number;
|
||||
public fileInfo?: ServerFileInfo;
|
||||
|
||||
public constructor(
|
||||
fileID: number,
|
||||
faceID: string,
|
||||
embeddings: number[],
|
||||
detection: ServerDetection,
|
||||
score: number,
|
||||
blur: number,
|
||||
fileInfo?: ServerFileInfo,
|
||||
) {
|
||||
this.fileID = fileID;
|
||||
this.faceID = faceID;
|
||||
this.embeddings = embeddings;
|
||||
this.detection = detection;
|
||||
this.score = score;
|
||||
this.blur = blur;
|
||||
this.fileInfo = fileInfo;
|
||||
}
|
||||
}
|
||||
|
||||
class ServerFileInfo {
|
||||
public imageWidth?: number;
|
||||
public imageHeight?: number;
|
||||
|
||||
public constructor(imageWidth?: number, imageHeight?: number) {
|
||||
this.imageWidth = imageWidth;
|
||||
this.imageHeight = imageHeight;
|
||||
}
|
||||
}
|
||||
|
||||
class ServerDetection {
|
||||
public box: ServerFaceBox;
|
||||
public landmarks: Landmark[];
|
||||
|
||||
public constructor(box: ServerFaceBox, landmarks: Landmark[]) {
|
||||
this.box = box;
|
||||
this.landmarks = landmarks;
|
||||
}
|
||||
}
|
||||
|
||||
class ServerFaceBox {
|
||||
public xMin: number;
|
||||
public yMin: number;
|
||||
public width: number;
|
||||
public height: number;
|
||||
|
||||
public constructor(
|
||||
xMin: number,
|
||||
yMin: number,
|
||||
width: number,
|
||||
height: number,
|
||||
) {
|
||||
this.xMin = xMin;
|
||||
this.yMin = yMin;
|
||||
this.width = width;
|
||||
this.height = height;
|
||||
}
|
||||
}
|
||||
|
||||
export function LocalFileMlDataToServerFileMl(
|
||||
localFileMlData: MlFileData,
|
||||
): ServerFileMl {
|
||||
if (
|
||||
localFileMlData.errorCount > 0 &&
|
||||
localFileMlData.lastErrorMessage !== undefined
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
const imageDimensions = localFileMlData.imageDimensions;
|
||||
const fileInfo = new ServerFileInfo(
|
||||
imageDimensions.width,
|
||||
imageDimensions.height,
|
||||
);
|
||||
const faces: ServerFace[] = [];
|
||||
for (let i = 0; i < localFileMlData.faces.length; i++) {
|
||||
const face: Face = localFileMlData.faces[i];
|
||||
const faceID = face.id;
|
||||
const embedding = face.embedding;
|
||||
const score = face.detection.probability;
|
||||
const blur = face.blurValue;
|
||||
const detection: FaceDetection = face.detection;
|
||||
const box = detection.box;
|
||||
const landmarks = detection.landmarks;
|
||||
const newBox = new ServerFaceBox(box.x, box.y, box.width, box.height);
|
||||
const newLandmarks: Landmark[] = [];
|
||||
for (let j = 0; j < landmarks.length; j++) {
|
||||
newLandmarks.push({
|
||||
x: landmarks[j].x,
|
||||
y: landmarks[j].y,
|
||||
} as Landmark);
|
||||
}
|
||||
|
||||
const newFaceObject = new ServerFace(
|
||||
localFileMlData.fileId,
|
||||
faceID,
|
||||
Array.from(embedding),
|
||||
new ServerDetection(newBox, newLandmarks),
|
||||
score,
|
||||
blur,
|
||||
fileInfo,
|
||||
);
|
||||
faces.push(newFaceObject);
|
||||
}
|
||||
const faceEmbeddings = new ServerFaceEmbeddings(
|
||||
faces,
|
||||
1,
|
||||
localFileMlData.lastErrorMessage,
|
||||
);
|
||||
return new ServerFileMl(
|
||||
localFileMlData.fileId,
|
||||
faceEmbeddings,
|
||||
null,
|
||||
imageDimensions.height,
|
||||
imageDimensions.width,
|
||||
);
|
||||
}
|
Loading…
Add table
Reference in a new issue