ソースを参照

feat(syncing-server): limit shared vaults creation based on role (#687)

* feat(syncing-server): limit shared vaults creation based on role

* fix: add role names emptyness validation

* fix: roles passing to response locals
Karol Sójko 1 年間 前
コミット
19b8921f28

+ 0 - 5
packages/api-gateway/src/Controller/AuthMiddleware.ts

@@ -1,5 +1,4 @@
 import { CrossServiceTokenData } from '@standardnotes/security'
-import { RoleName } from '@standardnotes/domain-core'
 import { TimerInterface } from '@standardnotes/time'
 import { NextFunction, Request, Response } from 'express'
 import { BaseMiddleware } from 'inversify-express-utils'
@@ -51,10 +50,6 @@ export abstract class AuthMiddleware extends BaseMiddleware {
 
       const decodedToken = <CrossServiceTokenData>verify(crossServiceToken, this.jwtSecret, { algorithms: ['HS256'] })
 
-      response.locals.freeUser =
-        decodedToken.roles.length === 1 &&
-        decodedToken.roles.find((role) => role.name === RoleName.NAMES.CoreUser) !== undefined
-
       if (this.crossServiceTokenCacheTTL && !crossServiceTokenFetchedFromCache) {
         await this.crossServiceTokenCache.set({
           authorizationHeaderValue: authHeaderValue,

+ 0 - 4
packages/api-gateway/src/Controller/WebSocketAuthMiddleware.ts

@@ -1,5 +1,4 @@
 import { CrossServiceTokenData } from '@standardnotes/security'
-import { RoleName } from '@standardnotes/domain-core'
 import { NextFunction, Request, Response } from 'express'
 import { inject, injectable } from 'inversify'
 import { BaseMiddleware } from 'inversify-express-utils'
@@ -60,9 +59,6 @@ export class WebSocketAuthMiddleware extends BaseMiddleware {
 
       const decodedToken = <CrossServiceTokenData>verify(crossServiceToken, this.jwtSecret, { algorithms: ['HS256'] })
 
-      response.locals.freeUser =
-        decodedToken.roles.length === 1 &&
-        decodedToken.roles.find((role) => role.name === RoleName.NAMES.CoreUser) !== undefined
       response.locals.user = decodedToken.user
       response.locals.roles = decodedToken.roles
     } catch (error) {

+ 1 - 0
packages/syncing-server/src/Domain/SharedVault/SharedVaultRepositoryInterface.ts

@@ -3,6 +3,7 @@ import { SharedVault } from './SharedVault'
 
 export interface SharedVaultRepositoryInterface {
   findByUuid(uuid: Uuid): Promise<SharedVault | null>
+  countByUserUuid(userUuid: Uuid): Promise<number>
   findByUuids(uuids: Uuid[], lastSyncTime?: number): Promise<SharedVault[]>
   save(sharedVault: SharedVault): Promise<void>
   remove(sharedVault: SharedVault): Promise<void>

+ 45 - 1
packages/syncing-server/src/Domain/UseCase/SharedVaults/CreateSharedVault/CreateSharedVault.spec.ts

@@ -1,5 +1,5 @@
 import { TimerInterface } from '@standardnotes/time'
-import { Result } from '@standardnotes/domain-core'
+import { Result, RoleName } from '@standardnotes/domain-core'
 
 import { SharedVaultRepositoryInterface } from '../../../SharedVault/SharedVaultRepositoryInterface'
 import { AddUserToSharedVault } from '../AddUserToSharedVault/AddUserToSharedVault'
@@ -29,12 +29,25 @@ describe('CreateSharedVault', () => {
 
     const result = await useCase.execute({
       userUuid: 'invalid-uuid',
+      userRoleNames: [RoleName.NAMES.ProUser],
     })
 
     expect(result.isFailed()).toBe(true)
     expect(result.getError()).toBe('Given value is not a valid uuid: invalid-uuid')
   })
 
+  it('should return a failure result if the user role names are empty', async () => {
+    const useCase = createUseCase()
+
+    const result = await useCase.execute({
+      userUuid: '00000000-0000-0000-0000-000000000000',
+      userRoleNames: [],
+    })
+
+    expect(result.isFailed()).toBe(true)
+    expect(result.getError()).toBe('Given value is empty: ')
+  })
+
   it('should return a failure result if the shared vault could not be created', async () => {
     const useCase = createUseCase()
 
@@ -45,6 +58,7 @@ describe('CreateSharedVault', () => {
 
     const result = await useCase.execute({
       userUuid: '00000000-0000-0000-0000-000000000000',
+      userRoleNames: [RoleName.NAMES.ProUser],
     })
 
     expect(result.isFailed()).toBe(true)
@@ -60,6 +74,7 @@ describe('CreateSharedVault', () => {
 
     const result = await useCase.execute({
       userUuid: '00000000-0000-0000-0000-000000000000',
+      userRoleNames: [RoleName.NAMES.ProUser],
     })
 
     expect(result.isFailed()).toBe(true)
@@ -71,6 +86,7 @@ describe('CreateSharedVault', () => {
 
     await useCase.execute({
       userUuid: '00000000-0000-0000-0000-000000000000',
+      userRoleNames: [RoleName.NAMES.ProUser],
     })
 
     expect(addUserToSharedVault.execute).toHaveBeenCalledWith({
@@ -80,4 +96,32 @@ describe('CreateSharedVault', () => {
     })
     expect(sharedVaultRepository.save).toHaveBeenCalled()
   })
+
+  it('should return a failure result if a plus user has reached the limit of shared vaults', async () => {
+    const useCase = createUseCase()
+
+    sharedVaultRepository.countByUserUuid = jest.fn().mockResolvedValue(3)
+
+    const result = await useCase.execute({
+      userUuid: '00000000-0000-0000-0000-000000000000',
+      userRoleNames: [RoleName.NAMES.PlusUser],
+    })
+
+    expect(result.isFailed()).toBe(true)
+    expect(result.getError()).toBe('You have reached the limit of shared vaults for your account.')
+  })
+
+  it('should return a failure result if a core user has reached the limit of shared vaults', async () => {
+    const useCase = createUseCase()
+
+    sharedVaultRepository.countByUserUuid = jest.fn().mockResolvedValue(1)
+
+    const result = await useCase.execute({
+      userUuid: '00000000-0000-0000-0000-000000000000',
+      userRoleNames: [RoleName.NAMES.CoreUser],
+    })
+
+    expect(result.isFailed()).toBe(true)
+    expect(result.getError()).toBe('You have reached the limit of shared vaults for your account.')
+  })
 })

+ 26 - 1
packages/syncing-server/src/Domain/UseCase/SharedVaults/CreateSharedVault/CreateSharedVault.ts

@@ -1,4 +1,4 @@
-import { Result, Timestamps, UseCaseInterface, Uuid } from '@standardnotes/domain-core'
+import { Result, RoleName, Timestamps, UseCaseInterface, Uuid, Validator } from '@standardnotes/domain-core'
 import { CreateSharedVaultResult } from './CreateSharedVaultResult'
 import { CreateSharedVaultDTO } from './CreateSharedVaultDTO'
 import { TimerInterface } from '@standardnotes/time'
@@ -22,6 +22,19 @@ export class CreateSharedVault implements UseCaseInterface<CreateSharedVaultResu
     }
     const userUuid = userUuidOrError.getValue()
 
+    const userRoleNamesValidationResult = Validator.isNotEmpty(dto.userRoleNames)
+    if (userRoleNamesValidationResult.isFailed()) {
+      return Result.fail(userRoleNamesValidationResult.getError())
+    }
+
+    const userSharedVaultLimit = this.getUserSharedVaultLimit(dto.userRoleNames)
+    if (userSharedVaultLimit !== undefined) {
+      const userSharedVaultCount = await this.sharedVaultRepository.countByUserUuid(userUuid)
+      if (userSharedVaultCount >= userSharedVaultLimit) {
+        return Result.fail('You have reached the limit of shared vaults for your account.')
+      }
+    }
+
     const timestamps = Timestamps.create(
       this.timer.getTimestampInMicroseconds(),
       this.timer.getTimestampInMicroseconds(),
@@ -52,4 +65,16 @@ export class CreateSharedVault implements UseCaseInterface<CreateSharedVaultResu
 
     return Result.ok({ sharedVault, sharedVaultUser })
   }
+
+  private getUserSharedVaultLimit(userRoleNames: string[]): number | undefined {
+    if (userRoleNames.includes(RoleName.NAMES.ProUser)) {
+      return undefined
+    }
+
+    if (userRoleNames.includes(RoleName.NAMES.PlusUser)) {
+      return 3
+    }
+
+    return 1
+  }
 }

+ 1 - 0
packages/syncing-server/src/Domain/UseCase/SharedVaults/CreateSharedVault/CreateSharedVaultDTO.ts

@@ -1,3 +1,4 @@
 export interface CreateSharedVaultDTO {
   userUuid: string
+  userRoleNames: string[]
 }

+ 0 - 3
packages/syncing-server/src/Domain/UseCase/Syncing/CheckIntegrity/CheckIntegrity.spec.ts

@@ -45,7 +45,6 @@ describe('CheckIntegrity', () => {
   it('should return an empty result if there are no integrity mismatches', async () => {
     const result = await createUseCase().execute({
       userUuid: '1-2-3',
-      freeUser: false,
       integrityPayloads: [
         {
           uuid: '1-2-3',
@@ -71,7 +70,6 @@ describe('CheckIntegrity', () => {
   it('should return a mismatch item that has a different update at timemstap', async () => {
     const result = await createUseCase().execute({
       userUuid: '1-2-3',
-      freeUser: false,
       integrityPayloads: [
         {
           uuid: '1-2-3',
@@ -102,7 +100,6 @@ describe('CheckIntegrity', () => {
   it('should return a mismatch item that is missing on the client side', async () => {
     const result = await createUseCase().execute({
       userUuid: '1-2-3',
-      freeUser: false,
       integrityPayloads: [
         {
           uuid: '1-2-3',

+ 0 - 1
packages/syncing-server/src/Domain/UseCase/Syncing/CheckIntegrity/CheckIntegrityDTO.ts

@@ -3,5 +3,4 @@ import { IntegrityPayload } from '@standardnotes/responses'
 export type CheckIntegrityDTO = {
   userUuid: string
   integrityPayloads: IntegrityPayload[]
-  freeUser: boolean
 }

+ 0 - 1
packages/syncing-server/src/Infra/InversifyExpressUtils/Base/BaseItemsController.ts

@@ -91,7 +91,6 @@ export class BaseItemsController extends BaseHttpController {
     const result = await this.checkIntegrity.execute({
       userUuid: response.locals.user.uuid,
       integrityPayloads,
-      freeUser: response.locals.freeUser,
     })
 
     if (result.isFailed()) {

+ 2 - 0
packages/syncing-server/src/Infra/InversifyExpressUtils/Base/BaseSharedVaultsController.ts

@@ -2,6 +2,7 @@ import { Request, Response } from 'express'
 import { BaseHttpController, results } from 'inversify-express-utils'
 import { HttpStatusCode } from '@standardnotes/responses'
 import { ControllerContainerInterface, MapperInterface } from '@standardnotes/domain-core'
+import { Role } from '@standardnotes/security'
 
 import { GetSharedVaults } from '../../../Domain/UseCase/SharedVaults/GetSharedVaults/GetSharedVaults'
 import { SharedVault } from '../../../Domain/SharedVault/SharedVault'
@@ -59,6 +60,7 @@ export class BaseSharedVaultsController extends BaseHttpController {
   async createSharedVault(_request: Request, response: Response): Promise<results.JsonResult> {
     const result = await this.createSharedVaultUseCase.execute({
       userUuid: response.locals.user.uuid,
+      userRoleNames: response.locals.roles.map((role: Role) => role.name),
     })
 
     if (result.isFailed()) {

+ 8 - 5
packages/syncing-server/src/Infra/InversifyExpressUtils/Middleware/InversifyExpressAuthMiddleware.spec.ts

@@ -61,10 +61,12 @@ describe('InversifyExpressAuthMiddleware', () => {
     await createMiddleware().handler(request, response, next)
 
     expect(response.locals.user).toEqual({ uuid: '123' })
-    expect(response.locals.roleNames).toEqual(['CORE_USER', 'PRO_USER'])
+    expect(response.locals.roles).toEqual([
+      { uuid: '1-2-3', name: RoleName.NAMES.CoreUser },
+      { uuid: '2-3-4', name: RoleName.NAMES.ProUser },
+    ])
     expect(response.locals.session).toEqual({ uuid: '234' })
     expect(response.locals.readOnlyAccess).toBeFalsy()
-    expect(response.locals.freeUser).toEqual(false)
 
     expect(next).toHaveBeenCalled()
   })
@@ -90,8 +92,6 @@ describe('InversifyExpressAuthMiddleware', () => {
 
     await createMiddleware().handler(request, response, next)
 
-    expect(response.locals.freeUser).toEqual(true)
-
     expect(next).toHaveBeenCalled()
   })
 
@@ -124,7 +124,10 @@ describe('InversifyExpressAuthMiddleware', () => {
     await createMiddleware().handler(request, response, next)
 
     expect(response.locals.user).toEqual({ uuid: '123' })
-    expect(response.locals.roleNames).toEqual(['CORE_USER', 'PRO_USER'])
+    expect(response.locals.roles).toEqual([
+      { uuid: '1-2-3', name: RoleName.NAMES.CoreUser },
+      { uuid: '2-3-4', name: RoleName.NAMES.ProUser },
+    ])
     expect(response.locals.session).toEqual({ uuid: '234', readonly_access: true })
     expect(response.locals.readOnlyAccess).toBeTruthy()
 

+ 1 - 4
packages/syncing-server/src/Infra/InversifyExpressUtils/Middleware/InversifyExpressAuthMiddleware.ts

@@ -3,7 +3,6 @@ import { BaseMiddleware } from 'inversify-express-utils'
 import { verify } from 'jsonwebtoken'
 import { CrossServiceTokenData } from '@standardnotes/security'
 import * as winston from 'winston'
-import { RoleName } from '@standardnotes/domain-core'
 
 export class InversifyExpressAuthMiddleware extends BaseMiddleware {
   constructor(private authJWTSecret: string, private logger: winston.Logger) {
@@ -23,9 +22,7 @@ export class InversifyExpressAuthMiddleware extends BaseMiddleware {
       const decodedToken = <CrossServiceTokenData>verify(authToken, this.authJWTSecret, { algorithms: ['HS256'] })
 
       response.locals.user = decodedToken.user
-      response.locals.roleNames = decodedToken.roles.map((role) => role.name)
-      response.locals.freeUser =
-        response.locals.roleNames.length === 1 && response.locals.roleNames[0] === RoleName.NAMES.CoreUser
+      response.locals.roles = decodedToken.roles
       response.locals.session = decodedToken.session
       response.locals.readOnlyAccess = decodedToken.session?.readonly_access ?? false
 

+ 11 - 0
packages/syncing-server/src/Infra/TypeORM/TypeORMSharedVaultRepository.ts

@@ -11,6 +11,17 @@ export class TypeORMSharedVaultRepository implements SharedVaultRepositoryInterf
     private mapper: MapperInterface<SharedVault, TypeORMSharedVault>,
   ) {}
 
+  async countByUserUuid(userUuid: Uuid): Promise<number> {
+    const count = await this.ormRepository
+      .createQueryBuilder('shared_vault')
+      .where('shared_vault.user_uuid = :userUuid', {
+        userUuid: userUuid.value,
+      })
+      .getCount()
+
+    return count
+  }
+
   async findByUuids(uuids: Uuid[], lastSyncTime?: number | undefined): Promise<SharedVault[]> {
     const queryBuilder = this.ormRepository
       .createQueryBuilder('shared_vault')