Pārlūkot izejas kodu

move prefilter logic to function

mertalev 1 gadu atpakaļ
vecāks
revīzija
13d7222a72

+ 2 - 5
server/src/immich/main.ts

@@ -1,5 +1,5 @@
 import { envName, getLogLevels, isDev, serverVersion } from '@app/domain';
-import { RedisIoAdapter, dataSource } from '@app/infra';
+import { RedisIoAdapter, enablePrefilter } from '@app/infra';
 import { Logger } from '@nestjs/common';
 import { NestFactory } from '@nestjs/core';
 import { NestExpressApplication } from '@nestjs/platform-express';
@@ -29,10 +29,7 @@ export async function bootstrap() {
   app.useStaticAssets('www');
   app.use(indexFallback(excludePaths));
 
-  if (!dataSource.isInitialized) {
-    await dataSource.initialize();
-  }
-  await dataSource.query(`SET vectors.enable_prefilter = on`);
+  await enablePrefilter();
 
   const server = await app.listen(port);
   server.requestTimeout = 30 * 60 * 1000;

+ 7 - 0
server/src/infra/database.config.ts

@@ -25,3 +25,10 @@ export const databaseConfig: PostgresConnectionOptions = {
 
 // this export is used by TypeORM commands in package.json#scripts
 export const dataSource = new DataSource(databaseConfig);
+
+export async function enablePrefilter() {
+  if (!dataSource.isInitialized) {
+    await dataSource.initialize();
+  }
+  await dataSource.query(`SET vectors.enable_prefilter = on`);
+}

+ 15 - 10
server/src/infra/repositories/smart-info.repository.ts

@@ -1,16 +1,12 @@
-import {
-  Embedding,
-  EmbeddingSearch,
-  ISmartInfoRepository,
-} from '@app/domain';
+import { Embedding, EmbeddingSearch, ISmartInfoRepository } from '@app/domain';
 import { getCLIPModelInfo } from '@app/domain/smart-info/smart-info.constant';
 import { DatabaseLock, RequireLock, asyncLock } from '@app/infra';
 import { AssetEntity, AssetFaceEntity, SmartInfoEntity, SmartSearchEntity } from '@app/infra/entities';
 import { Injectable, Logger } from '@nestjs/common';
 import { InjectRepository } from '@nestjs/typeorm';
 import { Repository } from 'typeorm';
-import { asVector, isValidInteger } from '../infra.utils';
 import { DummyValue, GenerateSql } from '../infra.util';
+import { asVector, isValidInteger } from '../infra.utils';
 
 @Injectable()
 export class SmartInfoRepository implements ISmartInfoRepository {
@@ -38,7 +34,9 @@ export class SmartInfoRepository implements ISmartInfoRepository {
     }
   }
 
-  @GenerateSql({ params: [{ ownerId: DummyValue.UUID, embedding: Array.from({ length: 512 }, Math.random), numResults: 100 }] })
+  @GenerateSql({
+    params: [{ ownerId: DummyValue.UUID, embedding: Array.from({ length: 512 }, Math.random), numResults: 100 }],
+  })
   async searchCLIP({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
     if (!isValidInteger(numResults, { min: 1 })) {
       throw new Error(`Invalid value for 'numResults': ${numResults}`);
@@ -61,7 +59,16 @@ export class SmartInfoRepository implements ISmartInfoRepository {
     return results;
   }
 
-  @GenerateSql({ params: [{ ownerId: DummyValue.UUID, embedding: Array.from({ length: 512 }, Math.random), numResults: 100, maxDistance: 0.6 }] })
+  @GenerateSql({
+    params: [
+      {
+        ownerId: DummyValue.UUID,
+        embedding: Array.from({ length: 512 }, Math.random),
+        numResults: 100,
+        maxDistance: 0.6,
+      },
+    ],
+  })
   async searchFaces({ ownerId, embedding, numResults, maxDistance }: EmbeddingSearch): Promise<AssetFaceEntity[]> {
     if (!isValidInteger(numResults, { min: 1 })) {
       throw new Error(`Invalid value for 'numResults': ${numResults}`);
@@ -79,8 +86,6 @@ export class SmartInfoRepository implements ISmartInfoRepository {
         .setParameters({ ownerId, embedding: asVector(embedding) })
         .limit(numResults);
 
-      // this.faceColumns.forEach((col) => cte.addSelect(`faces.${col} AS "${col}"`));
-
       results = await manager
         .createQueryBuilder()
         .select('res.*')

+ 1 - 1
server/src/infra/sql-generator/index.ts

@@ -1,3 +1,4 @@
+import { ISystemConfigRepository } from '@app/domain';
 import { INestApplication } from '@nestjs/common';
 import { Reflector } from '@nestjs/core';
 import { Test } from '@nestjs/testing';
@@ -26,7 +27,6 @@ import {
   UserTokenRepository,
 } from '../repositories';
 import { SqlLogger } from './sql.logger';
-import { ISystemConfigRepository } from '@app/domain';
 
 const reflector = new Reflector();
 const repositories = [

+ 2 - 1
server/src/microservices/main.ts

@@ -1,5 +1,5 @@
 import { envName, getLogLevels, serverVersion } from '@app/domain';
-import { RedisIoAdapter } from '@app/infra';
+import { RedisIoAdapter, enablePrefilter } from '@app/infra';
 import { Logger } from '@nestjs/common';
 import { NestFactory } from '@nestjs/core';
 import { AppService } from './app.service';
@@ -12,6 +12,7 @@ export async function bootstrap() {
   const app = await NestFactory.create(MicroservicesModule, { logger: getLogLevels() });
 
   app.useWebSocketAdapter(new RedisIoAdapter(app));
+  await enablePrefilter();
 
   await app.get(AppService).init();
   await app.listen(port);