瀏覽代碼

fixed parameterization

mertalev 1 年之前
父節點
當前提交
6894c89644

+ 7 - 2
server/src/infra/infra.utils.ts

@@ -33,5 +33,10 @@ export async function paginate<Entity extends ObjectLiteral>(
   return { items, hasNextPage };
 }
 
-export const asVector = (embedding: number[], escape = false) =>
-  escape ? `'[${embedding.join(',')}]'` : `[${embedding.join(',')}]`;
+export const asVector = (embedding: number[], quote = false) =>
+  quote ? `'[${embedding.join(',')}]'` : `[${embedding.join(',')}]`;
+
+export const isValidInteger = (value: number, options: {min?: number, max?: number}): boolean => {
+  const { min = Number.MIN_SAFE_INTEGER, max = Number.MAX_SAFE_INTEGER } = options;
+  return Number.isInteger(value) && value >= min && value <= max;
+}

+ 37 - 22
server/src/infra/repositories/person.repository.ts

@@ -9,10 +9,10 @@ import {
 } from '@app/domain';
 import { InjectRepository } from '@nestjs/typeorm';
 import { In, Repository } from 'typeorm';
+import { dataSource } from '..';
 import { AssetEntity, AssetFaceEntity, PersonEntity } from '../entities';
 import { DummyValue, GenerateSql } from '../infra.util';
-import { asVector } from '../infra.utils';
-import { dataSource } from '..';
+import { asVector, isValidInteger } from '../infra.utils';
 
 export class PersonRepository implements IPersonRepository {
   private readonly faceColumns: string[];
@@ -22,8 +22,8 @@ export class PersonRepository implements IPersonRepository {
     @InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
   ) {
     this.faceColumns = this.assetFaceRepository.manager.connection
-      .getMetadata(AssetFaceEntity).ownColumns
-      .map((column) => column.propertyName)
+      .getMetadata(AssetFaceEntity)
+      .ownColumns.map((column) => column.propertyName)
       .filter((propertyName) => propertyName !== 'embedding');
   }
 
@@ -248,24 +248,39 @@ export class PersonRepository implements IPersonRepository {
     return this.assetFaceRepository.findOneBy({ personId });
   }
 
-  async searchByEmbedding({ ownerId, embedding, numResults, maxDistance }: EmbeddingSearch): Promise<AssetFaceEntity[]> {
-    const cte = this.assetFaceRepository.createQueryBuilder('faces')
-      .select('1 + (faces.embedding <=> :embedding)', 'distance')
-      .innerJoin('faces.asset', 'asset')
-      .where('asset.ownerId = :ownerId')
-      .orderBy(`faces.embedding <=> :embedding`)
-      .setParameters({ownerId, embedding: asVector(embedding)})
-      .limit(numResults);
-
-    this.faceColumns.forEach((col) => cte.addSelect(`faces.${col} AS "${col}"`));
-
-    const res = await dataSource.createQueryBuilder()
-      .select('res.*')
-      .addCommonTableExpression(cte, 'cte')
-      .from('cte', 'res')
-      .where('res.distance <= :maxDistance', { maxDistance })
-      .getRawMany();
+  async searchByEmbedding({
+    ownerId,
+    embedding,
+    numResults,
+    maxDistance,
+  }: EmbeddingSearch): Promise<AssetFaceEntity[]> {
+    if (!isValidInteger(numResults, { min: 1 })) {
+      throw new Error(`Invalid value for 'numResults': ${numResults}`);
+    }
+
+    let results: AssetFaceEntity[] = [];
+    this.assetRepository.manager.transaction(async (manager) => {
+      await manager.query(`SET LOCAL vectors.k = '${numResults}'`);
+      const cte = manager
+        .createQueryBuilder(AssetFaceEntity, 'faces')
+        .select('1 + (faces.embedding <=> :embedding)', 'distance')
+        .innerJoin('faces.asset', 'asset')
+        .where('asset.ownerId = :ownerId')
+        .orderBy(`faces.embedding <=> :embedding`)
+        .setParameters({ ownerId, embedding: asVector(embedding) })
+        .limit(numResults);
+
+      this.faceColumns.forEach((col) => cte.addSelect(`faces.${col} AS "${col}"`));
+
+      results = await dataSource
+        .createQueryBuilder()
+        .select('res.*')
+        .addCommonTableExpression(cte, 'cte')
+        .from('cte', 'res')
+        .where('res.distance <= :maxDistance', { maxDistance })
+        .getRawMany();
+    });
 
-    return this.assetFaceRepository.create(res);
+    return this.assetFaceRepository.create(results);
   }
 }

+ 24 - 18
server/src/infra/repositories/smart-info.repository.ts

@@ -4,7 +4,7 @@ import { InjectRepository } from '@nestjs/typeorm';
 import AsyncLock from 'async-lock';
 import { Repository } from 'typeorm';
 import { AssetEntity, SmartInfoEntity, SmartSearchEntity } from '../entities';
-import { asVector } from '../infra.utils';
+import { asVector, isValidInteger } from '../infra.utils';
 
 @Injectable()
 export class SmartInfoRepository implements ISmartInfoRepository {
@@ -21,23 +21,25 @@ export class SmartInfoRepository implements ISmartInfoRepository {
   }
 
   async searchByEmbedding({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
-    const query: string = this.assetRepository
-      .createQueryBuilder('a')
-      .innerJoin('a.smartSearch', 's')
-      .where('a.ownerId = :ownerId')
-      .leftJoinAndSelect('a.exifInfo', 'e')
-      .orderBy('s.embedding <=> :embedding')
-      .setParameters({ embedding: asVector(embedding), ownerId })
-      .limit(numResults)
-      .getSql();
-
-    const queryWithK = `
-      BEGIN;
-      SET LOCAL vectors.k = ${numResults};
-      ${query};
-      COMMIT;
-    `;
-    return this.assetRepository.create(await this.assetRepository.manager.query(queryWithK));
+    if (!isValidInteger(numResults, { min: 1 })) {
+      throw new Error(`Invalid value for 'numResults': ${numResults}`);
+    }
+
+    let results: AssetEntity[] = [];
+    this.assetRepository.manager.transaction(async (manager) => {
+      await manager.query(`SET LOCAL vectors.k = '${numResults}'`);
+      results = await manager
+        .createQueryBuilder(AssetEntity, 'a')
+        .innerJoin('a.smartSearch', 's')
+        .where('a.ownerId = :ownerId')
+        .leftJoinAndSelect('a.exifInfo', 'e')
+        .orderBy('s.embedding <=> :embedding')
+        .setParameters({ ownerId, embedding: asVector(embedding) })
+        .limit(numResults)
+        .getMany();
+    });
+
+    return results;
   }
 
   async upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void> {
@@ -70,6 +72,10 @@ export class SmartInfoRepository implements ISmartInfoRepository {
    * this does not parameterize the query because it is not possible to parameterize the column type
    */
   private async updateDimSize(dimSize: number): Promise<void> {
+    if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
+      throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
+    }
+
     await this.lock.acquire('updateDimSizeLock', async () => {
       if (this.curDimSize === dimSize) return;