diff --git a/server/src/infra/infra.utils.ts b/server/src/infra/infra.utils.ts index 37a682559..d1ae51e78 100644 --- a/server/src/infra/infra.utils.ts +++ b/server/src/infra/infra.utils.ts @@ -33,5 +33,10 @@ export async function paginate( return { items, hasNextPage }; } -export const asVector = (embedding: number[], escape = false) => - escape ? `'[${embedding.join(',')}]'` : `[${embedding.join(',')}]`; +export const asVector = (embedding: number[], quote = false) => + quote ? `'[${embedding.join(',')}]'` : `[${embedding.join(',')}]`; + +export const isValidInteger = (value: number, options: {min?: number, max?: number}): boolean => { + const { min = Number.MIN_SAFE_INTEGER, max = Number.MAX_SAFE_INTEGER } = options; + return Number.isInteger(value) && value >= min && value <= max; +} diff --git a/server/src/infra/repositories/person.repository.ts b/server/src/infra/repositories/person.repository.ts index 50f689ae1..4af805ab2 100644 --- a/server/src/infra/repositories/person.repository.ts +++ b/server/src/infra/repositories/person.repository.ts @@ -9,10 +9,10 @@ import { } from '@app/domain'; import { InjectRepository } from '@nestjs/typeorm'; import { In, Repository } from 'typeorm'; +import { dataSource } from '..'; import { AssetEntity, AssetFaceEntity, PersonEntity } from '../entities'; import { DummyValue, GenerateSql } from '../infra.util'; -import { asVector } from '../infra.utils'; -import { dataSource } from '..'; +import { asVector, isValidInteger } from '../infra.utils'; export class PersonRepository implements IPersonRepository { private readonly faceColumns: string[]; @@ -22,8 +22,8 @@ export class PersonRepository implements IPersonRepository { @InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository, ) { this.faceColumns = this.assetFaceRepository.manager.connection - .getMetadata(AssetFaceEntity).ownColumns - .map((column) => column.propertyName) + .getMetadata(AssetFaceEntity) + .ownColumns.map((column) => column.propertyName) .filter((propertyName) => propertyName !== 'embedding'); } @@ -248,24 +248,39 @@ export class PersonRepository implements IPersonRepository { return this.assetFaceRepository.findOneBy({ personId }); } - async searchByEmbedding({ ownerId, embedding, numResults, maxDistance }: EmbeddingSearch): Promise { - const cte = this.assetFaceRepository.createQueryBuilder('faces') - .select('1 + (faces.embedding <=> :embedding)', 'distance') - .innerJoin('faces.asset', 'asset') - .where('asset.ownerId = :ownerId') - .orderBy(`faces.embedding <=> :embedding`) - .setParameters({ownerId, embedding: asVector(embedding)}) - .limit(numResults); + async searchByEmbedding({ + ownerId, + embedding, + numResults, + maxDistance, + }: EmbeddingSearch): Promise { + if (!isValidInteger(numResults, { min: 1 })) { + throw new Error(`Invalid value for 'numResults': ${numResults}`); + } - this.faceColumns.forEach((col) => cte.addSelect(`faces.${col} AS "${col}"`)); + let results: AssetFaceEntity[] = []; + this.assetRepository.manager.transaction(async (manager) => { + await manager.query(`SET LOCAL vectors.k = '${numResults}'`); + const cte = manager + .createQueryBuilder(AssetFaceEntity, 'faces') + .select('1 + (faces.embedding <=> :embedding)', 'distance') + .innerJoin('faces.asset', 'asset') + .where('asset.ownerId = :ownerId') + .orderBy(`faces.embedding <=> :embedding`) + .setParameters({ ownerId, embedding: asVector(embedding) }) + .limit(numResults); - const res = await dataSource.createQueryBuilder() - .select('res.*') - .addCommonTableExpression(cte, 'cte') - .from('cte', 'res') - .where('res.distance <= :maxDistance', { maxDistance }) - .getRawMany(); + this.faceColumns.forEach((col) => cte.addSelect(`faces.${col} AS "${col}"`)); - return this.assetFaceRepository.create(res); + results = await dataSource + .createQueryBuilder() + .select('res.*') + .addCommonTableExpression(cte, 'cte') + .from('cte', 'res') + .where('res.distance <= :maxDistance', { maxDistance }) + .getRawMany(); + }); + + return this.assetFaceRepository.create(results); } } diff --git a/server/src/infra/repositories/smart-info.repository.ts b/server/src/infra/repositories/smart-info.repository.ts index 7c91187cb..f19bd7f35 100644 --- a/server/src/infra/repositories/smart-info.repository.ts +++ b/server/src/infra/repositories/smart-info.repository.ts @@ -4,7 +4,7 @@ import { InjectRepository } from '@nestjs/typeorm'; import AsyncLock from 'async-lock'; import { Repository } from 'typeorm'; import { AssetEntity, SmartInfoEntity, SmartSearchEntity } from '../entities'; -import { asVector } from '../infra.utils'; +import { asVector, isValidInteger } from '../infra.utils'; @Injectable() export class SmartInfoRepository implements ISmartInfoRepository { @@ -21,23 +21,25 @@ export class SmartInfoRepository implements ISmartInfoRepository { } async searchByEmbedding({ ownerId, embedding, numResults }: EmbeddingSearch): Promise { - const query: string = this.assetRepository - .createQueryBuilder('a') - .innerJoin('a.smartSearch', 's') - .where('a.ownerId = :ownerId') - .leftJoinAndSelect('a.exifInfo', 'e') - .orderBy('s.embedding <=> :embedding') - .setParameters({ embedding: asVector(embedding), ownerId }) - .limit(numResults) - .getSql(); + if (!isValidInteger(numResults, { min: 1 })) { + throw new Error(`Invalid value for 'numResults': ${numResults}`); + } - const queryWithK = ` - BEGIN; - SET LOCAL vectors.k = ${numResults}; - ${query}; - COMMIT; - `; - return this.assetRepository.create(await this.assetRepository.manager.query(queryWithK)); + let results: AssetEntity[] = []; + this.assetRepository.manager.transaction(async (manager) => { + await manager.query(`SET LOCAL vectors.k = '${numResults}'`); + results = await manager + .createQueryBuilder(AssetEntity, 'a') + .innerJoin('a.smartSearch', 's') + .where('a.ownerId = :ownerId') + .leftJoinAndSelect('a.exifInfo', 'e') + .orderBy('s.embedding <=> :embedding') + .setParameters({ ownerId, embedding: asVector(embedding) }) + .limit(numResults) + .getMany(); + }); + + return results; } async upsert(smartInfo: Partial, embedding?: Embedding): Promise { @@ -70,6 +72,10 @@ export class SmartInfoRepository implements ISmartInfoRepository { * this does not parameterize the query because it is not possible to parameterize the column type */ private async updateDimSize(dimSize: number): Promise { + if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) { + throw new Error(`Invalid CLIP dimension size: ${dimSize}`); + } + await this.lock.acquire('updateDimSizeLock', async () => { if (this.curDimSize === dimSize) return;