Prechádzať zdrojové kódy

fix(auth): move tracing sessions to session creation instead of cross service token creation

Karol Sójko 2 rokov pred
rodič
commit
5255cfbb25

+ 139 - 0
packages/auth/src/Domain/Session/SessionService.spec.ts

@@ -15,6 +15,10 @@ import { SettingServiceInterface } from '../Setting/SettingServiceInterface'
 import { LogSessionUserAgentOption } from '@standardnotes/settings'
 import { Setting } from '../Setting/Setting'
 import { CryptoNode } from '@standardnotes/sncrypto-node'
+import { UserSubscriptionRepositoryInterface } from '../Subscription/UserSubscriptionRepositoryInterface'
+import { TraceSession } from '../UseCase/TraceSession/TraceSession'
+import { UserSubscription } from '../Subscription/UserSubscription'
+import { Result } from '@standardnotes/domain-core'
 
 describe('SessionService', () => {
   let sessionRepository: SessionRepositoryInterface
@@ -28,6 +32,8 @@ describe('SessionService', () => {
   let timer: TimerInterface
   let logger: winston.Logger
   let cryptoNode: CryptoNode
+  let traceSession: TraceSession
+  let userSubscriptionRepository: UserSubscriptionRepositoryInterface
 
   const createService = () =>
     new SessionService(
@@ -41,6 +47,8 @@ describe('SessionService', () => {
       234,
       settingService,
       cryptoNode,
+      traceSession,
+      userSubscriptionRepository,
     )
 
   beforeEach(() => {
@@ -106,6 +114,14 @@ describe('SessionService', () => {
     cryptoNode = {} as jest.Mocked<CryptoNode>
     cryptoNode.generateRandomKey = jest.fn().mockReturnValue('foo bar')
     cryptoNode.base64URLEncode = jest.fn().mockReturnValue('foobar')
+
+    traceSession = {} as jest.Mocked<TraceSession>
+    traceSession.execute = jest.fn()
+
+    userSubscriptionRepository = {} as jest.Mocked<UserSubscriptionRepositoryInterface>
+    userSubscriptionRepository.findOneByUserUuid = jest.fn().mockReturnValue({
+      planName: 'PRO_PLAN',
+    } as jest.Mocked<UserSubscription>)
   })
 
   it('should mark a revoked session as received', async () => {
@@ -204,6 +220,129 @@ describe('SessionService', () => {
     })
   })
 
+  it('should trace a session', async () => {
+    const user = {} as jest.Mocked<User>
+    user.uuid = '123'
+    user.email = 'test@test.te'
+
+    await createService().createNewSessionForUser({
+      user,
+      apiVersion: '003',
+      userAgent: 'Google Chrome',
+      readonlyAccess: false,
+    })
+
+    expect(traceSession.execute).toHaveBeenCalledWith({
+      userUuid: '123',
+      username: 'test@test.te',
+      subscriptionPlanName: 'PRO_PLAN',
+    })
+  })
+
+  it('should trace a session without a subscription', async () => {
+    userSubscriptionRepository.findOneByUserUuid = jest.fn().mockReturnValue(null)
+    const user = {} as jest.Mocked<User>
+    user.uuid = '123'
+    user.email = 'test@test.te'
+
+    await createService().createNewSessionForUser({
+      user,
+      apiVersion: '003',
+      userAgent: 'Google Chrome',
+      readonlyAccess: false,
+    })
+
+    expect(traceSession.execute).toHaveBeenCalledWith({
+      userUuid: '123',
+      username: 'test@test.te',
+      subscriptionPlanName: null,
+    })
+  })
+
+  it('should create a session if tracing session throws an error', async () => {
+    traceSession.execute = jest.fn().mockRejectedValue(new Error('foo bar'))
+    userSubscriptionRepository.findOneByUserUuid = jest.fn().mockReturnValue(null)
+    const user = {} as jest.Mocked<User>
+    user.uuid = '123'
+    user.email = 'test@test.te'
+
+    const sessionPayload = await createService().createNewSessionForUser({
+      user,
+      apiVersion: '003',
+      userAgent: 'Google Chrome',
+      readonlyAccess: false,
+    })
+
+    expect(traceSession.execute).toHaveBeenCalledWith({
+      userUuid: '123',
+      username: 'test@test.te',
+      subscriptionPlanName: null,
+    })
+    expect(sessionPayload).toEqual({
+      access_expiration: 123,
+      access_token: expect.any(String),
+      refresh_expiration: 123,
+      refresh_token: expect.any(String),
+      readonly_access: false,
+    })
+  })
+
+  it('should create a session if tracing session throws an error', async () => {
+    traceSession.execute = jest.fn().mockRejectedValue(new Error('foo bar'))
+    userSubscriptionRepository.findOneByUserUuid = jest.fn().mockReturnValue(null)
+    const user = {} as jest.Mocked<User>
+    user.uuid = '123'
+    user.email = 'test@test.te'
+
+    const sessionPayload = await createService().createNewSessionForUser({
+      user,
+      apiVersion: '003',
+      userAgent: 'Google Chrome',
+      readonlyAccess: false,
+    })
+
+    expect(traceSession.execute).toHaveBeenCalledWith({
+      userUuid: '123',
+      username: 'test@test.te',
+      subscriptionPlanName: null,
+    })
+    expect(sessionPayload).toEqual({
+      access_expiration: 123,
+      access_token: expect.any(String),
+      refresh_expiration: 123,
+      refresh_token: expect.any(String),
+      readonly_access: false,
+    })
+  })
+
+  it('should create a session if tracing session fails', async () => {
+    traceSession.execute = jest.fn().mockReturnValue(Result.fail('Oops'))
+    userSubscriptionRepository.findOneByUserUuid = jest.fn().mockReturnValue(null)
+    const user = {} as jest.Mocked<User>
+    user.uuid = '123'
+    user.email = 'test@test.te'
+
+    const sessionPayload = await createService().createNewSessionForUser({
+      user,
+      apiVersion: '003',
+      userAgent: 'Google Chrome',
+      readonlyAccess: false,
+    })
+
+    expect(traceSession.execute).toHaveBeenCalledWith({
+      userUuid: '123',
+      username: 'test@test.te',
+      subscriptionPlanName: null,
+    })
+    expect(sessionPayload).toEqual({
+      access_expiration: 123,
+      access_token: expect.any(String),
+      refresh_expiration: 123,
+      refresh_token: expect.any(String),
+      readonly_access: false,
+    })
+  })
+
   it('should create new ephemeral session for a user', async () => {
     const user = {} as jest.Mocked<User>
     user.uuid = '123'

+ 24 - 6
packages/auth/src/Domain/Session/SessionService.ts

@@ -1,10 +1,14 @@
 import * as crypto from 'crypto'
-import * as winston from 'winston'
 import * as dayjs from 'dayjs'
 import { UAParser } from 'ua-parser-js'
 import { inject, injectable } from 'inversify'
 import { v4 as uuidv4 } from 'uuid'
 import { TimerInterface } from '@standardnotes/time'
+import { Logger } from 'winston'
+import { LogSessionUserAgentOption, SettingName } from '@standardnotes/settings'
+import { SessionBody } from '@standardnotes/responses'
+import { Uuid } from '@standardnotes/common'
+import { CryptoNode } from '@standardnotes/sncrypto-node'
 
 import TYPES from '../../Bootstrap/Types'
 import { Session } from './Session'
@@ -16,10 +20,8 @@ import { EphemeralSession } from './EphemeralSession'
 import { RevokedSession } from './RevokedSession'
 import { RevokedSessionRepositoryInterface } from './RevokedSessionRepositoryInterface'
 import { SettingServiceInterface } from '../Setting/SettingServiceInterface'
-import { LogSessionUserAgentOption, SettingName } from '@standardnotes/settings'
-import { SessionBody } from '@standardnotes/responses'
-import { Uuid } from '@standardnotes/common'
-import { CryptoNode } from '@standardnotes/sncrypto-node'
+import { TraceSession } from '../UseCase/TraceSession/TraceSession'
+import { UserSubscriptionRepositoryInterface } from '../Subscription/UserSubscriptionRepositoryInterface'
 
 @injectable()
 export class SessionService implements SessionServiceInterface {
@@ -31,11 +33,13 @@ export class SessionService implements SessionServiceInterface {
     @inject(TYPES.RevokedSessionRepository) private revokedSessionRepository: RevokedSessionRepositoryInterface,
     @inject(TYPES.DeviceDetector) private deviceDetector: UAParser,
     @inject(TYPES.Timer) private timer: TimerInterface,
-    @inject(TYPES.Logger) private logger: winston.Logger,
+    @inject(TYPES.Logger) private logger: Logger,
     @inject(TYPES.ACCESS_TOKEN_AGE) private accessTokenAge: number,
     @inject(TYPES.REFRESH_TOKEN_AGE) private refreshTokenAge: number,
     @inject(TYPES.SettingService) private settingService: SettingServiceInterface,
     @inject(TYPES.CryptoNode) private cryptoNode: CryptoNode,
+    @inject(TYPES.TraceSession) private traceSession: TraceSession,
+    @inject(TYPES.UserSubscriptionRepository) private userSubscriptionRepository: UserSubscriptionRepositoryInterface,
   ) {}
 
   async createNewSessionForUser(dto: {
@@ -53,6 +57,20 @@ export class SessionService implements SessionServiceInterface {
 
     await this.sessionRepository.save(session)
 
+    try {
+      const userSubscription = await this.userSubscriptionRepository.findOneByUserUuid(dto.user.uuid)
+      const traceSessionResult = await this.traceSession.execute({
+        userUuid: dto.user.uuid,
+        username: dto.user.email,
+        subscriptionPlanName: userSubscription ? userSubscription.planName : null,
+      })
+      if (traceSessionResult.isFailed()) {
+        this.logger.error(traceSessionResult.getError())
+      }
+    } catch (error) {
+      this.logger.error(`Could not trace session while creating cross service token.: ${(error as Error).message}`)
+    }
+
     return sessionPayload
   }
 

+ 1 - 158
packages/auth/src/Domain/UseCase/CreateCrossServiceToken/CreateCrossServiceToken.spec.ts

@@ -8,10 +8,6 @@ import { Role } from '../../Role/Role'
 import { UserRepositoryInterface } from '../../User/UserRepositoryInterface'
 
 import { CreateCrossServiceToken } from './CreateCrossServiceToken'
-import { RoleToSubscriptionMapInterface } from '../../Role/RoleToSubscriptionMapInterface'
-import { TraceSession } from '../TraceSession/TraceSession'
-import { Logger } from 'winston'
-import { Result, RoleName, SubscriptionPlanName } from '@standardnotes/domain-core'
 
 describe('CreateCrossServiceToken', () => {
   let userProjector: ProjectorInterface<User>
@@ -19,9 +15,6 @@ describe('CreateCrossServiceToken', () => {
   let roleProjector: ProjectorInterface<Role>
   let tokenEncoder: TokenEncoderInterface<CrossServiceTokenData>
   let userRepository: UserRepositoryInterface
-  let roleToSubscriptionMap: RoleToSubscriptionMapInterface
-  let traceSession: TraceSession
-  let logger: Logger
   const jwtTTL = 60
 
   let session: Session
@@ -29,17 +22,7 @@ describe('CreateCrossServiceToken', () => {
   let role: Role
 
   const createUseCase = () =>
-    new CreateCrossServiceToken(
-      userProjector,
-      sessionProjector,
-      roleProjector,
-      tokenEncoder,
-      userRepository,
-      jwtTTL,
-      roleToSubscriptionMap,
-      traceSession,
-      logger,
-    )
+    new CreateCrossServiceToken(userProjector, sessionProjector, roleProjector, tokenEncoder, userRepository, jwtTTL)
 
   beforeEach(() => {
     session = {} as jest.Mocked<Session>
@@ -65,19 +48,6 @@ describe('CreateCrossServiceToken', () => {
 
     userRepository = {} as jest.Mocked<UserRepositoryInterface>
     userRepository.findOneByUuid = jest.fn().mockReturnValue(user)
-
-    roleToSubscriptionMap = {} as jest.Mocked<RoleToSubscriptionMapInterface>
-    roleToSubscriptionMap.filterSubscriptionRoles = jest.fn().mockReturnValue([RoleName.NAMES.PlusUser])
-    roleToSubscriptionMap.getSubscriptionNameForRoleName = jest
-      .fn()
-      .mockReturnValue(SubscriptionPlanName.NAMES.PlusPlan)
-
-    traceSession = {} as jest.Mocked<TraceSession>
-    traceSession.execute = jest.fn()
-
-    logger = {} as jest.Mocked<Logger>
-    logger.error = jest.fn()
-    logger.debug = jest.fn()
   })
 
   it('should create a cross service token for user', async () => {
@@ -86,11 +56,6 @@ describe('CreateCrossServiceToken', () => {
       session,
     })
 
-    expect(traceSession.execute).toHaveBeenCalledWith({
-      userUuid: '1-2-3',
-      username: 'test@test.te',
-      subscriptionPlanName: 'PLUS_PLAN',
-    })
     expect(tokenEncoder.encodeExpirableToken).toHaveBeenCalledWith(
       {
         roles: [
@@ -169,126 +134,4 @@ describe('CreateCrossServiceToken', () => {
 
     expect(caughtError).not.toBeNull()
   })
-
-  it('should trace session without a subscription role', async () => {
-    roleToSubscriptionMap.filterSubscriptionRoles = jest.fn().mockReturnValue([])
-
-    await createUseCase().execute({
-      user,
-      session,
-    })
-
-    expect(traceSession.execute).toHaveBeenCalledWith({
-      userUuid: '1-2-3',
-      username: 'test@test.te',
-      subscriptionPlanName: null,
-    })
-    expect(tokenEncoder.encodeExpirableToken).toHaveBeenCalledWith(
-      {
-        roles: [
-          {
-            name: 'role1',
-            uuid: '1-3-4',
-          },
-        ],
-        session: {
-          test: 'test',
-        },
-        user: {
-          email: 'test@test.te',
-          uuid: '1-2-3',
-        },
-      },
-      60,
-    )
-  })
-
-  it('should trace session without a subscription', async () => {
-    roleToSubscriptionMap.getSubscriptionNameForRoleName = jest.fn().mockReturnValue(undefined)
-
-    await createUseCase().execute({
-      user,
-      session,
-    })
-
-    expect(traceSession.execute).toHaveBeenCalledWith({
-      userUuid: '1-2-3',
-      username: 'test@test.te',
-      subscriptionPlanName: null,
-    })
-    expect(tokenEncoder.encodeExpirableToken).toHaveBeenCalledWith(
-      {
-        roles: [
-          {
-            name: 'role1',
-            uuid: '1-3-4',
-          },
-        ],
-        session: {
-          test: 'test',
-        },
-        user: {
-          email: 'test@test.te',
-          uuid: '1-2-3',
-        },
-      },
-      60,
-    )
-  })
-
-  it('should create token if tracing session throws an error', async () => {
-    traceSession.execute = jest.fn().mockRejectedValue(new Error('test'))
-
-    await createUseCase().execute({
-      user,
-      session,
-    })
-
-    expect(tokenEncoder.encodeExpirableToken).toHaveBeenCalledWith(
-      {
-        roles: [
-          {
-            name: 'role1',
-            uuid: '1-3-4',
-          },
-        ],
-        session: {
-          test: 'test',
-        },
-        user: {
-          email: 'test@test.te',
-          uuid: '1-2-3',
-        },
-      },
-      60,
-    )
-  })
-
-  it('should create token if tracing session fails', async () => {
-    traceSession.execute = jest.fn().mockReturnValue(Result.fail('Ooops'))
-
-    await createUseCase().execute({
-      user,
-      session,
-    })
-
-    expect(tokenEncoder.encodeExpirableToken).toHaveBeenCalledWith(
-      {
-        roles: [
-          {
-            name: 'role1',
-            uuid: '1-3-4',
-          },
-        ],
-        session: {
-          test: 'test',
-        },
-        user: {
-          email: 'test@test.te',
-          uuid: '1-2-3',
-        },
-      },
-      60,
-    )
-  })
 })

+ 0 - 32
packages/auth/src/Domain/UseCase/CreateCrossServiceToken/CreateCrossServiceToken.ts

@@ -1,16 +1,13 @@
 import { RoleName } from '@standardnotes/common'
 import { TokenEncoderInterface, CrossServiceTokenData } from '@standardnotes/security'
 import { inject, injectable } from 'inversify'
-import { Logger } from 'winston'
 
 import TYPES from '../../../Bootstrap/Types'
 import { ProjectorInterface } from '../../../Projection/ProjectorInterface'
 import { Role } from '../../Role/Role'
-import { RoleToSubscriptionMapInterface } from '../../Role/RoleToSubscriptionMapInterface'
 import { Session } from '../../Session/Session'
 import { User } from '../../User/User'
 import { UserRepositoryInterface } from '../../User/UserRepositoryInterface'
-import { TraceSession } from '../TraceSession/TraceSession'
 import { UseCaseInterface } from '../UseCaseInterface'
 
 import { CreateCrossServiceTokenDTO } from './CreateCrossServiceTokenDTO'
@@ -25,9 +22,6 @@ export class CreateCrossServiceToken implements UseCaseInterface {
     @inject(TYPES.CrossServiceTokenEncoder) private tokenEncoder: TokenEncoderInterface<CrossServiceTokenData>,
     @inject(TYPES.UserRepository) private userRepository: UserRepositoryInterface,
     @inject(TYPES.AUTH_JWT_TTL) private jwtTTL: number,
-    @inject(TYPES.RoleToSubscriptionMap) private roleToSubscriptionMap: RoleToSubscriptionMapInterface,
-    @inject(TYPES.TraceSession) private traceSession: TraceSession,
-    @inject(TYPES.Logger) private logger: Logger,
   ) {}
 
   async execute(dto: CreateCrossServiceTokenDTO): Promise<CreateCrossServiceTokenResponse> {
@@ -51,19 +45,6 @@ export class CreateCrossServiceToken implements UseCaseInterface {
       authTokenData.session = this.projectSession(dto.session)
     }
 
-    try {
-      const traceSessionResult = await this.traceSession.execute({
-        userUuid: user.uuid,
-        username: user.email,
-        subscriptionPlanName: this.getSubscriptionNameFromRoles(roles),
-      })
-      if (traceSessionResult.isFailed()) {
-        this.logger.error(traceSessionResult.getError())
-      }
-    } catch (error) {
-      this.logger.debug(`Could not trace session while creating cross service token.: ${(error as Error).message}`)
-    }
-
     return {
       token: this.tokenEncoder.encodeExpirableToken(authTokenData, this.jwtTTL),
     }
@@ -100,17 +81,4 @@ export class CreateCrossServiceToken implements UseCaseInterface {
   private projectRoles(roles: Array<Role>): Array<{ uuid: string; name: RoleName }> {
     return roles.map((role) => <{ uuid: string; name: RoleName }>this.roleProjector.projectSimple(role))
   }
-
-  private getSubscriptionNameFromRoles(roles: Array<Role>): string | null {
-    const nonSubscriptionRoles = this.roleToSubscriptionMap.filterSubscriptionRoles(roles)
-    if (nonSubscriptionRoles.length === 0) {
-      return null
-    }
-
-    const subscriptionName = this.roleToSubscriptionMap.getSubscriptionNameForRoleName(
-      nonSubscriptionRoles[0].name as RoleName,
-    )
-
-    return subscriptionName === undefined ? null : subscriptionName
-  }
 }