move face search to smart info repository

This commit is contained in:
mertalev 2023-11-24 12:57:38 -05:00
parent d7d11429f4
commit 8a8da5f5c8
No known key found for this signature in database
GPG key ID: 9181CD92C0A1C5E3
9 changed files with 71 additions and 71 deletions

View file

@ -13,6 +13,7 @@ import {
newMoveRepositoryMock, newMoveRepositoryMock,
newPersonRepositoryMock, newPersonRepositoryMock,
newSearchRepositoryMock, newSearchRepositoryMock,
newSmartInfoRepositoryMock,
newStorageRepositoryMock, newStorageRepositoryMock,
newSystemConfigRepositoryMock, newSystemConfigRepositoryMock,
personStub, personStub,
@ -27,6 +28,7 @@ import {
IMoveRepository, IMoveRepository,
IPersonRepository, IPersonRepository,
ISearchRepository, ISearchRepository,
ISmartInfoRepository,
IStorageRepository, IStorageRepository,
ISystemConfigRepository, ISystemConfigRepository,
WithoutProperty, WithoutProperty,
@ -70,8 +72,8 @@ describe(PersonService.name, () => {
let mediaMock: jest.Mocked<IMediaRepository>; let mediaMock: jest.Mocked<IMediaRepository>;
let moveMock: jest.Mocked<IMoveRepository>; let moveMock: jest.Mocked<IMoveRepository>;
let personMock: jest.Mocked<IPersonRepository>; let personMock: jest.Mocked<IPersonRepository>;
let searchMock: jest.Mocked<ISearchRepository>;
let storageMock: jest.Mocked<IStorageRepository>; let storageMock: jest.Mocked<IStorageRepository>;
let smartInfoMock: jest.Mocked<ISmartInfoRepository>;
let sut: PersonService; let sut: PersonService;
beforeEach(async () => { beforeEach(async () => {
@ -83,8 +85,8 @@ describe(PersonService.name, () => {
moveMock = newMoveRepositoryMock(); moveMock = newMoveRepositoryMock();
mediaMock = newMediaRepositoryMock(); mediaMock = newMediaRepositoryMock();
personMock = newPersonRepositoryMock(); personMock = newPersonRepositoryMock();
searchMock = newSearchRepositoryMock();
storageMock = newStorageRepositoryMock(); storageMock = newStorageRepositoryMock();
smartInfoMock = newSmartInfoRepositoryMock();
sut = new PersonService( sut = new PersonService(
accessMock, accessMock,
assetMock, assetMock,
@ -92,10 +94,10 @@ describe(PersonService.name, () => {
moveMock, moveMock,
mediaMock, mediaMock,
personMock, personMock,
searchMock,
configMock, configMock,
storageMock, storageMock,
jobMock, jobMock,
smartInfoMock
); );
mediaMock.crop.mockResolvedValue(croppedFace); mediaMock.crop.mockResolvedValue(croppedFace);
@ -591,7 +593,7 @@ describe(PersonService.name, () => {
it('should match existing people', async () => { it('should match existing people', async () => {
machineLearningMock.detectFaces.mockResolvedValue([detectFaceMock]); machineLearningMock.detectFaces.mockResolvedValue([detectFaceMock]);
personMock.searchByEmbedding.mockResolvedValue([faceStub.face1]); smartInfoMock.searchFaces.mockResolvedValue([faceStub.face1]);
assetMock.getByIds.mockResolvedValue([assetStub.image]); assetMock.getByIds.mockResolvedValue([assetStub.image]);
await sut.handleRecognizeFaces({ id: assetStub.image.id }); await sut.handleRecognizeFaces({ id: assetStub.image.id });
@ -610,7 +612,7 @@ describe(PersonService.name, () => {
it('should create a new person', async () => { it('should create a new person', async () => {
machineLearningMock.detectFaces.mockResolvedValue([detectFaceMock]); machineLearningMock.detectFaces.mockResolvedValue([detectFaceMock]);
personMock.searchByEmbedding.mockResolvedValue([]); smartInfoMock.searchFaces.mockResolvedValue([]);
personMock.create.mockResolvedValue(personStub.noName); personMock.create.mockResolvedValue(personStub.noName);
assetMock.getByIds.mockResolvedValue([assetStub.image]); assetMock.getByIds.mockResolvedValue([assetStub.image]);
personMock.createFace.mockResolvedValue(faceStub.primaryFace1); personMock.createFace.mockResolvedValue(faceStub.primaryFace1);

View file

@ -18,7 +18,7 @@ import {
IMediaRepository, IMediaRepository,
IMoveRepository, IMoveRepository,
IPersonRepository, IPersonRepository,
ISearchRepository, ISmartInfoRepository,
IStorageRepository, IStorageRepository,
ISystemConfigRepository, ISystemConfigRepository,
ImmichReadStream, ImmichReadStream,
@ -56,10 +56,10 @@ export class PersonService {
@Inject(IMoveRepository) moveRepository: IMoveRepository, @Inject(IMoveRepository) moveRepository: IMoveRepository,
@Inject(IMediaRepository) private mediaRepository: IMediaRepository, @Inject(IMediaRepository) private mediaRepository: IMediaRepository,
@Inject(IPersonRepository) private repository: IPersonRepository, @Inject(IPersonRepository) private repository: IPersonRepository,
@Inject(ISearchRepository) private searchRepository: ISearchRepository,
@Inject(ISystemConfigRepository) configRepository: ISystemConfigRepository, @Inject(ISystemConfigRepository) configRepository: ISystemConfigRepository,
@Inject(IStorageRepository) private storageRepository: IStorageRepository, @Inject(IStorageRepository) private storageRepository: IStorageRepository,
@Inject(IJobRepository) private jobRepository: IJobRepository, @Inject(IJobRepository) private jobRepository: IJobRepository,
@Inject(ISmartInfoRepository) private smartInfoRepository: ISmartInfoRepository,
) { ) {
this.access = AccessCore.create(accessRepository); this.access = AccessCore.create(accessRepository);
this.configCore = SystemConfigCore.create(configRepository); this.configCore = SystemConfigCore.create(configRepository);
@ -315,7 +315,7 @@ export class PersonService {
this.logger.verbose(faces.map((face) => ({ ...face, embedding: `vector(${face.embedding.length})` }))); this.logger.verbose(faces.map((face) => ({ ...face, embedding: `vector(${face.embedding.length})` })));
for (const { embedding, ...rest } of faces) { for (const { embedding, ...rest } of faces) {
const matches = await this.repository.searchByEmbedding({ const matches = await this.smartInfoRepository.searchFaces({
ownerId: asset.ownerId, ownerId: asset.ownerId,
embedding, embedding,
numResults: 1, numResults: 1,

View file

@ -25,15 +25,6 @@ export interface PersonStatistics {
assets: number; assets: number;
} }
export type Embedding = number[];
export interface EmbeddingSearch {
ownerId: string;
embedding: Embedding;
numResults: number;
maxDistance?: number;
}
export interface IPersonRepository { export interface IPersonRepository {
getAll(): Promise<PersonEntity[]>; getAll(): Promise<PersonEntity[]>;
getAllWithoutThumbnail(): Promise<PersonEntity[]>; getAllWithoutThumbnail(): Promise<PersonEntity[]>;
@ -51,7 +42,6 @@ export interface IPersonRepository {
delete(entity: PersonEntity): Promise<PersonEntity | null>; delete(entity: PersonEntity): Promise<PersonEntity | null>;
deleteAll(): Promise<number>; deleteAll(): Promise<number>;
getStatistics(personId: string): Promise<PersonStatistics>; getStatistics(personId: string): Promise<PersonStatistics>;
searchByEmbedding(search: EmbeddingSearch): Promise<AssetFaceEntity[]>;
getAllFaces(): Promise<AssetFaceEntity[]>; getAllFaces(): Promise<AssetFaceEntity[]>;
getFacesByIds(ids: AssetFaceId[]): Promise<AssetFaceEntity[]>; getFacesByIds(ids: AssetFaceId[]): Promise<AssetFaceEntity[]>;
getRandomFace(personId: string): Promise<AssetFaceEntity | null>; getRandomFace(personId: string): Promise<AssetFaceEntity | null>;

View file

@ -1,9 +1,18 @@
import { Embedding, EmbeddingSearch } from '@app/domain'; import { AssetEntity, AssetFaceEntity, SmartInfoEntity } from '@app/infra/entities';
import { AssetEntity, SmartInfoEntity } from '@app/infra/entities';
export const ISmartInfoRepository = 'ISmartInfoRepository'; export const ISmartInfoRepository = 'ISmartInfoRepository';
export type Embedding = number[];
export interface EmbeddingSearch {
ownerId: string;
embedding: Embedding;
numResults: number;
maxDistance?: number;
}
export interface ISmartInfoRepository { export interface ISmartInfoRepository {
searchByEmbedding(search: EmbeddingSearch): Promise<AssetEntity[]>; searchCLIP(search: EmbeddingSearch): Promise<AssetEntity[]>;
searchFaces(search: EmbeddingSearch): Promise<AssetFaceEntity[]>;
upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void>; upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void>;
} }

View file

@ -67,7 +67,7 @@ export class SearchService {
{ text: query }, { text: query },
machineLearning.clip, machineLearning.clip,
); );
assets = await this.smartInfoRepository.searchByEmbedding({ ownerId: authUser.id, embedding, numResults: 100 }); assets = await this.smartInfoRepository.searchCLIP({ ownerId: authUser.id, embedding, numResults: 100 });
break; break;
case SearchStrategy.TEXT: case SearchStrategy.TEXT:
assets = await this.assetRepository.searchMetadata(query, authUser.id, { numResults: 250 }); assets = await this.assetRepository.searchMetadata(query, authUser.id, { numResults: 250 });

View file

@ -1,6 +1,5 @@
import { import {
AssetFaceId, AssetFaceId,
EmbeddingSearch,
IPersonRepository, IPersonRepository,
PersonNameSearchOptions, PersonNameSearchOptions,
PersonSearchOptions, PersonSearchOptions,
@ -11,20 +10,14 @@ import { InjectRepository } from '@nestjs/typeorm';
import { In, Repository } from 'typeorm'; import { In, Repository } from 'typeorm';
import { AssetEntity, AssetFaceEntity, PersonEntity } from '../entities'; import { AssetEntity, AssetFaceEntity, PersonEntity } from '../entities';
import { DummyValue, GenerateSql } from '../infra.util'; import { DummyValue, GenerateSql } from '../infra.util';
import { asVector, isValidInteger } from '../infra.utils'; import { asVector } from '../infra.utils';
export class PersonRepository implements IPersonRepository { export class PersonRepository implements IPersonRepository {
private readonly faceColumns: string[];
constructor( constructor(
@InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>, @InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>,
@InjectRepository(PersonEntity) private personRepository: Repository<PersonEntity>, @InjectRepository(PersonEntity) private personRepository: Repository<PersonEntity>,
@InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>, @InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
) { ) {}
this.faceColumns = this.assetFaceRepository.manager.connection
.getMetadata(AssetFaceEntity)
.ownColumns.map((column) => column.propertyName)
.filter((propertyName) => propertyName !== 'embedding');
}
/** /**
* Before reassigning faces, delete potential key violations * Before reassigning faces, delete potential key violations
@ -248,40 +241,4 @@ export class PersonRepository implements IPersonRepository {
async getRandomFace(personId: string): Promise<AssetFaceEntity | null> { async getRandomFace(personId: string): Promise<AssetFaceEntity | null> {
return this.assetFaceRepository.findOneBy({ personId }); return this.assetFaceRepository.findOneBy({ personId });
} }
async searchByEmbedding({
ownerId,
embedding,
numResults,
maxDistance,
}: EmbeddingSearch): Promise<AssetFaceEntity[]> {
if (!isValidInteger(numResults, { min: 1 })) {
throw new Error(`Invalid value for 'numResults': ${numResults}`);
}
let results: AssetFaceEntity[] = [];
await this.assetRepository.manager.transaction(async (manager) => {
await manager.query(`SET LOCAL vectors.k = '${numResults}'`);
const cte = manager
.createQueryBuilder(AssetFaceEntity, 'faces')
.select('1 + (faces.embedding <=> :embedding)', 'distance')
.innerJoin('faces.asset', 'asset')
.where('asset.ownerId = :ownerId')
.orderBy(`faces.embedding <=> :embedding`)
.setParameters({ ownerId, embedding: asVector(embedding) })
.limit(numResults);
this.faceColumns.forEach((col) => cte.addSelect(`faces.${col} AS "${col}"`));
results = await manager
.createQueryBuilder()
.select('res.*')
.addCommonTableExpression(cte, 'cte')
.from('cte', 'res')
.where('res.distance <= :maxDistance', { maxDistance })
.getRawMany();
});
return this.assetFaceRepository.create(results);
}
} }

View file

@ -3,24 +3,30 @@ import { Injectable, Logger } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm'; import { InjectRepository } from '@nestjs/typeorm';
import AsyncLock from 'async-lock'; import AsyncLock from 'async-lock';
import { Repository } from 'typeorm'; import { Repository } from 'typeorm';
import { AssetEntity, SmartInfoEntity, SmartSearchEntity } from '../entities'; import { AssetEntity, AssetFaceEntity, SmartInfoEntity, SmartSearchEntity } from '@app/infra/entities';
import { asVector, isValidInteger } from '../infra.utils'; import { asVector, isValidInteger } from '../infra.utils';
@Injectable() @Injectable()
export class SmartInfoRepository implements ISmartInfoRepository { export class SmartInfoRepository implements ISmartInfoRepository {
private logger = new Logger(SmartInfoRepository.name); private logger = new Logger(SmartInfoRepository.name);
private lock: AsyncLock; private lock: AsyncLock;
private readonly faceColumns: string[];
private curDimSize: number | undefined; private curDimSize: number | undefined;
constructor( constructor(
@InjectRepository(SmartInfoEntity) private repository: Repository<SmartInfoEntity>, @InjectRepository(SmartInfoEntity) private repository: Repository<SmartInfoEntity>,
@InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>, @InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>,
@InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
@InjectRepository(SmartSearchEntity) private smartSearchRepository: Repository<SmartSearchEntity>, @InjectRepository(SmartSearchEntity) private smartSearchRepository: Repository<SmartSearchEntity>,
) { ) {
this.lock = new AsyncLock(); this.lock = new AsyncLock();
this.faceColumns = this.assetFaceRepository.manager.connection
.getMetadata(AssetFaceEntity)
.ownColumns.map((column) => column.propertyName)
.filter((propertyName) => propertyName !== 'embedding');
} }
async searchByEmbedding({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> { async searchCLIP({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
if (!isValidInteger(numResults, { min: 1 })) { if (!isValidInteger(numResults, { min: 1 })) {
throw new Error(`Invalid value for 'numResults': ${numResults}`); throw new Error(`Invalid value for 'numResults': ${numResults}`);
} }
@ -42,6 +48,42 @@ export class SmartInfoRepository implements ISmartInfoRepository {
return results; return results;
} }
async searchFaces({
ownerId,
embedding,
numResults,
maxDistance,
}: EmbeddingSearch): Promise<AssetFaceEntity[]> {
if (!isValidInteger(numResults, { min: 1 })) {
throw new Error(`Invalid value for 'numResults': ${numResults}`);
}
let results: AssetFaceEntity[] = [];
await this.assetRepository.manager.transaction(async (manager) => {
await manager.query(`SET LOCAL vectors.k = '${numResults}'`);
const cte = manager
.createQueryBuilder(AssetFaceEntity, 'faces')
.select('1 + (faces.embedding <=> :embedding)', 'distance')
.innerJoin('faces.asset', 'asset')
.where('asset.ownerId = :ownerId')
.orderBy(`faces.embedding <=> :embedding`)
.setParameters({ ownerId, embedding: asVector(embedding) })
.limit(numResults);
this.faceColumns.forEach((col) => cte.addSelect(`faces.${col} AS "${col}"`));
results = await manager
.createQueryBuilder()
.select('res.*')
.addCommonTableExpression(cte, 'cte')
.from('cte', 'res')
.where('res.distance <= :maxDistance', { maxDistance })
.getRawMany();
});
return this.assetFaceRepository.create(results);
}
async upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void> { async upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void> {
await this.repository.upsert(smartInfo, { conflictPaths: ['assetId'] }); await this.repository.upsert(smartInfo, { conflictPaths: ['assetId'] });
if (!smartInfo.assetId || !embedding) return; if (!smartInfo.assetId || !embedding) return;

View file

@ -17,7 +17,6 @@ export const newPersonRepositoryMock = (): jest.Mocked<IPersonRepository> => {
delete: jest.fn(), delete: jest.fn(),
getStatistics: jest.fn(), getStatistics: jest.fn(),
searchByEmbedding: jest.fn(),
getAllFaces: jest.fn(), getAllFaces: jest.fn(),
getFacesByIds: jest.fn(), getFacesByIds: jest.fn(),
getRandomFace: jest.fn(), getRandomFace: jest.fn(),

View file

@ -2,7 +2,8 @@ import { ISmartInfoRepository } from '@app/domain';
export const newSmartInfoRepositoryMock = (): jest.Mocked<ISmartInfoRepository> => { export const newSmartInfoRepositoryMock = (): jest.Mocked<ISmartInfoRepository> => {
return { return {
searchByEmbedding: jest.fn(), searchCLIP: jest.fn(),
searchFaces: jest.fn(),
upsert: jest.fn(), upsert: jest.fn(),
}; };
}; };