浏览代码

fix: pass session uuid to websockets token

Karol Sójko 1 年之前
父节点
当前提交
bcd1d830e6

+ 1 - 0
packages/auth/src/Bootstrap/Container.ts

@@ -1195,6 +1195,7 @@ export class ContainerConfigLoader {
           container.get<GetRegularSubscriptionForUser>(TYPES.Auth_GetRegularSubscriptionForUser),
           container.get<GetSubscriptionSetting>(TYPES.Auth_GetSubscriptionSetting),
           container.get<SharedVaultUserRepositoryInterface>(TYPES.Auth_SharedVaultUserRepository),
+          container.get<GetActiveSessionsForUser>(TYPES.Auth_GetActiveSessionsForUser),
         ),
       )
     container.bind<ProcessUserRequest>(TYPES.Auth_ProcessUserRequest).to(ProcessUserRequest)

+ 69 - 0
packages/auth/src/Domain/UseCase/CreateCrossServiceToken/CreateCrossServiceToken.spec.ts

@@ -22,6 +22,7 @@ import { GetRegularSubscriptionForUser } from '../GetRegularSubscriptionForUser/
 import { UserSubscription } from '../../Subscription/UserSubscription'
 import { SubscriptionSetting } from '../../Setting/SubscriptionSetting'
 import { EncryptionVersion } from '../../Encryption/EncryptionVersion'
+import { GetActiveSessionsForUser } from '../GetActiveSessionsForUser'
 
 describe('CreateCrossServiceToken', () => {
   let userProjector: ProjectorInterface<User>
@@ -32,6 +33,7 @@ describe('CreateCrossServiceToken', () => {
   let getRegularSubscription: GetRegularSubscriptionForUser
   let getSubscriptionSetting: GetSubscriptionSetting
   let sharedVaultUserRepository: SharedVaultUserRepositoryInterface
+  let getActiveSessionsForUser: GetActiveSessionsForUser
   const jwtTTL = 60
 
   let session: Session
@@ -49,11 +51,15 @@ describe('CreateCrossServiceToken', () => {
       getRegularSubscription,
       getSubscriptionSetting,
       sharedVaultUserRepository,
+      getActiveSessionsForUser,
     )
 
   beforeEach(() => {
     session = {} as jest.Mocked<Session>
 
+    getActiveSessionsForUser = {} as jest.Mocked<GetActiveSessionsForUser>
+    getActiveSessionsForUser.execute = jest.fn().mockReturnValue({ sessions: [session] })
+
     user = {
       uuid: '00000000-0000-0000-0000-000000000000',
       email: 'test@test.te',
@@ -195,6 +201,69 @@ describe('CreateCrossServiceToken', () => {
     )
   })
 
+  it('should create a cross service token for a user and a specific session', async () => {
+    await createUseCase().execute({
+      userUuid: '00000000-0000-0000-0000-000000000000',
+      sessionUuid: '00000000-0000-0000-0000-000000000000',
+    })
+
+    expect(tokenEncoder.encodeExpirableToken).toHaveBeenCalledWith(
+      {
+        roles: [
+          {
+            name: 'role1',
+            uuid: '1-3-4',
+          },
+        ],
+        belongs_to_shared_vaults: [
+          {
+            shared_vault_uuid: '00000000-0000-0000-0000-000000000000',
+            permission: 'read',
+          },
+        ],
+        session: {
+          test: 'test',
+        },
+        user: {
+          email: 'test@test.te',
+          uuid: '00000000-0000-0000-0000-000000000000',
+        },
+      },
+      60,
+    )
+  })
+
+  it('should create a cross service token for a user and specific session if the session is missing', async () => {
+    getActiveSessionsForUser.execute = jest.fn().mockReturnValue({ sessions: [] })
+
+    await createUseCase().execute({
+      userUuid: '00000000-0000-0000-0000-000000000000',
+      sessionUuid: '00000000-0000-0000-0000-000000000000',
+    })
+
+    expect(tokenEncoder.encodeExpirableToken).toHaveBeenCalledWith(
+      {
+        roles: [
+          {
+            name: 'role1',
+            uuid: '1-3-4',
+          },
+        ],
+        belongs_to_shared_vaults: [
+          {
+            shared_vault_uuid: '00000000-0000-0000-0000-000000000000',
+            permission: 'read',
+          },
+        ],
+        user: {
+          email: 'test@test.te',
+          uuid: '00000000-0000-0000-0000-000000000000',
+        },
+      },
+      60,
+    )
+  })
+
   it('should throw an error if user does not exist', async () => {
     userRepository.findOneByUuid = jest.fn().mockReturnValue(null)
 

+ 10 - 0
packages/auth/src/Domain/UseCase/CreateCrossServiceToken/CreateCrossServiceToken.ts

@@ -11,6 +11,7 @@ import { CreateCrossServiceTokenDTO } from './CreateCrossServiceTokenDTO'
 import { SharedVaultUserRepositoryInterface } from '../../SharedVault/SharedVaultUserRepositoryInterface'
 import { GetSubscriptionSetting } from '../GetSubscriptionSetting/GetSubscriptionSetting'
 import { GetRegularSubscriptionForUser } from '../GetRegularSubscriptionForUser/GetRegularSubscriptionForUser'
+import { GetActiveSessionsForUser } from '../GetActiveSessionsForUser'
 
 export class CreateCrossServiceToken implements UseCaseInterface<string> {
   constructor(
@@ -23,6 +24,7 @@ export class CreateCrossServiceToken implements UseCaseInterface<string> {
     private getRegularSubscription: GetRegularSubscriptionForUser,
     private getSubscriptionSettingUseCase: GetSubscriptionSetting,
     private sharedVaultUserRepository: SharedVaultUserRepositoryInterface,
+    private getActiveSessions: GetActiveSessionsForUser,
   ) {}
 
   async execute(dto: CreateCrossServiceTokenDTO): Promise<Result<string>> {
@@ -84,6 +86,14 @@ export class CreateCrossServiceToken implements UseCaseInterface<string> {
 
     if (dto.session !== undefined) {
       authTokenData.session = this.projectSession(dto.session)
+    } else if (dto.sessionUuid !== undefined) {
+      const activeSessionsResponse = await this.getActiveSessions.execute({
+        userUuid: user.uuid,
+        sessionUuid: dto.sessionUuid,
+      })
+      if (activeSessionsResponse.sessions.length) {
+        authTokenData.session = this.projectSession(activeSessionsResponse.sessions[0])
+      }
     }
 
     return Result.ok(this.tokenEncoder.encodeExpirableToken(authTokenData, this.jwtTTL))

+ 1 - 0
packages/auth/src/Domain/UseCase/CreateCrossServiceToken/CreateCrossServiceTokenDTO.ts

@@ -10,5 +10,6 @@ export type CreateCrossServiceTokenDTO = Either<
   },
   {
     userUuid: string
+    sessionUuid?: string
   }
 >

+ 6 - 0
packages/auth/src/Domain/UseCase/GetActiveSessionsForUser.spec.ts

@@ -65,4 +65,10 @@ describe('GetActiveSessionsForUser', () => {
 
     expect(sessionRepository.findAllByRefreshExpirationAndUserUuid).toHaveBeenCalledWith('1-2-3')
   })
+
+  it('should get a single session for a user', async () => {
+    expect(await createUseCase().execute({ userUuid: '1-2-3', sessionUuid: '2-3-4' })).toEqual({
+      sessions: [session2],
+    })
+  })
 })

+ 20 - 6
packages/auth/src/Domain/UseCase/GetActiveSessionsForUser.ts

@@ -5,6 +5,7 @@ import { SessionRepositoryInterface } from '../Session/SessionRepositoryInterfac
 import { GetActiveSessionsForUserDTO } from './GetActiveSessionsForUserDTO'
 import { GetActiveSessionsForUserResponse } from './GetActiveSessionsForUserResponse'
 import { UseCaseInterface } from './UseCaseInterface'
+import { Session } from '../Session/Session'
 
 @injectable()
 export class GetActiveSessionsForUser implements UseCaseInterface {
@@ -18,13 +19,26 @@ export class GetActiveSessionsForUser implements UseCaseInterface {
     const ephemeralSessions = await this.ephemeralSessionRepository.findAllByUserUuid(dto.userUuid)
     const sessions = await this.sessionRepository.findAllByRefreshExpirationAndUserUuid(dto.userUuid)
 
-    return {
-      sessions: sessions.concat(ephemeralSessions).sort((a, b) => {
-        const dateA = a.refreshExpiration instanceof Date ? a.refreshExpiration : new Date(a.refreshExpiration)
-        const dateB = b.refreshExpiration instanceof Date ? b.refreshExpiration : new Date(b.refreshExpiration)
+    const activeSessions = sessions.concat(ephemeralSessions).sort((a, b) => {
+      const dateA = a.refreshExpiration instanceof Date ? a.refreshExpiration : new Date(a.refreshExpiration)
+      const dateB = b.refreshExpiration instanceof Date ? b.refreshExpiration : new Date(b.refreshExpiration)
+
+      return dateB.getTime() - dateA.getTime()
+    })
 
-        return dateB.getTime() - dateA.getTime()
-      }),
+    if (dto.sessionUuid) {
+      let sessions: Session[] = []
+      const session = activeSessions.find((session) => session.uuid === dto.sessionUuid)
+      if (session) {
+        sessions = [session]
+      }
+      return {
+        sessions,
+      }
+    }
+
+    return {
+      sessions: activeSessions,
     }
   }
 }

+ 1 - 0
packages/auth/src/Domain/UseCase/GetActiveSessionsForUserDTO.ts

@@ -1,3 +1,4 @@
 export type GetActiveSessionsForUserDTO = {
   userUuid: string
+  sessionUuid?: string
 }

+ 1 - 0
packages/security/src/Domain/Token/WebSocketConnectionToken.ts

@@ -1,3 +1,4 @@
 export type WebSocketConnectionTokenData = {
   userUuid: string
+  sessionUuid: string
 }

+ 2 - 2
packages/websockets/src/Domain/UseCase/CreateWebSocketConnectionToken/CreateWebSocketConnection.spec.ts

@@ -16,10 +16,10 @@ describe('CreateWebSocketConnection', () => {
   })
 
   it('should create a web socket connection token', async () => {
-    const result = await createUseCase().execute({ userUuid: '1-2-3' })
+    const result = await createUseCase().execute({ userUuid: '1-2-3', sessionUuid: '4-5-6' })
 
     expect(result.token).toEqual('foobar')
 
-    expect(tokenEncoder.encodeExpirableToken).toHaveBeenCalledWith({ userUuid: '1-2-3' }, 30)
+    expect(tokenEncoder.encodeExpirableToken).toHaveBeenCalledWith({ userUuid: '1-2-3', sessionUuid: '4-5-6' }, 30)
   })
 })

+ 1 - 0
packages/websockets/src/Domain/UseCase/CreateWebSocketConnectionToken/CreateWebSocketConnectionDTO.ts

@@ -1,3 +1,4 @@
 export type CreateWebSocketConnectionDTO = {
   userUuid: string
+  sessionUuid: string
 }

+ 1 - 0
packages/websockets/src/Domain/UseCase/CreateWebSocketConnectionToken/CreateWebSocketConnectionToken.ts

@@ -17,6 +17,7 @@ export class CreateWebSocketConnectionToken implements UseCaseInterface {
   async execute(dto: CreateWebSocketConnectionDTO): Promise<CreateWebSocketConnectionResponse> {
     const data: WebSocketConnectionTokenData = {
       userUuid: dto.userUuid,
+      sessionUuid: dto.sessionUuid,
     }
 
     return {

+ 1 - 0
packages/websockets/src/Infra/InversifyExpressUtils/AnnotatedWebSocketsController.ts

@@ -28,6 +28,7 @@ export class AnnotatedWebSocketsController extends BaseHttpController {
   async createConnectionToken(_request: Request, response: Response): Promise<results.JsonResult> {
     const result = await this.createWebSocketConnectionToken.execute({
       userUuid: response.locals.user.uuid,
+      sessionUuid: response.locals.session.uuid,
     })
 
     return this.json(result)