Browse Source

facial recognition fixes

mertalev 1 năm trước cách đây
mục cha
commit
5cb3312589

+ 2 - 1
server/src/infra/infra.utils.ts

@@ -33,4 +33,5 @@ export async function paginate<Entity extends ObjectLiteral>(
   return { items, hasNextPage };
 }
 
-export const asVector = (embedding: number[]) => `[${embedding.join(',')}]`;
+export const asVector = (embedding: number[], escape = false) =>
+  escape ? `'[${embedding.join(',')}]'` : `[${embedding.join(',')}]`;

+ 4 - 1
server/src/infra/migrations/vector/1699746198141-UsePgVectors.ts

@@ -19,7 +19,10 @@ export class UsePgVectors1699746198141 implements MigrationInterface {
     await queryRunner.query('DROP EXTENSION IF EXISTS vectors');
     await queryRunner.query('CREATE EXTENSION vectors');
 
-    await queryRunner.query(`ALTER TABLE asset_faces ALTER COLUMN embedding TYPE vector(${faceDimSize})`);
+    await queryRunner.query(`
+      ALTER TABLE asset_faces 
+        ALTER COLUMN embedding SET NOT NULL,
+        ALTER COLUMN embedding TYPE vector(${faceDimSize})`);
     await queryRunner.query(`
       CREATE TABLE smart_search (
         "assetId"  uuid PRIMARY KEY NOT NULL REFERENCES assets(id) ON DELETE CASCADE,

+ 16 - 15
server/src/infra/repositories/person.repository.ts

@@ -12,6 +12,7 @@ import { In, Repository } from 'typeorm';
 import { AssetEntity, AssetFaceEntity, PersonEntity } from '../entities';
 import { DummyValue, GenerateSql } from '../infra.util';
 import { asVector } from '../infra.utils';
+import { dataSource } from '..';
 
 export class PersonRepository implements IPersonRepository {
   constructor(
@@ -222,11 +223,7 @@ export class PersonRepository implements IPersonRepository {
       throw new Error('Person ID is required to create a face');
     }
     const { embedding, ...face } = entity;
-    await this.assetFaceRepository.save(face);
-    await this.assetFaceRepository.manager.query(
-      `UPDATE "asset_faces" SET "embedding" = ${asVector(embedding)} WHERE "assetId" = $1 AND "personId" = $2`,
-      [entity.assetId, entity.personId],
-    );
+    await this.assetFaceRepository.insert({ ...face, embedding: () => asVector(embedding, true) });
     return this.assetFaceRepository.findOneByOrFail({ assetId: entity.assetId, personId: entity.personId });
   }
 
@@ -245,18 +242,22 @@ export class PersonRepository implements IPersonRepository {
     return this.assetFaceRepository.findOneBy({ personId });
   }
 
-  searchByEmbedding({ ownerId, embedding, numResults, maxDistance }: EmbeddingSearch): Promise<AssetFaceEntity[]> {
-    let query = this.assetFaceRepository
-      .createQueryBuilder('faces')
+  async searchByEmbedding({ ownerId, embedding, numResults, maxDistance }: EmbeddingSearch): Promise<AssetFaceEntity[]> {
+    const cte = this.assetFaceRepository.createQueryBuilder('faces')
+      .select('1 + (faces.embedding <=> :embedding)', 'distance')
       .leftJoinAndSelect('faces.asset', 'asset')
-      .where('asset.ownerId = :ownerId', { ownerId })
-      .orderBy(`faces.embedding <=> ${asVector(embedding)}`)
+      .where('asset.ownerId = :ownerId')
+      .orderBy(`faces.embedding <=> :embedding`)
+      .setParameters({ownerId, embedding: asVector(embedding)})
       .limit(numResults);
+    
+    const res = await dataSource.createQueryBuilder()
+      .select('res.*')
+      .addCommonTableExpression(cte, 'cte')
+      .from('cte', 'res')
+      .where('res.distance <= :maxDistance', { maxDistance })
+      .getRawMany();
 
-    if (maxDistance) {
-      query = query.andWhere(`(faces.embedding <=> ${asVector(embedding)}) <= :maxDistance`, { maxDistance });
-    }
-
-    return query.getMany();
+    return this.assetFaceRepository.create(res);
   }
 }

+ 30 - 17
server/src/infra/repositories/smart-info.repository.ts

@@ -13,28 +13,30 @@ export class SmartInfoRepository implements ISmartInfoRepository {
   private curDimSize: number | undefined;
 
   constructor(
-      @InjectRepository(SmartInfoEntity) private repository: Repository<SmartInfoEntity>,
-      @InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>,
-        @InjectRepository(SmartSearchEntity) private smartSearchRepository: Repository<SmartSearchEntity>) {
+    @InjectRepository(SmartInfoEntity) private repository: Repository<SmartInfoEntity>,
+    @InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>,
+    @InjectRepository(SmartSearchEntity) private smartSearchRepository: Repository<SmartSearchEntity>,
+  ) {
     this.lock = new AsyncLock();
   }
 
   async searchByEmbedding({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
-    const query: string = this.assetRepository.createQueryBuilder('a')
-    .innerJoin('a.smartSearch', 's')
-    .where('a.ownerId = :ownerId')
-    .leftJoinAndSelect('a.exifInfo', 'e')
-    .orderBy('s.embedding <=> :embedding')
-    .setParameters({ embedding: asVector(embedding), ownerId })
-    .limit(numResults)
-    .getSql();
+    const query: string = this.assetRepository
+      .createQueryBuilder('a')
+      .innerJoin('a.smartSearch', 's')
+      .where('a.ownerId = :ownerId')
+      .leftJoinAndSelect('a.exifInfo', 'e')
+      .orderBy('s.embedding <=> :embedding')
+      .setParameters({ embedding: asVector(embedding), ownerId })
+      .limit(numResults)
+      .getSql();
 
     const queryWithK = `
       BEGIN;
       SET LOCAL vectors.k = ${numResults};
       ${query};
       COMMIT;
-    `
+    `;
     return this.assetRepository.create(await this.assetRepository.manager.query(queryWithK));
   }
 
@@ -51,9 +53,15 @@ export class SmartInfoRepository implements ISmartInfoRepository {
   }
 
   private async upsertEmbedding(assetId: string, embedding: number[]): Promise<void> {
-    await this.smartSearchRepository.manager.query(
-      `INSERT INTO smart_search ($1, $2) ON CONFLICT ("assetId") SET embedding = $2`,
-      [assetId, asVector(embedding)],
+    if (this.lock.isBusy('updateDimSizeLock')) {
+      this.logger.log('Waiting for CLIP dimension size update to finish');
+      await new Promise((resolve) => setTimeout(resolve, 1000));
+      return this.upsertEmbedding(assetId, embedding);
+    }
+
+    await this.smartSearchRepository.upsert(
+      { assetId, embedding: () => asVector(embedding, true) },
+      { conflictPaths: ['assetId'] },
     );
   }
 
@@ -67,7 +75,8 @@ export class SmartInfoRepository implements ISmartInfoRepository {
 
       this.logger.log(`Updating CLIP dimension size to ${dimSize}`);
 
-      await this.smartSearchRepository.manager.query(`
+      try {
+        await this.smartSearchRepository.manager.query(`
         BEGIN;
 
         ALTER TABLE smart_search
@@ -75,7 +84,7 @@ export class SmartInfoRepository implements ISmartInfoRepository {
         ADD COLUMN embedding vector(${dimSize});
 
         CREATE INDEX clip_index ON smart_search
-        USING vectors (embedding dot_ops) WITH (options = $$
+        USING vectors (embedding cosine_ops) WITH (options = $$
         [indexing.hnsw]
         m = 16
         ef_construction = 300
@@ -83,6 +92,10 @@ export class SmartInfoRepository implements ISmartInfoRepository {
 
         COMMIT;
       `);
+      } catch (err) {
+        this.logger.error(`Failed to update CLIP dimension size to ${dimSize}: ${err}`);
+        this.smartSearchRepository.manager.query('ROLLBACK');
+      }
 
       this.curDimSize = dimSize;
       this.logger.log(`Successfully updated CLIP dimension size to ${dimSize}`);