Bladeren bron

refactor(server): use UserService (#1309)

* refactor: communication gateway

* refactor: share strategy

* refactor: communication module
Jason Rasmussen 2 jaren geleden
bovenliggende
commit
92ca447f33

+ 10 - 48
server/apps/immich/src/api-v1/communication/communication.gateway.ts

@@ -1,70 +1,32 @@
-import { OnGatewayConnection, OnGatewayDisconnect, WebSocketGateway, WebSocketServer } from '@nestjs/websockets';
-import { Socket, Server } from 'socket.io';
-import { ImmichJwtService, JwtValidationResult } from '../../modules/immich-jwt/immich-jwt.service';
 import { Logger } from '@nestjs/common';
-import { InjectRepository } from '@nestjs/typeorm';
-import { UserEntity } from '@app/infra';
-import { Repository } from 'typeorm';
-import cookieParser from 'cookie';
-import { IMMICH_ACCESS_COOKIE } from '../../constants/jwt.constant';
+import { OnGatewayConnection, OnGatewayDisconnect, WebSocketGateway, WebSocketServer } from '@nestjs/websockets';
+import { Server, Socket } from 'socket.io';
+import { ImmichJwtService } from '../../modules/immich-jwt/immich-jwt.service';
 
 @WebSocketGateway({ cors: true })
 export class CommunicationGateway implements OnGatewayConnection, OnGatewayDisconnect {
-  constructor(
-    private immichJwtService: ImmichJwtService,
+  private logger = new Logger(CommunicationGateway.name);
 
-    @InjectRepository(UserEntity)
-    private userRepository: Repository<UserEntity>,
-  ) {}
+  constructor(private immichJwtService: ImmichJwtService) {}
 
   @WebSocketServer() server!: Server;
 
   handleDisconnect(client: Socket) {
     client.leave(client.nsp.name);
-
-    Logger.log(`Client ${client.id} disconnected from Websocket`, 'WebsocketConnectionEvent');
+    this.logger.log(`Client ${client.id} disconnected from Websocket`);
   }
 
   async handleConnection(client: Socket) {
     try {
-      Logger.log(`New websocket connection: ${client.id}`, 'WebsocketConnectionEvent');
-      let accessToken = '';
+      this.logger.log(`New websocket connection: ${client.id}`);
 
-      if (client.handshake.headers.cookie != undefined) {
-        const cookies = cookieParser.parse(client.handshake.headers.cookie);
-        if (cookies[IMMICH_ACCESS_COOKIE]) {
-          accessToken = cookies[IMMICH_ACCESS_COOKIE];
-        } else {
-          client.emit('error', 'unauthorized');
-          client.disconnect();
-          return;
-        }
-      } else if (client.handshake.headers.authorization != undefined) {
-        accessToken = client.handshake.headers.authorization.split(' ')[1];
+      const user = await this.immichJwtService.validateSocket(client);
+      if (user) {
+        client.join(user.id);
       } else {
         client.emit('error', 'unauthorized');
         client.disconnect();
-        return;
       }
-
-      const res: JwtValidationResult = accessToken
-        ? await this.immichJwtService.validateToken(accessToken)
-        : { status: false, userId: null };
-
-      if (!res.status || res.userId == null) {
-        client.emit('error', 'unauthorized');
-        client.disconnect();
-        return;
-      }
-
-      const user = await this.userRepository.findOne({ where: { id: res.userId } });
-      if (!user) {
-        client.emit('error', 'unauthorized');
-        client.disconnect();
-        return;
-      }
-
-      client.join(user.id);
     } catch (e) {
       // Logger.error(`Error establish websocket conneciton ${e}`, 'HandleWebscoketConnection');
     }

+ 2 - 8
server/apps/immich/src/api-v1/communication/communication.module.ts

@@ -1,16 +1,10 @@
 import { Module } from '@nestjs/common';
-import { CommunicationService } from './communication.service';
 import { CommunicationGateway } from './communication.gateway';
 import { ImmichJwtModule } from '../../modules/immich-jwt/immich-jwt.module';
-import { ImmichJwtService } from '../../modules/immich-jwt/immich-jwt.service';
-import { JwtModule } from '@nestjs/jwt';
-import { jwtConfig } from '../../config/jwt.config';
-import { TypeOrmModule } from '@nestjs/typeorm';
-import { UserEntity } from '@app/infra';
 
 @Module({
-  imports: [TypeOrmModule.forFeature([UserEntity]), ImmichJwtModule, JwtModule.register(jwtConfig)],
-  providers: [CommunicationGateway, CommunicationService, ImmichJwtService],
+  imports: [ImmichJwtModule],
+  providers: [CommunicationGateway],
   exports: [CommunicationGateway],
 })
 export class CommunicationModule {}

+ 0 - 4
server/apps/immich/src/api-v1/communication/communication.service.ts

@@ -1,4 +0,0 @@
-import { Injectable } from '@nestjs/common';
-
-@Injectable()
-export class CommunicationService {}

+ 8 - 16
server/apps/immich/src/api-v1/share/share.core.ts

@@ -32,7 +32,7 @@ export class ShareCore {
     }
   }
 
-  async getSharedLinks(userId: string): Promise<SharedLinkEntity[]> {
+  getSharedLinks(userId: string): Promise<SharedLinkEntity[]> {
     return this.sharedLinkRepository.get(userId);
   }
 
@@ -46,27 +46,19 @@ export class ShareCore {
     return await this.sharedLinkRepository.remove(link);
   }
 
-  async getSharedLinkById(id: string): Promise<SharedLinkEntity> {
-    const link = await this.sharedLinkRepository.getById(id);
-    if (!link) {
-      throw new BadRequestException('Shared link not found');
-    }
-
-    return link;
+  getSharedLinkById(id: string): Promise<SharedLinkEntity | null> {
+    return this.sharedLinkRepository.getById(id);
   }
 
-  async getSharedLinkByKey(key: string): Promise<SharedLinkEntity> {
-    const link = await this.sharedLinkRepository.getByKey(key);
-
-    if (!link) {
-      throw new BadRequestException();
-    }
-
-    return link;
+  getSharedLinkByKey(key: string): Promise<SharedLinkEntity | null> {
+    return this.sharedLinkRepository.getByKey(key);
   }
 
   async updateAssetsInSharedLink(sharedLinkId: string, assets: AssetEntity[]) {
     const link = await this.getSharedLinkById(sharedLinkId);
+    if (!link) {
+      throw new BadRequestException('Shared link not found');
+    }
 
     link.assets = assets;
 

+ 38 - 5
server/apps/immich/src/api-v1/share/share.service.ts

@@ -1,4 +1,12 @@
-import { ForbiddenException, Inject, Injectable, Logger } from '@nestjs/common';
+import {
+  BadRequestException,
+  ForbiddenException,
+  Inject,
+  Injectable,
+  Logger,
+  UnauthorizedException,
+} from '@nestjs/common';
+import { UserService } from '@app/domain';
 import { AuthUserDto } from '../../decorators/auth-user.decorator';
 import { EditSharedLinkDto } from './dto/edit-shared-link.dto';
 import { mapSharedLinkToResponseDto, SharedLinkResponseDto } from './response-dto/shared-link-response.dto';
@@ -13,9 +21,31 @@ export class ShareService {
   constructor(
     @Inject(ISharedLinkRepository)
     sharedLinkRepository: ISharedLinkRepository,
+    private userService: UserService,
   ) {
     this.shareCore = new ShareCore(sharedLinkRepository);
   }
+
+  async validate(key: string): Promise<AuthUserDto> {
+    const link = await this.shareCore.getSharedLinkByKey(key);
+    if (link) {
+      if (!link.expiresAt || new Date(link.expiresAt) > new Date()) {
+        const user = await this.userService.getUserById(link.userId).catch(() => null);
+        if (user) {
+          return {
+            id: user.id,
+            email: user.email,
+            isAdmin: user.isAdmin,
+            isPublicUser: true,
+            sharedLinkId: link.id,
+            isAllowUpload: link.allowUpload,
+          };
+        }
+      }
+    }
+    throw new UnauthorizedException();
+  }
+
   async getAll(authUser: AuthUserDto): Promise<SharedLinkResponseDto[]> {
     const links = await this.shareCore.getSharedLinks(authUser.id);
     return links.map(mapSharedLinkToResponseDto);
@@ -26,13 +56,14 @@ export class ShareService {
       throw new ForbiddenException();
     }
 
-    const link = await this.shareCore.getSharedLinkById(authUser.sharedLinkId);
-
-    return mapSharedLinkToResponseDto(link);
+    return this.getById(authUser.sharedLinkId);
   }
 
   async getById(id: string): Promise<SharedLinkResponseDto> {
     const link = await this.shareCore.getSharedLinkById(id);
+    if (!link) {
+      throw new BadRequestException('Shared link not found');
+    }
     return mapSharedLinkToResponseDto(link);
   }
 
@@ -43,12 +74,14 @@ export class ShareService {
 
   async getByKey(key: string): Promise<SharedLinkResponseDto> {
     const link = await this.shareCore.getSharedLinkByKey(key);
+    if (!link) {
+      throw new BadRequestException('Shared link not found');
+    }
     return mapSharedLinkToResponseDto(link);
   }
 
   async edit(id: string, authUser: AuthUserDto, dto: EditSharedLinkDto) {
     const link = await this.shareCore.updateSharedLink(id, authUser.id, dto);
-
     return mapSharedLinkToResponseDto(link);
   }
 }

+ 1 - 3
server/apps/immich/src/modules/immich-jwt/immich-jwt.module.ts

@@ -3,15 +3,13 @@ import { ImmichJwtService } from './immich-jwt.service';
 import { JwtModule } from '@nestjs/jwt';
 import { jwtConfig } from '../../config/jwt.config';
 import { JwtStrategy } from './strategies/jwt.strategy';
-import { TypeOrmModule } from '@nestjs/typeorm';
-import { UserEntity } from '@app/infra';
 import { APIKeyModule } from '../../api-v1/api-key/api-key.module';
 import { APIKeyStrategy } from './strategies/api-key.strategy';
 import { ShareModule } from '../../api-v1/share/share.module';
 import { PublicShareStrategy } from './strategies/public-share.strategy';
 
 @Module({
-  imports: [JwtModule.register(jwtConfig), TypeOrmModule.forFeature([UserEntity]), APIKeyModule, ShareModule],
+  imports: [JwtModule.register(jwtConfig), APIKeyModule, ShareModule],
   providers: [ImmichJwtService, JwtStrategy, APIKeyStrategy, PublicShareStrategy],
   exports: [ImmichJwtService],
 })

+ 12 - 6
server/apps/immich/src/modules/immich-jwt/immich-jwt.service.spec.ts

@@ -5,9 +5,11 @@ import { UserEntity } from '@app/infra';
 import { LoginResponseDto } from '../../api-v1/auth/response-dto/login-response.dto';
 import { AuthType } from '../../constants/jwt.constant';
 import { ImmichJwtService } from './immich-jwt.service';
+import { UserService } from '@app/domain';
 
 describe('ImmichJwtService', () => {
   let jwtServiceMock: jest.Mocked<JwtService>;
+  let userServiceMock: jest.Mocked<UserService>;
   let sut: ImmichJwtService;
 
   beforeEach(() => {
@@ -16,7 +18,11 @@ describe('ImmichJwtService', () => {
       verifyAsync: jest.fn(),
     } as unknown as jest.Mocked<JwtService>;
 
-    sut = new ImmichJwtService(jwtServiceMock);
+    userServiceMock = {
+      getUserById: jest.fn(),
+    } as unknown as jest.Mocked<UserService>;
+
+    sut = new ImmichJwtService(jwtServiceMock, userServiceMock);
   });
 
   afterEach(() => {
@@ -102,7 +108,7 @@ describe('ImmichJwtService', () => {
       const request = {
         headers: {},
       } as Request;
-      const token = sut.extractJwtFromHeader(request);
+      const token = sut.extractJwtFromHeader(request.headers);
       expect(token).toBe(null);
     });
 
@@ -119,15 +125,15 @@ describe('ImmichJwtService', () => {
         },
       } as Request;
 
-      expect(sut.extractJwtFromHeader(upper)).toBe('token');
-      expect(sut.extractJwtFromHeader(lower)).toBe('token');
+      expect(sut.extractJwtFromHeader(upper.headers)).toBe('token');
+      expect(sut.extractJwtFromHeader(lower.headers)).toBe('token');
     });
   });
 
   describe('extracJwtFromCookie', () => {
     it('should handle no cookie', () => {
       const request = {} as Request;
-      const token = sut.extractJwtFromCookie(request);
+      const token = sut.extractJwtFromCookie(request.cookies);
       expect(token).toBe(null);
     });
 
@@ -137,7 +143,7 @@ describe('ImmichJwtService', () => {
           immich_access_token: 'cookie',
         },
       } as Request;
-      const token = sut.extractJwtFromCookie(request);
+      const token = sut.extractJwtFromCookie(request.cookies);
       expect(token).toBe('cookie');
     });
   });

+ 33 - 12
server/apps/immich/src/modules/immich-jwt/immich-jwt.service.ts

@@ -1,10 +1,13 @@
 import { UserEntity } from '@app/infra';
 import { Injectable, Logger } from '@nestjs/common';
 import { JwtService } from '@nestjs/jwt';
-import { Request } from 'express';
+import { IncomingHttpHeaders } from 'http';
 import { JwtPayloadDto } from '../../api-v1/auth/dto/jwt-payload.dto';
 import { LoginResponseDto, mapLoginResponse } from '../../api-v1/auth/response-dto/login-response.dto';
 import { AuthType, IMMICH_ACCESS_COOKIE, IMMICH_AUTH_TYPE_COOKIE, jwtSecret } from '../../constants/jwt.constant';
+import { Socket } from 'socket.io';
+import cookieParser from 'cookie';
+import { UserResponseDto, UserService } from '@app/domain';
 
 export type JwtValidationResult = {
   status: boolean;
@@ -13,7 +16,7 @@ export type JwtValidationResult = {
 
 @Injectable()
 export class ImmichJwtService {
-  constructor(private jwtService: JwtService) {}
+  constructor(private jwtService: JwtService, private userService: UserService) {}
 
   public getCookieNames() {
     return [IMMICH_ACCESS_COOKIE, IMMICH_AUTH_TYPE_COOKIE];
@@ -51,20 +54,38 @@ export class ImmichJwtService {
     }
   }
 
-  public extractJwtFromHeader(req: Request) {
-    if (
-      req.headers.authorization &&
-      (req.headers.authorization.split(' ')[0] === 'Bearer' || req.headers.authorization.split(' ')[0] === 'bearer')
-    ) {
-      const accessToken = req.headers.authorization.split(' ')[1];
-      return accessToken;
+  public extractJwtFromHeader(headers: IncomingHttpHeaders) {
+    if (!headers.authorization) {
+      return null;
+    }
+    const [type, accessToken] = headers.authorization.split(' ');
+    if (type.toLowerCase() !== 'bearer') {
+      return null;
     }
 
-    return null;
+    return accessToken;
+  }
+
+  public extractJwtFromCookie(cookies: Record<string, string>) {
+    return cookies?.[IMMICH_ACCESS_COOKIE] || null;
   }
 
-  public extractJwtFromCookie(req: Request) {
-    return req.cookies?.[IMMICH_ACCESS_COOKIE] || null;
+  public async validateSocket(client: Socket): Promise<UserResponseDto | null> {
+    const headers = client.handshake.headers;
+    const accessToken =
+      this.extractJwtFromCookie(cookieParser.parse(headers.cookie || '')) || this.extractJwtFromHeader(headers);
+
+    if (accessToken) {
+      const { userId, status } = await this.validateToken(accessToken);
+      if (userId && status) {
+        const user = await this.userService.getUserById(userId).catch(() => null);
+        if (user) {
+          return user;
+        }
+      }
+    }
+
+    return null;
   }
 
   private async generateToken(payload: JwtPayloadDto) {

+ 5 - 12
server/apps/immich/src/modules/immich-jwt/strategies/jwt.strategy.ts

@@ -1,9 +1,7 @@
-import { UserEntity } from '@app/infra';
 import { Injectable, UnauthorizedException } from '@nestjs/common';
 import { PassportStrategy } from '@nestjs/passport';
-import { InjectRepository } from '@nestjs/typeorm';
 import { ExtractJwt, Strategy, StrategyOptions } from 'passport-jwt';
-import { Repository } from 'typeorm';
+import { UserService } from '@app/domain';
 import { JwtPayloadDto } from '../../../api-v1/auth/dto/jwt-payload.dto';
 import { jwtSecret } from '../../../constants/jwt.constant';
 import { AuthUserDto } from '../../../decorators/auth-user.decorator';
@@ -13,15 +11,11 @@ export const JWT_STRATEGY = 'jwt';
 
 @Injectable()
 export class JwtStrategy extends PassportStrategy(Strategy, JWT_STRATEGY) {
-  constructor(
-    @InjectRepository(UserEntity)
-    private usersRepository: Repository<UserEntity>,
-    immichJwtService: ImmichJwtService,
-  ) {
+  constructor(private userService: UserService, immichJwtService: ImmichJwtService) {
     super({
       jwtFromRequest: ExtractJwt.fromExtractors([
-        immichJwtService.extractJwtFromCookie,
-        immichJwtService.extractJwtFromHeader,
+        (req) => immichJwtService.extractJwtFromCookie(req.cookies),
+        (req) => immichJwtService.extractJwtFromHeader(req.headers),
       ]),
       ignoreExpiration: false,
       secretOrKey: jwtSecret,
@@ -30,8 +24,7 @@ export class JwtStrategy extends PassportStrategy(Strategy, JWT_STRATEGY) {
 
   async validate(payload: JwtPayloadDto): Promise<AuthUserDto> {
     const { userId } = payload;
-    const user = await this.usersRepository.findOne({ where: { id: userId } });
-
+    const user = await this.userService.getUserById(userId).catch(() => null);
     if (!user) {
       throw new UnauthorizedException('Failure to validate JWT payload');
     }

+ 3 - 33
server/apps/immich/src/modules/immich-jwt/strategies/public-share.strategy.ts

@@ -1,9 +1,6 @@
-import { UserEntity } from '@app/infra';
-import { Injectable, UnauthorizedException } from '@nestjs/common';
+import { Injectable } from '@nestjs/common';
 import { PassportStrategy } from '@nestjs/passport';
-import { InjectRepository } from '@nestjs/typeorm';
 import { IStrategyOptions, Strategy } from 'passport-http-header-strategy';
-import { Repository } from 'typeorm';
 import { ShareService } from '../../../api-v1/share/share.service';
 import { AuthUserDto } from '../../../decorators/auth-user.decorator';
 
@@ -16,38 +13,11 @@ const options: IStrategyOptions = {
 
 @Injectable()
 export class PublicShareStrategy extends PassportStrategy(Strategy, PUBLIC_SHARE_STRATEGY) {
-  constructor(
-    private shareService: ShareService,
-    @InjectRepository(UserEntity)
-    private usersRepository: Repository<UserEntity>,
-  ) {
+  constructor(private shareService: ShareService) {
     super(options);
   }
 
   async validate(key: string): Promise<AuthUserDto> {
-    const validatedLink = await this.shareService.getByKey(key);
-
-    if (validatedLink.expiresAt) {
-      const now = new Date().getTime();
-      const expiresAt = new Date(validatedLink.expiresAt).getTime();
-
-      if (now > expiresAt) {
-        throw new UnauthorizedException('Expired link');
-      }
-    }
-
-    const user = await this.usersRepository.findOne({ where: { id: validatedLink.userId } });
-
-    if (!user) {
-      throw new UnauthorizedException('Failure to validate public share payload');
-    }
-
-    let publicUser = new AuthUserDto();
-    publicUser = user;
-    publicUser.isPublicUser = true;
-    publicUser.sharedLinkId = validatedLink.id;
-    publicUser.isAllowUpload = validatedLink.allowUpload;
-
-    return publicUser;
+    return this.shareService.validate(key);
   }
 }