move face search to smart info repository
This commit is contained in:
parent
d7d11429f4
commit
8a8da5f5c8
9 changed files with 71 additions and 71 deletions
|
@ -13,6 +13,7 @@ import {
|
||||||
newMoveRepositoryMock,
|
newMoveRepositoryMock,
|
||||||
newPersonRepositoryMock,
|
newPersonRepositoryMock,
|
||||||
newSearchRepositoryMock,
|
newSearchRepositoryMock,
|
||||||
|
newSmartInfoRepositoryMock,
|
||||||
newStorageRepositoryMock,
|
newStorageRepositoryMock,
|
||||||
newSystemConfigRepositoryMock,
|
newSystemConfigRepositoryMock,
|
||||||
personStub,
|
personStub,
|
||||||
|
@ -27,6 +28,7 @@ import {
|
||||||
IMoveRepository,
|
IMoveRepository,
|
||||||
IPersonRepository,
|
IPersonRepository,
|
||||||
ISearchRepository,
|
ISearchRepository,
|
||||||
|
ISmartInfoRepository,
|
||||||
IStorageRepository,
|
IStorageRepository,
|
||||||
ISystemConfigRepository,
|
ISystemConfigRepository,
|
||||||
WithoutProperty,
|
WithoutProperty,
|
||||||
|
@ -70,8 +72,8 @@ describe(PersonService.name, () => {
|
||||||
let mediaMock: jest.Mocked<IMediaRepository>;
|
let mediaMock: jest.Mocked<IMediaRepository>;
|
||||||
let moveMock: jest.Mocked<IMoveRepository>;
|
let moveMock: jest.Mocked<IMoveRepository>;
|
||||||
let personMock: jest.Mocked<IPersonRepository>;
|
let personMock: jest.Mocked<IPersonRepository>;
|
||||||
let searchMock: jest.Mocked<ISearchRepository>;
|
|
||||||
let storageMock: jest.Mocked<IStorageRepository>;
|
let storageMock: jest.Mocked<IStorageRepository>;
|
||||||
|
let smartInfoMock: jest.Mocked<ISmartInfoRepository>;
|
||||||
let sut: PersonService;
|
let sut: PersonService;
|
||||||
|
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
|
@ -83,8 +85,8 @@ describe(PersonService.name, () => {
|
||||||
moveMock = newMoveRepositoryMock();
|
moveMock = newMoveRepositoryMock();
|
||||||
mediaMock = newMediaRepositoryMock();
|
mediaMock = newMediaRepositoryMock();
|
||||||
personMock = newPersonRepositoryMock();
|
personMock = newPersonRepositoryMock();
|
||||||
searchMock = newSearchRepositoryMock();
|
|
||||||
storageMock = newStorageRepositoryMock();
|
storageMock = newStorageRepositoryMock();
|
||||||
|
smartInfoMock = newSmartInfoRepositoryMock();
|
||||||
sut = new PersonService(
|
sut = new PersonService(
|
||||||
accessMock,
|
accessMock,
|
||||||
assetMock,
|
assetMock,
|
||||||
|
@ -92,10 +94,10 @@ describe(PersonService.name, () => {
|
||||||
moveMock,
|
moveMock,
|
||||||
mediaMock,
|
mediaMock,
|
||||||
personMock,
|
personMock,
|
||||||
searchMock,
|
|
||||||
configMock,
|
configMock,
|
||||||
storageMock,
|
storageMock,
|
||||||
jobMock,
|
jobMock,
|
||||||
|
smartInfoMock
|
||||||
);
|
);
|
||||||
|
|
||||||
mediaMock.crop.mockResolvedValue(croppedFace);
|
mediaMock.crop.mockResolvedValue(croppedFace);
|
||||||
|
@ -591,7 +593,7 @@ describe(PersonService.name, () => {
|
||||||
|
|
||||||
it('should match existing people', async () => {
|
it('should match existing people', async () => {
|
||||||
machineLearningMock.detectFaces.mockResolvedValue([detectFaceMock]);
|
machineLearningMock.detectFaces.mockResolvedValue([detectFaceMock]);
|
||||||
personMock.searchByEmbedding.mockResolvedValue([faceStub.face1]);
|
smartInfoMock.searchFaces.mockResolvedValue([faceStub.face1]);
|
||||||
assetMock.getByIds.mockResolvedValue([assetStub.image]);
|
assetMock.getByIds.mockResolvedValue([assetStub.image]);
|
||||||
await sut.handleRecognizeFaces({ id: assetStub.image.id });
|
await sut.handleRecognizeFaces({ id: assetStub.image.id });
|
||||||
|
|
||||||
|
@ -610,7 +612,7 @@ describe(PersonService.name, () => {
|
||||||
|
|
||||||
it('should create a new person', async () => {
|
it('should create a new person', async () => {
|
||||||
machineLearningMock.detectFaces.mockResolvedValue([detectFaceMock]);
|
machineLearningMock.detectFaces.mockResolvedValue([detectFaceMock]);
|
||||||
personMock.searchByEmbedding.mockResolvedValue([]);
|
smartInfoMock.searchFaces.mockResolvedValue([]);
|
||||||
personMock.create.mockResolvedValue(personStub.noName);
|
personMock.create.mockResolvedValue(personStub.noName);
|
||||||
assetMock.getByIds.mockResolvedValue([assetStub.image]);
|
assetMock.getByIds.mockResolvedValue([assetStub.image]);
|
||||||
personMock.createFace.mockResolvedValue(faceStub.primaryFace1);
|
personMock.createFace.mockResolvedValue(faceStub.primaryFace1);
|
||||||
|
|
|
@ -18,7 +18,7 @@ import {
|
||||||
IMediaRepository,
|
IMediaRepository,
|
||||||
IMoveRepository,
|
IMoveRepository,
|
||||||
IPersonRepository,
|
IPersonRepository,
|
||||||
ISearchRepository,
|
ISmartInfoRepository,
|
||||||
IStorageRepository,
|
IStorageRepository,
|
||||||
ISystemConfigRepository,
|
ISystemConfigRepository,
|
||||||
ImmichReadStream,
|
ImmichReadStream,
|
||||||
|
@ -56,10 +56,10 @@ export class PersonService {
|
||||||
@Inject(IMoveRepository) moveRepository: IMoveRepository,
|
@Inject(IMoveRepository) moveRepository: IMoveRepository,
|
||||||
@Inject(IMediaRepository) private mediaRepository: IMediaRepository,
|
@Inject(IMediaRepository) private mediaRepository: IMediaRepository,
|
||||||
@Inject(IPersonRepository) private repository: IPersonRepository,
|
@Inject(IPersonRepository) private repository: IPersonRepository,
|
||||||
@Inject(ISearchRepository) private searchRepository: ISearchRepository,
|
|
||||||
@Inject(ISystemConfigRepository) configRepository: ISystemConfigRepository,
|
@Inject(ISystemConfigRepository) configRepository: ISystemConfigRepository,
|
||||||
@Inject(IStorageRepository) private storageRepository: IStorageRepository,
|
@Inject(IStorageRepository) private storageRepository: IStorageRepository,
|
||||||
@Inject(IJobRepository) private jobRepository: IJobRepository,
|
@Inject(IJobRepository) private jobRepository: IJobRepository,
|
||||||
|
@Inject(ISmartInfoRepository) private smartInfoRepository: ISmartInfoRepository,
|
||||||
) {
|
) {
|
||||||
this.access = AccessCore.create(accessRepository);
|
this.access = AccessCore.create(accessRepository);
|
||||||
this.configCore = SystemConfigCore.create(configRepository);
|
this.configCore = SystemConfigCore.create(configRepository);
|
||||||
|
@ -315,7 +315,7 @@ export class PersonService {
|
||||||
this.logger.verbose(faces.map((face) => ({ ...face, embedding: `vector(${face.embedding.length})` })));
|
this.logger.verbose(faces.map((face) => ({ ...face, embedding: `vector(${face.embedding.length})` })));
|
||||||
|
|
||||||
for (const { embedding, ...rest } of faces) {
|
for (const { embedding, ...rest } of faces) {
|
||||||
const matches = await this.repository.searchByEmbedding({
|
const matches = await this.smartInfoRepository.searchFaces({
|
||||||
ownerId: asset.ownerId,
|
ownerId: asset.ownerId,
|
||||||
embedding,
|
embedding,
|
||||||
numResults: 1,
|
numResults: 1,
|
||||||
|
|
|
@ -25,15 +25,6 @@ export interface PersonStatistics {
|
||||||
assets: number;
|
assets: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export type Embedding = number[];
|
|
||||||
|
|
||||||
export interface EmbeddingSearch {
|
|
||||||
ownerId: string;
|
|
||||||
embedding: Embedding;
|
|
||||||
numResults: number;
|
|
||||||
maxDistance?: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface IPersonRepository {
|
export interface IPersonRepository {
|
||||||
getAll(): Promise<PersonEntity[]>;
|
getAll(): Promise<PersonEntity[]>;
|
||||||
getAllWithoutThumbnail(): Promise<PersonEntity[]>;
|
getAllWithoutThumbnail(): Promise<PersonEntity[]>;
|
||||||
|
@ -51,7 +42,6 @@ export interface IPersonRepository {
|
||||||
delete(entity: PersonEntity): Promise<PersonEntity | null>;
|
delete(entity: PersonEntity): Promise<PersonEntity | null>;
|
||||||
deleteAll(): Promise<number>;
|
deleteAll(): Promise<number>;
|
||||||
getStatistics(personId: string): Promise<PersonStatistics>;
|
getStatistics(personId: string): Promise<PersonStatistics>;
|
||||||
searchByEmbedding(search: EmbeddingSearch): Promise<AssetFaceEntity[]>;
|
|
||||||
getAllFaces(): Promise<AssetFaceEntity[]>;
|
getAllFaces(): Promise<AssetFaceEntity[]>;
|
||||||
getFacesByIds(ids: AssetFaceId[]): Promise<AssetFaceEntity[]>;
|
getFacesByIds(ids: AssetFaceId[]): Promise<AssetFaceEntity[]>;
|
||||||
getRandomFace(personId: string): Promise<AssetFaceEntity | null>;
|
getRandomFace(personId: string): Promise<AssetFaceEntity | null>;
|
||||||
|
|
|
@ -1,9 +1,18 @@
|
||||||
import { Embedding, EmbeddingSearch } from '@app/domain';
|
import { AssetEntity, AssetFaceEntity, SmartInfoEntity } from '@app/infra/entities';
|
||||||
import { AssetEntity, SmartInfoEntity } from '@app/infra/entities';
|
|
||||||
|
|
||||||
export const ISmartInfoRepository = 'ISmartInfoRepository';
|
export const ISmartInfoRepository = 'ISmartInfoRepository';
|
||||||
|
|
||||||
|
export type Embedding = number[];
|
||||||
|
|
||||||
|
export interface EmbeddingSearch {
|
||||||
|
ownerId: string;
|
||||||
|
embedding: Embedding;
|
||||||
|
numResults: number;
|
||||||
|
maxDistance?: number;
|
||||||
|
}
|
||||||
|
|
||||||
export interface ISmartInfoRepository {
|
export interface ISmartInfoRepository {
|
||||||
searchByEmbedding(search: EmbeddingSearch): Promise<AssetEntity[]>;
|
searchCLIP(search: EmbeddingSearch): Promise<AssetEntity[]>;
|
||||||
|
searchFaces(search: EmbeddingSearch): Promise<AssetFaceEntity[]>;
|
||||||
upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void>;
|
upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,7 +67,7 @@ export class SearchService {
|
||||||
{ text: query },
|
{ text: query },
|
||||||
machineLearning.clip,
|
machineLearning.clip,
|
||||||
);
|
);
|
||||||
assets = await this.smartInfoRepository.searchByEmbedding({ ownerId: authUser.id, embedding, numResults: 100 });
|
assets = await this.smartInfoRepository.searchCLIP({ ownerId: authUser.id, embedding, numResults: 100 });
|
||||||
break;
|
break;
|
||||||
case SearchStrategy.TEXT:
|
case SearchStrategy.TEXT:
|
||||||
assets = await this.assetRepository.searchMetadata(query, authUser.id, { numResults: 250 });
|
assets = await this.assetRepository.searchMetadata(query, authUser.id, { numResults: 250 });
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import {
|
import {
|
||||||
AssetFaceId,
|
AssetFaceId,
|
||||||
EmbeddingSearch,
|
|
||||||
IPersonRepository,
|
IPersonRepository,
|
||||||
PersonNameSearchOptions,
|
PersonNameSearchOptions,
|
||||||
PersonSearchOptions,
|
PersonSearchOptions,
|
||||||
|
@ -11,20 +10,14 @@ import { InjectRepository } from '@nestjs/typeorm';
|
||||||
import { In, Repository } from 'typeorm';
|
import { In, Repository } from 'typeorm';
|
||||||
import { AssetEntity, AssetFaceEntity, PersonEntity } from '../entities';
|
import { AssetEntity, AssetFaceEntity, PersonEntity } from '../entities';
|
||||||
import { DummyValue, GenerateSql } from '../infra.util';
|
import { DummyValue, GenerateSql } from '../infra.util';
|
||||||
import { asVector, isValidInteger } from '../infra.utils';
|
import { asVector } from '../infra.utils';
|
||||||
|
|
||||||
export class PersonRepository implements IPersonRepository {
|
export class PersonRepository implements IPersonRepository {
|
||||||
private readonly faceColumns: string[];
|
|
||||||
constructor(
|
constructor(
|
||||||
@InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>,
|
@InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>,
|
||||||
@InjectRepository(PersonEntity) private personRepository: Repository<PersonEntity>,
|
@InjectRepository(PersonEntity) private personRepository: Repository<PersonEntity>,
|
||||||
@InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
|
@InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
|
||||||
) {
|
) {}
|
||||||
this.faceColumns = this.assetFaceRepository.manager.connection
|
|
||||||
.getMetadata(AssetFaceEntity)
|
|
||||||
.ownColumns.map((column) => column.propertyName)
|
|
||||||
.filter((propertyName) => propertyName !== 'embedding');
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Before reassigning faces, delete potential key violations
|
* Before reassigning faces, delete potential key violations
|
||||||
|
@ -248,40 +241,4 @@ export class PersonRepository implements IPersonRepository {
|
||||||
async getRandomFace(personId: string): Promise<AssetFaceEntity | null> {
|
async getRandomFace(personId: string): Promise<AssetFaceEntity | null> {
|
||||||
return this.assetFaceRepository.findOneBy({ personId });
|
return this.assetFaceRepository.findOneBy({ personId });
|
||||||
}
|
}
|
||||||
|
|
||||||
async searchByEmbedding({
|
|
||||||
ownerId,
|
|
||||||
embedding,
|
|
||||||
numResults,
|
|
||||||
maxDistance,
|
|
||||||
}: EmbeddingSearch): Promise<AssetFaceEntity[]> {
|
|
||||||
if (!isValidInteger(numResults, { min: 1 })) {
|
|
||||||
throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
let results: AssetFaceEntity[] = [];
|
|
||||||
await this.assetRepository.manager.transaction(async (manager) => {
|
|
||||||
await manager.query(`SET LOCAL vectors.k = '${numResults}'`);
|
|
||||||
const cte = manager
|
|
||||||
.createQueryBuilder(AssetFaceEntity, 'faces')
|
|
||||||
.select('1 + (faces.embedding <=> :embedding)', 'distance')
|
|
||||||
.innerJoin('faces.asset', 'asset')
|
|
||||||
.where('asset.ownerId = :ownerId')
|
|
||||||
.orderBy(`faces.embedding <=> :embedding`)
|
|
||||||
.setParameters({ ownerId, embedding: asVector(embedding) })
|
|
||||||
.limit(numResults);
|
|
||||||
|
|
||||||
this.faceColumns.forEach((col) => cte.addSelect(`faces.${col} AS "${col}"`));
|
|
||||||
|
|
||||||
results = await manager
|
|
||||||
.createQueryBuilder()
|
|
||||||
.select('res.*')
|
|
||||||
.addCommonTableExpression(cte, 'cte')
|
|
||||||
.from('cte', 'res')
|
|
||||||
.where('res.distance <= :maxDistance', { maxDistance })
|
|
||||||
.getRawMany();
|
|
||||||
});
|
|
||||||
|
|
||||||
return this.assetFaceRepository.create(results);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,24 +3,30 @@ import { Injectable, Logger } from '@nestjs/common';
|
||||||
import { InjectRepository } from '@nestjs/typeorm';
|
import { InjectRepository } from '@nestjs/typeorm';
|
||||||
import AsyncLock from 'async-lock';
|
import AsyncLock from 'async-lock';
|
||||||
import { Repository } from 'typeorm';
|
import { Repository } from 'typeorm';
|
||||||
import { AssetEntity, SmartInfoEntity, SmartSearchEntity } from '../entities';
|
import { AssetEntity, AssetFaceEntity, SmartInfoEntity, SmartSearchEntity } from '@app/infra/entities';
|
||||||
import { asVector, isValidInteger } from '../infra.utils';
|
import { asVector, isValidInteger } from '../infra.utils';
|
||||||
|
|
||||||
@Injectable()
|
@Injectable()
|
||||||
export class SmartInfoRepository implements ISmartInfoRepository {
|
export class SmartInfoRepository implements ISmartInfoRepository {
|
||||||
private logger = new Logger(SmartInfoRepository.name);
|
private logger = new Logger(SmartInfoRepository.name);
|
||||||
private lock: AsyncLock;
|
private lock: AsyncLock;
|
||||||
|
private readonly faceColumns: string[];
|
||||||
private curDimSize: number | undefined;
|
private curDimSize: number | undefined;
|
||||||
|
|
||||||
constructor(
|
constructor(
|
||||||
@InjectRepository(SmartInfoEntity) private repository: Repository<SmartInfoEntity>,
|
@InjectRepository(SmartInfoEntity) private repository: Repository<SmartInfoEntity>,
|
||||||
@InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>,
|
@InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>,
|
||||||
|
@InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
|
||||||
@InjectRepository(SmartSearchEntity) private smartSearchRepository: Repository<SmartSearchEntity>,
|
@InjectRepository(SmartSearchEntity) private smartSearchRepository: Repository<SmartSearchEntity>,
|
||||||
) {
|
) {
|
||||||
this.lock = new AsyncLock();
|
this.lock = new AsyncLock();
|
||||||
|
this.faceColumns = this.assetFaceRepository.manager.connection
|
||||||
|
.getMetadata(AssetFaceEntity)
|
||||||
|
.ownColumns.map((column) => column.propertyName)
|
||||||
|
.filter((propertyName) => propertyName !== 'embedding');
|
||||||
}
|
}
|
||||||
|
|
||||||
async searchByEmbedding({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
|
async searchCLIP({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
|
||||||
if (!isValidInteger(numResults, { min: 1 })) {
|
if (!isValidInteger(numResults, { min: 1 })) {
|
||||||
throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
||||||
}
|
}
|
||||||
|
@ -42,6 +48,42 @@ export class SmartInfoRepository implements ISmartInfoRepository {
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async searchFaces({
|
||||||
|
ownerId,
|
||||||
|
embedding,
|
||||||
|
numResults,
|
||||||
|
maxDistance,
|
||||||
|
}: EmbeddingSearch): Promise<AssetFaceEntity[]> {
|
||||||
|
if (!isValidInteger(numResults, { min: 1 })) {
|
||||||
|
throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
let results: AssetFaceEntity[] = [];
|
||||||
|
await this.assetRepository.manager.transaction(async (manager) => {
|
||||||
|
await manager.query(`SET LOCAL vectors.k = '${numResults}'`);
|
||||||
|
const cte = manager
|
||||||
|
.createQueryBuilder(AssetFaceEntity, 'faces')
|
||||||
|
.select('1 + (faces.embedding <=> :embedding)', 'distance')
|
||||||
|
.innerJoin('faces.asset', 'asset')
|
||||||
|
.where('asset.ownerId = :ownerId')
|
||||||
|
.orderBy(`faces.embedding <=> :embedding`)
|
||||||
|
.setParameters({ ownerId, embedding: asVector(embedding) })
|
||||||
|
.limit(numResults);
|
||||||
|
|
||||||
|
this.faceColumns.forEach((col) => cte.addSelect(`faces.${col} AS "${col}"`));
|
||||||
|
|
||||||
|
results = await manager
|
||||||
|
.createQueryBuilder()
|
||||||
|
.select('res.*')
|
||||||
|
.addCommonTableExpression(cte, 'cte')
|
||||||
|
.from('cte', 'res')
|
||||||
|
.where('res.distance <= :maxDistance', { maxDistance })
|
||||||
|
.getRawMany();
|
||||||
|
});
|
||||||
|
|
||||||
|
return this.assetFaceRepository.create(results);
|
||||||
|
}
|
||||||
|
|
||||||
async upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void> {
|
async upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void> {
|
||||||
await this.repository.upsert(smartInfo, { conflictPaths: ['assetId'] });
|
await this.repository.upsert(smartInfo, { conflictPaths: ['assetId'] });
|
||||||
if (!smartInfo.assetId || !embedding) return;
|
if (!smartInfo.assetId || !embedding) return;
|
||||||
|
|
|
@ -17,7 +17,6 @@ export const newPersonRepositoryMock = (): jest.Mocked<IPersonRepository> => {
|
||||||
delete: jest.fn(),
|
delete: jest.fn(),
|
||||||
|
|
||||||
getStatistics: jest.fn(),
|
getStatistics: jest.fn(),
|
||||||
searchByEmbedding: jest.fn(),
|
|
||||||
getAllFaces: jest.fn(),
|
getAllFaces: jest.fn(),
|
||||||
getFacesByIds: jest.fn(),
|
getFacesByIds: jest.fn(),
|
||||||
getRandomFace: jest.fn(),
|
getRandomFace: jest.fn(),
|
||||||
|
|
|
@ -2,7 +2,8 @@ import { ISmartInfoRepository } from '@app/domain';
|
||||||
|
|
||||||
export const newSmartInfoRepositoryMock = (): jest.Mocked<ISmartInfoRepository> => {
|
export const newSmartInfoRepositoryMock = (): jest.Mocked<ISmartInfoRepository> => {
|
||||||
return {
|
return {
|
||||||
searchByEmbedding: jest.fn(),
|
searchCLIP: jest.fn(),
|
||||||
|
searchFaces: jest.fn(),
|
||||||
upsert: jest.fn(),
|
upsert: jest.fn(),
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in a new issue