|
@@ -9,10 +9,10 @@ import {
|
|
|
} from '@app/domain';
|
|
|
import { InjectRepository } from '@nestjs/typeorm';
|
|
|
import { In, Repository } from 'typeorm';
|
|
|
+import { dataSource } from '..';
|
|
|
import { AssetEntity, AssetFaceEntity, PersonEntity } from '../entities';
|
|
|
import { DummyValue, GenerateSql } from '../infra.util';
|
|
|
-import { asVector } from '../infra.utils';
|
|
|
-import { dataSource } from '..';
|
|
|
+import { asVector, isValidInteger } from '../infra.utils';
|
|
|
|
|
|
export class PersonRepository implements IPersonRepository {
|
|
|
private readonly faceColumns: string[];
|
|
@@ -22,8 +22,8 @@ export class PersonRepository implements IPersonRepository {
|
|
|
@InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
|
|
|
) {
|
|
|
this.faceColumns = this.assetFaceRepository.manager.connection
|
|
|
- .getMetadata(AssetFaceEntity).ownColumns
|
|
|
- .map((column) => column.propertyName)
|
|
|
+ .getMetadata(AssetFaceEntity)
|
|
|
+ .ownColumns.map((column) => column.propertyName)
|
|
|
.filter((propertyName) => propertyName !== 'embedding');
|
|
|
}
|
|
|
|
|
@@ -248,24 +248,39 @@ export class PersonRepository implements IPersonRepository {
|
|
|
return this.assetFaceRepository.findOneBy({ personId });
|
|
|
}
|
|
|
|
|
|
- async searchByEmbedding({ ownerId, embedding, numResults, maxDistance }: EmbeddingSearch): Promise<AssetFaceEntity[]> {
|
|
|
- const cte = this.assetFaceRepository.createQueryBuilder('faces')
|
|
|
- .select('1 + (faces.embedding <=> :embedding)', 'distance')
|
|
|
- .innerJoin('faces.asset', 'asset')
|
|
|
- .where('asset.ownerId = :ownerId')
|
|
|
- .orderBy(`faces.embedding <=> :embedding`)
|
|
|
- .setParameters({ownerId, embedding: asVector(embedding)})
|
|
|
- .limit(numResults);
|
|
|
-
|
|
|
- this.faceColumns.forEach((col) => cte.addSelect(`faces.${col} AS "${col}"`));
|
|
|
-
|
|
|
- const res = await dataSource.createQueryBuilder()
|
|
|
- .select('res.*')
|
|
|
- .addCommonTableExpression(cte, 'cte')
|
|
|
- .from('cte', 'res')
|
|
|
- .where('res.distance <= :maxDistance', { maxDistance })
|
|
|
- .getRawMany();
|
|
|
+ async searchByEmbedding({
|
|
|
+ ownerId,
|
|
|
+ embedding,
|
|
|
+ numResults,
|
|
|
+ maxDistance,
|
|
|
+ }: EmbeddingSearch): Promise<AssetFaceEntity[]> {
|
|
|
+ if (!isValidInteger(numResults, { min: 1 })) {
|
|
|
+ throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
|
|
+ }
|
|
|
+
|
|
|
+ let results: AssetFaceEntity[] = [];
|
|
|
+ this.assetRepository.manager.transaction(async (manager) => {
|
|
|
+ await manager.query(`SET LOCAL vectors.k = '${numResults}'`);
|
|
|
+ const cte = manager
|
|
|
+ .createQueryBuilder(AssetFaceEntity, 'faces')
|
|
|
+ .select('1 + (faces.embedding <=> :embedding)', 'distance')
|
|
|
+ .innerJoin('faces.asset', 'asset')
|
|
|
+ .where('asset.ownerId = :ownerId')
|
|
|
+ .orderBy(`faces.embedding <=> :embedding`)
|
|
|
+ .setParameters({ ownerId, embedding: asVector(embedding) })
|
|
|
+ .limit(numResults);
|
|
|
+
|
|
|
+ this.faceColumns.forEach((col) => cte.addSelect(`faces.${col} AS "${col}"`));
|
|
|
+
|
|
|
+ results = await dataSource
|
|
|
+ .createQueryBuilder()
|
|
|
+ .select('res.*')
|
|
|
+ .addCommonTableExpression(cte, 'cte')
|
|
|
+ .from('cte', 'res')
|
|
|
+ .where('res.distance <= :maxDistance', { maxDistance })
|
|
|
+ .getRawMany();
|
|
|
+ });
|
|
|
|
|
|
- return this.assetFaceRepository.create(res);
|
|
|
+ return this.assetFaceRepository.create(results);
|
|
|
}
|
|
|
}
|