Remove object detection related code

This commit is contained in:
Manav Rathi 2024-04-11 11:06:04 +05:30
parent b6e1c4d3da
commit 1ad5cb83f9
No known key found for this signature in database
6 changed files with 9 additions and 286 deletions

View file

@ -1,51 +0,0 @@
import Box from "@mui/material/Box";
import { Chip } from "components/Chip";
import { Legend } from "components/PhotoViewer/styledComponents/Legend";
import { t } from "i18next";
import { useEffect, useState } from "react";
import { EnteFile } from "types/file";
import mlIDbStorage from "utils/storage/mlIDbStorage";
export function ObjectLabelList(props: {
file: EnteFile;
updateMLDataIndex: number;
}) {
const [objects, setObjects] = useState<Array<string>>([]);
useEffect(() => {
let didCancel = false;
const main = async () => {
const objects = await mlIDbStorage.getAllObjectsMap();
const uniqueObjectNames = [
...new Set(
(objects.get(props.file.id) ?? []).map(
(object) => object.detection.class,
),
),
];
!didCancel && setObjects(uniqueObjectNames);
};
main();
return () => {
didCancel = true;
};
}, [props.file, props.updateMLDataIndex]);
if (objects.length === 0) return <></>;
return (
<div>
<Legend sx={{ pb: 1, display: "block" }}>{t("OBJECTS")}</Legend>
<Box
display={"flex"}
gap={1}
flexWrap="wrap"
justifyContent={"flex-start"}
alignItems={"flex-start"}
>
{objects.map((object) => (
<Chip key={object}>{object}</Chip>
))}
</Box>
</div>
);
}

View file

@ -10,7 +10,6 @@ import TextSnippetOutlined from "@mui/icons-material/TextSnippetOutlined";
import { Box, DialogProps, Link, Stack, styled } from "@mui/material";
import { Chip } from "components/Chip";
import { EnteDrawer } from "components/EnteDrawer";
import { ObjectLabelList } from "components/MachineLearning/ObjectList";
import {
PhotoPeopleList,
UnidentifiedFaces,
@ -344,10 +343,6 @@ export function FileInfo({
file={file}
updateMLDataIndex={updateMLDataIndex}
/>
<ObjectLabelList
file={file}
updateMLDataIndex={updateMLDataIndex}
/>
</>
)}
</Stack>

View file

@ -1,146 +0,0 @@
import log from "@/next/log";
import {
DetectedObject,
MLSyncContext,
MLSyncFileContext,
Thing,
} from "types/machineLearning";
import {
getAllObjectsFromMap,
getObjectId,
isDifferentOrOld,
} from "utils/machineLearning";
import mlIDbStorage from "utils/storage/mlIDbStorage";
import ReaderService from "./readerService";
class ObjectService {
async syncFileObjectDetections(
syncContext: MLSyncContext,
fileContext: MLSyncFileContext,
) {
const startTime = Date.now();
const { oldMlFile, newMlFile } = fileContext;
if (
!isDifferentOrOld(
oldMlFile?.objectDetectionMethod,
syncContext.objectDetectionService.method,
) &&
!isDifferentOrOld(
oldMlFile?.sceneDetectionMethod,
syncContext.sceneDetectionService.method,
) &&
oldMlFile?.imageSource === syncContext.config.imageSource
) {
newMlFile.objects = oldMlFile?.objects;
newMlFile.imageSource = oldMlFile.imageSource;
newMlFile.imageDimensions = oldMlFile.imageDimensions;
newMlFile.objectDetectionMethod = oldMlFile.objectDetectionMethod;
newMlFile.sceneDetectionMethod = oldMlFile.sceneDetectionMethod;
return;
}
newMlFile.objectDetectionMethod =
syncContext.objectDetectionService.method;
newMlFile.sceneDetectionMethod =
syncContext.sceneDetectionService.method;
fileContext.newDetection = true;
const imageBitmap = await ReaderService.getImageBitmap(
syncContext,
fileContext,
);
const objectDetections =
await syncContext.objectDetectionService.detectObjects(
imageBitmap,
syncContext.config.objectDetection.maxNumBoxes,
syncContext.config.objectDetection.minScore,
);
objectDetections.push(
...(await syncContext.sceneDetectionService.detectScenes(
imageBitmap,
syncContext.config.sceneDetection.minScore,
)),
);
// log.info('3 TF Memory stats: ',JSON.stringify(tf.memory()));
// TODO: reenable faces filtering based on width
const detectedObjects = objectDetections?.map((detection) => {
return {
fileID: fileContext.enteFile.id,
detection,
} as DetectedObject;
});
newMlFile.objects = detectedObjects?.map((detectedObject) => ({
...detectedObject,
id: getObjectId(detectedObject, newMlFile.imageDimensions),
className: detectedObject.detection.class,
}));
// ?.filter((f) =>
// f.box.width > syncContext.config.faceDetection.minFaceSize
// );
log.info(
`object detection time taken ${fileContext.enteFile.id}`,
Date.now() - startTime,
"ms",
);
log.info("[MLService] Detected Objects: ", newMlFile.objects?.length);
}
async getAllSyncedObjectsMap(syncContext: MLSyncContext) {
if (syncContext.allSyncedObjectsMap) {
return syncContext.allSyncedObjectsMap;
}
syncContext.allSyncedObjectsMap = await mlIDbStorage.getAllObjectsMap();
return syncContext.allSyncedObjectsMap;
}
public async clusterThings(syncContext: MLSyncContext): Promise<Thing[]> {
const allObjectsMap = await this.getAllSyncedObjectsMap(syncContext);
const allObjects = getAllObjectsFromMap(allObjectsMap);
const objectClusters = new Map<string, number[]>();
allObjects.map((object) => {
if (!objectClusters.has(object.detection.class)) {
objectClusters.set(object.detection.class, []);
}
const objectsInCluster = objectClusters.get(object.detection.class);
objectsInCluster.push(object.fileID);
});
return [...objectClusters.entries()].map(([className, files], id) => ({
id,
name: className,
files,
}));
}
async syncThingsIndex(syncContext: MLSyncContext) {
const filesVersion = await mlIDbStorage.getIndexVersion("files");
log.info("things", await mlIDbStorage.getIndexVersion("things"));
if (filesVersion <= (await mlIDbStorage.getIndexVersion("things"))) {
log.info(
"[MLService] Skipping people index as already synced to latest version",
);
return;
}
const things = await this.clusterThings(syncContext);
if (!things || things.length < 1) {
return;
}
await mlIDbStorage.clearAllThings();
for (const thing of things) {
await mlIDbStorage.putThing(thing);
}
await mlIDbStorage.setIndexVersion("things", filesVersion);
}
async getAllThings() {
return await mlIDbStorage.getAllThings();
}
}
export default new ObjectService();

View file

@ -161,22 +161,6 @@ export interface ObjectDetection {
score: number;
}
export interface DetectedObject {
fileID: number;
detection: ObjectDetection;
}
export interface RealWorldObject extends DetectedObject {
id: string;
className: string;
}
export interface Thing {
id: number;
name: string;
files: Array<number>;
}
export interface WordGroup {
word: string;
files: Array<number>;
@ -185,7 +169,6 @@ export interface WordGroup {
export interface MlFileData {
fileId: number;
faces?: Face[];
objects?: RealWorldObject[];
imageSource?: ImageType;
imageDimensions?: Dimensions;
faceDetectionMethod?: Versioned<FaceDetectionMethod>;

View file

@ -15,13 +15,11 @@ import { Dimensions } from "types/image";
import {
AlignedFace,
DetectedFace,
DetectedObject,
Face,
FaceAlignment,
FaceImageBlob,
MlFileData,
Person,
RealWorldObject,
Versioned,
} from "types/machineLearning";
import { getRenderableImage } from "utils/file";
@ -196,12 +194,6 @@ export function getAllFacesFromMap(allFacesMap: Map<number, Array<Face>>) {
return allFaces;
}
export function getAllObjectsFromMap(
allObjectsMap: Map<number, Array<RealWorldObject>>,
) {
return [...allObjectsMap.values()].flat();
}
export async function getLocalFile(fileId: number) {
const localFiles = await getLocalFiles();
return localFiles.find((f) => f.id === fileId);
@ -312,26 +304,6 @@ export function getFaceId(detectedFace: DetectedFace, imageDims: Dimensions) {
return faceID;
}
export function getObjectId(
detectedObject: DetectedObject,
imageDims: Dimensions,
) {
const imgDimPoint = new Point(imageDims.width, imageDims.height);
const objectCenterPoint = new Point(
detectedObject.detection.bbox[2] / 2,
detectedObject.detection.bbox[3] / 2,
);
const gridPt = objectCenterPoint
.mul(new Point(100, 100))
.div(imgDimPoint)
.floor()
.bound(0, 99);
const gridPaddedX = leftFillNum(gridPt.x, 2, 0);
const gridPaddedY = leftFillNum(gridPt.y, 2, 0);
return `${detectedObject.fileID}-${gridPaddedX}-${gridPaddedY}`;
}
export async function getTFImage(blob): Promise<tf.Tensor3D> {
const imageBitmap = await createImageBitmap(blob);
const tfImage = tf.browser.fromPixels(imageBitmap);

View file

@ -15,14 +15,7 @@ import {
openDB,
} from "idb";
import isElectron from "is-electron";
import {
Face,
MLLibraryData,
MlFileData,
Person,
RealWorldObject,
Thing,
} from "types/machineLearning";
import { Face, MLLibraryData, MlFileData, Person } from "types/machineLearning";
import { IndexStatus } from "types/machineLearning/ui";
interface Config {}
@ -42,9 +35,11 @@ interface MLDb extends DBSchema {
key: number;
value: Person;
};
// Unused, we only retain this is the schema so that we can delete it during
// migration.
things: {
key: number;
value: Thing;
value: unknown;
};
versions: {
key: string;
@ -72,7 +67,7 @@ class MLIDbStorage {
}
private openDB(): Promise<IDBPDatabase<MLDb>> {
return openDB<MLDb>(MLDATA_DB_NAME, 3, {
return openDB<MLDb>(MLDATA_DB_NAME, 4, {
terminated: async () => {
log.error("ML Indexed DB terminated");
this._db = undefined;
@ -128,6 +123,10 @@ class MLIDbStorage {
.objectStore("configs")
.add(DEFAULT_ML_SEARCH_CONFIG, ML_SEARCH_CONFIG_NAME);
}
if (oldVersion < 4) {
db.deleteObjectStore("things");
}
log.info(
`Ml DB upgraded to version: ${newVersion} from version: ${oldVersion}`,
);
@ -299,21 +298,6 @@ class MLIDbStorage {
log.info("updateFaces", Date.now() - startTime, "ms");
}
public async getAllObjectsMap() {
const startTime = Date.now();
const db = await this.db;
const allFiles = await db.getAll("files");
const allObjectsMap = new Map<number, Array<RealWorldObject>>();
allFiles.forEach(
(mlFileData) =>
mlFileData.objects &&
allObjectsMap.set(mlFileData.fileId, mlFileData.objects),
);
log.info("allObjectsMap", Date.now() - startTime, "ms");
return allObjectsMap;
}
public async getPerson(id: number) {
const db = await this.db;
return db.get("people", id);
@ -334,20 +318,6 @@ class MLIDbStorage {
return db.clear("people");
}
public async getAllThings() {
const db = await this.db;
return db.getAll("things");
}
public async putThing(thing: Thing) {
const db = await this.db;
return db.put("things", thing);
}
public async clearAllThings() {
const db = await this.db;
return db.clear("things");
}
public async getIndexVersion(index: string) {
const db = await this.db;
return db.get("versions", index);