浏览代码

move face search to smart info repository

mertalev 1 年之前
父节点
当前提交
8a8da5f5c8

+ 7 - 5
server/src/domain/person/person.service.spec.ts

@@ -13,6 +13,7 @@ import {
   newMoveRepositoryMock,
   newPersonRepositoryMock,
   newSearchRepositoryMock,
+  newSmartInfoRepositoryMock,
   newStorageRepositoryMock,
   newSystemConfigRepositoryMock,
   personStub,
@@ -27,6 +28,7 @@ import {
   IMoveRepository,
   IPersonRepository,
   ISearchRepository,
+  ISmartInfoRepository,
   IStorageRepository,
   ISystemConfigRepository,
   WithoutProperty,
@@ -70,8 +72,8 @@ describe(PersonService.name, () => {
   let mediaMock: jest.Mocked<IMediaRepository>;
   let moveMock: jest.Mocked<IMoveRepository>;
   let personMock: jest.Mocked<IPersonRepository>;
-  let searchMock: jest.Mocked<ISearchRepository>;
   let storageMock: jest.Mocked<IStorageRepository>;
+  let smartInfoMock: jest.Mocked<ISmartInfoRepository>;
   let sut: PersonService;
 
   beforeEach(async () => {
@@ -83,8 +85,8 @@ describe(PersonService.name, () => {
     moveMock = newMoveRepositoryMock();
     mediaMock = newMediaRepositoryMock();
     personMock = newPersonRepositoryMock();
-    searchMock = newSearchRepositoryMock();
     storageMock = newStorageRepositoryMock();
+    smartInfoMock = newSmartInfoRepositoryMock();
     sut = new PersonService(
       accessMock,
       assetMock,
@@ -92,10 +94,10 @@ describe(PersonService.name, () => {
       moveMock,
       mediaMock,
       personMock,
-      searchMock,
       configMock,
       storageMock,
       jobMock,
+      smartInfoMock
     );
 
     mediaMock.crop.mockResolvedValue(croppedFace);
@@ -591,7 +593,7 @@ describe(PersonService.name, () => {
 
     it('should match existing people', async () => {
       machineLearningMock.detectFaces.mockResolvedValue([detectFaceMock]);
-      personMock.searchByEmbedding.mockResolvedValue([faceStub.face1]);
+      smartInfoMock.searchFaces.mockResolvedValue([faceStub.face1]);
       assetMock.getByIds.mockResolvedValue([assetStub.image]);
       await sut.handleRecognizeFaces({ id: assetStub.image.id });
 
@@ -610,7 +612,7 @@ describe(PersonService.name, () => {
 
     it('should create a new person', async () => {
       machineLearningMock.detectFaces.mockResolvedValue([detectFaceMock]);
-      personMock.searchByEmbedding.mockResolvedValue([]);
+      smartInfoMock.searchFaces.mockResolvedValue([]);
       personMock.create.mockResolvedValue(personStub.noName);
       assetMock.getByIds.mockResolvedValue([assetStub.image]);
       personMock.createFace.mockResolvedValue(faceStub.primaryFace1);

+ 3 - 3
server/src/domain/person/person.service.ts

@@ -18,7 +18,7 @@ import {
   IMediaRepository,
   IMoveRepository,
   IPersonRepository,
-  ISearchRepository,
+  ISmartInfoRepository,
   IStorageRepository,
   ISystemConfigRepository,
   ImmichReadStream,
@@ -56,10 +56,10 @@ export class PersonService {
     @Inject(IMoveRepository) moveRepository: IMoveRepository,
     @Inject(IMediaRepository) private mediaRepository: IMediaRepository,
     @Inject(IPersonRepository) private repository: IPersonRepository,
-    @Inject(ISearchRepository) private searchRepository: ISearchRepository,
     @Inject(ISystemConfigRepository) configRepository: ISystemConfigRepository,
     @Inject(IStorageRepository) private storageRepository: IStorageRepository,
     @Inject(IJobRepository) private jobRepository: IJobRepository,
+    @Inject(ISmartInfoRepository) private smartInfoRepository: ISmartInfoRepository,
   ) {
     this.access = AccessCore.create(accessRepository);
     this.configCore = SystemConfigCore.create(configRepository);
@@ -315,7 +315,7 @@ export class PersonService {
     this.logger.verbose(faces.map((face) => ({ ...face, embedding: `vector(${face.embedding.length})` })));
 
     for (const { embedding, ...rest } of faces) {
-      const matches = await this.repository.searchByEmbedding({
+      const matches = await this.smartInfoRepository.searchFaces({
         ownerId: asset.ownerId,
         embedding,
         numResults: 1,

+ 0 - 10
server/src/domain/repositories/person.repository.ts

@@ -25,15 +25,6 @@ export interface PersonStatistics {
   assets: number;
 }
 
-export type Embedding = number[]; 
-
-export interface EmbeddingSearch {
-  ownerId: string;
-  embedding: Embedding;
-  numResults: number;
-  maxDistance?: number;
-}
-
 export interface IPersonRepository {
   getAll(): Promise<PersonEntity[]>;
   getAllWithoutThumbnail(): Promise<PersonEntity[]>;
@@ -51,7 +42,6 @@ export interface IPersonRepository {
   delete(entity: PersonEntity): Promise<PersonEntity | null>;
   deleteAll(): Promise<number>;
   getStatistics(personId: string): Promise<PersonStatistics>;
-  searchByEmbedding(search: EmbeddingSearch): Promise<AssetFaceEntity[]>;
   getAllFaces(): Promise<AssetFaceEntity[]>;
   getFacesByIds(ids: AssetFaceId[]): Promise<AssetFaceEntity[]>;
   getRandomFace(personId: string): Promise<AssetFaceEntity | null>;

+ 12 - 3
server/src/domain/repositories/smart-info.repository.ts

@@ -1,9 +1,18 @@
-import { Embedding, EmbeddingSearch } from '@app/domain';
-import { AssetEntity, SmartInfoEntity } from '@app/infra/entities';
+import { AssetEntity, AssetFaceEntity, SmartInfoEntity } from '@app/infra/entities';
 
 export const ISmartInfoRepository = 'ISmartInfoRepository';
 
+export type Embedding = number[]; 
+
+export interface EmbeddingSearch {
+  ownerId: string;
+  embedding: Embedding;
+  numResults: number;
+  maxDistance?: number;
+}
+
 export interface ISmartInfoRepository {
-  searchByEmbedding(search: EmbeddingSearch): Promise<AssetEntity[]>;
+  searchCLIP(search: EmbeddingSearch): Promise<AssetEntity[]>;
+  searchFaces(search: EmbeddingSearch): Promise<AssetFaceEntity[]>;
   upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void>;
 }

+ 1 - 1
server/src/domain/search/search.service.ts

@@ -67,7 +67,7 @@ export class SearchService {
           { text: query },
           machineLearning.clip,
         );
-        assets = await this.smartInfoRepository.searchByEmbedding({ ownerId: authUser.id, embedding, numResults: 100 });
+        assets = await this.smartInfoRepository.searchCLIP({ ownerId: authUser.id, embedding, numResults: 100 });
         break;
       case SearchStrategy.TEXT:
         assets = await this.assetRepository.searchMetadata(query, authUser.id, { numResults: 250 });

+ 2 - 45
server/src/infra/repositories/person.repository.ts

@@ -1,6 +1,5 @@
 import {
   AssetFaceId,
-  EmbeddingSearch,
   IPersonRepository,
   PersonNameSearchOptions,
   PersonSearchOptions,
@@ -11,20 +10,14 @@ import { InjectRepository } from '@nestjs/typeorm';
 import { In, Repository } from 'typeorm';
 import { AssetEntity, AssetFaceEntity, PersonEntity } from '../entities';
 import { DummyValue, GenerateSql } from '../infra.util';
-import { asVector, isValidInteger } from '../infra.utils';
+import { asVector } from '../infra.utils';
 
 export class PersonRepository implements IPersonRepository {
-  private readonly faceColumns: string[];
   constructor(
     @InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>,
     @InjectRepository(PersonEntity) private personRepository: Repository<PersonEntity>,
     @InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
-  ) {
-    this.faceColumns = this.assetFaceRepository.manager.connection
-      .getMetadata(AssetFaceEntity)
-      .ownColumns.map((column) => column.propertyName)
-      .filter((propertyName) => propertyName !== 'embedding');
-  }
+  ) {}
 
   /**
    * Before reassigning faces, delete potential key violations
@@ -248,40 +241,4 @@ export class PersonRepository implements IPersonRepository {
   async getRandomFace(personId: string): Promise<AssetFaceEntity | null> {
     return this.assetFaceRepository.findOneBy({ personId });
   }
-
-  async searchByEmbedding({
-    ownerId,
-    embedding,
-    numResults,
-    maxDistance,
-  }: EmbeddingSearch): Promise<AssetFaceEntity[]> {
-    if (!isValidInteger(numResults, { min: 1 })) {
-      throw new Error(`Invalid value for 'numResults': ${numResults}`);
-    }
-
-    let results: AssetFaceEntity[] = [];
-    await this.assetRepository.manager.transaction(async (manager) => {
-      await manager.query(`SET LOCAL vectors.k = '${numResults}'`);
-      const cte = manager
-        .createQueryBuilder(AssetFaceEntity, 'faces')
-        .select('1 + (faces.embedding <=> :embedding)', 'distance')
-        .innerJoin('faces.asset', 'asset')
-        .where('asset.ownerId = :ownerId')
-        .orderBy(`faces.embedding <=> :embedding`)
-        .setParameters({ ownerId, embedding: asVector(embedding) })
-        .limit(numResults);
-
-      this.faceColumns.forEach((col) => cte.addSelect(`faces.${col} AS "${col}"`));
-
-      results = await manager
-        .createQueryBuilder()
-        .select('res.*')
-        .addCommonTableExpression(cte, 'cte')
-        .from('cte', 'res')
-        .where('res.distance <= :maxDistance', { maxDistance })
-        .getRawMany();
-    });
-
-    return this.assetFaceRepository.create(results);
-  }
 }

+ 44 - 2
server/src/infra/repositories/smart-info.repository.ts

@@ -3,24 +3,30 @@ import { Injectable, Logger } from '@nestjs/common';
 import { InjectRepository } from '@nestjs/typeorm';
 import AsyncLock from 'async-lock';
 import { Repository } from 'typeorm';
-import { AssetEntity, SmartInfoEntity, SmartSearchEntity } from '../entities';
+import { AssetEntity, AssetFaceEntity, SmartInfoEntity, SmartSearchEntity } from '@app/infra/entities';
 import { asVector, isValidInteger } from '../infra.utils';
 
 @Injectable()
 export class SmartInfoRepository implements ISmartInfoRepository {
   private logger = new Logger(SmartInfoRepository.name);
   private lock: AsyncLock;
+  private readonly faceColumns: string[];
   private curDimSize: number | undefined;
 
   constructor(
     @InjectRepository(SmartInfoEntity) private repository: Repository<SmartInfoEntity>,
     @InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>,
+    @InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
     @InjectRepository(SmartSearchEntity) private smartSearchRepository: Repository<SmartSearchEntity>,
   ) {
     this.lock = new AsyncLock();
+    this.faceColumns = this.assetFaceRepository.manager.connection
+      .getMetadata(AssetFaceEntity)
+      .ownColumns.map((column) => column.propertyName)
+      .filter((propertyName) => propertyName !== 'embedding');
   }
 
-  async searchByEmbedding({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
+  async searchCLIP({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
     if (!isValidInteger(numResults, { min: 1 })) {
       throw new Error(`Invalid value for 'numResults': ${numResults}`);
     }
@@ -42,6 +48,42 @@ export class SmartInfoRepository implements ISmartInfoRepository {
     return results;
   }
 
+  async searchFaces({
+    ownerId,
+    embedding,
+    numResults,
+    maxDistance,
+  }: EmbeddingSearch): Promise<AssetFaceEntity[]> {
+    if (!isValidInteger(numResults, { min: 1 })) {
+      throw new Error(`Invalid value for 'numResults': ${numResults}`);
+    }
+
+    let results: AssetFaceEntity[] = [];
+    await this.assetRepository.manager.transaction(async (manager) => {
+      await manager.query(`SET LOCAL vectors.k = '${numResults}'`);
+      const cte = manager
+        .createQueryBuilder(AssetFaceEntity, 'faces')
+        .select('1 + (faces.embedding <=> :embedding)', 'distance')
+        .innerJoin('faces.asset', 'asset')
+        .where('asset.ownerId = :ownerId')
+        .orderBy(`faces.embedding <=> :embedding`)
+        .setParameters({ ownerId, embedding: asVector(embedding) })
+        .limit(numResults);
+
+      this.faceColumns.forEach((col) => cte.addSelect(`faces.${col} AS "${col}"`));
+
+      results = await manager
+        .createQueryBuilder()
+        .select('res.*')
+        .addCommonTableExpression(cte, 'cte')
+        .from('cte', 'res')
+        .where('res.distance <= :maxDistance', { maxDistance })
+        .getRawMany();
+    });
+
+    return this.assetFaceRepository.create(results);
+  }
+
   async upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void> {
     await this.repository.upsert(smartInfo, { conflictPaths: ['assetId'] });
     if (!smartInfo.assetId || !embedding) return;

+ 0 - 1
server/test/repositories/person.repository.mock.ts

@@ -17,7 +17,6 @@ export const newPersonRepositoryMock = (): jest.Mocked<IPersonRepository> => {
     delete: jest.fn(),
 
     getStatistics: jest.fn(),
-    searchByEmbedding: jest.fn(),
     getAllFaces: jest.fn(),
     getFacesByIds: jest.fn(),
     getRandomFace: jest.fn(),

+ 2 - 1
server/test/repositories/smart-info.repository.mock.ts

@@ -2,7 +2,8 @@ import { ISmartInfoRepository } from '@app/domain';
 
 export const newSmartInfoRepositoryMock = (): jest.Mocked<ISmartInfoRepository> => {
   return {
-    searchByEmbedding: jest.fn(),
+    searchCLIP: jest.fn(),
+    searchFaces: jest.fn(),
     upsert: jest.fn(),
   };
 };