Forráskód Böngészése

fix(auth): adding user roles upon renewal of shared subscription (#1012)

* fix(auth): adding user roles upon renewal of shared subscription

* feat(auth): add procedure to fix roles on shared subscriptions
Karol Sójko 1 éve
szülő
commit
26b13ed6d4

+ 67 - 0
packages/auth/bin/fix_roles.ts

@@ -0,0 +1,67 @@
+import 'reflect-metadata'
+
+import { Logger } from 'winston'
+import * as dayjs from 'dayjs'
+import * as utc from 'dayjs/plugin/utc'
+import { Uuid } from '@standardnotes/domain-core'
+
+import { ContainerConfigLoader } from '../src/Bootstrap/Container'
+import TYPES from '../src/Bootstrap/Types'
+import { Env } from '../src/Bootstrap/Env'
+import { UserSubscriptionRepositoryInterface } from '../src/Domain/Subscription/UserSubscriptionRepositoryInterface'
+import { RoleServiceInterface } from '../src/Domain/Role/RoleServiceInterface'
+import { UserSubscriptionType } from '../src/Domain/Subscription/UserSubscriptionType'
+import { UserRepositoryInterface } from '../src/Domain/User/UserRepositoryInterface'
+
+const fixRoles = async (
+  userRepository: UserRepositoryInterface,
+  userSubscriptionRepository: UserSubscriptionRepositoryInterface,
+  roleService: RoleServiceInterface,
+): Promise<void> => {
+  const subscriptions = await userSubscriptionRepository.findActiveByType(UserSubscriptionType.Shared)
+
+  for (const subscription of subscriptions) {
+    const userUuidOrError = Uuid.create(subscription.userUuid)
+    if (userUuidOrError.isFailed()) {
+      continue
+    }
+    const userUuid = userUuidOrError.getValue()
+
+    const user = await userRepository.findOneByUuid(userUuid)
+    if (!user) {
+      continue
+    }
+
+    await roleService.addUserRoleBasedOnSubscription(user, subscription.planName)
+  }
+}
+
+const container = new ContainerConfigLoader('worker')
+void container.load().then((container) => {
+  dayjs.extend(utc)
+
+  const env: Env = new Env()
+  env.load()
+
+  const logger: Logger = container.get(TYPES.Auth_Logger)
+
+  logger.info('Starting roles fix for shared subscriptions...')
+
+  const userRepository = container.get<UserRepositoryInterface>(TYPES.Auth_UserRepository)
+  const userSubscriptionRepository = container.get<UserSubscriptionRepositoryInterface>(
+    TYPES.Auth_UserSubscriptionRepository,
+  )
+  const roleService = container.get<RoleServiceInterface>(TYPES.Auth_RoleService)
+
+  Promise.resolve(fixRoles(userRepository, userSubscriptionRepository, roleService))
+    .then(() => {
+      logger.info('Finished fixing roles for shared subscriptions')
+
+      process.exit(0)
+    })
+    .catch((error) => {
+      logger.error(`Error while fixing roles for shared subscriptions: ${(error as Error).message}`)
+
+      process.exit(1)
+    })
+})

+ 11 - 0
packages/auth/docker/entrypoint-fix-roles.js

@@ -0,0 +1,11 @@
+'use strict'
+
+const path = require('path')
+
+const pnp = require(path.normalize(path.resolve(__dirname, '../../..', '.pnp.cjs'))).setup()
+
+const index = require(path.normalize(path.resolve(__dirname, '../dist/bin/fix_roles.js')))
+
+Object.defineProperty(exports, '__esModule', { value: true })
+
+exports.default = index

+ 4 - 0
packages/auth/docker/entrypoint.sh

@@ -38,6 +38,10 @@ case "$COMMAND" in
     exec node docker/entrypoint-fix-quota.js $EMAIL
     ;;
 
+  'fix-roles' )
+    exec node docker/entrypoint-fix-roles.js
+    ;;
+
   'delete-accounts' )
     FILE_NAME=$1 && shift 1
     MODE=$1 && shift 1

+ 1 - 0
packages/auth/src/Bootstrap/Container.ts

@@ -1284,6 +1284,7 @@ export class ContainerConfigLoader {
           ),
           container.get<UserSubscriptionRepositoryInterface>(TYPES.Auth_UserSubscriptionRepository),
           container.get<UserRepositoryInterface>(TYPES.Auth_UserRepository),
+          container.get<RoleServiceInterface>(TYPES.Auth_RoleService),
           container.get<winston.Logger>(TYPES.Auth_Logger),
         ),
       )

+ 1 - 0
packages/auth/src/Domain/Subscription/UserSubscriptionRepositoryInterface.ts

@@ -10,6 +10,7 @@ export interface UserSubscriptionRepositoryInterface {
   findByUserUuid(userUuid: string): Promise<UserSubscription[]>
   countByPlanName(planNames: SubscriptionPlanName[]): Promise<number>
   findByPlanName(planNames: SubscriptionPlanName[], offset: number, limit: number): Promise<UserSubscription[]>
+  findActiveByType(type: UserSubscriptionType): Promise<UserSubscription[]>
   findOneByUserUuidAndSubscriptionId(userUuid: string, subscriptionId: number): Promise<UserSubscription | null>
   findBySubscriptionIdAndType(subscriptionId: number, type: UserSubscriptionType): Promise<UserSubscription[]>
   findBySubscriptionId(subscriptionId: number): Promise<UserSubscription[]>

+ 62 - 1
packages/auth/src/Domain/UseCase/RenewSharedSubscriptions/RenewSharedSubscriptions.spec.ts

@@ -8,6 +8,7 @@ import { SharedSubscriptionInvitation } from '../../SharedSubscription/SharedSub
 import { InviteeIdentifierType } from '../../SharedSubscription/InviteeIdentifierType'
 import { User } from '../../User/User'
 import { InvitationStatus } from '../../SharedSubscription/InvitationStatus'
+import { RoleServiceInterface } from '../../Role/RoleServiceInterface'
 
 describe('RenewSharedSubscriptions', () => {
   let listSharedSubscriptionInvitations: ListSharedSubscriptionInvitations
@@ -17,6 +18,7 @@ describe('RenewSharedSubscriptions', () => {
   let logger: Logger
   let sharedSubscriptionInvitation: SharedSubscriptionInvitation
   let user: User
+  let roleService: RoleServiceInterface
 
   const createUseCase = () =>
     new RenewSharedSubscriptions(
@@ -24,6 +26,7 @@ describe('RenewSharedSubscriptions', () => {
       sharedSubscriptionInvitationRepository,
       userSubscriptionRepository,
       userRepository,
+      roleService,
       logger,
     )
 
@@ -48,8 +51,12 @@ describe('RenewSharedSubscriptions', () => {
     userSubscriptionRepository = {} as jest.Mocked<UserSubscriptionRepositoryInterface>
     userSubscriptionRepository.save = jest.fn()
 
+    roleService = {} as jest.Mocked<RoleServiceInterface>
+    roleService.addUserRoleBasedOnSubscription = jest.fn()
+
     userRepository = {} as jest.Mocked<UserRepositoryInterface>
     userRepository.findOneByUsernameOrEmail = jest.fn().mockReturnValue(user)
+    userRepository.findOneByUuid = jest.fn().mockReturnValue(user)
 
     logger = {} as jest.Mocked<Logger>
     logger.error = jest.fn()
@@ -71,7 +78,7 @@ describe('RenewSharedSubscriptions', () => {
     expect(userSubscriptionRepository.save).toBeCalledTimes(1)
   })
 
-  it('should log error if user not found', async () => {
+  it('should log error if user not found by email', async () => {
     userRepository.findOneByUsernameOrEmail = jest.fn().mockReturnValue(null)
 
     const useCase = createUseCase()
@@ -88,6 +95,42 @@ describe('RenewSharedSubscriptions', () => {
     expect(logger.error).toBeCalledTimes(1)
   })
 
+  it('should log error if user not found by uuid', async () => {
+    sharedSubscriptionInvitation.inviteeIdentifierType = InviteeIdentifierType.Uuid
+    sharedSubscriptionInvitation.inviteeIdentifier = '00000000-0000-0000-0000-000000000000'
+    userRepository.findOneByUuid = jest.fn().mockReturnValue(null)
+
+    const useCase = createUseCase()
+
+    const result = await useCase.execute({
+      inviterEmail: 'inviter@test.te',
+      newSubscriptionId: 123,
+      newSubscriptionName: 'test',
+      newSubscriptionExpiresAt: 123,
+      timestamp: 123,
+    })
+
+    expect(result.isFailed()).toBeFalsy()
+    expect(logger.error).toBeCalledTimes(1)
+  })
+
+  it('should log error if user not found by unknown identifier type', async () => {
+    sharedSubscriptionInvitation.inviteeIdentifierType = 'unknown' as InviteeIdentifierType
+
+    const useCase = createUseCase()
+
+    const result = await useCase.execute({
+      inviterEmail: 'inviter@test.te',
+      newSubscriptionId: 123,
+      newSubscriptionName: 'test',
+      newSubscriptionExpiresAt: 123,
+      timestamp: 123,
+    })
+
+    expect(result.isFailed()).toBeFalsy()
+    expect(logger.error).toBeCalledTimes(1)
+  })
+
   it('should log error if error occurs', async () => {
     userRepository.findOneByUsernameOrEmail = jest.fn().mockImplementation(() => {
       throw new Error('test')
@@ -125,6 +168,24 @@ describe('RenewSharedSubscriptions', () => {
     expect(logger.error).toBeCalledTimes(1)
   })
 
+  it('should log error if uuid is invalid', async () => {
+    sharedSubscriptionInvitation.inviteeIdentifierType = InviteeIdentifierType.Uuid
+    sharedSubscriptionInvitation.inviteeIdentifier = 'invalid'
+
+    const useCase = createUseCase()
+
+    const result = await useCase.execute({
+      inviterEmail: 'inviter@test.te',
+      newSubscriptionId: 123,
+      newSubscriptionName: 'test',
+      newSubscriptionExpiresAt: 123,
+      timestamp: 123,
+    })
+
+    expect(result.isFailed()).toBeFalsy()
+    expect(logger.error).toBeCalledTimes(1)
+  })
+
   it('should renew shared subscription for invitations by user uuid', async () => {
     sharedSubscriptionInvitation.inviteeIdentifierType = InviteeIdentifierType.Uuid
     sharedSubscriptionInvitation.inviteeIdentifier = '00000000-0000-0000-0000-000000000000'

+ 17 - 10
packages/auth/src/Domain/UseCase/RenewSharedSubscriptions/RenewSharedSubscriptions.ts

@@ -1,4 +1,4 @@
-import { Result, UseCaseInterface, Username } from '@standardnotes/domain-core'
+import { Result, UseCaseInterface, Username, Uuid } from '@standardnotes/domain-core'
 import { Logger } from 'winston'
 
 import { RenewSharedSubscriptionsDTO } from './RenewSharedSubscriptionsDTO'
@@ -10,6 +10,8 @@ import { UserSubscriptionType } from '../../Subscription/UserSubscriptionType'
 import { UserSubscriptionRepositoryInterface } from '../../Subscription/UserSubscriptionRepositoryInterface'
 import { UserRepositoryInterface } from '../../User/UserRepositoryInterface'
 import { InviteeIdentifierType } from '../../SharedSubscription/InviteeIdentifierType'
+import { RoleServiceInterface } from '../../Role/RoleServiceInterface'
+import { User } from '../../User/User'
 
 export class RenewSharedSubscriptions implements UseCaseInterface<void> {
   constructor(
@@ -17,6 +19,7 @@ export class RenewSharedSubscriptions implements UseCaseInterface<void> {
     private sharedSubscriptionInvitationRepository: SharedSubscriptionInvitationRepositoryInterface,
     private userSubscriptionRepository: UserSubscriptionRepositoryInterface,
     private userRepository: UserRepositoryInterface,
+    private roleService: RoleServiceInterface,
     private logger: Logger,
   ) {}
 
@@ -31,8 +34,8 @@ export class RenewSharedSubscriptions implements UseCaseInterface<void> {
 
     for (const invitation of acceptedInvitations) {
       try {
-        const userUuid = await this.getInviteeUserUuid(invitation.inviteeIdentifier, invitation.inviteeIdentifierType)
-        if (userUuid === null) {
+        const user = await this.getInviteeUserUuid(invitation.inviteeIdentifier, invitation.inviteeIdentifierType)
+        if (user === null) {
           this.logger.error(
             `[SUBSCRIPTION: ${dto.newSubscriptionId}] Could not renew shared subscription for invitation: ${invitation.uuid}: Could not find user with identifier: ${invitation.inviteeIdentifier}`,
           )
@@ -42,11 +45,13 @@ export class RenewSharedSubscriptions implements UseCaseInterface<void> {
         await this.createSharedSubscription({
           subscriptionId: dto.newSubscriptionId,
           subscriptionName: dto.newSubscriptionName,
-          userUuid,
+          userUuid: user.uuid,
           timestamp: dto.timestamp,
           subscriptionExpiresAt: dto.newSubscriptionExpiresAt,
         })
 
+        await this.roleService.addUserRoleBasedOnSubscription(user, dto.newSubscriptionName)
+
         invitation.subscriptionId = dto.newSubscriptionId
         invitation.updatedAt = dto.timestamp
 
@@ -83,7 +88,7 @@ export class RenewSharedSubscriptions implements UseCaseInterface<void> {
     return this.userSubscriptionRepository.save(subscription)
   }
 
-  private async getInviteeUserUuid(inviteeIdentifier: string, inviteeIdentifierType: string): Promise<string | null> {
+  private async getInviteeUserUuid(inviteeIdentifier: string, inviteeIdentifierType: string): Promise<User | null> {
     if (inviteeIdentifierType === InviteeIdentifierType.Email) {
       const usernameOrError = Username.create(inviteeIdentifier)
       if (usernameOrError.isFailed()) {
@@ -91,14 +96,16 @@ export class RenewSharedSubscriptions implements UseCaseInterface<void> {
       }
       const username = usernameOrError.getValue()
 
-      const user = await this.userRepository.findOneByUsernameOrEmail(username)
-      if (user === null) {
+      return this.userRepository.findOneByUsernameOrEmail(username)
+    } else if (inviteeIdentifierType === InviteeIdentifierType.Uuid) {
+      const uuidOrError = Uuid.create(inviteeIdentifier)
+      if (uuidOrError.isFailed()) {
         return null
       }
-
-      return user.uuid
+      const uuid = uuidOrError.getValue()
+      return this.userRepository.findOneByUuid(uuid)
     }
 
-    return inviteeIdentifier
+    return null
   }
 }

+ 10 - 1
packages/auth/src/Infra/TypeORM/TypeORMUserSubscriptionRepository.ts

@@ -1,3 +1,4 @@
+import { SubscriptionPlanName } from '@standardnotes/domain-core'
 import { TimerInterface } from '@standardnotes/time'
 import { inject, injectable } from 'inversify'
 import { Repository } from 'typeorm'
@@ -6,7 +7,6 @@ import TYPES from '../../Bootstrap/Types'
 import { UserSubscription } from '../../Domain/Subscription/UserSubscription'
 import { UserSubscriptionRepositoryInterface } from '../../Domain/Subscription/UserSubscriptionRepositoryInterface'
 import { UserSubscriptionType } from '../../Domain/Subscription/UserSubscriptionType'
-import { SubscriptionPlanName } from '@standardnotes/domain-core'
 
 @injectable()
 export class TypeORMUserSubscriptionRepository implements UserSubscriptionRepositoryInterface {
@@ -16,6 +16,15 @@ export class TypeORMUserSubscriptionRepository implements UserSubscriptionReposi
     @inject(TYPES.Auth_Timer) private timer: TimerInterface,
   ) {}
 
+  async findActiveByType(type: UserSubscriptionType): Promise<UserSubscription[]> {
+    return await this.ormRepository
+      .createQueryBuilder()
+      .where('ends_at > :timestamp', { timestamp: this.timer.getTimestampInMicroseconds() })
+      .andWhere('subscription_type = :type', { type })
+      .orderBy('created_at', 'ASC')
+      .getMany()
+  }
+
   async countByPlanName(planNames: SubscriptionPlanName[]): Promise<number> {
     return await this.ormRepository
       .createQueryBuilder()