[web] ML purge - Part 2/x (#1725)

This commit is contained in:
Manav Rathi 2024-05-14 17:00:28 +05:30 committed by GitHub
commit bbfd2ae640
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 1148 additions and 1242 deletions

View file

@ -5,8 +5,8 @@ import { t } from "i18next";
import { AppContext } from "pages/_app";
import { useContext } from "react";
import { components } from "react-select";
import { IndexStatus } from "services/ml/db";
import { Suggestion, SuggestionType } from "types/search";
import { IndexStatus } from "utils/storage/mlIDbStorage";
const { Menu } = components;

View file

@ -9,7 +9,7 @@ import { t } from "i18next";
import { useRouter } from "next/router";
import { AppContext } from "pages/_app";
import { useContext, useState } from "react";
// import mlIDbStorage from 'utils/storage/mlIDbStorage';
// import mlIDbStorage from 'services/ml/db';
import {
configurePasskeyRecovery,
isPasskeyRecoveryEnabled,

View file

@ -3,9 +3,9 @@ import { Skeleton, styled } from "@mui/material";
import { Legend } from "components/PhotoViewer/styledComponents/Legend";
import { t } from "i18next";
import React, { useEffect, useState } from "react";
import { Face, Person } from "services/ml/types";
import mlIDbStorage from "services/ml/db";
import { Face, Person, type MlFileData } from "services/ml/types";
import { EnteFile } from "types/file";
import { getPeopleList, getUnidentifiedFaces } from "utils/machineLearning";
const FaceChipContainer = styled("div")`
display: flex;
@ -194,3 +194,45 @@ const FaceCropImageView: React.FC<FaceCropImageViewProps> = ({
<Skeleton variant="circular" height={120} width={120} />
);
};
async function getPeopleList(file: EnteFile): Promise<Array<Person>> {
let startTime = Date.now();
const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id);
log.info(
"getPeopleList:mlFilesStore:getItem",
Date.now() - startTime,
"ms",
);
if (!mlFileData?.faces || mlFileData.faces.length < 1) {
return [];
}
const peopleIds = mlFileData.faces
.filter((f) => f.personId !== null && f.personId !== undefined)
.map((f) => f.personId);
if (!peopleIds || peopleIds.length < 1) {
return [];
}
// log.info("peopleIds: ", peopleIds);
startTime = Date.now();
const peoplePromises = peopleIds.map(
(p) => mlIDbStorage.getPerson(p) as Promise<Person>,
);
const peopleList = await Promise.all(peoplePromises);
log.info(
"getPeopleList:mlPeopleStore:getItems",
Date.now() - startTime,
"ms",
);
// log.info("peopleList: ", peopleList);
return peopleList;
}
async function getUnidentifiedFaces(file: EnteFile): Promise<Array<Face>> {
const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id);
return mlFileData?.faces?.filter(
(f) => f.personId === null || f.personId === undefined,
);
}

View file

@ -1,56 +0,0 @@
import { MLSearchConfig, MLSyncConfig } from "services/ml/types";
import { JobConfig } from "types/common/job";
export const DEFAULT_ML_SYNC_JOB_CONFIG: JobConfig = {
intervalSec: 5,
// TODO: finalize this after seeing effects on and from machine sleep
maxItervalSec: 960,
backoffMultiplier: 2,
};
export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = {
batchSize: 200,
imageSource: "Original",
faceDetection: {
method: "YoloFace",
},
faceCrop: {
enabled: true,
method: "ArcFace",
padding: 0.25,
maxSize: 256,
blobOptions: {
type: "image/jpeg",
quality: 0.8,
},
},
faceAlignment: {
method: "ArcFace",
},
blurDetection: {
method: "Laplacian",
threshold: 15,
},
faceEmbedding: {
method: "MobileFaceNet",
faceSize: 112,
generateTsne: true,
},
faceClustering: {
method: "Hdbscan",
minClusterSize: 3,
minSamples: 5,
clusterSelectionEpsilon: 0.6,
clusterSelectionMethod: "leaf",
minInputSize: 50,
// maxDistanceInsideCluster: 0.4,
generateDebugInfo: true,
},
mlVersion: 3,
};
export const DEFAULT_ML_SEARCH_CONFIG: MLSearchConfig = {
enabled: false,
};
export const MAX_ML_SYNC_ERROR_COUNT = 1;

View file

@ -53,6 +53,10 @@ import { createContext, useEffect, useRef, useState } from "react";
import LoadingBar from "react-top-loading-bar";
import DownloadManager from "services/download";
import exportService, { resumeExportsIfNeeded } from "services/export";
import {
getMLSearchConfig,
updateMLSearchConfig,
} from "services/machineLearning/machineLearningService";
import mlWorkManager from "services/machineLearning/mlWorkManager";
import {
getFamilyPortalRedirectURL,
@ -64,10 +68,6 @@ import {
NotificationAttributes,
SetNotificationAttributes,
} from "types/Notification";
import {
getMLSearchConfig,
updateMLSearchConfig,
} from "utils/machineLearning/config";
import {
getUpdateAvailableForDownloadMessage,
getUpdateReadyToInstallMessage,

View file

@ -84,7 +84,10 @@ import {
getSectionSummaries,
} from "services/collectionService";
import downloadManager from "services/download";
import { syncEmbeddings, syncFileEmbeddings } from "services/embeddingService";
import {
syncCLIPEmbeddings,
syncFaceEmbeddings,
} from "services/embeddingService";
import { syncEntities } from "services/entityService";
import { getLocalFiles, syncFiles } from "services/fileService";
import locationSearchService from "services/locationSearchService";
@ -130,6 +133,7 @@ import {
} from "utils/file";
import { isArchivedFile } from "utils/magicMetadata";
import { getSessionExpiredMessage } from "utils/ui";
import { isInternalUserForML } from "utils/user";
import { getLocalFamilyData } from "utils/user/family";
export const DeadCenter = styled("div")`
@ -698,10 +702,10 @@ export default function Gallery() {
await syncTrash(collections, setTrashedFiles);
await syncEntities();
await syncMapEnabled();
await syncEmbeddings();
await syncCLIPEmbeddings();
const electron = globalThis.electron;
if (electron) {
await syncFileEmbeddings();
if (isInternalUserForML() && electron) {
await syncFaceEmbeddings();
}
if (clipService.isPlatformSupported()) {
void clipService.scheduleImageEmbeddingExtraction();

View file

@ -11,7 +11,7 @@ import { Embedding } from "types/embedding";
import { EnteFile } from "types/file";
import { getPersonalFiles } from "utils/file";
import downloadManager from "./download";
import { getLocalEmbeddings, putEmbedding } from "./embeddingService";
import { localCLIPEmbeddings, putEmbedding } from "./embeddingService";
import { getAllLocalFiles, getLocalFiles } from "./fileService";
/** Status of CLIP indexing on the images in the user's local library. */
@ -195,7 +195,7 @@ class CLIPService {
return;
}
const localFiles = getPersonalFiles(await getAllLocalFiles(), user);
const existingEmbeddings = await getLocalEmbeddings();
const existingEmbeddings = await localCLIPEmbeddings();
const pendingFiles = await getNonClipEmbeddingExtractedFiles(
localFiles,
existingEmbeddings,
@ -394,7 +394,7 @@ export const computeClipMatchScore = async (
const initialIndexingStatus = async (): Promise<CLIPIndexingStatus> => {
const user = getData(LS_KEYS.USER);
if (!user) throw new Error("Orphan CLIP indexing without a login");
const allEmbeddings = await getLocalEmbeddings();
const allEmbeddings = await localCLIPEmbeddings();
const localFiles = getPersonalFiles(await getLocalFiles(), user);
const pendingFiles = await getNonClipEmbeddingExtractedFiles(
localFiles,

View file

@ -7,6 +7,7 @@ import HTTPService from "@ente/shared/network/HTTPService";
import { getEndpoint } from "@ente/shared/network/api";
import localForage from "@ente/shared/storage/localForage";
import { getToken } from "@ente/shared/storage/localStorage/helpers";
import { FileML } from "services/machineLearning/machineLearningService";
import type {
Embedding,
EmbeddingModel,
@ -15,31 +16,30 @@ import type {
PutEmbeddingRequest,
} from "types/embedding";
import { EnteFile } from "types/file";
import {
getLatestVersionEmbeddings,
getLatestVersionFileEmbeddings,
} from "utils/embedding";
import { FileML } from "utils/machineLearning/mldataMappers";
import { getLocalCollections } from "./collectionService";
import { getAllLocalFiles } from "./fileService";
import { getLocalTrashedFiles } from "./trashService";
const ENDPOINT = getEndpoint();
const DIFF_LIMIT = 500;
const EMBEDDINGS_TABLE_V1 = "embeddings";
const EMBEDDINGS_TABLE = "embeddings_v2";
/** Local storage key suffix for embedding sync times */
const embeddingSyncTimeLSKeySuffix = "embedding_sync_time";
/** Local storage key for CLIP embeddings. */
const clipEmbeddingsLSKey = "embeddings_v2";
const FILE_EMBEDING_TABLE = "file_embeddings";
const EMBEDDING_SYNC_TIME_TABLE = "embedding_sync_time";
export const getAllLocalEmbeddings = async () => {
/** Return all CLIP embeddings that we have available locally. */
export const localCLIPEmbeddings = async () =>
(await storedCLIPEmbeddings()).filter(({ model }) => model === "onnx-clip");
const storedCLIPEmbeddings = async () => {
const embeddings: Array<Embedding> =
await localForage.getItem<Embedding[]>(EMBEDDINGS_TABLE);
await localForage.getItem<Embedding[]>(clipEmbeddingsLSKey);
if (!embeddings) {
await localForage.removeItem(EMBEDDINGS_TABLE_V1);
await localForage.removeItem(EMBEDDING_SYNC_TIME_TABLE);
await localForage.setItem(EMBEDDINGS_TABLE, []);
// Migrate
await localForage.removeItem("embeddings");
await localForage.removeItem("embedding_sync_time");
await localForage.setItem(clipEmbeddingsLSKey, []);
return [];
}
return embeddings;
@ -54,15 +54,10 @@ export const getFileMLEmbeddings = async (): Promise<FileML[]> => {
return embeddings;
};
export const getLocalEmbeddings = async () => {
const embeddings = await getAllLocalEmbeddings();
return embeddings.filter((embedding) => embedding.model === "onnx-clip");
};
const getModelEmbeddingSyncTime = async (model: EmbeddingModel) => {
return (
(await localForage.getItem<number>(
`${model}-${EMBEDDING_SYNC_TIME_TABLE}`,
`${model}-${embeddingSyncTimeLSKeySuffix}`,
)) ?? 0
);
};
@ -71,13 +66,17 @@ const setModelEmbeddingSyncTime = async (
model: EmbeddingModel,
time: number,
) => {
await localForage.setItem(`${model}-${EMBEDDING_SYNC_TIME_TABLE}`, time);
await localForage.setItem(`${model}-${embeddingSyncTimeLSKeySuffix}`, time);
};
export const syncEmbeddings = async () => {
const models: EmbeddingModel[] = ["onnx-clip"];
/**
* Fetch new CLIP embeddings with the server and save them locally. Also prune
* local embeddings for any files no longer exist locally.
*/
export const syncCLIPEmbeddings = async () => {
const model: EmbeddingModel = "onnx-clip";
try {
let allEmbeddings = await getAllLocalEmbeddings();
let allEmbeddings = await storedCLIPEmbeddings();
const localFiles = await getAllLocalFiles();
const hiddenAlbums = await getLocalCollections("hidden");
const localTrashFiles = await getLocalTrashedFiles();
@ -89,79 +88,75 @@ export const syncEmbeddings = async () => {
await cleanupDeletedEmbeddings(
allLocalFiles,
allEmbeddings,
EMBEDDINGS_TABLE,
clipEmbeddingsLSKey,
);
log.info(`Syncing embeddings localCount: ${allEmbeddings.length}`);
for (const model of models) {
let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
log.info(
`Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`,
);
let response: GetEmbeddingDiffResponse;
do {
response = await getEmbeddingsDiff(modelLastSinceTime, model);
if (!response.diff?.length) {
return;
}
const newEmbeddings = await Promise.all(
response.diff.map(async (embedding) => {
try {
const {
encryptedEmbedding,
decryptionHeader,
...rest
} = embedding;
const worker =
await ComlinkCryptoWorker.getInstance();
const fileKey = fileIdToKeyMap.get(
embedding.fileID,
);
if (!fileKey) {
throw Error(CustomError.FILE_NOT_FOUND);
}
const decryptedData = await worker.decryptEmbedding(
encryptedEmbedding,
decryptionHeader,
fileIdToKeyMap.get(embedding.fileID),
);
return {
...rest,
embedding: decryptedData,
} as Embedding;
} catch (e) {
let hasHiddenAlbums = false;
if (e.message === CustomError.FILE_NOT_FOUND) {
hasHiddenAlbums = hiddenAlbums?.length > 0;
}
log.error(
`decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`,
e,
);
let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
log.info(
`Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`,
);
let response: GetEmbeddingDiffResponse;
do {
response = await getEmbeddingsDiff(modelLastSinceTime, model);
if (!response.diff?.length) {
return;
}
const newEmbeddings = await Promise.all(
response.diff.map(async (embedding) => {
try {
const {
encryptedEmbedding,
decryptionHeader,
...rest
} = embedding;
const worker = await ComlinkCryptoWorker.getInstance();
const fileKey = fileIdToKeyMap.get(embedding.fileID);
if (!fileKey) {
throw Error(CustomError.FILE_NOT_FOUND);
}
}),
);
allEmbeddings = getLatestVersionEmbeddings([
...allEmbeddings,
...newEmbeddings,
]);
if (response.diff.length) {
modelLastSinceTime = response.diff.slice(-1)[0].updatedAt;
}
await localForage.setItem(EMBEDDINGS_TABLE, allEmbeddings);
await setModelEmbeddingSyncTime(model, modelLastSinceTime);
log.info(
`Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`,
);
} while (response.diff.length === DIFF_LIMIT);
}
const decryptedData = await worker.decryptEmbedding(
encryptedEmbedding,
decryptionHeader,
fileIdToKeyMap.get(embedding.fileID),
);
return {
...rest,
embedding: decryptedData,
} as Embedding;
} catch (e) {
let hasHiddenAlbums = false;
if (e.message === CustomError.FILE_NOT_FOUND) {
hasHiddenAlbums = hiddenAlbums?.length > 0;
}
log.error(
`decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`,
e,
);
}
}),
);
allEmbeddings = getLatestVersionEmbeddings([
...allEmbeddings,
...newEmbeddings,
]);
if (response.diff.length) {
modelLastSinceTime = response.diff.slice(-1)[0].updatedAt;
}
await localForage.setItem(clipEmbeddingsLSKey, allEmbeddings);
await setModelEmbeddingSyncTime(model, modelLastSinceTime);
log.info(
`Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`,
);
} while (response.diff.length === DIFF_LIMIT);
} catch (e) {
log.error("Sync embeddings failed", e);
}
};
export const syncFileEmbeddings = async () => {
const models: EmbeddingModel[] = ["file-ml-clip-face"];
export const syncFaceEmbeddings = async () => {
const model: EmbeddingModel = "file-ml-clip-face";
try {
let allEmbeddings: FileML[] = await getFileMLEmbeddings();
const localFiles = await getAllLocalFiles();
@ -178,69 +173,99 @@ export const syncFileEmbeddings = async () => {
FILE_EMBEDING_TABLE,
);
log.info(`Syncing embeddings localCount: ${allEmbeddings.length}`);
for (const model of models) {
let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
log.info(
`Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`,
);
let response: GetEmbeddingDiffResponse;
do {
response = await getEmbeddingsDiff(modelLastSinceTime, model);
if (!response.diff?.length) {
return;
}
const newEmbeddings = await Promise.all(
response.diff.map(async (embedding) => {
try {
const worker =
await ComlinkCryptoWorker.getInstance();
const fileKey = fileIdToKeyMap.get(
embedding.fileID,
);
if (!fileKey) {
throw Error(CustomError.FILE_NOT_FOUND);
}
const decryptedData = await worker.decryptMetadata(
embedding.encryptedEmbedding,
embedding.decryptionHeader,
fileIdToKeyMap.get(embedding.fileID),
);
return {
...decryptedData,
updatedAt: embedding.updatedAt,
} as unknown as FileML;
} catch (e) {
let hasHiddenAlbums = false;
if (e.message === CustomError.FILE_NOT_FOUND) {
hasHiddenAlbums = hiddenAlbums?.length > 0;
}
log.error(
`decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`,
e,
);
let modelLastSinceTime = await getModelEmbeddingSyncTime(model);
log.info(
`Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}`,
);
let response: GetEmbeddingDiffResponse;
do {
response = await getEmbeddingsDiff(modelLastSinceTime, model);
if (!response.diff?.length) {
return;
}
const newEmbeddings = await Promise.all(
response.diff.map(async (embedding) => {
try {
const worker = await ComlinkCryptoWorker.getInstance();
const fileKey = fileIdToKeyMap.get(embedding.fileID);
if (!fileKey) {
throw Error(CustomError.FILE_NOT_FOUND);
}
}),
);
allEmbeddings = getLatestVersionFileEmbeddings([
...allEmbeddings,
...newEmbeddings,
]);
if (response.diff.length) {
modelLastSinceTime = response.diff.slice(-1)[0].updatedAt;
}
await localForage.setItem(FILE_EMBEDING_TABLE, allEmbeddings);
await setModelEmbeddingSyncTime(model, modelLastSinceTime);
log.info(
`Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`,
);
} while (response.diff.length === DIFF_LIMIT);
}
const decryptedData = await worker.decryptMetadata(
embedding.encryptedEmbedding,
embedding.decryptionHeader,
fileIdToKeyMap.get(embedding.fileID),
);
return {
...decryptedData,
updatedAt: embedding.updatedAt,
} as unknown as FileML;
} catch (e) {
let hasHiddenAlbums = false;
if (e.message === CustomError.FILE_NOT_FOUND) {
hasHiddenAlbums = hiddenAlbums?.length > 0;
}
log.error(
`decryptEmbedding failed for file (hasHiddenAlbums: ${hasHiddenAlbums})`,
e,
);
}
}),
);
allEmbeddings = getLatestVersionFileEmbeddings([
...allEmbeddings,
...newEmbeddings,
]);
if (response.diff.length) {
modelLastSinceTime = response.diff.slice(-1)[0].updatedAt;
}
await localForage.setItem(FILE_EMBEDING_TABLE, allEmbeddings);
await setModelEmbeddingSyncTime(model, modelLastSinceTime);
log.info(
`Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`,
);
} while (response.diff.length === DIFF_LIMIT);
} catch (e) {
log.error("Sync embeddings failed", e);
}
};
const getLatestVersionEmbeddings = (embeddings: Embedding[]) => {
const latestVersionEntities = new Map<number, Embedding>();
embeddings.forEach((embedding) => {
if (!embedding?.fileID) {
return;
}
const existingEmbeddings = latestVersionEntities.get(embedding.fileID);
if (
!existingEmbeddings ||
existingEmbeddings.updatedAt < embedding.updatedAt
) {
latestVersionEntities.set(embedding.fileID, embedding);
}
});
return Array.from(latestVersionEntities.values());
};
const getLatestVersionFileEmbeddings = (embeddings: FileML[]) => {
const latestVersionEntities = new Map<number, FileML>();
embeddings.forEach((embedding) => {
if (!embedding?.fileID) {
return;
}
const existingEmbeddings = latestVersionEntities.get(embedding.fileID);
if (
!existingEmbeddings ||
existingEmbeddings.updatedAt < embedding.updatedAt
) {
latestVersionEntities.set(embedding.fileID, embedding);
}
});
return Array.from(latestVersionEntities.values());
};
export const getEmbeddingsDiff = async (
sinceTime: number,
model: EmbeddingModel,
@ -251,7 +276,7 @@ export const getEmbeddingsDiff = async (
return;
}
const response = await HTTPService.get(
`${ENDPOINT}/embeddings/diff`,
`${getEndpoint()}/embeddings/diff`,
{
sinceTime,
limit: DIFF_LIMIT,
@ -280,7 +305,7 @@ export const putEmbedding = async (
throw Error(CustomError.TOKEN_MISSING);
}
const resp = await HTTPService.put(
`${ENDPOINT}/embeddings`,
`${getEndpoint()}/embeddings`,
putEmbeddingReq,
null,
{

View file

@ -1,3 +1,5 @@
import { Matrix } from "ml-matrix";
import { Point } from "services/ml/geom";
import {
FaceAlignment,
FaceAlignmentMethod,
@ -5,7 +7,7 @@ import {
FaceDetection,
Versioned,
} from "services/ml/types";
import { getArcfaceAlignment } from "utils/machineLearning/faceAlign";
import { getSimilarityTransformation } from "similarity-transformation";
class ArcfaceAlignmentService implements FaceAlignmentService {
public method: Versioned<FaceAlignmentMethod>;
@ -23,3 +25,86 @@ class ArcfaceAlignmentService implements FaceAlignmentService {
}
export default new ArcfaceAlignmentService();
const ARCFACE_LANDMARKS = [
[38.2946, 51.6963],
[73.5318, 51.5014],
[56.0252, 71.7366],
[56.1396, 92.2848],
] as Array<[number, number]>;
const ARCFACE_LANDMARKS_FACE_SIZE = 112;
const ARC_FACE_5_LANDMARKS = [
[38.2946, 51.6963],
[73.5318, 51.5014],
[56.0252, 71.7366],
[41.5493, 92.3655],
[70.7299, 92.2041],
] as Array<[number, number]>;
export function getArcfaceAlignment(
faceDetection: FaceDetection,
): FaceAlignment {
const landmarkCount = faceDetection.landmarks.length;
return getFaceAlignmentUsingSimilarityTransform(
faceDetection,
normalizeLandmarks(
landmarkCount === 5 ? ARC_FACE_5_LANDMARKS : ARCFACE_LANDMARKS,
ARCFACE_LANDMARKS_FACE_SIZE,
),
);
}
function getFaceAlignmentUsingSimilarityTransform(
faceDetection: FaceDetection,
alignedLandmarks: Array<[number, number]>,
// alignmentMethod: Versioned<FaceAlignmentMethod>
): FaceAlignment {
const landmarksMat = new Matrix(
faceDetection.landmarks
.map((p) => [p.x, p.y])
.slice(0, alignedLandmarks.length),
).transpose();
const alignedLandmarksMat = new Matrix(alignedLandmarks).transpose();
const simTransform = getSimilarityTransformation(
landmarksMat,
alignedLandmarksMat,
);
const RS = Matrix.mul(simTransform.rotation, simTransform.scale);
const TR = simTransform.translation;
const affineMatrix = [
[RS.get(0, 0), RS.get(0, 1), TR.get(0, 0)],
[RS.get(1, 0), RS.get(1, 1), TR.get(1, 0)],
[0, 0, 1],
];
const size = 1 / simTransform.scale;
const meanTranslation = simTransform.toMean.sub(0.5).mul(size);
const centerMat = simTransform.fromMean.sub(meanTranslation);
const center = new Point(centerMat.get(0, 0), centerMat.get(1, 0));
const rotation = -Math.atan2(
simTransform.rotation.get(0, 1),
simTransform.rotation.get(0, 0),
);
// log.info({ affineMatrix, meanTranslation, centerMat, center, toMean: simTransform.toMean, fromMean: simTransform.fromMean, size });
return {
affineMatrix,
center,
size,
rotation,
};
}
function normalizeLandmarks(
landmarks: Array<[number, number]>,
faceSize: number,
): Array<[number, number]> {
return landmarks.map((landmark) =>
landmark.map((p) => p / faceSize),
) as Array<[number, number]>;
}

View file

@ -1,4 +1,6 @@
import { Box, enlargeBox } from "services/ml/geom";
import {
FaceAlignment,
FaceCrop,
FaceCropConfig,
FaceCropMethod,
@ -6,8 +8,8 @@ import {
FaceDetection,
Versioned,
} from "services/ml/types";
import { getArcfaceAlignment } from "utils/machineLearning/faceAlign";
import { getFaceCrop } from "utils/machineLearning/faceCrop";
import { cropWithRotation } from "utils/image";
import { getArcfaceAlignment } from "./arcfaceAlignmentService";
class ArcFaceCropService implements FaceCropService {
public method: Versioned<FaceCropMethod>;
@ -32,3 +34,27 @@ class ArcFaceCropService implements FaceCropService {
}
export default new ArcFaceCropService();
export function getFaceCrop(
imageBitmap: ImageBitmap,
alignment: FaceAlignment,
config: FaceCropConfig,
): FaceCrop {
const alignmentBox = new Box({
x: alignment.center.x - alignment.size / 2,
y: alignment.center.y - alignment.size / 2,
width: alignment.size,
height: alignment.size,
}).round();
const scaleForPadding = 1 + config.padding * 2;
const paddedBox = enlargeBox(alignmentBox, scaleForPadding).round();
const faceImageBitmap = cropWithRotation(imageBitmap, paddedBox, 0, {
width: config.maxSize,
height: config.maxSize,
});
return {
image: faceImageBitmap,
imageBox: paddedBox,
};
}

View file

@ -1,22 +1,20 @@
import { openCache } from "@/next/blob-cache";
import log from "@/next/log";
import mlIDbStorage from "services/ml/db";
import {
DetectedFace,
Face,
MLSyncContext,
MLSyncFileContext,
type FaceAlignment,
type Versioned,
} from "services/ml/types";
import { imageBitmapToBlob } from "utils/image";
import {
areFaceIdsSame,
extractFaceImagesToFloat32,
import { imageBitmapToBlob, warpAffineFloat32List } from "utils/image";
import ReaderService, {
getFaceId,
getLocalFile,
getOriginalImageBitmap,
isDifferentOrOld,
} from "utils/machineLearning";
import mlIDbStorage from "utils/storage/mlIDbStorage";
import ReaderService from "./readerService";
} from "./readerService";
class FaceService {
async syncFileFaceDetections(
@ -304,3 +302,58 @@ class FaceService {
}
export default new FaceService();
export function areFaceIdsSame(ofFaces: Array<Face>, toFaces: Array<Face>) {
if (
(ofFaces === null || ofFaces === undefined) &&
(toFaces === null || toFaces === undefined)
) {
return true;
}
return primitiveArrayEquals(
ofFaces?.map((f) => f.id),
toFaces?.map((f) => f.id),
);
}
function primitiveArrayEquals(a, b) {
return (
Array.isArray(a) &&
Array.isArray(b) &&
a.length === b.length &&
a.every((val, index) => val === b[index])
);
}
export function isDifferentOrOld(
method: Versioned<string>,
thanMethod: Versioned<string>,
) {
return (
!method ||
method.value !== thanMethod.value ||
method.version < thanMethod.version
);
}
async function extractFaceImagesToFloat32(
faceAlignments: Array<FaceAlignment>,
faceSize: number,
image: ImageBitmap,
): Promise<Float32Array> {
const faceData = new Float32Array(
faceAlignments.length * faceSize * faceSize * 3,
);
for (let i = 0; i < faceAlignments.length; i++) {
const alignedFace = faceAlignments[i];
const faceDataOffset = i * faceSize * faceSize * 3;
warpAffineFloat32List(
image,
alignedFace,
faceSize,
faceData,
faceDataOffset,
);
}
return faceData;
}

View file

@ -1,216 +0,0 @@
import { haveWindow } from "@/next/env";
import log from "@/next/log";
import { ComlinkWorker } from "@/next/worker/comlink-worker";
import { getDedicatedCryptoWorker } from "@ente/shared/crypto";
import { DedicatedCryptoWorker } from "@ente/shared/crypto/internal/crypto.worker";
import PQueue from "p-queue";
import {
BlurDetectionMethod,
BlurDetectionService,
ClusteringMethod,
ClusteringService,
Face,
FaceAlignmentMethod,
FaceAlignmentService,
FaceCropMethod,
FaceCropService,
FaceDetectionMethod,
FaceDetectionService,
FaceEmbeddingMethod,
FaceEmbeddingService,
MLLibraryData,
MLSyncConfig,
MLSyncContext,
} from "services/ml/types";
import { EnteFile } from "types/file";
import { logQueueStats } from "utils/machineLearning";
import arcfaceAlignmentService from "./arcfaceAlignmentService";
import arcfaceCropService from "./arcfaceCropService";
import dbscanClusteringService from "./dbscanClusteringService";
import hdbscanClusteringService from "./hdbscanClusteringService";
import laplacianBlurDetectionService from "./laplacianBlurDetectionService";
import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService";
import yoloFaceDetectionService from "./yoloFaceDetectionService";
export class MLFactory {
public static getFaceDetectionService(
method: FaceDetectionMethod,
): FaceDetectionService {
if (method === "YoloFace") {
return yoloFaceDetectionService;
}
throw Error("Unknon face detection method: " + method);
}
public static getFaceCropService(method: FaceCropMethod) {
if (method === "ArcFace") {
return arcfaceCropService;
}
throw Error("Unknon face crop method: " + method);
}
public static getFaceAlignmentService(
method: FaceAlignmentMethod,
): FaceAlignmentService {
if (method === "ArcFace") {
return arcfaceAlignmentService;
}
throw Error("Unknon face alignment method: " + method);
}
public static getBlurDetectionService(
method: BlurDetectionMethod,
): BlurDetectionService {
if (method === "Laplacian") {
return laplacianBlurDetectionService;
}
throw Error("Unknon blur detection method: " + method);
}
public static getFaceEmbeddingService(
method: FaceEmbeddingMethod,
): FaceEmbeddingService {
if (method === "MobileFaceNet") {
return mobileFaceNetEmbeddingService;
}
throw Error("Unknon face embedding method: " + method);
}
public static getClusteringService(
method: ClusteringMethod,
): ClusteringService {
if (method === "Hdbscan") {
return hdbscanClusteringService;
}
if (method === "Dbscan") {
return dbscanClusteringService;
}
throw Error("Unknon clustering method: " + method);
}
public static getMLSyncContext(
token: string,
userID: number,
config: MLSyncConfig,
shouldUpdateMLVersion: boolean = true,
) {
return new LocalMLSyncContext(
token,
userID,
config,
shouldUpdateMLVersion,
);
}
}
export class LocalMLSyncContext implements MLSyncContext {
public token: string;
public userID: number;
public config: MLSyncConfig;
public shouldUpdateMLVersion: boolean;
public faceDetectionService: FaceDetectionService;
public faceCropService: FaceCropService;
public faceAlignmentService: FaceAlignmentService;
public blurDetectionService: BlurDetectionService;
public faceEmbeddingService: FaceEmbeddingService;
public faceClusteringService: ClusteringService;
public localFilesMap: Map<number, EnteFile>;
public outOfSyncFiles: EnteFile[];
public nSyncedFiles: number;
public nSyncedFaces: number;
public allSyncedFacesMap?: Map<number, Array<Face>>;
public error?: Error;
public mlLibraryData: MLLibraryData;
public syncQueue: PQueue;
// TODO: wheather to limit concurrent downloads
// private downloadQueue: PQueue;
private concurrency: number;
private comlinkCryptoWorker: Array<
ComlinkWorker<typeof DedicatedCryptoWorker>
>;
private enteWorkers: Array<any>;
constructor(
token: string,
userID: number,
config: MLSyncConfig,
shouldUpdateMLVersion: boolean = true,
concurrency?: number,
) {
this.token = token;
this.userID = userID;
this.config = config;
this.shouldUpdateMLVersion = shouldUpdateMLVersion;
this.faceDetectionService = MLFactory.getFaceDetectionService(
this.config.faceDetection.method,
);
this.faceCropService = MLFactory.getFaceCropService(
this.config.faceCrop.method,
);
this.faceAlignmentService = MLFactory.getFaceAlignmentService(
this.config.faceAlignment.method,
);
this.blurDetectionService = MLFactory.getBlurDetectionService(
this.config.blurDetection.method,
);
this.faceEmbeddingService = MLFactory.getFaceEmbeddingService(
this.config.faceEmbedding.method,
);
this.faceClusteringService = MLFactory.getClusteringService(
this.config.faceClustering.method,
);
this.outOfSyncFiles = [];
this.nSyncedFiles = 0;
this.nSyncedFaces = 0;
this.concurrency = concurrency ?? getConcurrency();
log.info("Using concurrency: ", this.concurrency);
// timeout is added on downloads
// timeout on queue will keep the operation open till worker is terminated
this.syncQueue = new PQueue({ concurrency: this.concurrency });
logQueueStats(this.syncQueue, "sync");
// this.downloadQueue = new PQueue({ concurrency: 1 });
// logQueueStats(this.downloadQueue, 'download');
this.comlinkCryptoWorker = new Array(this.concurrency);
this.enteWorkers = new Array(this.concurrency);
}
public async getEnteWorker(id: number): Promise<any> {
const wid = id % this.enteWorkers.length;
console.log("getEnteWorker: ", id, wid);
if (!this.enteWorkers[wid]) {
this.comlinkCryptoWorker[wid] = getDedicatedCryptoWorker();
this.enteWorkers[wid] = await this.comlinkCryptoWorker[wid].remote;
}
return this.enteWorkers[wid];
}
public async dispose() {
this.localFilesMap = undefined;
await this.syncQueue.onIdle();
this.syncQueue.removeAllListeners();
for (const enteComlinkWorker of this.comlinkCryptoWorker) {
enteComlinkWorker?.terminate();
}
}
}
export const getConcurrency = () =>
haveWindow() && Math.max(2, Math.ceil(navigator.hardwareConcurrency / 2));

View file

@ -1,25 +1,330 @@
import { haveWindow } from "@/next/env";
import log from "@/next/log";
import { ComlinkWorker } from "@/next/worker/comlink-worker";
import { APPS } from "@ente/shared/apps/constants";
import ComlinkCryptoWorker from "@ente/shared/crypto";
import ComlinkCryptoWorker, {
getDedicatedCryptoWorker,
} from "@ente/shared/crypto";
import { DedicatedCryptoWorker } from "@ente/shared/crypto/internal/crypto.worker";
import { CustomError, parseUploadErrorCodes } from "@ente/shared/error";
import { MAX_ML_SYNC_ERROR_COUNT } from "constants/mlConfig";
import PQueue from "p-queue";
import downloadManager from "services/download";
import { putEmbedding } from "services/embeddingService";
import { getLocalFiles } from "services/fileService";
import mlIDbStorage, {
ML_SEARCH_CONFIG_NAME,
ML_SYNC_CONFIG_NAME,
ML_SYNC_JOB_CONFIG_NAME,
} from "services/ml/db";
import {
BlurDetectionMethod,
BlurDetectionService,
ClusteringMethod,
ClusteringService,
Face,
FaceAlignmentMethod,
FaceAlignmentService,
FaceCropMethod,
FaceCropService,
FaceDetection,
FaceDetectionMethod,
FaceDetectionService,
FaceEmbeddingMethod,
FaceEmbeddingService,
Landmark,
MLLibraryData,
MLSearchConfig,
MLSyncConfig,
MLSyncContext,
MLSyncFileContext,
MLSyncResult,
MlFileData,
} from "services/ml/types";
import { EnteFile } from "types/file";
import { getMLSyncConfig } from "utils/machineLearning/config";
import { LocalFileMlDataToServerFileMl } from "utils/machineLearning/mldataMappers";
import mlIDbStorage from "utils/storage/mlIDbStorage";
import { isInternalUserForML } from "utils/user";
import arcfaceAlignmentService from "./arcfaceAlignmentService";
import arcfaceCropService from "./arcfaceCropService";
import dbscanClusteringService from "./dbscanClusteringService";
import FaceService from "./faceService";
import { MLFactory } from "./machineLearningFactory";
import hdbscanClusteringService from "./hdbscanClusteringService";
import laplacianBlurDetectionService from "./laplacianBlurDetectionService";
import type { JobConfig } from "./mlWorkManager";
import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService";
import PeopleService from "./peopleService";
import ReaderService from "./readerService";
import yoloFaceDetectionService from "./yoloFaceDetectionService";
export const DEFAULT_ML_SYNC_JOB_CONFIG: JobConfig = {
intervalSec: 5,
// TODO: finalize this after seeing effects on and from machine sleep
maxItervalSec: 960,
backoffMultiplier: 2,
};
export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = {
batchSize: 200,
imageSource: "Original",
faceDetection: {
method: "YoloFace",
},
faceCrop: {
enabled: true,
method: "ArcFace",
padding: 0.25,
maxSize: 256,
blobOptions: {
type: "image/jpeg",
quality: 0.8,
},
},
faceAlignment: {
method: "ArcFace",
},
blurDetection: {
method: "Laplacian",
threshold: 15,
},
faceEmbedding: {
method: "MobileFaceNet",
faceSize: 112,
generateTsne: true,
},
faceClustering: {
method: "Hdbscan",
minClusterSize: 3,
minSamples: 5,
clusterSelectionEpsilon: 0.6,
clusterSelectionMethod: "leaf",
minInputSize: 50,
// maxDistanceInsideCluster: 0.4,
generateDebugInfo: true,
},
mlVersion: 3,
};
export const DEFAULT_ML_SEARCH_CONFIG: MLSearchConfig = {
enabled: false,
};
export const MAX_ML_SYNC_ERROR_COUNT = 1;
export async function getMLSyncJobConfig() {
return mlIDbStorage.getConfig(
ML_SYNC_JOB_CONFIG_NAME,
DEFAULT_ML_SYNC_JOB_CONFIG,
);
}
export async function getMLSyncConfig() {
return mlIDbStorage.getConfig(ML_SYNC_CONFIG_NAME, DEFAULT_ML_SYNC_CONFIG);
}
export async function getMLSearchConfig() {
if (isInternalUserForML()) {
return mlIDbStorage.getConfig(
ML_SEARCH_CONFIG_NAME,
DEFAULT_ML_SEARCH_CONFIG,
);
}
// Force disabled for everyone else while we finalize it to avoid redundant
// reindexing for users.
return DEFAULT_ML_SEARCH_CONFIG;
}
export async function updateMLSyncJobConfig(newConfig: JobConfig) {
return mlIDbStorage.putConfig(ML_SYNC_JOB_CONFIG_NAME, newConfig);
}
export async function updateMLSyncConfig(newConfig: MLSyncConfig) {
return mlIDbStorage.putConfig(ML_SYNC_CONFIG_NAME, newConfig);
}
export async function updateMLSearchConfig(newConfig: MLSearchConfig) {
return mlIDbStorage.putConfig(ML_SEARCH_CONFIG_NAME, newConfig);
}
export class MLFactory {
public static getFaceDetectionService(
method: FaceDetectionMethod,
): FaceDetectionService {
if (method === "YoloFace") {
return yoloFaceDetectionService;
}
throw Error("Unknon face detection method: " + method);
}
public static getFaceCropService(method: FaceCropMethod) {
if (method === "ArcFace") {
return arcfaceCropService;
}
throw Error("Unknon face crop method: " + method);
}
public static getFaceAlignmentService(
method: FaceAlignmentMethod,
): FaceAlignmentService {
if (method === "ArcFace") {
return arcfaceAlignmentService;
}
throw Error("Unknon face alignment method: " + method);
}
public static getBlurDetectionService(
method: BlurDetectionMethod,
): BlurDetectionService {
if (method === "Laplacian") {
return laplacianBlurDetectionService;
}
throw Error("Unknon blur detection method: " + method);
}
public static getFaceEmbeddingService(
method: FaceEmbeddingMethod,
): FaceEmbeddingService {
if (method === "MobileFaceNet") {
return mobileFaceNetEmbeddingService;
}
throw Error("Unknon face embedding method: " + method);
}
public static getClusteringService(
method: ClusteringMethod,
): ClusteringService {
if (method === "Hdbscan") {
return hdbscanClusteringService;
}
if (method === "Dbscan") {
return dbscanClusteringService;
}
throw Error("Unknon clustering method: " + method);
}
public static getMLSyncContext(
token: string,
userID: number,
config: MLSyncConfig,
shouldUpdateMLVersion: boolean = true,
) {
return new LocalMLSyncContext(
token,
userID,
config,
shouldUpdateMLVersion,
);
}
}
export class LocalMLSyncContext implements MLSyncContext {
public token: string;
public userID: number;
public config: MLSyncConfig;
public shouldUpdateMLVersion: boolean;
public faceDetectionService: FaceDetectionService;
public faceCropService: FaceCropService;
public faceAlignmentService: FaceAlignmentService;
public blurDetectionService: BlurDetectionService;
public faceEmbeddingService: FaceEmbeddingService;
public faceClusteringService: ClusteringService;
public localFilesMap: Map<number, EnteFile>;
public outOfSyncFiles: EnteFile[];
public nSyncedFiles: number;
public nSyncedFaces: number;
public allSyncedFacesMap?: Map<number, Array<Face>>;
public error?: Error;
public mlLibraryData: MLLibraryData;
public syncQueue: PQueue;
// TODO: wheather to limit concurrent downloads
// private downloadQueue: PQueue;
private concurrency: number;
private comlinkCryptoWorker: Array<
ComlinkWorker<typeof DedicatedCryptoWorker>
>;
private enteWorkers: Array<any>;
constructor(
token: string,
userID: number,
config: MLSyncConfig,
shouldUpdateMLVersion: boolean = true,
concurrency?: number,
) {
this.token = token;
this.userID = userID;
this.config = config;
this.shouldUpdateMLVersion = shouldUpdateMLVersion;
this.faceDetectionService = MLFactory.getFaceDetectionService(
this.config.faceDetection.method,
);
this.faceCropService = MLFactory.getFaceCropService(
this.config.faceCrop.method,
);
this.faceAlignmentService = MLFactory.getFaceAlignmentService(
this.config.faceAlignment.method,
);
this.blurDetectionService = MLFactory.getBlurDetectionService(
this.config.blurDetection.method,
);
this.faceEmbeddingService = MLFactory.getFaceEmbeddingService(
this.config.faceEmbedding.method,
);
this.faceClusteringService = MLFactory.getClusteringService(
this.config.faceClustering.method,
);
this.outOfSyncFiles = [];
this.nSyncedFiles = 0;
this.nSyncedFaces = 0;
this.concurrency = concurrency ?? getConcurrency();
log.info("Using concurrency: ", this.concurrency);
// timeout is added on downloads
// timeout on queue will keep the operation open till worker is terminated
this.syncQueue = new PQueue({ concurrency: this.concurrency });
logQueueStats(this.syncQueue, "sync");
// this.downloadQueue = new PQueue({ concurrency: 1 });
// logQueueStats(this.downloadQueue, 'download');
this.comlinkCryptoWorker = new Array(this.concurrency);
this.enteWorkers = new Array(this.concurrency);
}
public async getEnteWorker(id: number): Promise<any> {
const wid = id % this.enteWorkers.length;
console.log("getEnteWorker: ", id, wid);
if (!this.enteWorkers[wid]) {
this.comlinkCryptoWorker[wid] = getDedicatedCryptoWorker();
this.enteWorkers[wid] = await this.comlinkCryptoWorker[wid].remote;
}
return this.enteWorkers[wid];
}
public async dispose() {
this.localFilesMap = undefined;
await this.syncQueue.onIdle();
this.syncQueue.removeAllListeners();
for (const enteComlinkWorker of this.comlinkCryptoWorker) {
enteComlinkWorker?.terminate();
}
}
}
export const getConcurrency = () =>
haveWindow() && Math.max(2, Math.ceil(navigator.hardwareConcurrency / 2));
class MachineLearningService {
private localSyncContext: Promise<MLSyncContext>;
@ -445,3 +750,160 @@ class MachineLearningService {
}
export default new MachineLearningService();
export interface FileML extends ServerFileMl {
updatedAt: number;
}
class ServerFileMl {
public fileID: number;
public height?: number;
public width?: number;
public faceEmbedding: ServerFaceEmbeddings;
public constructor(
fileID: number,
faceEmbedding: ServerFaceEmbeddings,
height?: number,
width?: number,
) {
this.fileID = fileID;
this.height = height;
this.width = width;
this.faceEmbedding = faceEmbedding;
}
}
class ServerFaceEmbeddings {
public faces: ServerFace[];
public version: number;
public client?: string;
public error?: boolean;
public constructor(
faces: ServerFace[],
version: number,
client?: string,
error?: boolean,
) {
this.faces = faces;
this.version = version;
this.client = client;
this.error = error;
}
}
class ServerFace {
public faceID: string;
public embeddings: number[];
public detection: ServerDetection;
public score: number;
public blur: number;
public constructor(
faceID: string,
embeddings: number[],
detection: ServerDetection,
score: number,
blur: number,
) {
this.faceID = faceID;
this.embeddings = embeddings;
this.detection = detection;
this.score = score;
this.blur = blur;
}
}
class ServerDetection {
public box: ServerFaceBox;
public landmarks: Landmark[];
public constructor(box: ServerFaceBox, landmarks: Landmark[]) {
this.box = box;
this.landmarks = landmarks;
}
}
class ServerFaceBox {
public xMin: number;
public yMin: number;
public width: number;
public height: number;
public constructor(
xMin: number,
yMin: number,
width: number,
height: number,
) {
this.xMin = xMin;
this.yMin = yMin;
this.width = width;
this.height = height;
}
}
function LocalFileMlDataToServerFileMl(
localFileMlData: MlFileData,
): ServerFileMl {
if (
localFileMlData.errorCount > 0 &&
localFileMlData.lastErrorMessage !== undefined
) {
return null;
}
const imageDimensions = localFileMlData.imageDimensions;
const faces: ServerFace[] = [];
for (let i = 0; i < localFileMlData.faces.length; i++) {
const face: Face = localFileMlData.faces[i];
const faceID = face.id;
const embedding = face.embedding;
const score = face.detection.probability;
const blur = face.blurValue;
const detection: FaceDetection = face.detection;
const box = detection.box;
const landmarks = detection.landmarks;
const newBox = new ServerFaceBox(box.x, box.y, box.width, box.height);
const newLandmarks: Landmark[] = [];
for (let j = 0; j < landmarks.length; j++) {
newLandmarks.push({
x: landmarks[j].x,
y: landmarks[j].y,
} as Landmark);
}
const newFaceObject = new ServerFace(
faceID,
Array.from(embedding),
new ServerDetection(newBox, newLandmarks),
score,
blur,
);
faces.push(newFaceObject);
}
const faceEmbeddings = new ServerFaceEmbeddings(
faces,
1,
localFileMlData.lastErrorMessage,
);
return new ServerFileMl(
localFileMlData.fileId,
faceEmbeddings,
imageDimensions.height,
imageDimensions.width,
);
}
export function logQueueStats(queue: PQueue, name: string) {
queue.on("active", () =>
log.info(
`queuestats: ${name}: Active, Size: ${queue.size} Pending: ${queue.pending}`,
),
);
queue.on("idle", () => log.info(`queuestats: ${name}: Idle`));
queue.on("error", (error) =>
console.error(`queuestats: ${name}: Error, `, error),
);
}

View file

@ -5,20 +5,110 @@ import { eventBus, Events } from "@ente/shared/events";
import { getToken, getUserID } from "@ente/shared/storage/localStorage/helpers";
import debounce from "debounce";
import PQueue from "p-queue";
import { getMLSyncJobConfig } from "services/machineLearning/machineLearningService";
import mlIDbStorage from "services/ml/db";
import { MLSyncResult } from "services/ml/types";
import { JobResult } from "types/common/job";
import { EnteFile } from "types/file";
import { getDedicatedMLWorker } from "utils/comlink/ComlinkMLWorker";
import { SimpleJob } from "utils/common/job";
import { logQueueStats } from "utils/machineLearning";
import { getMLSyncJobConfig } from "utils/machineLearning/config";
import mlIDbStorage from "utils/storage/mlIDbStorage";
import { DedicatedMLWorker } from "worker/ml.worker";
import { logQueueStats } from "./machineLearningService";
const LIVE_SYNC_IDLE_DEBOUNCE_SEC = 30;
const LIVE_SYNC_QUEUE_TIMEOUT_SEC = 300;
const LOCAL_FILES_UPDATED_DEBOUNCE_SEC = 30;
export type JobState = "Scheduled" | "Running" | "NotScheduled";
export interface JobConfig {
intervalSec: number;
maxItervalSec: number;
backoffMultiplier: number;
}
export interface JobResult {
shouldBackoff: boolean;
}
export class SimpleJob<R extends JobResult> {
private config: JobConfig;
private runCallback: () => Promise<R>;
private state: JobState;
private stopped: boolean;
private intervalSec: number;
private nextTimeoutId: ReturnType<typeof setTimeout>;
constructor(config: JobConfig, runCallback: () => Promise<R>) {
this.config = config;
this.runCallback = runCallback;
this.state = "NotScheduled";
this.stopped = true;
this.intervalSec = this.config.intervalSec;
}
public resetInterval() {
this.intervalSec = this.config.intervalSec;
}
public start() {
this.stopped = false;
this.resetInterval();
if (this.state !== "Running") {
this.scheduleNext();
} else {
log.info("Job already running, not scheduling");
}
}
private scheduleNext() {
if (this.state === "Scheduled" || this.nextTimeoutId) {
this.clearScheduled();
}
this.nextTimeoutId = setTimeout(
() => this.run(),
this.intervalSec * 1000,
);
this.state = "Scheduled";
log.info("Scheduled next job after: ", this.intervalSec);
}
async run() {
this.nextTimeoutId = undefined;
this.state = "Running";
try {
const jobResult = await this.runCallback();
if (jobResult && jobResult.shouldBackoff) {
this.intervalSec = Math.min(
this.config.maxItervalSec,
this.intervalSec * this.config.backoffMultiplier,
);
} else {
this.resetInterval();
}
log.info("Job completed");
} catch (e) {
console.error("Error while running Job: ", e);
} finally {
this.state = "NotScheduled";
!this.stopped && this.scheduleNext();
}
}
// currently client is responsible to terminate running job
public stop() {
this.stopped = true;
this.clearScheduled();
}
private clearScheduled() {
clearTimeout(this.nextTimeoutId);
this.nextTimeoutId = undefined;
this.state = "NotScheduled";
log.info("Cleared next job");
}
}
export interface MLSyncJobResult extends JobResult {
mlSyncResult: MLSyncResult;
}

View file

@ -1,14 +1,8 @@
import log from "@/next/log";
import mlIDbStorage from "services/ml/db";
import { Face, MLSyncContext, Person } from "services/ml/types";
import {
findFirstIfSorted,
getAllFacesFromMap,
getLocalFile,
getOriginalImageBitmap,
isDifferentOrOld,
} from "utils/machineLearning";
import mlIDbStorage from "utils/storage/mlIDbStorage";
import FaceService from "./faceService";
import FaceService, { isDifferentOrOld } from "./faceService";
import { getLocalFile, getOriginalImageBitmap } from "./readerService";
class PeopleService {
async syncPeopleIndex(syncContext: MLSyncContext) {
@ -92,3 +86,28 @@ class PeopleService {
}
export default new PeopleService();
function findFirstIfSorted<T>(
elements: Array<T>,
comparator: (a: T, b: T) => number,
) {
if (!elements || elements.length < 1) {
return;
}
let first = elements[0];
for (let i = 1; i < elements.length; i++) {
const comp = comparator(elements[i], first);
if (comp < 0) {
first = elements[i];
}
}
return first;
}
function getAllFacesFromMap(allFacesMap: Map<number, Array<Face>>) {
const allFaces = [...allFacesMap.values()].flat();
return allFaces;
}

View file

@ -1,11 +1,18 @@
import { FILE_TYPE } from "@/media/file-type";
import { decodeLivePhoto } from "@/media/live-photo";
import log from "@/next/log";
import { MLSyncContext, MLSyncFileContext } from "services/ml/types";
import PQueue from "p-queue";
import DownloadManager from "services/download";
import { getLocalFiles } from "services/fileService";
import { Dimensions } from "services/ml/geom";
import {
getLocalFileImageBitmap,
getOriginalImageBitmap,
getThumbnailImageBitmap,
} from "utils/machineLearning";
DetectedFace,
MLSyncContext,
MLSyncFileContext,
} from "services/ml/types";
import { EnteFile } from "types/file";
import { getRenderableImage } from "utils/file";
import { clamp } from "utils/image";
class ReaderService {
async getImageBitmap(
@ -55,3 +62,95 @@ class ReaderService {
}
}
export default new ReaderService();
export async function getLocalFile(fileId: number) {
const localFiles = await getLocalFiles();
return localFiles.find((f) => f.id === fileId);
}
export function getFaceId(detectedFace: DetectedFace, imageDims: Dimensions) {
const xMin = clamp(
detectedFace.detection.box.x / imageDims.width,
0.0,
0.999999,
)
.toFixed(5)
.substring(2);
const yMin = clamp(
detectedFace.detection.box.y / imageDims.height,
0.0,
0.999999,
)
.toFixed(5)
.substring(2);
const xMax = clamp(
(detectedFace.detection.box.x + detectedFace.detection.box.width) /
imageDims.width,
0.0,
0.999999,
)
.toFixed(5)
.substring(2);
const yMax = clamp(
(detectedFace.detection.box.y + detectedFace.detection.box.height) /
imageDims.height,
0.0,
0.999999,
)
.toFixed(5)
.substring(2);
const rawFaceID = `${xMin}_${yMin}_${xMax}_${yMax}`;
const faceID = `${detectedFace.fileId}_${rawFaceID}`;
return faceID;
}
async function getImageBlobBitmap(blob: Blob): Promise<ImageBitmap> {
return await createImageBitmap(blob);
}
async function getOriginalFile(file: EnteFile, queue?: PQueue) {
let fileStream;
if (queue) {
fileStream = await queue.add(() => DownloadManager.getFile(file));
} else {
fileStream = await DownloadManager.getFile(file);
}
return new Response(fileStream).blob();
}
async function getOriginalConvertedFile(file: EnteFile, queue?: PQueue) {
const fileBlob = await getOriginalFile(file, queue);
if (file.metadata.fileType === FILE_TYPE.IMAGE) {
return await getRenderableImage(file.metadata.title, fileBlob);
} else {
const { imageFileName, imageData } = await decodeLivePhoto(
file.metadata.title,
fileBlob,
);
return await getRenderableImage(imageFileName, new Blob([imageData]));
}
}
export async function getOriginalImageBitmap(file: EnteFile, queue?: PQueue) {
const fileBlob = await getOriginalConvertedFile(file, queue);
log.info("[MLService] Got file: ", file.id.toString());
return getImageBlobBitmap(fileBlob);
}
export async function getThumbnailImageBitmap(file: EnteFile) {
const thumb = await DownloadManager.getThumbnail(file);
log.info("[MLService] Got thumbnail: ", file.id.toString());
return getImageBlobBitmap(new Blob([thumb]));
}
export async function getLocalFileImageBitmap(
enteFile: EnteFile,
localFile: globalThis.File,
) {
let fileBlob = localFile as Blob;
fileBlob = await getRenderableImage(enteFile.metadata.title, fileBlob);
return getImageBlobBitmap(fileBlob);
}

View file

@ -323,14 +323,10 @@ function transformBox(box: Box, transform: Matrix) {
const topLeft = transformPoint(box.topLeft, transform);
const bottomRight = transformPoint(box.bottomRight, transform);
return newBoxFromPoints(topLeft.x, topLeft.y, bottomRight.x, bottomRight.y);
}
function newBoxFromPoints(
left: number,
top: number,
right: number,
bottom: number,
) {
return boxFromBoundingBox({ left, top, right, bottom });
return boxFromBoundingBox({
left: topLeft.x,
top: topLeft.y,
right: bottomRight.x,
bottom: bottomRight.y,
});
}

View file

@ -1,11 +1,5 @@
import { haveWindow } from "@/next/env";
import log from "@/next/log";
import {
DEFAULT_ML_SEARCH_CONFIG,
DEFAULT_ML_SYNC_CONFIG,
DEFAULT_ML_SYNC_JOB_CONFIG,
MAX_ML_SYNC_ERROR_COUNT,
} from "constants/mlConfig";
import {
DBSchema,
IDBPDatabase,
@ -15,6 +9,12 @@ import {
openDB,
} from "idb";
import isElectron from "is-electron";
import {
DEFAULT_ML_SEARCH_CONFIG,
DEFAULT_ML_SYNC_CONFIG,
DEFAULT_ML_SYNC_JOB_CONFIG,
MAX_ML_SYNC_ERROR_COUNT,
} from "services/machineLearning/machineLearningService";
import { Face, MLLibraryData, MlFileData, Person } from "services/ml/types";
export interface IndexStatus {

View file

@ -76,3 +76,17 @@ export class Box implements IRect {
return new Box({ x, y, width, height });
}
}
export function enlargeBox(box: Box, factor: number = 1.5) {
const center = new Point(box.x + box.width / 2, box.y + box.height / 2);
const size = new Point(box.width, box.height);
const newHalfSize = new Point((factor * size.x) / 2, (factor * size.y) / 2);
return boxFromBoundingBox({
left: center.x - newHalfSize.x,
top: center.y - newHalfSize.y,
right: center.x + newHalfSize.x,
bottom: center.y + newHalfSize.y,
});
}

View file

@ -329,8 +329,3 @@ export interface MachineLearningWorker {
close(): void;
}
export interface ClipEmbedding {
embedding: Float32Array;
model: "ggml-clip" | "onnx-clip";
}

View file

@ -2,6 +2,8 @@ import { FILE_TYPE } from "@/media/file-type";
import log from "@/next/log";
import * as chrono from "chrono-node";
import { t } from "i18next";
import { getMLSyncConfig } from "services/machineLearning/machineLearningService";
import mlIDbStorage from "services/ml/db";
import { Person } from "services/ml/types";
import { Collection } from "types/collection";
import { EntityType, LocationTag, LocationTagData } from "types/entity";
@ -16,12 +18,9 @@ import {
} from "types/search";
import ComlinkSearchWorker from "utils/comlink/ComlinkSearchWorker";
import { getUniqueFiles } from "utils/file";
import { getAllPeople } from "utils/machineLearning";
import { getMLSyncConfig } from "utils/machineLearning/config";
import { getFormattedDate } from "utils/search";
import mlIDbStorage from "utils/storage/mlIDbStorage";
import { clipService, computeClipMatchScore } from "./clip-service";
import { getLocalEmbeddings } from "./embeddingService";
import { localCLIPEmbeddings } from "./embeddingService";
import { getLatestEntities } from "./entityService";
import locationSearchService, { City } from "./locationSearchService";
@ -376,7 +375,7 @@ const searchClip = async (
await clipService.getTextEmbeddingIfAvailable(searchPhrase);
if (!textEmbedding) return undefined;
const imageEmbeddings = await getLocalEmbeddings();
const imageEmbeddings = await localCLIPEmbeddings();
const clipSearchResult = new Map<number, number>(
(
await Promise.all(
@ -430,3 +429,14 @@ function convertSuggestionToSearchQuery(option: Suggestion): Search {
return { clip: option.value as ClipSearchScores };
}
}
async function getAllPeople(limit: number = undefined) {
let people: Array<Person> = await mlIDbStorage.getAllPeople();
// await mlPeopleStore.iterate<Person, void>((person) => {
// people.push(person);
// });
people = people ?? [];
return people
.sort((p1, p2) => p2.files.length - p1.files.length)
.slice(0, limit);
}

View file

@ -1,11 +0,0 @@
export type JobState = "Scheduled" | "Running" | "NotScheduled";
export interface JobConfig {
intervalSec: number;
maxItervalSec: number;
backoffMultiplier: number;
}
export interface JobResult {
shouldBackoff: boolean;
}

View file

@ -1,9 +1,9 @@
/**
* The embeddings models that we support.
* The embeddings that we (the current client) knows how to handle.
*
* This is an exhaustive set of values we pass when PUT-ting encrypted
* embeddings on the server. However, we should be prepared to receive an
* {@link EncryptedEmbedding} with a model value distinct from one of these.
* {@link EncryptedEmbedding} with a model value different from these.
*/
export type EmbeddingModel = "onnx-clip" | "file-ml-clip-face";

View file

@ -1,9 +1,9 @@
import { FILE_TYPE } from "@/media/file-type";
import { City } from "services/locationSearchService";
import { IndexStatus } from "services/ml/db";
import { Person } from "services/ml/types";
import { LocationTagData } from "types/entity";
import { EnteFile } from "types/file";
import { IndexStatus } from "utils/storage/mlIDbStorage";
export enum SuggestionType {
DATE = "DATE",

View file

@ -1,82 +0,0 @@
import log from "@/next/log";
import { JobConfig, JobResult, JobState } from "types/common/job";
export class SimpleJob<R extends JobResult> {
private config: JobConfig;
private runCallback: () => Promise<R>;
private state: JobState;
private stopped: boolean;
private intervalSec: number;
private nextTimeoutId: ReturnType<typeof setTimeout>;
constructor(config: JobConfig, runCallback: () => Promise<R>) {
this.config = config;
this.runCallback = runCallback;
this.state = "NotScheduled";
this.stopped = true;
this.intervalSec = this.config.intervalSec;
}
public resetInterval() {
this.intervalSec = this.config.intervalSec;
}
public start() {
this.stopped = false;
this.resetInterval();
if (this.state !== "Running") {
this.scheduleNext();
} else {
log.info("Job already running, not scheduling");
}
}
private scheduleNext() {
if (this.state === "Scheduled" || this.nextTimeoutId) {
this.clearScheduled();
}
this.nextTimeoutId = setTimeout(
() => this.run(),
this.intervalSec * 1000,
);
this.state = "Scheduled";
log.info("Scheduled next job after: ", this.intervalSec);
}
async run() {
this.nextTimeoutId = undefined;
this.state = "Running";
try {
const jobResult = await this.runCallback();
if (jobResult && jobResult.shouldBackoff) {
this.intervalSec = Math.min(
this.config.maxItervalSec,
this.intervalSec * this.config.backoffMultiplier,
);
} else {
this.resetInterval();
}
log.info("Job completed");
} catch (e) {
console.error("Error while running Job: ", e);
} finally {
this.state = "NotScheduled";
!this.stopped && this.scheduleNext();
}
}
// currently client is responsible to terminate running job
public stop() {
this.stopped = true;
this.clearScheduled();
}
private clearScheduled() {
clearTimeout(this.nextTimeoutId);
this.nextTimeoutId = undefined;
this.state = "NotScheduled";
log.info("Cleared next job");
}
}

View file

@ -1,36 +0,0 @@
import { Embedding } from "types/embedding";
import { FileML } from "./machineLearning/mldataMappers";
export const getLatestVersionEmbeddings = (embeddings: Embedding[]) => {
const latestVersionEntities = new Map<number, Embedding>();
embeddings.forEach((embedding) => {
if (!embedding?.fileID) {
return;
}
const existingEmbeddings = latestVersionEntities.get(embedding.fileID);
if (
!existingEmbeddings ||
existingEmbeddings.updatedAt < embedding.updatedAt
) {
latestVersionEntities.set(embedding.fileID, embedding);
}
});
return Array.from(latestVersionEntities.values());
};
export const getLatestVersionFileEmbeddings = (embeddings: FileML[]) => {
const latestVersionEntities = new Map<number, FileML>();
embeddings.forEach((embedding) => {
if (!embedding?.fileID) {
return;
}
const existingEmbeddings = latestVersionEntities.get(embedding.fileID);
if (
!existingEmbeddings ||
existingEmbeddings.updatedAt < embedding.updatedAt
) {
latestVersionEntities.set(embedding.fileID, embedding);
}
});
return Array.from(latestVersionEntities.values());
};

View file

@ -1,9 +1,8 @@
// these utils only work in env where OffscreenCanvas is available
import { Matrix, inverse } from "ml-matrix";
import { Box, Dimensions } from "services/ml/geom";
import { Box, Dimensions, enlargeBox } from "services/ml/geom";
import { FaceAlignment } from "services/ml/types";
import { enlargeBox } from "utils/machineLearning";
export function normalizePixelBetween0And1(pixelValue: number) {
return pixelValue / 255.0;

View file

@ -1,48 +0,0 @@
import {
DEFAULT_ML_SEARCH_CONFIG,
DEFAULT_ML_SYNC_CONFIG,
DEFAULT_ML_SYNC_JOB_CONFIG,
} from "constants/mlConfig";
import { MLSearchConfig, MLSyncConfig } from "services/ml/types";
import { JobConfig } from "types/common/job";
import mlIDbStorage, {
ML_SEARCH_CONFIG_NAME,
ML_SYNC_CONFIG_NAME,
ML_SYNC_JOB_CONFIG_NAME,
} from "utils/storage/mlIDbStorage";
import { isInternalUserForML } from "utils/user";
export async function getMLSyncJobConfig() {
return mlIDbStorage.getConfig(
ML_SYNC_JOB_CONFIG_NAME,
DEFAULT_ML_SYNC_JOB_CONFIG,
);
}
export async function getMLSyncConfig() {
return mlIDbStorage.getConfig(ML_SYNC_CONFIG_NAME, DEFAULT_ML_SYNC_CONFIG);
}
export async function getMLSearchConfig() {
if (isInternalUserForML()) {
return mlIDbStorage.getConfig(
ML_SEARCH_CONFIG_NAME,
DEFAULT_ML_SEARCH_CONFIG,
);
}
// Force disabled for everyone else while we finalize it to avoid redundant
// reindexing for users.
return DEFAULT_ML_SEARCH_CONFIG;
}
export async function updateMLSyncJobConfig(newConfig: JobConfig) {
return mlIDbStorage.putConfig(ML_SYNC_JOB_CONFIG_NAME, newConfig);
}
export async function updateMLSyncConfig(newConfig: MLSyncConfig) {
return mlIDbStorage.putConfig(ML_SYNC_CONFIG_NAME, newConfig);
}
export async function updateMLSearchConfig(newConfig: MLSearchConfig) {
return mlIDbStorage.putConfig(ML_SEARCH_CONFIG_NAME, newConfig);
}

View file

@ -1,87 +0,0 @@
import { Matrix } from "ml-matrix";
import { Point } from "services/ml/geom";
import { FaceAlignment, FaceDetection } from "services/ml/types";
import { getSimilarityTransformation } from "similarity-transformation";
const ARCFACE_LANDMARKS = [
[38.2946, 51.6963],
[73.5318, 51.5014],
[56.0252, 71.7366],
[56.1396, 92.2848],
] as Array<[number, number]>;
const ARCFACE_LANDMARKS_FACE_SIZE = 112;
const ARC_FACE_5_LANDMARKS = [
[38.2946, 51.6963],
[73.5318, 51.5014],
[56.0252, 71.7366],
[41.5493, 92.3655],
[70.7299, 92.2041],
] as Array<[number, number]>;
export function getArcfaceAlignment(
faceDetection: FaceDetection,
): FaceAlignment {
const landmarkCount = faceDetection.landmarks.length;
return getFaceAlignmentUsingSimilarityTransform(
faceDetection,
normalizeLandmarks(
landmarkCount === 5 ? ARC_FACE_5_LANDMARKS : ARCFACE_LANDMARKS,
ARCFACE_LANDMARKS_FACE_SIZE,
),
);
}
function getFaceAlignmentUsingSimilarityTransform(
faceDetection: FaceDetection,
alignedLandmarks: Array<[number, number]>,
// alignmentMethod: Versioned<FaceAlignmentMethod>
): FaceAlignment {
const landmarksMat = new Matrix(
faceDetection.landmarks
.map((p) => [p.x, p.y])
.slice(0, alignedLandmarks.length),
).transpose();
const alignedLandmarksMat = new Matrix(alignedLandmarks).transpose();
const simTransform = getSimilarityTransformation(
landmarksMat,
alignedLandmarksMat,
);
const RS = Matrix.mul(simTransform.rotation, simTransform.scale);
const TR = simTransform.translation;
const affineMatrix = [
[RS.get(0, 0), RS.get(0, 1), TR.get(0, 0)],
[RS.get(1, 0), RS.get(1, 1), TR.get(1, 0)],
[0, 0, 1],
];
const size = 1 / simTransform.scale;
const meanTranslation = simTransform.toMean.sub(0.5).mul(size);
const centerMat = simTransform.fromMean.sub(meanTranslation);
const center = new Point(centerMat.get(0, 0), centerMat.get(1, 0));
const rotation = -Math.atan2(
simTransform.rotation.get(0, 1),
simTransform.rotation.get(0, 0),
);
// log.info({ affineMatrix, meanTranslation, centerMat, center, toMean: simTransform.toMean, fromMean: simTransform.fromMean, size });
return {
affineMatrix,
center,
size,
rotation,
};
}
function normalizeLandmarks(
landmarks: Array<[number, number]>,
faceSize: number,
): Array<[number, number]> {
return landmarks.map((landmark) =>
landmark.map((p) => p / faceSize),
) as Array<[number, number]>;
}

View file

@ -1,28 +0,0 @@
import { Box } from "services/ml/geom";
import { FaceAlignment, FaceCrop, FaceCropConfig } from "services/ml/types";
import { cropWithRotation } from "utils/image";
import { enlargeBox } from ".";
export function getFaceCrop(
imageBitmap: ImageBitmap,
alignment: FaceAlignment,
config: FaceCropConfig,
): FaceCrop {
const alignmentBox = new Box({
x: alignment.center.x - alignment.size / 2,
y: alignment.center.y - alignment.size / 2,
width: alignment.size,
height: alignment.size,
}).round();
const scaleForPadding = 1 + config.padding * 2;
const paddedBox = enlargeBox(alignmentBox, scaleForPadding).round();
const faceImageBitmap = cropWithRotation(imageBitmap, paddedBox, 0, {
width: config.maxSize,
height: config.maxSize,
});
return {
image: faceImageBitmap,
imageBox: paddedBox,
};
}

View file

@ -1,272 +0,0 @@
import { FILE_TYPE } from "@/media/file-type";
import { decodeLivePhoto } from "@/media/live-photo";
import log from "@/next/log";
import PQueue from "p-queue";
import DownloadManager from "services/download";
import { getLocalFiles } from "services/fileService";
import { Box, Dimensions, Point, boxFromBoundingBox } from "services/ml/geom";
import {
DetectedFace,
Face,
FaceAlignment,
MlFileData,
Person,
Versioned,
} from "services/ml/types";
import { EnteFile } from "types/file";
import { getRenderableImage } from "utils/file";
import { clamp, warpAffineFloat32List } from "utils/image";
import mlIDbStorage from "utils/storage/mlIDbStorage";
export function enlargeBox(box: Box, factor: number = 1.5) {
const center = new Point(box.x + box.width / 2, box.y + box.height / 2);
const size = new Point(box.width, box.height);
const newHalfSize = new Point((factor * size.x) / 2, (factor * size.y) / 2);
return boxFromBoundingBox({
left: center.x - newHalfSize.x,
top: center.y - newHalfSize.y,
right: center.x + newHalfSize.x,
bottom: center.y + newHalfSize.y,
});
}
export function getAllFacesFromMap(allFacesMap: Map<number, Array<Face>>) {
const allFaces = [...allFacesMap.values()].flat();
return allFaces;
}
export async function getLocalFile(fileId: number) {
const localFiles = await getLocalFiles();
return localFiles.find((f) => f.id === fileId);
}
export async function extractFaceImagesToFloat32(
faceAlignments: Array<FaceAlignment>,
faceSize: number,
image: ImageBitmap,
): Promise<Float32Array> {
const faceData = new Float32Array(
faceAlignments.length * faceSize * faceSize * 3,
);
for (let i = 0; i < faceAlignments.length; i++) {
const alignedFace = faceAlignments[i];
const faceDataOffset = i * faceSize * faceSize * 3;
warpAffineFloat32List(
image,
alignedFace,
faceSize,
faceData,
faceDataOffset,
);
}
return faceData;
}
export function getFaceId(detectedFace: DetectedFace, imageDims: Dimensions) {
const xMin = clamp(
detectedFace.detection.box.x / imageDims.width,
0.0,
0.999999,
)
.toFixed(5)
.substring(2);
const yMin = clamp(
detectedFace.detection.box.y / imageDims.height,
0.0,
0.999999,
)
.toFixed(5)
.substring(2);
const xMax = clamp(
(detectedFace.detection.box.x + detectedFace.detection.box.width) /
imageDims.width,
0.0,
0.999999,
)
.toFixed(5)
.substring(2);
const yMax = clamp(
(detectedFace.detection.box.y + detectedFace.detection.box.height) /
imageDims.height,
0.0,
0.999999,
)
.toFixed(5)
.substring(2);
const rawFaceID = `${xMin}_${yMin}_${xMax}_${yMax}`;
const faceID = `${detectedFace.fileId}_${rawFaceID}`;
return faceID;
}
export async function getImageBlobBitmap(blob: Blob): Promise<ImageBitmap> {
return await createImageBitmap(blob);
}
async function getOriginalFile(file: EnteFile, queue?: PQueue) {
let fileStream;
if (queue) {
fileStream = await queue.add(() => DownloadManager.getFile(file));
} else {
fileStream = await DownloadManager.getFile(file);
}
return new Response(fileStream).blob();
}
async function getOriginalConvertedFile(file: EnteFile, queue?: PQueue) {
const fileBlob = await getOriginalFile(file, queue);
if (file.metadata.fileType === FILE_TYPE.IMAGE) {
return await getRenderableImage(file.metadata.title, fileBlob);
} else {
const { imageFileName, imageData } = await decodeLivePhoto(
file.metadata.title,
fileBlob,
);
return await getRenderableImage(imageFileName, new Blob([imageData]));
}
}
export async function getOriginalImageBitmap(file: EnteFile, queue?: PQueue) {
const fileBlob = await getOriginalConvertedFile(file, queue);
log.info("[MLService] Got file: ", file.id.toString());
return getImageBlobBitmap(fileBlob);
}
export async function getThumbnailImageBitmap(file: EnteFile) {
const thumb = await DownloadManager.getThumbnail(file);
log.info("[MLService] Got thumbnail: ", file.id.toString());
return getImageBlobBitmap(new Blob([thumb]));
}
export async function getLocalFileImageBitmap(
enteFile: EnteFile,
localFile: globalThis.File,
) {
let fileBlob = localFile as Blob;
fileBlob = await getRenderableImage(enteFile.metadata.title, fileBlob);
return getImageBlobBitmap(fileBlob);
}
export async function getPeopleList(file: EnteFile): Promise<Array<Person>> {
let startTime = Date.now();
const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id);
log.info(
"getPeopleList:mlFilesStore:getItem",
Date.now() - startTime,
"ms",
);
if (!mlFileData?.faces || mlFileData.faces.length < 1) {
return [];
}
const peopleIds = mlFileData.faces
.filter((f) => f.personId !== null && f.personId !== undefined)
.map((f) => f.personId);
if (!peopleIds || peopleIds.length < 1) {
return [];
}
// log.info("peopleIds: ", peopleIds);
startTime = Date.now();
const peoplePromises = peopleIds.map(
(p) => mlIDbStorage.getPerson(p) as Promise<Person>,
);
const peopleList = await Promise.all(peoplePromises);
log.info(
"getPeopleList:mlPeopleStore:getItems",
Date.now() - startTime,
"ms",
);
// log.info("peopleList: ", peopleList);
return peopleList;
}
export async function getUnidentifiedFaces(
file: EnteFile,
): Promise<Array<Face>> {
const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id);
return mlFileData?.faces?.filter(
(f) => f.personId === null || f.personId === undefined,
);
}
export async function getAllPeople(limit: number = undefined) {
let people: Array<Person> = await mlIDbStorage.getAllPeople();
// await mlPeopleStore.iterate<Person, void>((person) => {
// people.push(person);
// });
people = people ?? [];
return people
.sort((p1, p2) => p2.files.length - p1.files.length)
.slice(0, limit);
}
export function findFirstIfSorted<T>(
elements: Array<T>,
comparator: (a: T, b: T) => number,
) {
if (!elements || elements.length < 1) {
return;
}
let first = elements[0];
for (let i = 1; i < elements.length; i++) {
const comp = comparator(elements[i], first);
if (comp < 0) {
first = elements[i];
}
}
return first;
}
export function isDifferentOrOld(
method: Versioned<string>,
thanMethod: Versioned<string>,
) {
return (
!method ||
method.value !== thanMethod.value ||
method.version < thanMethod.version
);
}
function primitiveArrayEquals(a, b) {
return (
Array.isArray(a) &&
Array.isArray(b) &&
a.length === b.length &&
a.every((val, index) => val === b[index])
);
}
export function areFaceIdsSame(ofFaces: Array<Face>, toFaces: Array<Face>) {
if (
(ofFaces === null || ofFaces === undefined) &&
(toFaces === null || toFaces === undefined)
) {
return true;
}
return primitiveArrayEquals(
ofFaces?.map((f) => f.id),
toFaces?.map((f) => f.id),
);
}
export function logQueueStats(queue: PQueue, name: string) {
queue.on("active", () =>
log.info(
`queuestats: ${name}: Active, Size: ${queue.size} Pending: ${queue.pending}`,
),
);
queue.on("idle", () => log.info(`queuestats: ${name}: Idle`));
queue.on("error", (error) =>
console.error(`queuestats: ${name}: Error, `, error),
);
}

View file

@ -1,177 +0,0 @@
import {
ClipEmbedding,
Face,
FaceDetection,
Landmark,
MlFileData,
} from "services/ml/types";
export interface FileML extends ServerFileMl {
updatedAt: number;
}
class ServerFileMl {
public fileID: number;
public height?: number;
public width?: number;
public faceEmbedding: ServerFaceEmbeddings;
public clipEmbedding?: ClipEmbedding;
public constructor(
fileID: number,
faceEmbedding: ServerFaceEmbeddings,
clipEmbedding?: ClipEmbedding,
height?: number,
width?: number,
) {
this.fileID = fileID;
this.height = height;
this.width = width;
this.faceEmbedding = faceEmbedding;
this.clipEmbedding = clipEmbedding;
}
}
class ServerFaceEmbeddings {
public faces: ServerFace[];
public version: number;
public client?: string;
public error?: boolean;
public constructor(
faces: ServerFace[],
version: number,
client?: string,
error?: boolean,
) {
this.faces = faces;
this.version = version;
this.client = client;
this.error = error;
}
}
class ServerFace {
public fileID: number;
public faceID: string;
public embeddings: number[];
public detection: ServerDetection;
public score: number;
public blur: number;
public fileInfo?: ServerFileInfo;
public constructor(
fileID: number,
faceID: string,
embeddings: number[],
detection: ServerDetection,
score: number,
blur: number,
fileInfo?: ServerFileInfo,
) {
this.fileID = fileID;
this.faceID = faceID;
this.embeddings = embeddings;
this.detection = detection;
this.score = score;
this.blur = blur;
this.fileInfo = fileInfo;
}
}
class ServerFileInfo {
public imageWidth?: number;
public imageHeight?: number;
public constructor(imageWidth?: number, imageHeight?: number) {
this.imageWidth = imageWidth;
this.imageHeight = imageHeight;
}
}
class ServerDetection {
public box: ServerFaceBox;
public landmarks: Landmark[];
public constructor(box: ServerFaceBox, landmarks: Landmark[]) {
this.box = box;
this.landmarks = landmarks;
}
}
class ServerFaceBox {
public xMin: number;
public yMin: number;
public width: number;
public height: number;
public constructor(
xMin: number,
yMin: number,
width: number,
height: number,
) {
this.xMin = xMin;
this.yMin = yMin;
this.width = width;
this.height = height;
}
}
export function LocalFileMlDataToServerFileMl(
localFileMlData: MlFileData,
): ServerFileMl {
if (
localFileMlData.errorCount > 0 &&
localFileMlData.lastErrorMessage !== undefined
) {
return null;
}
const imageDimensions = localFileMlData.imageDimensions;
const fileInfo = new ServerFileInfo(
imageDimensions.width,
imageDimensions.height,
);
const faces: ServerFace[] = [];
for (let i = 0; i < localFileMlData.faces.length; i++) {
const face: Face = localFileMlData.faces[i];
const faceID = face.id;
const embedding = face.embedding;
const score = face.detection.probability;
const blur = face.blurValue;
const detection: FaceDetection = face.detection;
const box = detection.box;
const landmarks = detection.landmarks;
const newBox = new ServerFaceBox(box.x, box.y, box.width, box.height);
const newLandmarks: Landmark[] = [];
for (let j = 0; j < landmarks.length; j++) {
newLandmarks.push({
x: landmarks[j].x,
y: landmarks[j].y,
} as Landmark);
}
const newFaceObject = new ServerFace(
localFileMlData.fileId,
faceID,
Array.from(embedding),
new ServerDetection(newBox, newLandmarks),
score,
blur,
fileInfo,
);
faces.push(newFaceObject);
}
const faceEmbeddings = new ServerFaceEmbeddings(
faces,
1,
localFileMlData.lastErrorMessage,
);
return new ServerFileMl(
localFileMlData.fileId,
faceEmbeddings,
null,
imageDimensions.height,
imageDimensions.width,
);
}