Browse Source

feat: block file operations during transition (#856)

* feat: block file operations during transition

* fix: tracing sessions

* fix fs file removal

* fix: checking if directory exists before listing files

* fix: removing shared vault user on auth side
Karol Sójko 1 year ago
parent
commit
676cf36f8d
23 changed files with 179 additions and 60 deletions
  1. 4 4
      packages/auth/src/Bootstrap/Container.ts
  2. 0 9
      packages/auth/src/Domain/Handler/AccountDeletionRequestedEventHandler.ts
  3. 1 1
      packages/auth/src/Domain/Session/SessionTraceRepositoryInterface.ts
  4. 14 1
      packages/auth/src/Domain/UseCase/CreateValetToken/CreateValetToken.spec.ts
  5. 7 0
      packages/auth/src/Domain/UseCase/CreateValetToken/CreateValetToken.ts
  6. 7 7
      packages/auth/src/Domain/UseCase/TraceSession/TraceSession.spec.ts
  7. 1 1
      packages/auth/src/Domain/UseCase/TraceSession/TraceSession.ts
  8. 8 5
      packages/auth/src/Infra/TypeORM/TypeORMSessionTraceRepository.ts
  9. 4 5
      packages/auth/src/Mapping/SessionTracePersistenceMapper.ts
  10. 19 3
      packages/files/src/Bootstrap/Container.ts
  11. 1 1
      packages/files/src/Domain/File/RemovedFileDescription.ts
  12. 13 7
      packages/files/src/Domain/Handler/AccountDeletionRequestedEventHandler.ts
  13. 15 7
      packages/files/src/Domain/Handler/SharedSubscriptionInvitationCanceledEventHandler.ts
  14. 4 0
      packages/files/src/Domain/Handler/SharedVaultRemovedEventHandler.ts
  15. 1 1
      packages/files/src/Domain/Services/FileRemoverInterface.ts
  16. 31 7
      packages/files/src/Infra/FS/FSFileRemover.ts
  17. 24 0
      packages/files/src/Infra/InversifyExpress/Middleware/ValetTokenAuthMiddleware.spec.ts
  18. 13 0
      packages/files/src/Infra/InversifyExpress/Middleware/ValetTokenAuthMiddleware.ts
  19. 1 1
      packages/files/src/Infra/S3/S3FileRemover.ts
  20. 1 0
      packages/security/src/Domain/Token/ValetTokenData.ts
  21. 5 0
      packages/time/src/Domain/Time/Timer.spec.ts
  22. 4 0
      packages/time/src/Domain/Time/Timer.ts
  23. 1 0
      packages/time/src/Domain/Time/TimerInterface.ts

+ 4 - 4
packages/auth/src/Bootstrap/Container.ts

@@ -369,7 +369,7 @@ export class ContainerConfigLoader {
     // Mapping
     container
       .bind<MapperInterface<SessionTrace, TypeORMSessionTrace>>(TYPES.Auth_SessionTracePersistenceMapper)
-      .toConstantValue(new SessionTracePersistenceMapper())
+      .toConstantValue(new SessionTracePersistenceMapper(container.get<TimerInterface>(TYPES.Auth_Timer)))
     container
       .bind<MapperInterface<Authenticator, TypeORMAuthenticator>>(TYPES.Auth_AuthenticatorPersistenceMapper)
       .toConstantValue(new AuthenticatorPersistenceMapper())
@@ -458,8 +458,9 @@ export class ContainerConfigLoader {
       .bind<SessionTraceRepositoryInterface>(TYPES.Auth_SessionTraceRepository)
       .toConstantValue(
         new TypeORMSessionTraceRepository(
-          container.get(TYPES.Auth_ORMSessionTraceRepository),
-          container.get(TYPES.Auth_SessionTracePersistenceMapper),
+          container.get<Repository<TypeORMSessionTrace>>(TYPES.Auth_ORMSessionTraceRepository),
+          container.get<MapperInterface<SessionTrace, TypeORMSessionTrace>>(TYPES.Auth_SessionTracePersistenceMapper),
+          container.get<TimerInterface>(TYPES.Auth_Timer),
         ),
       )
     container
@@ -1011,7 +1012,6 @@ export class ContainerConfigLoader {
           container.get<SessionRepositoryInterface>(TYPES.Auth_SessionRepository),
           container.get<EphemeralSessionRepositoryInterface>(TYPES.Auth_EphemeralSessionRepository),
           container.get<RevokedSessionRepositoryInterface>(TYPES.Auth_RevokedSessionRepository),
-          container.get<RemoveSharedVaultUser>(TYPES.Auth_RemoveSharedVaultUser),
           container.get<winston.Logger>(TYPES.Auth_Logger),
         ),
       )

+ 0 - 9
packages/auth/src/Domain/Handler/AccountDeletionRequestedEventHandler.ts

@@ -6,7 +6,6 @@ import { EphemeralSessionRepositoryInterface } from '../Session/EphemeralSession
 import { RevokedSessionRepositoryInterface } from '../Session/RevokedSessionRepositoryInterface'
 import { SessionRepositoryInterface } from '../Session/SessionRepositoryInterface'
 import { UserRepositoryInterface } from '../User/UserRepositoryInterface'
-import { RemoveSharedVaultUser } from '../UseCase/RemoveSharedVaultUser/RemoveSharedVaultUser'
 
 export class AccountDeletionRequestedEventHandler implements DomainEventHandlerInterface {
   constructor(
@@ -14,7 +13,6 @@ export class AccountDeletionRequestedEventHandler implements DomainEventHandlerI
     private sessionRepository: SessionRepositoryInterface,
     private ephemeralSessionRepository: EphemeralSessionRepositoryInterface,
     private revokedSessionRepository: RevokedSessionRepositoryInterface,
-    private removeSharedVaultUser: RemoveSharedVaultUser,
     private logger: Logger,
   ) {}
 
@@ -37,13 +35,6 @@ export class AccountDeletionRequestedEventHandler implements DomainEventHandlerI
 
     await this.removeSessions(userUuid.value)
 
-    const result = await this.removeSharedVaultUser.execute({
-      userUuid: userUuid.value,
-    })
-    if (result.isFailed()) {
-      this.logger.error(`Could not remove shared vault user: ${result.getError()}`)
-    }
-
     await this.userRepository.remove(user)
 
     this.logger.info(`Finished account cleanup for user: ${userUuid.value}`)

+ 1 - 1
packages/auth/src/Domain/Session/SessionTraceRepositoryInterface.ts

@@ -3,7 +3,7 @@ import { SubscriptionPlanName, Uuid } from '@standardnotes/domain-core'
 import { SessionTrace } from './SessionTrace'
 
 export interface SessionTraceRepositoryInterface {
-  save(sessionTrace: SessionTrace): Promise<void>
+  insert(sessionTrace: SessionTrace): Promise<void>
   removeExpiredBefore(date: Date): Promise<void>
   findOneByUserUuidAndDate(userUuid: Uuid, date: Date): Promise<SessionTrace | null>
   countByDate(date: Date): Promise<number>

+ 14 - 1
packages/auth/src/Domain/UseCase/CreateValetToken/CreateValetToken.spec.ts

@@ -1,14 +1,17 @@
 import 'reflect-metadata'
 
+import { TransitionStatus } from '@standardnotes/domain-core'
+import { TimerInterface } from '@standardnotes/time'
 import { TokenEncoderInterface, ValetTokenData, ValetTokenOperation } from '@standardnotes/security'
+
 import { CreateValetToken } from './CreateValetToken'
-import { TimerInterface } from '@standardnotes/time'
 import { UserSubscription } from '../../Subscription/UserSubscription'
 import { SubscriptionSettingServiceInterface } from '../../Setting/SubscriptionSettingServiceInterface'
 import { User } from '../../User/User'
 import { UserSubscriptionType } from '../../Subscription/UserSubscriptionType'
 import { SubscriptionSettingsAssociationServiceInterface } from '../../Setting/SubscriptionSettingsAssociationServiceInterface'
 import { UserSubscriptionServiceInterface } from '../../Subscription/UserSubscriptionServiceInterface'
+import { TransitionStatusRepositoryInterface } from '../../Transition/TransitionStatusRepositoryInterface'
 
 describe('CreateValetToken', () => {
   let tokenEncoder: TokenEncoderInterface<ValetTokenData>
@@ -20,6 +23,7 @@ describe('CreateValetToken', () => {
   let regularSubscription: UserSubscription
   let sharedSubscription: UserSubscription
   let user: User
+  let transitionStatusRepository: TransitionStatusRepositoryInterface
 
   const createUseCase = () =>
     new CreateValetToken(
@@ -29,6 +33,7 @@ describe('CreateValetToken', () => {
       userSubscriptionService,
       timer,
       valetTokenTTL,
+      transitionStatusRepository,
     )
 
   beforeEach(() => {
@@ -66,6 +71,11 @@ describe('CreateValetToken', () => {
 
     timer = {} as jest.Mocked<TimerInterface>
     timer.getTimestampInMicroseconds = jest.fn().mockReturnValue(100)
+
+    transitionStatusRepository = {} as jest.Mocked<TransitionStatusRepositoryInterface>
+    transitionStatusRepository.getStatus = jest
+      .fn()
+      .mockReturnValue(TransitionStatus.create(TransitionStatus.STATUSES.Verified).getValue())
   })
 
   it('should create a read valet token', async () => {
@@ -166,6 +176,7 @@ describe('CreateValetToken', () => {
       {
         sharedSubscriptionUuid: undefined,
         regularSubscriptionUuid: '1-2-3',
+        ongoingTransition: false,
         permittedOperation: 'write',
         permittedResources: [
           {
@@ -206,6 +217,7 @@ describe('CreateValetToken', () => {
       {
         sharedSubscriptionUuid: '2-3-4',
         regularSubscriptionUuid: '1-2-3',
+        ongoingTransition: false,
         permittedOperation: 'write',
         permittedResources: [
           {
@@ -266,6 +278,7 @@ describe('CreateValetToken', () => {
       {
         sharedSubscriptionUuid: undefined,
         regularSubscriptionUuid: '1-2-3',
+        ongoingTransition: false,
         permittedOperation: 'write',
         permittedResources: [
           {

+ 7 - 0
packages/auth/src/Domain/UseCase/CreateValetToken/CreateValetToken.ts

@@ -13,6 +13,8 @@ import { CreateValetTokenDTO } from './CreateValetTokenDTO'
 import { SubscriptionSettingsAssociationServiceInterface } from '../../Setting/SubscriptionSettingsAssociationServiceInterface'
 import { UserSubscriptionServiceInterface } from '../../Subscription/UserSubscriptionServiceInterface'
 import { CreateValetTokenPayload } from '../../ValetToken/CreateValetTokenPayload'
+import { TransitionStatusRepositoryInterface } from '../../Transition/TransitionStatusRepositoryInterface'
+import { TransitionStatus } from '@standardnotes/domain-core'
 
 @injectable()
 export class CreateValetToken implements UseCaseInterface {
@@ -25,6 +27,8 @@ export class CreateValetToken implements UseCaseInterface {
     @inject(TYPES.Auth_UserSubscriptionService) private userSubscriptionService: UserSubscriptionServiceInterface,
     @inject(TYPES.Auth_Timer) private timer: TimerInterface,
     @inject(TYPES.Auth_VALET_TOKEN_TTL) private valetTokenTTL: number,
+    @inject(TYPES.Auth_TransitionStatusRepository)
+    private transitionStatusRepository: TransitionStatusRepositoryInterface,
   ) {}
 
   async execute(dto: CreateValetTokenDTO): Promise<CreateValetTokenResponseData> {
@@ -83,6 +87,8 @@ export class CreateValetToken implements UseCaseInterface {
       sharedSubscriptionUuid = sharedSubscription.uuid
     }
 
+    const transitionStatus = await this.transitionStatusRepository.getStatus(userUuid, 'items')
+
     const tokenData: ValetTokenData = {
       userUuid: dto.userUuid,
       permittedOperation: dto.operation,
@@ -91,6 +97,7 @@ export class CreateValetToken implements UseCaseInterface {
       uploadBytesLimit,
       sharedSubscriptionUuid,
       regularSubscriptionUuid: regularSubscription.uuid,
+      ongoingTransition: transitionStatus?.value === TransitionStatus.STATUSES.InProgress,
     }
 
     const valetToken = this.tokenEncoder.encodeExpirableToken(tokenData, this.valetTokenTTL)

+ 7 - 7
packages/auth/src/Domain/UseCase/TraceSession/TraceSession.spec.ts

@@ -15,7 +15,7 @@ describe('TraceSession', () => {
   beforeEach(() => {
     sessionTraceRepository = {} as jest.Mocked<SessionTraceRepositoryInterface>
     sessionTraceRepository.findOneByUserUuidAndDate = jest.fn().mockReturnValue(null)
-    sessionTraceRepository.save = jest.fn()
+    sessionTraceRepository.insert = jest.fn()
 
     timer = {} as jest.Mocked<TimerInterface>
     timer.getUTCDateNDaysAhead = jest.fn().mockReturnValue(new Date())
@@ -30,7 +30,7 @@ describe('TraceSession', () => {
 
     expect(result.isFailed()).toBe(false)
     expect(result.getValue().props.userUuid.value).toEqual('0702b137-4f5c-438a-915e-8f8b46572ce5')
-    expect(sessionTraceRepository.save).toHaveBeenCalledTimes(1)
+    expect(sessionTraceRepository.insert).toHaveBeenCalledTimes(1)
   })
 
   it('should not save a session trace if one already exists for the same user and date', async () => {
@@ -43,7 +43,7 @@ describe('TraceSession', () => {
     })
 
     expect(result.isFailed()).toBe(false)
-    expect(sessionTraceRepository.save).not.toHaveBeenCalled()
+    expect(sessionTraceRepository.insert).not.toHaveBeenCalled()
   })
 
   it('should return an error if userUuid is invalid', async () => {
@@ -54,7 +54,7 @@ describe('TraceSession', () => {
     })
 
     expect(result.isFailed()).toBe(true)
-    expect(sessionTraceRepository.save).not.toHaveBeenCalled()
+    expect(sessionTraceRepository.insert).not.toHaveBeenCalled()
   })
 
   it('should return an error if username is invalid', async () => {
@@ -65,7 +65,7 @@ describe('TraceSession', () => {
     })
 
     expect(result.isFailed()).toBe(true)
-    expect(sessionTraceRepository.save).not.toHaveBeenCalled()
+    expect(sessionTraceRepository.insert).not.toHaveBeenCalled()
   })
 
   it('should return an error if subscriptionPlanName is invalid', async () => {
@@ -76,7 +76,7 @@ describe('TraceSession', () => {
     })
 
     expect(result.isFailed()).toBe(true)
-    expect(sessionTraceRepository.save).not.toHaveBeenCalled()
+    expect(sessionTraceRepository.insert).not.toHaveBeenCalled()
   })
 
   it('should not save a session trace if creating of the session trace fails', async () => {
@@ -90,7 +90,7 @@ describe('TraceSession', () => {
     })
 
     expect(result.isFailed()).toBe(true)
-    expect(sessionTraceRepository.save).not.toHaveBeenCalled()
+    expect(sessionTraceRepository.insert).not.toHaveBeenCalled()
 
     mock.mockRestore()
   })

+ 1 - 1
packages/auth/src/Domain/UseCase/TraceSession/TraceSession.ts

@@ -53,7 +53,7 @@ export class TraceSession implements UseCaseInterface<SessionTrace> {
     }
     const sessionTrace = sessionTraceOrError.getValue()
 
-    await this.sessionTraceRepository.save(sessionTrace)
+    await this.sessionTraceRepository.insert(sessionTrace)
 
     return Result.ok<SessionTrace>(sessionTrace)
   }

+ 8 - 5
packages/auth/src/Infra/TypeORM/TypeORMSessionTraceRepository.ts

@@ -1,5 +1,7 @@
 import { MapperInterface, SubscriptionPlanName, Uuid } from '@standardnotes/domain-core'
+import { TimerInterface } from '@standardnotes/time'
 import { Repository } from 'typeorm'
+
 import { SessionTrace } from '../../Domain/Session/SessionTrace'
 import { SessionTraceRepositoryInterface } from '../../Domain/Session/SessionTraceRepositoryInterface'
 import { TypeORMSessionTrace } from './TypeORMSessionTrace'
@@ -8,13 +10,14 @@ export class TypeORMSessionTraceRepository implements SessionTraceRepositoryInte
   constructor(
     private ormRepository: Repository<TypeORMSessionTrace>,
     private mapper: MapperInterface<SessionTrace, TypeORMSessionTrace>,
+    private timer: TimerInterface,
   ) {}
 
   async countByDateAndSubscriptionPlanName(date: Date, subscriptionPlanName: SubscriptionPlanName): Promise<number> {
     return this.ormRepository
       .createQueryBuilder('trace')
       .where('trace.creation_date = :creationDate', {
-        creationDate: `${date.getFullYear()}-${date.getMonth() + 1}-${date.getDate()}`,
+        creationDate: this.timer.convertDateToFormattedString(date, 'YYYY-MM-DD'),
       })
       .andWhere('trace.subscription_plan_name = :subscriptionPlanName', {
         subscriptionPlanName: subscriptionPlanName.value,
@@ -26,7 +29,7 @@ export class TypeORMSessionTraceRepository implements SessionTraceRepositoryInte
     return this.ormRepository
       .createQueryBuilder('trace')
       .where('trace.creation_date = :creationDate', {
-        creationDate: `${date.getFullYear()}-${date.getMonth() + 1}-${date.getDate()}`,
+        creationDate: this.timer.convertDateToFormattedString(date, 'YYYY-MM-DD'),
       })
       .getCount()
   }
@@ -44,7 +47,7 @@ export class TypeORMSessionTraceRepository implements SessionTraceRepositoryInte
       .createQueryBuilder('trace')
       .where('trace.user_uuid = :userUuid AND trace.creation_date = :creationDate', {
         userUuid: userUuid.value,
-        creationDate: `${date.getFullYear()}-${date.getMonth() + 1}-${date.getDate()}`,
+        creationDate: this.timer.convertDateToFormattedString(date, 'YYYY-MM-DD'),
       })
       .getOne()
 
@@ -55,9 +58,9 @@ export class TypeORMSessionTraceRepository implements SessionTraceRepositoryInte
     return this.mapper.toDomain(typeOrm)
   }
 
-  async save(sessionTrace: SessionTrace): Promise<void> {
+  async insert(sessionTrace: SessionTrace): Promise<void> {
     const persistence = this.mapper.toProjection(sessionTrace)
 
-    await this.ormRepository.save(persistence)
+    await this.ormRepository.insert(persistence)
   }
 }

+ 4 - 5
packages/auth/src/Mapping/SessionTracePersistenceMapper.ts

@@ -1,8 +1,11 @@
 import { MapperInterface, SubscriptionPlanName, UniqueEntityId, Username, Uuid } from '@standardnotes/domain-core'
 import { SessionTrace } from '../Domain/Session/SessionTrace'
 import { TypeORMSessionTrace } from '../Infra/TypeORM/TypeORMSessionTrace'
+import { TimerInterface } from '@standardnotes/time'
 
 export class SessionTracePersistenceMapper implements MapperInterface<SessionTrace, TypeORMSessionTrace> {
+  constructor(private timer: TimerInterface) {}
+
   toDomain(projection: TypeORMSessionTrace): SessionTrace {
     const userUuidOrError = Uuid.create(projection.userUuid)
     if (userUuidOrError.isFailed()) {
@@ -50,11 +53,7 @@ export class SessionTracePersistenceMapper implements MapperInterface<SessionTra
     typeOrm.username = domain.props.username.value
     typeOrm.subscriptionPlanName = domain.props.subscriptionPlanName ? domain.props.subscriptionPlanName.value : null
     typeOrm.createdAt = domain.props.createdAt
-    typeOrm.creationDate = new Date(
-      domain.props.createdAt.getFullYear(),
-      domain.props.createdAt.getMonth(),
-      domain.props.createdAt.getDate(),
-    )
+    typeOrm.creationDate = new Date(this.timer.convertDateToFormattedString(domain.props.createdAt, 'YYYY-MM-DD'))
     typeOrm.expiresAt = domain.props.expiresAt
 
     return typeOrm

+ 19 - 3
packages/files/src/Bootstrap/Container.ts

@@ -198,7 +198,9 @@ export class ContainerConfigLoader {
         .toConstantValue(
           new FSFileUploader(container.get(TYPES.Files_FILE_UPLOAD_PATH), container.get(TYPES.Files_Logger)),
         )
-      container.bind<FileRemoverInterface>(TYPES.Files_FileRemover).to(FSFileRemover)
+      container
+        .bind<FileRemoverInterface>(TYPES.Files_FileRemover)
+        .toConstantValue(new FSFileRemover(container.get<string>(TYPES.Files_FILE_UPLOAD_PATH)))
       container.bind<FileMoverInterface>(TYPES.Files_FileMover).to(FSFileMover)
     }
 
@@ -247,12 +249,26 @@ export class ContainerConfigLoader {
     // Handlers
     container
       .bind<AccountDeletionRequestedEventHandler>(TYPES.Files_AccountDeletionRequestedEventHandler)
-      .to(AccountDeletionRequestedEventHandler)
+      .toConstantValue(
+        new AccountDeletionRequestedEventHandler(
+          container.get<MarkFilesToBeRemoved>(TYPES.Files_MarkFilesToBeRemoved),
+          container.get<DomainEventPublisherInterface>(TYPES.Files_DomainEventPublisher),
+          container.get<DomainEventFactoryInterface>(TYPES.Files_DomainEventFactory),
+          container.get<winston.Logger>(TYPES.Files_Logger),
+        ),
+      )
     container
       .bind<SharedSubscriptionInvitationCanceledEventHandler>(
         TYPES.Files_SharedSubscriptionInvitationCanceledEventHandler,
       )
-      .to(SharedSubscriptionInvitationCanceledEventHandler)
+      .toConstantValue(
+        new SharedSubscriptionInvitationCanceledEventHandler(
+          container.get<MarkFilesToBeRemoved>(TYPES.Files_MarkFilesToBeRemoved),
+          container.get<DomainEventPublisherInterface>(TYPES.Files_DomainEventPublisher),
+          container.get<DomainEventFactoryInterface>(TYPES.Files_DomainEventFactory),
+          container.get<winston.Logger>(TYPES.Files_Logger),
+        ),
+      )
 
     const eventHandlers: Map<string, DomainEventHandlerInterface> = new Map([
       ['ACCOUNT_DELETION_REQUESTED', container.get(TYPES.Files_AccountDeletionRequestedEventHandler)],

+ 1 - 1
packages/files/src/Domain/File/RemovedFileDescription.ts

@@ -1,5 +1,5 @@
 export type RemovedFileDescription = {
-  userUuid: string
+  userOrSharedVaultUuid: string
   filePath: string
   fileName: string
   fileByteSize: number

+ 13 - 7
packages/files/src/Domain/Handler/AccountDeletionRequestedEventHandler.ts

@@ -3,18 +3,17 @@ import {
   DomainEventHandlerInterface,
   DomainEventPublisherInterface,
 } from '@standardnotes/domain-events'
-import { inject, injectable } from 'inversify'
+import { Logger } from 'winston'
 
-import TYPES from '../../Bootstrap/Types'
 import { DomainEventFactoryInterface } from '../Event/DomainEventFactoryInterface'
 import { MarkFilesToBeRemoved } from '../UseCase/MarkFilesToBeRemoved/MarkFilesToBeRemoved'
 
-@injectable()
 export class AccountDeletionRequestedEventHandler implements DomainEventHandlerInterface {
   constructor(
-    @inject(TYPES.Files_MarkFilesToBeRemoved) private markFilesToBeRemoved: MarkFilesToBeRemoved,
-    @inject(TYPES.Files_DomainEventPublisher) private domainEventPublisher: DomainEventPublisherInterface,
-    @inject(TYPES.Files_DomainEventFactory) private domainEventFactory: DomainEventFactoryInterface,
+    private markFilesToBeRemoved: MarkFilesToBeRemoved,
+    private domainEventPublisher: DomainEventPublisherInterface,
+    private domainEventFactory: DomainEventFactoryInterface,
+    private logger: Logger,
   ) {}
 
   async handle(event: AccountDeletionRequestedEvent): Promise<void> {
@@ -27,16 +26,23 @@ export class AccountDeletionRequestedEventHandler implements DomainEventHandlerI
     })
 
     if (result.isFailed()) {
+      this.logger.error(`Could not mark files for removal for user ${event.payload.userUuid}: ${result.getError()}`)
+
       return
     }
 
     const filesRemoved = result.getValue()
 
+    this.logger.debug(`Marked ${filesRemoved.length} files for removal for user ${event.payload.userUuid}`)
+
     for (const fileRemoved of filesRemoved) {
       await this.domainEventPublisher.publish(
         this.domainEventFactory.createFileRemovedEvent({
           regularSubscriptionUuid: event.payload.regularSubscriptionUuid,
-          ...fileRemoved,
+          userUuid: fileRemoved.userOrSharedVaultUuid,
+          filePath: fileRemoved.filePath,
+          fileName: fileRemoved.fileName,
+          fileByteSize: fileRemoved.fileByteSize,
         }),
       )
     }

+ 15 - 7
packages/files/src/Domain/Handler/SharedSubscriptionInvitationCanceledEventHandler.ts

@@ -3,18 +3,17 @@ import {
   DomainEventHandlerInterface,
   DomainEventPublisherInterface,
 } from '@standardnotes/domain-events'
-import { inject, injectable } from 'inversify'
+import { Logger } from 'winston'
 
-import TYPES from '../../Bootstrap/Types'
 import { DomainEventFactoryInterface } from '../Event/DomainEventFactoryInterface'
 import { MarkFilesToBeRemoved } from '../UseCase/MarkFilesToBeRemoved/MarkFilesToBeRemoved'
 
-@injectable()
 export class SharedSubscriptionInvitationCanceledEventHandler implements DomainEventHandlerInterface {
   constructor(
-    @inject(TYPES.Files_MarkFilesToBeRemoved) private markFilesToBeRemoved: MarkFilesToBeRemoved,
-    @inject(TYPES.Files_DomainEventPublisher) private domainEventPublisher: DomainEventPublisherInterface,
-    @inject(TYPES.Files_DomainEventFactory) private domainEventFactory: DomainEventFactoryInterface,
+    private markFilesToBeRemoved: MarkFilesToBeRemoved,
+    private domainEventPublisher: DomainEventPublisherInterface,
+    private domainEventFactory: DomainEventFactoryInterface,
+    private logger: Logger,
   ) {}
 
   async handle(event: SharedSubscriptionInvitationCanceledEvent): Promise<void> {
@@ -27,16 +26,25 @@ export class SharedSubscriptionInvitationCanceledEventHandler implements DomainE
     })
 
     if (result.isFailed()) {
+      this.logger.error(
+        `Could not mark files to be removed for invitee: ${event.payload.inviteeIdentifier}: ${result.getError()}`,
+      )
+
       return
     }
 
     const filesRemoved = result.getValue()
 
+    this.logger.debug(`Marked ${filesRemoved.length} files for removal for invitee ${event.payload.inviteeIdentifier}`)
+
     for (const fileRemoved of filesRemoved) {
       await this.domainEventPublisher.publish(
         this.domainEventFactory.createFileRemovedEvent({
           regularSubscriptionUuid: event.payload.inviterSubscriptionUuid,
-          ...fileRemoved,
+          userUuid: fileRemoved.userOrSharedVaultUuid,
+          filePath: fileRemoved.filePath,
+          fileName: fileRemoved.fileName,
+          fileByteSize: fileRemoved.fileByteSize,
         }),
       )
     }

+ 4 - 0
packages/files/src/Domain/Handler/SharedVaultRemovedEventHandler.ts

@@ -29,6 +29,10 @@ export class SharedVaultRemovedEventHandler implements DomainEventHandlerInterfa
 
     const filesRemoved = result.getValue()
 
+    this.logger.debug(
+      `Marked ${filesRemoved.length} files for removal for shared vault ${event.payload.sharedVaultUuid}`,
+    )
+
     for (const fileRemoved of filesRemoved) {
       await this.domainEventPublisher.publish(
         this.domainEventFactory.createSharedVaultFileRemovedEvent({

+ 1 - 1
packages/files/src/Domain/Services/FileRemoverInterface.ts

@@ -2,5 +2,5 @@ import { RemovedFileDescription } from '../File/RemovedFileDescription'
 
 export interface FileRemoverInterface {
   remove(filePath: string): Promise<number>
-  markFilesToBeRemoved(userUuid: string): Promise<Array<RemovedFileDescription>>
+  markFilesToBeRemoved(userOrSharedVaultUuid: string): Promise<Array<RemovedFileDescription>>
 }

+ 31 - 7
packages/files/src/Infra/FS/FSFileRemover.ts

@@ -1,18 +1,42 @@
-import { inject, injectable } from 'inversify'
 import { promises } from 'fs'
 
 import { FileRemoverInterface } from '../../Domain/Services/FileRemoverInterface'
 import { RemovedFileDescription } from '../../Domain/File/RemovedFileDescription'
-import TYPES from '../../Bootstrap/Types'
 
-@injectable()
 export class FSFileRemover implements FileRemoverInterface {
-  constructor(@inject(TYPES.Files_FILE_UPLOAD_PATH) private fileUploadPath: string) {}
+  constructor(private fileUploadPath: string) {}
 
-  async markFilesToBeRemoved(userUuid: string): Promise<Array<RemovedFileDescription>> {
-    await promises.rmdir(`${this.fileUploadPath}/${userUuid}`)
+  async markFilesToBeRemoved(userOrSharedVaultUuid: string): Promise<Array<RemovedFileDescription>> {
+    const removedFileDescriptions: RemovedFileDescription[] = []
 
-    return []
+    let directoryExists: boolean
+    try {
+      await promises.access(`${this.fileUploadPath}/${userOrSharedVaultUuid}`)
+      directoryExists = true
+    } catch (error) {
+      directoryExists = false
+    }
+
+    if (!directoryExists) {
+      return []
+    }
+
+    const files = await promises.readdir(`${this.fileUploadPath}/${userOrSharedVaultUuid}`, { withFileTypes: true })
+
+    for (const file of files) {
+      const filePath = `${this.fileUploadPath}/${userOrSharedVaultUuid}/${file.name}`
+
+      const fileByteSize = await this.remove(`${userOrSharedVaultUuid}/${file.name}`)
+
+      removedFileDescriptions.push({
+        filePath,
+        fileByteSize,
+        userOrSharedVaultUuid,
+        fileName: file.name,
+      })
+    }
+
+    return removedFileDescriptions
   }
 
   async remove(filePath: string): Promise<number> {

+ 24 - 0
packages/files/src/Infra/InversifyExpress/Middleware/ValetTokenAuthMiddleware.spec.ts

@@ -13,6 +13,7 @@ describe('ValetTokenAuthMiddleware', () => {
 
   const logger = {
     debug: jest.fn(),
+    error: jest.fn(),
   } as unknown as jest.Mocked<Logger>
 
   const createMiddleware = () => new ValetTokenAuthMiddleware(tokenDecoder, logger)
@@ -222,4 +223,27 @@ describe('ValetTokenAuthMiddleware', () => {
 
     expect(next).toHaveBeenCalledWith(error)
   })
+
+  it('should throw an error if the valet token indicates an ongoing transition', async () => {
+    request.headers['x-valet-token'] = 'valet-token'
+
+    tokenDecoder.decodeToken = jest.fn().mockReturnValue({
+      userUuid: '1-2-3',
+      permittedResources: [
+        {
+          remoteIdentifier: '00000000-0000-0000-0000-000000000000',
+          unencryptedFileSize: 30,
+        },
+      ],
+      permittedOperation: 'write',
+      uploadBytesLimit: -1,
+      uploadBytesUsed: 80,
+      ongoingTransition: true,
+    })
+
+    await createMiddleware().handler(request, response, next)
+
+    expect(response.status).toHaveBeenCalledWith(500)
+    expect(next).not.toHaveBeenCalled()
+  })
 })

+ 13 - 0
packages/files/src/Infra/InversifyExpress/Middleware/ValetTokenAuthMiddleware.ts

@@ -46,6 +46,19 @@ export class ValetTokenAuthMiddleware extends BaseMiddleware {
         return
       }
 
+      if (valetTokenData.ongoingTransition === true) {
+        this.logger.error(`Cannot perform file operations for user ${valetTokenData.userUuid} during transition`)
+
+        response.status(500).send({
+          error: {
+            tag: 'ongoing-transition',
+            message: 'Cannot perform file operations during transition',
+          },
+        })
+
+        return
+      }
+
       for (const resource of valetTokenData.permittedResources) {
         const resourceUuidOrError = Uuid.create(resource.remoteIdentifier)
         if (resourceUuidOrError.isFailed()) {

+ 1 - 1
packages/files/src/Infra/S3/S3FileRemover.ts

@@ -59,7 +59,7 @@ export class S3FileRemover implements FileRemoverInterface {
         fileByteSize: file.Size as number,
         fileName: file.Key.replace(`${userUuid}/`, ''),
         filePath: file.Key,
-        userUuid,
+        userOrSharedVaultUuid: userUuid,
       })
     }
 

+ 1 - 0
packages/security/src/Domain/Token/ValetTokenData.ts

@@ -11,4 +11,5 @@ export type ValetTokenData = {
   }>
   uploadBytesUsed: number
   uploadBytesLimit: number
+  ongoingTransition?: boolean
 }

+ 5 - 0
packages/time/src/Domain/Time/Timer.spec.ts

@@ -76,6 +76,11 @@ describe('Timer', () => {
     expect(isoString).toEqual('2021-03-29T08:00:05.000Z')
   })
 
+  it('should convert a date to formatted string', () => {
+    const isoString = createTimer().convertDateToFormattedString(new Date(Date.UTC(2021, 2, 29, 8, 0, 5)), 'YYYY-MM-DD')
+    expect(isoString).toEqual('2021-03-29')
+  })
+
   it('should convert a string date to microseconds', () => {
     const timestamp = createTimer().convertStringDateToMicroseconds('2021-03-29 08:00:05.233Z')
     expect(timestamp).toEqual(1617004805233000)

+ 4 - 0
packages/time/src/Domain/Time/Timer.ts

@@ -95,6 +95,10 @@ export class Timer implements TimerInterface {
     return dayjs.utc(date).toISOString()
   }
 
+  convertDateToFormattedString(date: Date, format: string): string {
+    return dayjs.utc(date).format(format)
+  }
+
   dateWasNDaysAgo(date: Date): number {
     return dayjs.utc().diff(date, 'days')
   }

+ 1 - 0
packages/time/src/Domain/Time/TimerInterface.ts

@@ -12,6 +12,7 @@ export interface TimerInterface {
   convertDateToMilliseconds(date: Date): number
   convertDateToMicroseconds(date: Date): number
   convertDateToISOString(date: Date): string
+  convertDateToFormattedString(date: Date, format: string): string
   convertStringDateToDate(date: string): Date
   convertStringDateToMicroseconds(date: string): number
   convertStringDateToMilliseconds(date: string): number