|
@@ -3,24 +3,30 @@ import { Injectable, Logger } from '@nestjs/common';
|
|
|
import { InjectRepository } from '@nestjs/typeorm';
|
|
|
import AsyncLock from 'async-lock';
|
|
|
import { Repository } from 'typeorm';
|
|
|
-import { AssetEntity, SmartInfoEntity, SmartSearchEntity } from '../entities';
|
|
|
+import { AssetEntity, AssetFaceEntity, SmartInfoEntity, SmartSearchEntity } from '@app/infra/entities';
|
|
|
import { asVector, isValidInteger } from '../infra.utils';
|
|
|
|
|
|
@Injectable()
|
|
|
export class SmartInfoRepository implements ISmartInfoRepository {
|
|
|
private logger = new Logger(SmartInfoRepository.name);
|
|
|
private lock: AsyncLock;
|
|
|
+ private readonly faceColumns: string[];
|
|
|
private curDimSize: number | undefined;
|
|
|
|
|
|
constructor(
|
|
|
@InjectRepository(SmartInfoEntity) private repository: Repository<SmartInfoEntity>,
|
|
|
@InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>,
|
|
|
+ @InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
|
|
|
@InjectRepository(SmartSearchEntity) private smartSearchRepository: Repository<SmartSearchEntity>,
|
|
|
) {
|
|
|
this.lock = new AsyncLock();
|
|
|
+ this.faceColumns = this.assetFaceRepository.manager.connection
|
|
|
+ .getMetadata(AssetFaceEntity)
|
|
|
+ .ownColumns.map((column) => column.propertyName)
|
|
|
+ .filter((propertyName) => propertyName !== 'embedding');
|
|
|
}
|
|
|
|
|
|
- async searchByEmbedding({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
|
|
|
+ async searchCLIP({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
|
|
|
if (!isValidInteger(numResults, { min: 1 })) {
|
|
|
throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
|
|
}
|
|
@@ -42,6 +48,42 @@ export class SmartInfoRepository implements ISmartInfoRepository {
|
|
|
return results;
|
|
|
}
|
|
|
|
|
|
+ async searchFaces({
|
|
|
+ ownerId,
|
|
|
+ embedding,
|
|
|
+ numResults,
|
|
|
+ maxDistance,
|
|
|
+ }: EmbeddingSearch): Promise<AssetFaceEntity[]> {
|
|
|
+ if (!isValidInteger(numResults, { min: 1 })) {
|
|
|
+ throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
|
|
+ }
|
|
|
+
|
|
|
+ let results: AssetFaceEntity[] = [];
|
|
|
+ await this.assetRepository.manager.transaction(async (manager) => {
|
|
|
+ await manager.query(`SET LOCAL vectors.k = '${numResults}'`);
|
|
|
+ const cte = manager
|
|
|
+ .createQueryBuilder(AssetFaceEntity, 'faces')
|
|
|
+ .select('1 + (faces.embedding <=> :embedding)', 'distance')
|
|
|
+ .innerJoin('faces.asset', 'asset')
|
|
|
+ .where('asset.ownerId = :ownerId')
|
|
|
+ .orderBy(`faces.embedding <=> :embedding`)
|
|
|
+ .setParameters({ ownerId, embedding: asVector(embedding) })
|
|
|
+ .limit(numResults);
|
|
|
+
|
|
|
+ this.faceColumns.forEach((col) => cte.addSelect(`faces.${col} AS "${col}"`));
|
|
|
+
|
|
|
+ results = await manager
|
|
|
+ .createQueryBuilder()
|
|
|
+ .select('res.*')
|
|
|
+ .addCommonTableExpression(cte, 'cte')
|
|
|
+ .from('cte', 'res')
|
|
|
+ .where('res.distance <= :maxDistance', { maxDistance })
|
|
|
+ .getRawMany();
|
|
|
+ });
|
|
|
+
|
|
|
+ return this.assetFaceRepository.create(results);
|
|
|
+ }
|
|
|
+
|
|
|
async upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void> {
|
|
|
await this.repository.upsert(smartInfo, { conflictPaths: ['assetId'] });
|
|
|
if (!smartInfo.assetId || !embedding) return;
|