Browse Source

feat: send event to client upon items change on server (#941)

* feat(websockets): persist connections in mysql

* fix: add sending event to client upon items changed on server

* fix payload

* fix: add cathcing errors

* fix: send changed items event only on a 10% dice roll
Karol Sójko 1 year ago
parent
commit
69b404f5d4
47 changed files with 774 additions and 419 deletions
  1. 1 1
      .pnp.cjs
  2. 7 0
      packages/domain-events/src/Domain/Event/ItemsChangedOnServerEvent.ts
  3. 5 0
      packages/domain-events/src/Domain/Event/ItemsChangedOnServerEventPayload.ts
  4. 1 0
      packages/domain-events/src/Domain/Event/WebSocketMessageRequestedEventPayload.ts
  5. 2 0
      packages/domain-events/src/Domain/index.ts
  6. 20 7
      packages/syncing-server/src/Bootstrap/Container.ts
  7. 1 0
      packages/syncing-server/src/Bootstrap/Types.ts
  8. 25 1
      packages/syncing-server/src/Domain/Event/DomainEventFactory.ts
  9. 11 1
      packages/syncing-server/src/Domain/Event/DomainEventFactoryInterface.ts
  10. 34 1
      packages/syncing-server/src/Domain/UseCase/Syncing/SaveItems/SaveItems.spec.ts
  11. 39 0
      packages/syncing-server/src/Domain/UseCase/Syncing/SaveItems/SaveItems.ts
  12. 19 0
      packages/syncing-server/src/Domain/UseCase/Syncing/SendEventToClient/SendEventToClient.spec.ts
  13. 17 12
      packages/syncing-server/src/Domain/UseCase/Syncing/SendEventToClient/SendEventToClient.ts
  14. 1 0
      packages/syncing-server/src/Domain/UseCase/Syncing/SendEventToClient/SendEventToClientDTO.ts
  15. 9 0
      packages/websockets/.env.sample
  16. 1 1
      packages/websockets/bin/server.ts
  17. 1 1
      packages/websockets/bin/worker.ts
  18. 17 0
      packages/websockets/migrations/mysql/1701087671322-initial-database.ts
  19. 1 1
      packages/websockets/package.json
  20. 58 26
      packages/websockets/src/Bootstrap/Container.ts
  21. 84 0
      packages/websockets/src/Bootstrap/DataSource.ts
  22. 7 0
      packages/websockets/src/Bootstrap/MigrationsDataSource.ts
  23. 6 3
      packages/websockets/src/Bootstrap/Types.ts
  24. 0 3
      packages/websockets/src/Client/ClientMessengerInterface.ts
  25. 0 99
      packages/websockets/src/Controller/ApiGatewayAuthMiddleware.spec.ts
  26. 0 28
      packages/websockets/src/Controller/WebSocketsController.spec.ts
  27. 0 29
      packages/websockets/src/Controller/WebSocketsController.ts
  28. 13 0
      packages/websockets/src/Domain/Connection/Connection.ts
  29. 8 0
      packages/websockets/src/Domain/Connection/ConnectionProps.ts
  30. 11 9
      packages/websockets/src/Domain/Handler/WebSocketMessageRequestedEventHandler.ts
  31. 36 5
      packages/websockets/src/Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnection.spec.ts
  32. 33 7
      packages/websockets/src/Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnection.ts
  33. 1 0
      packages/websockets/src/Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnectionDTO.ts
  34. 99 0
      packages/websockets/src/Domain/UseCase/SendMessageToClient/SendMessageToClient.spec.ts
  35. 57 0
      packages/websockets/src/Domain/UseCase/SendMessageToClient/SendMessageToClient.ts
  36. 5 0
      packages/websockets/src/Domain/UseCase/SendMessageToClient/SendMessageToClientDTO.ts
  37. 5 2
      packages/websockets/src/Domain/WebSockets/WebSocketsConnectionRepositoryInterface.ts
  38. 6 4
      packages/websockets/src/Infra/InversifyExpressUtils/AnnotatedWebSocketsController.ts
  39. 2 2
      packages/websockets/src/Infra/InversifyExpressUtils/Middleware/ApiGatewayAuthMiddleware.ts
  40. 0 44
      packages/websockets/src/Infra/Redis/RedisWebSocketsConnectionRepository.spec.ts
  41. 0 28
      packages/websockets/src/Infra/Redis/RedisWebSocketsConnectionRepository.ts
  42. 42 0
      packages/websockets/src/Infra/TypeORM/SQLConnection.ts
  43. 38 0
      packages/websockets/src/Infra/TypeORM/SQLConnectionRepository.ts
  44. 0 49
      packages/websockets/src/Infra/WebSockets/WebSocketsClientMessenger.ts
  45. 0 54
      packages/websockets/src/Infra/WebSockets/WebSocketsClientService.spec.ts
  46. 50 0
      packages/websockets/src/Mapping/SQL/ConnectionPersistenceMapper.ts
  47. 1 1
      yarn.lock

+ 1 - 1
.pnp.cjs

@@ -7045,13 +7045,13 @@ const RAW_RUNTIME_STATE =
           ["@standardnotes/websockets-server", "workspace:packages/websockets"],\
           ["@aws-sdk/client-apigatewaymanagementapi", "npm:3.427.0"],\
           ["@aws-sdk/client-sqs", "npm:3.427.0"],\
-          ["@standardnotes/api", "npm:1.26.26"],\
           ["@standardnotes/common", "workspace:packages/common"],\
           ["@standardnotes/domain-core", "workspace:packages/domain-core"],\
           ["@standardnotes/domain-events", "workspace:packages/domain-events"],\
           ["@standardnotes/domain-events-infra", "workspace:packages/domain-events-infra"],\
           ["@standardnotes/responses", "npm:1.13.27"],\
           ["@standardnotes/security", "workspace:packages/security"],\
+          ["@standardnotes/time", "workspace:packages/time"],\
           ["@types/cors", "npm:2.8.13"],\
           ["@types/express", "npm:4.17.17"],\
           ["@types/ioredis", "npm:5.0.0"],\

+ 7 - 0
packages/domain-events/src/Domain/Event/ItemsChangedOnServerEvent.ts

@@ -0,0 +1,7 @@
+import { DomainEventInterface } from './DomainEventInterface'
+import { ItemsChangedOnServerEventPayload } from './ItemsChangedOnServerEventPayload'
+
+export interface ItemsChangedOnServerEvent extends DomainEventInterface {
+  type: 'ITEMS_CHANGED_ON_SERVER'
+  payload: ItemsChangedOnServerEventPayload
+}

+ 5 - 0
packages/domain-events/src/Domain/Event/ItemsChangedOnServerEventPayload.ts

@@ -0,0 +1,5 @@
+export interface ItemsChangedOnServerEventPayload {
+  userUuid: string
+  sessionUuid: string
+  timestamp: number
+}

+ 1 - 0
packages/domain-events/src/Domain/Event/WebSocketMessageRequestedEventPayload.ts

@@ -1,4 +1,5 @@
 export interface WebSocketMessageRequestedEventPayload {
   userUuid: string
   message: string
+  originatingSessionUuid?: string
 }

+ 2 - 0
packages/domain-events/src/Domain/index.ts

@@ -40,6 +40,8 @@ export * from './Event/ItemRemovedFromSharedVaultEvent'
 export * from './Event/ItemRemovedFromSharedVaultEventPayload'
 export * from './Event/ItemRevisionCreationRequestedEvent'
 export * from './Event/ItemRevisionCreationRequestedEventPayload'
+export * from './Event/ItemsChangedOnServerEvent'
+export * from './Event/ItemsChangedOnServerEventPayload'
 export * from './Event/ListedAccountCreatedEvent'
 export * from './Event/ListedAccountCreatedEventPayload'
 export * from './Event/ListedAccountDeletedEvent'

+ 20 - 7
packages/syncing-server/src/Bootstrap/Container.ts

@@ -114,7 +114,13 @@ import { GetSharedVaults } from '../Domain/UseCase/SharedVaults/GetSharedVaults/
 import { CreateSharedVault } from '../Domain/UseCase/SharedVaults/CreateSharedVault/CreateSharedVault'
 import { DeleteSharedVault } from '../Domain/UseCase/SharedVaults/DeleteSharedVault/DeleteSharedVault'
 import { CreateSharedVaultFileValetToken } from '../Domain/UseCase/SharedVaults/CreateSharedVaultFileValetToken/CreateSharedVaultFileValetToken'
-import { SharedVaultValetTokenData, TokenEncoder, TokenEncoderInterface } from '@standardnotes/security'
+import {
+  DeterministicSelector,
+  SelectorInterface,
+  SharedVaultValetTokenData,
+  TokenEncoder,
+  TokenEncoderInterface,
+} from '@standardnotes/security'
 import { SharedVaultHttpRepresentation } from '../Mapping/Http/SharedVaultHttpRepresentation'
 import { SharedVaultHttpMapper } from '../Mapping/Http/SharedVaultHttpMapper'
 import { SharedVaultInviteHttpRepresentation } from '../Mapping/Http/SharedVaultInviteHttpRepresentation'
@@ -200,6 +206,10 @@ export class ContainerConfigLoader {
     }
     container.bind<winston.Logger>(TYPES.Sync_Logger).toConstantValue(logger)
 
+    container
+      .bind<SelectorInterface<number>>(TYPES.Sync_NumberSelector)
+      .toConstantValue(new DeterministicSelector<number>())
+
     const appDataSource = new AppDataSource({ env, runMigrations: this.mode === 'server' })
     await appDataSource.initialize()
 
@@ -601,12 +611,15 @@ export class ContainerConfigLoader {
       .bind<SaveItems>(TYPES.Sync_SaveItems)
       .toConstantValue(
         new SaveItems(
-          container.get(TYPES.Sync_ItemSaveValidator),
-          container.get(TYPES.Sync_SQLItemRepository),
-          container.get(TYPES.Sync_Timer),
-          container.get(TYPES.Sync_SaveNewItem),
-          container.get(TYPES.Sync_UpdateExistingItem),
-          container.get(TYPES.Sync_Logger),
+          container.get<ItemSaveValidatorInterface>(TYPES.Sync_ItemSaveValidator),
+          container.get<ItemRepositoryInterface>(TYPES.Sync_SQLItemRepository),
+          container.get<TimerInterface>(TYPES.Sync_Timer),
+          container.get<SaveNewItem>(TYPES.Sync_SaveNewItem),
+          container.get<UpdateExistingItem>(TYPES.Sync_UpdateExistingItem),
+          container.get<SendEventToClient>(TYPES.Sync_SendEventToClient),
+          container.get<DomainEventFactoryInterface>(TYPES.Sync_DomainEventFactory),
+          container.get<SelectorInterface<number>>(TYPES.Sync_NumberSelector),
+          container.get<Logger>(TYPES.Sync_Logger),
         ),
       )
     container

+ 1 - 0
packages/syncing-server/src/Bootstrap/Types.ts

@@ -6,6 +6,7 @@ const TYPES = {
   Sync_SQS: Symbol.for('Sync_SQS'),
   Sync_S3: Symbol.for('Sync_S3'),
   Sync_Env: Symbol.for('Sync_Env'),
+  Sync_NumberSelector: Symbol.for('Sync_NumberSelector'),
   // Repositories
   Sync_SQLItemRepository: Symbol.for('Sync_SQLItemRepository'),
   Sync_SharedVaultRepository: Symbol.for('Sync_SharedVaultRepository'),

+ 25 - 1
packages/syncing-server/src/Domain/Event/DomainEventFactory.ts

@@ -7,6 +7,7 @@ import {
   ItemDumpedEvent,
   ItemRemovedFromSharedVaultEvent,
   ItemRevisionCreationRequestedEvent,
+  ItemsChangedOnServerEvent,
   MessageSentToUserEvent,
   NotificationAddedForUserEvent,
   RevisionsCopyRequestedEvent,
@@ -23,6 +24,25 @@ import { DomainEventFactoryInterface } from './DomainEventFactoryInterface'
 export class DomainEventFactory implements DomainEventFactoryInterface {
   constructor(private timer: TimerInterface) {}
 
+  createItemsChangedOnServerEvent(dto: {
+    userUuid: string
+    sessionUuid: string
+    timestamp: number
+  }): ItemsChangedOnServerEvent {
+    return {
+      type: 'ITEMS_CHANGED_ON_SERVER',
+      createdAt: this.timer.getUTCDate(),
+      meta: {
+        correlation: {
+          userIdentifier: dto.userUuid,
+          userIdentifierType: 'uuid',
+        },
+        origin: DomainEventService.SyncingServer,
+      },
+      payload: dto,
+    }
+  }
+
   createAccountDeletionVerificationPassedEvent(dto: {
     userUuid: string
     email: string
@@ -207,7 +227,11 @@ export class DomainEventFactory implements DomainEventFactoryInterface {
     }
   }
 
-  createWebSocketMessageRequestedEvent(dto: { userUuid: string; message: string }): WebSocketMessageRequestedEvent {
+  createWebSocketMessageRequestedEvent(dto: {
+    userUuid: string
+    message: string
+    originatingSessionUuid?: string
+  }): WebSocketMessageRequestedEvent {
     return {
       type: 'WEB_SOCKET_MESSAGE_REQUESTED',
       createdAt: this.timer.getUTCDate(),

+ 11 - 1
packages/syncing-server/src/Domain/Event/DomainEventFactoryInterface.ts

@@ -5,6 +5,7 @@ import {
   ItemDumpedEvent,
   ItemRemovedFromSharedVaultEvent,
   ItemRevisionCreationRequestedEvent,
+  ItemsChangedOnServerEvent,
   MessageSentToUserEvent,
   NotificationAddedForUserEvent,
   RevisionsCopyRequestedEvent,
@@ -17,7 +18,16 @@ import {
 } from '@standardnotes/domain-events'
 
 export interface DomainEventFactoryInterface {
-  createWebSocketMessageRequestedEvent(dto: { userUuid: string; message: string }): WebSocketMessageRequestedEvent
+  createWebSocketMessageRequestedEvent(dto: {
+    userUuid: string
+    message: string
+    originatingSessionUuid?: string
+  }): WebSocketMessageRequestedEvent
+  createItemsChangedOnServerEvent(dto: {
+    userUuid: string
+    sessionUuid: string
+    timestamp: number
+  }): ItemsChangedOnServerEvent
   createUserInvitedToSharedVaultEvent(dto: {
     invite: {
       uuid: string

+ 34 - 1
packages/syncing-server/src/Domain/UseCase/Syncing/SaveItems/SaveItems.spec.ts

@@ -8,6 +8,10 @@ import { Logger } from 'winston'
 import { ContentType, Dates, Result, Timestamps, Uuid } from '@standardnotes/domain-core'
 import { ItemHash } from '../../../Item/ItemHash'
 import { Item } from '../../../Item/Item'
+import { SendEventToClient } from '../SendEventToClient/SendEventToClient'
+import { DomainEventFactoryInterface } from '../../../Event/DomainEventFactoryInterface'
+import { ItemsChangedOnServerEvent } from '@standardnotes/domain-events'
+import { SelectorInterface } from '@standardnotes/security'
 
 describe('SaveItems', () => {
   let itemSaveValidator: ItemSaveValidatorInterface
@@ -18,11 +22,35 @@ describe('SaveItems', () => {
   let logger: Logger
   let itemHash1: ItemHash
   let savedItem: Item
+  let sendEventToClient: SendEventToClient
+  let domainEventFactory: DomainEventFactoryInterface
+  let deterministicSelector: SelectorInterface<number>
 
   const createUseCase = () =>
-    new SaveItems(itemSaveValidator, itemRepository, timer, saveNewItem, updateExistingItem, logger)
+    new SaveItems(
+      itemSaveValidator,
+      itemRepository,
+      timer,
+      saveNewItem,
+      updateExistingItem,
+      sendEventToClient,
+      domainEventFactory,
+      deterministicSelector,
+      logger,
+    )
 
   beforeEach(() => {
+    deterministicSelector = {} as jest.Mocked<SelectorInterface<number>>
+    deterministicSelector.select = jest.fn().mockReturnValue(1)
+
+    sendEventToClient = {} as jest.Mocked<SendEventToClient>
+    sendEventToClient.execute = jest.fn().mockReturnValue(Result.ok())
+
+    domainEventFactory = {} as jest.Mocked<DomainEventFactoryInterface>
+    domainEventFactory.createItemsChangedOnServerEvent = jest
+      .fn()
+      .mockReturnValue({} as jest.Mocked<ItemsChangedOnServerEvent>)
+
     itemSaveValidator = {} as jest.Mocked<ItemSaveValidatorInterface>
     itemSaveValidator.validate = jest.fn().mockResolvedValue({ passed: true })
 
@@ -92,6 +120,7 @@ describe('SaveItems', () => {
       userUuid: 'user-uuid',
       sessionUuid: 'session-uuid',
     })
+    expect(sendEventToClient.execute).toHaveBeenCalled()
   })
 
   it('should mark items as conflicts if saving new item fails', async () => {
@@ -115,6 +144,7 @@ describe('SaveItems', () => {
         type: 'uuid_conflict',
       },
     ])
+    expect(sendEventToClient.execute).not.toHaveBeenCalled()
   })
 
   it('should mark items as conflicts if saving new item throws an error', async () => {
@@ -197,6 +227,8 @@ describe('SaveItems', () => {
   })
 
   it('should update existing items', async () => {
+    deterministicSelector.select = jest.fn().mockReturnValue(0)
+
     const useCase = createUseCase()
 
     itemRepository.findByUuid = jest.fn().mockResolvedValue(savedItem)
@@ -217,6 +249,7 @@ describe('SaveItems', () => {
       sessionUuid: 'session-uuid',
       performingUserUuid: '00000000-0000-0000-0000-000000000000',
     })
+    expect(sendEventToClient.execute).not.toHaveBeenCalled()
   })
 
   it('should mark items as conflicts if updating existing item fails', async () => {

+ 39 - 0
packages/syncing-server/src/Domain/UseCase/Syncing/SaveItems/SaveItems.ts

@@ -11,6 +11,9 @@ import { ItemSaveValidatorInterface } from '../../../Item/SaveValidator/ItemSave
 import { SaveNewItem } from '../SaveNewItem/SaveNewItem'
 import { UpdateExistingItem } from '../UpdateExistingItem/UpdateExistingItem'
 import { ItemRepositoryInterface } from '../../../Item/ItemRepositoryInterface'
+import { SendEventToClient } from '../SendEventToClient/SendEventToClient'
+import { DomainEventFactoryInterface } from '../../../Event/DomainEventFactoryInterface'
+import { SelectorInterface } from '@standardnotes/security'
 
 export class SaveItems implements UseCaseInterface<SaveItemsResult> {
   private readonly SYNC_TOKEN_VERSION = 2
@@ -21,6 +24,9 @@ export class SaveItems implements UseCaseInterface<SaveItemsResult> {
     private timer: TimerInterface,
     private saveNewItem: SaveNewItem,
     private updateExistingItem: UpdateExistingItem,
+    private sendEventToClient: SendEventToClient,
+    private domainEventFactory: DomainEventFactoryInterface,
+    private deterministicSelector: SelectorInterface<number>,
     private logger: Logger,
   ) {}
 
@@ -133,6 +139,8 @@ export class SaveItems implements UseCaseInterface<SaveItemsResult> {
 
     const syncToken = this.calculateSyncToken(lastUpdatedTimestamp, savedItems)
 
+    await this.notifyOtherClientsOfTheUserThatItemsChanged(dto, savedItems, lastUpdatedTimestamp)
+
     return Result.ok({
       savedItems,
       conflicts,
@@ -140,6 +148,37 @@ export class SaveItems implements UseCaseInterface<SaveItemsResult> {
     })
   }
 
+  private async notifyOtherClientsOfTheUserThatItemsChanged(
+    dto: SaveItemsDTO,
+    savedItems: Item[],
+    lastUpdatedTimestamp: number,
+  ): Promise<void> {
+    if (savedItems.length === 0 || !dto.sessionUuid) {
+      return
+    }
+
+    const tenPercentSpreadArray = Array.from(Array(10).keys())
+    const diceRoll = this.deterministicSelector.select(dto.userUuid, tenPercentSpreadArray)
+    if (diceRoll !== 1) {
+      return
+    }
+
+    const itemsChangedEvent = this.domainEventFactory.createItemsChangedOnServerEvent({
+      userUuid: dto.userUuid,
+      sessionUuid: dto.sessionUuid,
+      timestamp: lastUpdatedTimestamp,
+    })
+    const result = await this.sendEventToClient.execute({
+      userUuid: dto.userUuid,
+      originatingSessionUuid: dto.sessionUuid,
+      event: itemsChangedEvent,
+    })
+    /* istanbul ignore next */
+    if (result.isFailed()) {
+      this.logger.error(`[${dto.userUuid}] Sending items changed event to client failed. Error: ${result.getError()}`)
+    }
+  }
+
   private calculateSyncToken(lastUpdatedTimestamp: number, savedItems: Array<Item>): string {
     if (savedItems.length) {
       const sortedItems = savedItems.sort((itemA: Item, itemB: Item) => {

+ 19 - 0
packages/syncing-server/src/Domain/UseCase/Syncing/SendEventToClient/SendEventToClient.spec.ts

@@ -17,6 +17,8 @@ describe('SendEventToClient', () => {
   beforeEach(() => {
     logger = {} as jest.Mocked<Logger>
     logger.info = jest.fn()
+    logger.debug = jest.fn()
+    logger.error = jest.fn()
 
     domainEventFactory = {} as jest.Mocked<DomainEventFactoryInterface>
     domainEventFactory.createWebSocketMessageRequestedEvent = jest
@@ -58,4 +60,21 @@ describe('SendEventToClient', () => {
 
     expect(result.isFailed()).toBe(true)
   })
+
+  it('should return a failed result if error is thrown', async () => {
+    const useCase = createUseCase()
+
+    domainEventFactory.createWebSocketMessageRequestedEvent = jest.fn().mockImplementation(() => {
+      throw new Error('test')
+    })
+
+    const result = await useCase.execute({
+      userUuid: '00000000-0000-0000-0000-000000000000',
+      event: {
+        type: 'test',
+      } as jest.Mocked<DomainEventInterface>,
+    })
+
+    expect(result.isFailed()).toBe(true)
+  })
 })

+ 17 - 12
packages/syncing-server/src/Domain/UseCase/Syncing/SendEventToClient/SendEventToClient.ts

@@ -13,21 +13,26 @@ export class SendEventToClient implements UseCaseInterface<void> {
   ) {}
 
   async execute(dto: SendEventToClientDTO): Promise<Result<void>> {
-    const userUuidOrError = Uuid.create(dto.userUuid)
-    if (userUuidOrError.isFailed()) {
-      return Result.fail(userUuidOrError.getError())
-    }
-    const userUuid = userUuidOrError.getValue()
+    try {
+      const userUuidOrError = Uuid.create(dto.userUuid)
+      if (userUuidOrError.isFailed()) {
+        return Result.fail(userUuidOrError.getError())
+      }
+      const userUuid = userUuidOrError.getValue()
 
-    this.logger.info(`[WebSockets] Requesting message ${dto.event.type} to user ${dto.userUuid}`)
+      this.logger.debug(`[WebSockets] Requesting message ${dto.event.type} to user ${dto.userUuid}`)
 
-    const event = this.domainEventFactory.createWebSocketMessageRequestedEvent({
-      userUuid: userUuid.value,
-      message: JSON.stringify(dto.event),
-    })
+      const event = this.domainEventFactory.createWebSocketMessageRequestedEvent({
+        userUuid: userUuid.value,
+        message: JSON.stringify(dto.event),
+        originatingSessionUuid: dto.originatingSessionUuid,
+      })
 
-    await this.domainEventPublisher.publish(event)
+      await this.domainEventPublisher.publish(event)
 
-    return Result.ok()
+      return Result.ok()
+    } catch (error) {
+      return Result.fail(`Failed to send event to client: ${(error as Error).message}`)
+    }
   }
 }

+ 1 - 0
packages/syncing-server/src/Domain/UseCase/Syncing/SendEventToClient/SendEventToClientDTO.ts

@@ -3,4 +3,5 @@ import { DomainEventInterface } from '@standardnotes/domain-events'
 export interface SendEventToClientDTO {
   userUuid: string
   event: DomainEventInterface
+  originatingSessionUuid?: string
 }

+ 9 - 0
packages/websockets/.env.sample

@@ -8,6 +8,15 @@ AUTH_JWT_SECRET=auth_jwt_secret
 
 REDIS_URL=redis://cache
 
+DB_HOST=127.0.0.1
+DB_REPLICA_HOST=127.0.0.1
+DB_PORT=3306
+DB_USERNAME=websockets
+DB_PASSWORD=changeme123
+DB_DATABASE=websockets
+DB_DEBUG_LEVEL=all # "all" | "query" | "schema" | "error" | "warn" | "info" | "log" | "migration"
+DB_TYPE=mysql
+
 SNS_TOPIC_ARN=
 SNS_AWS_REGION=
 SQS_QUEUE_URL=

+ 1 - 1
packages/websockets/bin/server.ts

@@ -12,7 +12,7 @@ import { ContainerConfigLoader } from '../src/Bootstrap/Container'
 import TYPES from '../src/Bootstrap/Types'
 import { Env } from '../src/Bootstrap/Env'
 
-const container = new ContainerConfigLoader()
+const container = new ContainerConfigLoader('server')
 void container.load().then((container) => {
   const env: Env = new Env()
   env.load()

+ 1 - 1
packages/websockets/bin/worker.ts

@@ -7,7 +7,7 @@ import TYPES from '../src/Bootstrap/Types'
 import { Env } from '../src/Bootstrap/Env'
 import { DomainEventSubscriberInterface } from '@standardnotes/domain-events'
 
-const container = new ContainerConfigLoader()
+const container = new ContainerConfigLoader('worker')
 void container.load().then((container) => {
   const env: Env = new Env()
   env.load()

+ 17 - 0
packages/websockets/migrations/mysql/1701087671322-initial-database.ts

@@ -0,0 +1,17 @@
+import { MigrationInterface, QueryRunner } from 'typeorm'
+
+export class InitialDatabase1701087671322 implements MigrationInterface {
+  name = 'InitialDatabase1701087671322'
+
+  public async up(queryRunner: QueryRunner): Promise<void> {
+    await queryRunner.query(
+      'CREATE TABLE `connections` (`uuid` varchar(36) NOT NULL, `user_uuid` varchar(36) NOT NULL, `session_uuid` varchar(36) NOT NULL, `connection_id` varchar(255) NOT NULL, `created_at_timestamp` bigint NOT NULL, `updated_at_timestamp` bigint NOT NULL, INDEX `index_connections_on_user_uuid` (`user_uuid`), UNIQUE INDEX `index_connections_on_connection_id` (`connection_id`), PRIMARY KEY (`uuid`)) ENGINE=InnoDB',
+    )
+  }
+
+  public async down(queryRunner: QueryRunner): Promise<void> {
+    await queryRunner.query('DROP INDEX `index_connections_on_connection_id` ON `connections`')
+    await queryRunner.query('DROP INDEX `index_connections_on_user_uuid` ON `connections`')
+    await queryRunner.query('DROP TABLE `connections`')
+  }
+}

+ 1 - 1
packages/websockets/package.json

@@ -29,13 +29,13 @@
   "dependencies": {
     "@aws-sdk/client-apigatewaymanagementapi": "^3.427.0",
     "@aws-sdk/client-sqs": "^3.427.0",
-    "@standardnotes/api": "^1.26.26",
     "@standardnotes/common": "workspace:^",
     "@standardnotes/domain-core": "workspace:^",
     "@standardnotes/domain-events": "workspace:^",
     "@standardnotes/domain-events-infra": "workspace:^",
     "@standardnotes/responses": "^1.13.27",
     "@standardnotes/security": "workspace:^",
+    "@standardnotes/time": "workspace:^",
     "cors": "2.8.5",
     "dotenv": "^16.0.1",
     "express": "^4.18.2",

+ 58 - 26
packages/websockets/src/Bootstrap/Container.ts

@@ -1,5 +1,4 @@
 import * as winston from 'winston'
-import Redis from 'ioredis'
 import { SQSClient, SQSClientConfig } from '@aws-sdk/client-sqs'
 import { ApiGatewayManagementApiClient } from '@aws-sdk/client-apigatewaymanagementapi'
 import { Container } from 'inversify'
@@ -8,16 +7,14 @@ import {
   DomainEventMessageHandlerInterface,
   DomainEventSubscriberInterface,
 } from '@standardnotes/domain-events'
+import { TimerInterface, Timer } from '@standardnotes/time'
 import { Env } from './Env'
 import TYPES from './Types'
 import { WebSocketsConnectionRepositoryInterface } from '../Domain/WebSockets/WebSocketsConnectionRepositoryInterface'
-import { RedisWebSocketsConnectionRepository } from '../Infra/Redis/RedisWebSocketsConnectionRepository'
 import { AddWebSocketsConnection } from '../Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnection'
 import { RemoveWebSocketsConnection } from '../Domain/UseCase/RemoveWebSocketsConnection/RemoveWebSocketsConnection'
-import { WebSocketsClientMessenger } from '../Infra/WebSockets/WebSocketsClientMessenger'
 import { SQSDomainEventSubscriber, SQSEventMessageHandler } from '@standardnotes/domain-events-infra'
-import { ApiGatewayAuthMiddleware } from '../Controller/ApiGatewayAuthMiddleware'
-
+import { ApiGatewayAuthMiddleware } from '../Infra/InversifyExpressUtils/Middleware/ApiGatewayAuthMiddleware'
 import {
   CrossServiceTokenData,
   TokenDecoder,
@@ -27,29 +24,25 @@ import {
   WebSocketConnectionTokenData,
 } from '@standardnotes/security'
 import { CreateWebSocketConnectionToken } from '../Domain/UseCase/CreateWebSocketConnectionToken/CreateWebSocketConnectionToken'
-import { WebSocketsController } from '../Controller/WebSocketsController'
-import { WebSocketServerInterface } from '@standardnotes/api'
-import { ClientMessengerInterface } from '../Client/ClientMessengerInterface'
 import { WebSocketMessageRequestedEventHandler } from '../Domain/Handler/WebSocketMessageRequestedEventHandler'
+import { SQLConnectionRepository } from '../Infra/TypeORM/SQLConnectionRepository'
+import { Connection } from '../Domain/Connection/Connection'
+import { SQLConnection } from '../Infra/TypeORM/SQLConnection'
+import { MapperInterface } from '@standardnotes/domain-core'
+import { Repository } from 'typeorm'
+import { ConnectionPersistenceMapper } from '../Mapping/SQL/ConnectionPersistenceMapper'
+import { AppDataSource } from './DataSource'
+import { SendMessageToClient } from '../Domain/UseCase/SendMessageToClient/SendMessageToClient'
 
 export class ContainerConfigLoader {
+  constructor(private mode: 'server' | 'worker' = 'server') {}
+
   async load(): Promise<Container> {
     const env: Env = new Env()
     env.load()
 
     const container = new Container()
 
-    const redisUrl = env.get('REDIS_URL')
-    const isRedisInClusterMode = redisUrl.indexOf(',') > 0
-    let redis
-    if (isRedisInClusterMode) {
-      redis = new Redis.Cluster(redisUrl.split(','))
-    } else {
-      redis = new Redis(redisUrl)
-    }
-
-    container.bind(TYPES.Redis).toConstantValue(redis)
-
     const winstonFormatters = [winston.format.splat(), winston.format.json()]
 
     const logger = winston.createLogger({
@@ -59,6 +52,13 @@ export class ContainerConfigLoader {
     })
     container.bind<winston.Logger>(TYPES.Logger).toConstantValue(logger)
 
+    const appDataSource = new AppDataSource({ env, runMigrations: this.mode === 'server' })
+    await appDataSource.initialize()
+
+    logger.debug('Database initialized')
+
+    container.bind<TimerInterface>(TYPES.Timer).toConstantValue(new Timer())
+
     if (env.get('SQS_QUEUE_URL', true)) {
       const sqsConfig: SQSClientConfig = {
         region: env.get('SQS_AWS_REGION', true),
@@ -83,14 +83,26 @@ export class ContainerConfigLoader {
         region: env.get('API_GATEWAY_AWS_REGION', true) ?? 'us-east-1',
       }),
     )
+    // Mappers
+    container
+      .bind<MapperInterface<Connection, SQLConnection>>(TYPES.ConnectionPersistenceMapper)
+      .toConstantValue(new ConnectionPersistenceMapper())
 
-    // Controller
-    container.bind<WebSocketServerInterface>(TYPES.WebSocketsController).to(WebSocketsController)
+    // ORM
+    container
+      .bind<Repository<SQLConnection>>(TYPES.ORMConnectionRepository)
+      .toConstantValue(appDataSource.getRepository(SQLConnection))
 
     // Repositories
     container
       .bind<WebSocketsConnectionRepositoryInterface>(TYPES.WebSocketsConnectionRepository)
-      .to(RedisWebSocketsConnectionRepository)
+      .toConstantValue(
+        new SQLConnectionRepository(
+          container.get<Repository<SQLConnection>>(TYPES.ORMConnectionRepository),
+          container.get<MapperInterface<Connection, SQLConnection>>(TYPES.ConnectionPersistenceMapper),
+          container.get<winston.Logger>(TYPES.Logger),
+        ),
+      )
 
     // Middleware
     container.bind<ApiGatewayAuthMiddleware>(TYPES.ApiGatewayAuthMiddleware).to(ApiGatewayAuthMiddleware)
@@ -103,21 +115,42 @@ export class ContainerConfigLoader {
     container
       .bind(TYPES.WEB_SOCKET_CONNECTION_TOKEN_TTL)
       .toConstantValue(+env.get('WEB_SOCKET_CONNECTION_TOKEN_TTL', true))
-    container.bind(TYPES.REDIS_URL).toConstantValue(env.get('REDIS_URL'))
     container.bind(TYPES.SQS_QUEUE_URL).toConstantValue(env.get('SQS_QUEUE_URL'))
     container.bind(TYPES.VERSION).toConstantValue(env.get('VERSION'))
 
     // use cases
-    container.bind<AddWebSocketsConnection>(TYPES.AddWebSocketsConnection).to(AddWebSocketsConnection)
+    container
+      .bind<AddWebSocketsConnection>(TYPES.AddWebSocketsConnection)
+      .toConstantValue(
+        new AddWebSocketsConnection(
+          container.get<WebSocketsConnectionRepositoryInterface>(TYPES.WebSocketsConnectionRepository),
+          container.get<TimerInterface>(TYPES.Timer),
+          container.get<winston.Logger>(TYPES.Logger),
+        ),
+      )
     container.bind<RemoveWebSocketsConnection>(TYPES.RemoveWebSocketsConnection).to(RemoveWebSocketsConnection)
     container
       .bind<CreateWebSocketConnectionToken>(TYPES.CreateWebSocketConnectionToken)
       .to(CreateWebSocketConnectionToken)
+    container
+      .bind<SendMessageToClient>(TYPES.SendMessageToClient)
+      .toConstantValue(
+        new SendMessageToClient(
+          container.get<WebSocketsConnectionRepositoryInterface>(TYPES.WebSocketsConnectionRepository),
+          container.get<ApiGatewayManagementApiClient>(TYPES.WebSockets_ApiGatewayManagementApiClient),
+          container.get<winston.Logger>(TYPES.Logger),
+        ),
+      )
 
     // Handlers
     container
       .bind<WebSocketMessageRequestedEventHandler>(TYPES.WebSocketMessageRequestedEventHandler)
-      .to(WebSocketMessageRequestedEventHandler)
+      .toConstantValue(
+        new WebSocketMessageRequestedEventHandler(
+          container.get<SendMessageToClient>(TYPES.SendMessageToClient),
+          container.get<winston.Logger>(TYPES.Logger),
+        ),
+      )
 
     // Services
     container
@@ -128,7 +161,6 @@ export class ContainerConfigLoader {
       .toConstantValue(
         new TokenEncoder<WebSocketConnectionTokenData>(container.get(TYPES.WEB_SOCKET_CONNECTION_TOKEN_SECRET)),
       )
-    container.bind<ClientMessengerInterface>(TYPES.WebSocketsClientMessenger).to(WebSocketsClientMessenger)
 
     const eventHandlers: Map<string, DomainEventHandlerInterface> = new Map([
       ['WEB_SOCKET_MESSAGE_REQUESTED', container.get(TYPES.WebSocketMessageRequestedEventHandler)],

+ 84 - 0
packages/websockets/src/Bootstrap/DataSource.ts

@@ -0,0 +1,84 @@
+import { DataSource, EntityTarget, LoggerOptions, ObjectLiteral, Repository } from 'typeorm'
+import { MysqlConnectionOptions } from 'typeorm/driver/mysql/MysqlConnectionOptions'
+import { Env } from './Env'
+import { SQLConnection } from '../Infra/TypeORM/SQLConnection'
+
+export class AppDataSource {
+  private _dataSource: DataSource | undefined
+
+  constructor(
+    private configuration: {
+      env: Env
+      runMigrations: boolean
+    },
+  ) {}
+
+  getRepository<Entity extends ObjectLiteral>(target: EntityTarget<Entity>): Repository<Entity> {
+    if (!this._dataSource) {
+      throw new Error('DataSource not initialized')
+    }
+
+    return this._dataSource.getRepository(target)
+  }
+
+  async initialize(): Promise<void> {
+    await this.dataSource.initialize()
+  }
+
+  get dataSource(): DataSource {
+    this.configuration.env.load()
+
+    const maxQueryExecutionTime = this.configuration.env.get('DB_MAX_QUERY_EXECUTION_TIME', true)
+      ? +this.configuration.env.get('DB_MAX_QUERY_EXECUTION_TIME', true)
+      : 45_000
+
+    const commonDataSourceOptions = {
+      maxQueryExecutionTime,
+      entities: [SQLConnection],
+      migrations: [`${__dirname}/../../migrations/mysql/*.js`],
+      migrationsRun: this.configuration.runMigrations,
+      logging: <LoggerOptions>this.configuration.env.get('DB_DEBUG_LEVEL', true) ?? 'info',
+    }
+
+    const inReplicaMode = this.configuration.env.get('DB_REPLICA_HOST', true) ? true : false
+
+    const replicationConfig = {
+      master: {
+        host: this.configuration.env.get('DB_HOST'),
+        port: parseInt(this.configuration.env.get('DB_PORT')),
+        username: this.configuration.env.get('DB_USERNAME'),
+        password: this.configuration.env.get('DB_PASSWORD'),
+        database: this.configuration.env.get('DB_DATABASE'),
+      },
+      slaves: [
+        {
+          host: this.configuration.env.get('DB_REPLICA_HOST', true),
+          port: parseInt(this.configuration.env.get('DB_PORT')),
+          username: this.configuration.env.get('DB_USERNAME'),
+          password: this.configuration.env.get('DB_PASSWORD'),
+          database: this.configuration.env.get('DB_DATABASE'),
+        },
+      ],
+      removeNodeErrorCount: 10,
+      restoreNodeTimeout: 5,
+    }
+
+    const mySQLDataSourceOptions: MysqlConnectionOptions = {
+      ...commonDataSourceOptions,
+      type: 'mysql',
+      charset: 'utf8mb4',
+      supportBigNumbers: true,
+      bigNumberStrings: false,
+      replication: inReplicaMode ? replicationConfig : undefined,
+      host: inReplicaMode ? undefined : this.configuration.env.get('DB_HOST'),
+      port: inReplicaMode ? undefined : parseInt(this.configuration.env.get('DB_PORT')),
+      username: inReplicaMode ? undefined : this.configuration.env.get('DB_USERNAME'),
+      password: inReplicaMode ? undefined : this.configuration.env.get('DB_PASSWORD'),
+      database: inReplicaMode ? undefined : this.configuration.env.get('DB_DATABASE'),
+    }
+
+    this._dataSource = new DataSource(mySQLDataSourceOptions)
+
+    return this._dataSource
+  }
+}

+ 7 - 0
packages/websockets/src/Bootstrap/MigrationsDataSource.ts

@@ -0,0 +1,7 @@
+import { AppDataSource } from './DataSource'
+import { Env } from './Env'
+
+const env: Env = new Env()
+env.load()
+
+export const MigrationsDataSource = new AppDataSource({ env, runMigrations: true }).dataSource

+ 6 - 3
packages/websockets/src/Bootstrap/Types.ts

@@ -1,10 +1,12 @@
 const TYPES = {
   Logger: Symbol.for('Logger'),
-  Redis: Symbol.for('Redis'),
+  Timer: Symbol.for('Timer'),
   SQS: Symbol.for('SQS'),
   WebSockets_ApiGatewayManagementApiClient: Symbol.for('WebSockets_ApiGatewayManagementApiClient'),
-  // Controller
-  WebSocketsController: Symbol.for('WebSocketsController'),
+  // Mappers
+  ConnectionPersistenceMapper: Symbol.for('ConnectionPersistenceMapper'),
+  // ORM
+  ORMConnectionRepository: Symbol.for('ORMConnectionRepository'),
   // Repositories
   WebSocketsConnectionRepository: Symbol.for('WebSocketsConnectionRepository'),
   // Middleware
@@ -22,6 +24,7 @@ const TYPES = {
   AddWebSocketsConnection: Symbol.for('AddWebSocketsConnection'),
   RemoveWebSocketsConnection: Symbol.for('RemoveWebSocketsConnection'),
   CreateWebSocketConnectionToken: Symbol.for('CreateWebSocketConnectionToken'),
+  SendMessageToClient: Symbol.for('SendMessageToClient'),
   // Handlers
   WebSocketMessageRequestedEventHandler: Symbol.for('WebSocketMessageRequestedEventHandler'),
   // Services

+ 0 - 3
packages/websockets/src/Client/ClientMessengerInterface.ts

@@ -1,3 +0,0 @@
-export interface ClientMessengerInterface {
-  send(userUuid: string, message: string): Promise<void>
-}

+ 0 - 99
packages/websockets/src/Controller/ApiGatewayAuthMiddleware.spec.ts

@@ -1,99 +0,0 @@
-import 'reflect-metadata'
-
-import { ApiGatewayAuthMiddleware } from './ApiGatewayAuthMiddleware'
-import { NextFunction, Request, Response } from 'express'
-import { Logger } from 'winston'
-import { CrossServiceTokenData, TokenDecoderInterface } from '@standardnotes/security'
-import { RoleName } from '@standardnotes/domain-core'
-
-describe('ApiGatewayAuthMiddleware', () => {
-  let tokenDecoder: TokenDecoderInterface<CrossServiceTokenData>
-  let request: Request
-  let response: Response
-  let next: NextFunction
-
-  const logger = {
-    debug: jest.fn(),
-  } as unknown as jest.Mocked<Logger>
-
-  const createMiddleware = () => new ApiGatewayAuthMiddleware(tokenDecoder, logger)
-
-  beforeEach(() => {
-    tokenDecoder = {} as jest.Mocked<TokenDecoderInterface<CrossServiceTokenData>>
-    tokenDecoder.decodeToken = jest.fn().mockReturnValue({
-      user: {
-        uuid: '1-2-3',
-        email: 'test@test.te',
-      },
-      roles: [
-        {
-          uuid: 'a-b-c',
-          name: RoleName.NAMES.CoreUser,
-        },
-      ],
-    })
-
-    request = {
-      headers: {},
-    } as jest.Mocked<Request>
-    response = {
-      locals: {},
-    } as jest.Mocked<Response>
-    response.status = jest.fn().mockReturnThis()
-    response.send = jest.fn()
-    next = jest.fn()
-  })
-
-  it('should authorize user', async () => {
-    request.headers['x-auth-token'] = 'auth-jwt-token'
-
-    await createMiddleware().handler(request, response, next)
-
-    expect(response.locals.user).toEqual({
-      uuid: '1-2-3',
-      email: 'test@test.te',
-    })
-    expect(response.locals.roles).toEqual([
-      {
-        uuid: 'a-b-c',
-        name: RoleName.NAMES.CoreUser,
-      },
-    ])
-
-    expect(next).toHaveBeenCalled()
-  })
-
-  it('should not authorize if request is missing auth jwt token in headers', async () => {
-    await createMiddleware().handler(request, response, next)
-
-    expect(response.status).toHaveBeenCalledWith(401)
-    expect(next).not.toHaveBeenCalled()
-  })
-
-  it('should not authorize if auth jwt token is malformed', async () => {
-    request.headers['x-auth-token'] = 'auth-jwt-token'
-
-    tokenDecoder.decodeToken = jest.fn().mockReturnValue(undefined)
-
-    await createMiddleware().handler(request, response, next)
-
-    expect(response.status).toHaveBeenCalledWith(401)
-    expect(next).not.toHaveBeenCalled()
-  })
-
-  it('should pass the error to next middleware if one occurres', async () => {
-    request.headers['x-auth-token'] = 'auth-jwt-token'
-
-    const error = new Error('Ooops')
-
-    tokenDecoder.decodeToken = jest.fn().mockImplementation(() => {
-      throw error
-    })
-
-    await createMiddleware().handler(request, response, next)
-
-    expect(response.status).not.toHaveBeenCalled()
-
-    expect(next).toHaveBeenCalledWith(error)
-  })
-})

+ 0 - 28
packages/websockets/src/Controller/WebSocketsController.spec.ts

@@ -1,28 +0,0 @@
-import 'reflect-metadata'
-
-import { WebSocketsController } from './WebSocketsController'
-import { CreateWebSocketConnectionToken } from '../Domain/UseCase/CreateWebSocketConnectionToken/CreateWebSocketConnectionToken'
-
-describe('WebSocketsController', () => {
-  let createWebSocketConnectionToken: CreateWebSocketConnectionToken
-
-  const createController = () => new WebSocketsController(createWebSocketConnectionToken)
-
-  beforeEach(() => {
-    createWebSocketConnectionToken = {} as jest.Mocked<CreateWebSocketConnectionToken>
-    createWebSocketConnectionToken.execute = jest.fn().mockReturnValue({ token: 'foobar' })
-  })
-
-  it('should create a web sockets connection token', async () => {
-    const response = await createController().createConnectionToken({ userUuid: '1-2-3' })
-
-    expect(response).toEqual({
-      status: 200,
-      data: { token: 'foobar' },
-    })
-
-    expect(createWebSocketConnectionToken.execute).toHaveBeenCalledWith({
-      userUuid: '1-2-3',
-    })
-  })
-})

+ 0 - 29
packages/websockets/src/Controller/WebSocketsController.ts

@@ -1,29 +0,0 @@
-import { HttpStatusCode, HttpResponse } from '@standardnotes/responses'
-import {
-  WebSocketConnectionTokenRequestParams,
-  WebSocketConnectionTokenResponseBody,
-  WebSocketServerInterface,
-} from '@standardnotes/api'
-import { inject, injectable } from 'inversify'
-
-import TYPES from '../Bootstrap/Types'
-import { CreateWebSocketConnectionToken } from '../Domain/UseCase/CreateWebSocketConnectionToken/CreateWebSocketConnectionToken'
-
-@injectable()
-export class WebSocketsController implements WebSocketServerInterface {
-  constructor(
-    @inject(TYPES.CreateWebSocketConnectionToken)
-    private createWebSocketConnectionToken: CreateWebSocketConnectionToken,
-  ) {}
-
-  async createConnectionToken(
-    params: WebSocketConnectionTokenRequestParams,
-  ): Promise<HttpResponse<WebSocketConnectionTokenResponseBody>> {
-    const result = await this.createWebSocketConnectionToken.execute({ userUuid: params.userUuid as string })
-
-    return {
-      status: HttpStatusCode.Success,
-      data: result,
-    }
-  }
-}

+ 13 - 0
packages/websockets/src/Domain/Connection/Connection.ts

@@ -0,0 +1,13 @@
+import { Entity, Result, UniqueEntityId } from '@standardnotes/domain-core'
+
+import { ConnectionProps } from './ConnectionProps'
+
+export class Connection extends Entity<ConnectionProps> {
+  private constructor(props: ConnectionProps, id?: UniqueEntityId) {
+    super(props, id)
+  }
+
+  static create(props: ConnectionProps, id?: UniqueEntityId): Result<Connection> {
+    return Result.ok<Connection>(new Connection(props, id))
+  }
+}

+ 8 - 0
packages/websockets/src/Domain/Connection/ConnectionProps.ts

@@ -0,0 +1,8 @@
+import { Timestamps, Uuid } from '@standardnotes/domain-core'
+
+export interface ConnectionProps {
+  userUuid: Uuid
+  sessionUuid: Uuid
+  connectionId: string
+  timestamps: Timestamps
+}

+ 11 - 9
packages/websockets/src/Domain/Handler/WebSocketMessageRequestedEventHandler.ts

@@ -1,20 +1,22 @@
 import { DomainEventHandlerInterface, WebSocketMessageRequestedEvent } from '@standardnotes/domain-events'
-import { inject, injectable } from 'inversify'
 import { Logger } from 'winston'
+import { SendMessageToClient } from '../UseCase/SendMessageToClient/SendMessageToClient'
 
-import TYPES from '../../Bootstrap/Types'
-import { ClientMessengerInterface } from '../../Client/ClientMessengerInterface'
-
-@injectable()
 export class WebSocketMessageRequestedEventHandler implements DomainEventHandlerInterface {
   constructor(
-    @inject(TYPES.WebSocketsClientMessenger) private webSocketsClientMessenger: ClientMessengerInterface,
-    @inject(TYPES.Logger) private logger: Logger,
+    private sendMessageToClient: SendMessageToClient,
+    private logger: Logger,
   ) {}
 
   async handle(event: WebSocketMessageRequestedEvent): Promise<void> {
-    this.logger.debug(`Sending message to user ${event.payload.userUuid}`)
+    const result = await this.sendMessageToClient.execute({
+      userUuid: event.payload.userUuid,
+      message: event.payload.message,
+      originatingSessionUuid: event.payload.originatingSessionUuid,
+    })
 
-    await this.webSocketsClientMessenger.send(event.payload.userUuid, event.payload.message)
+    if (result.isFailed()) {
+      this.logger.error(`Could not send message to user ${event.payload.userUuid}. Error: ${result.getError()}`)
+    }
   }
 }

+ 36 - 5
packages/websockets/src/Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnection.spec.ts

@@ -1,14 +1,15 @@
-import 'reflect-metadata'
 import { Logger } from 'winston'
+import { TimerInterface } from '@standardnotes/time'
 import { WebSocketsConnectionRepositoryInterface } from '../../WebSockets/WebSocketsConnectionRepositoryInterface'
 
 import { AddWebSocketsConnection } from './AddWebSocketsConnection'
 
 describe('AddWebSocketsConnection', () => {
   let webSocketsConnectionRepository: WebSocketsConnectionRepositoryInterface
+  let timer: TimerInterface
   let logger: Logger
 
-  const createUseCase = () => new AddWebSocketsConnection(webSocketsConnectionRepository, logger)
+  const createUseCase = () => new AddWebSocketsConnection(webSocketsConnectionRepository, timer, logger)
 
   beforeEach(() => {
     webSocketsConnectionRepository = {} as jest.Mocked<WebSocketsConnectionRepositoryInterface>
@@ -17,12 +18,18 @@ describe('AddWebSocketsConnection', () => {
     logger = {} as jest.Mocked<Logger>
     logger.debug = jest.fn()
     logger.error = jest.fn()
+
+    timer = {} as jest.Mocked<TimerInterface>
+    timer.getTimestampInMicroseconds = jest.fn().mockReturnValue(123)
   })
 
   it('should save a web sockets connection for a user for further communication', async () => {
-    const result = await createUseCase().execute({ userUuid: '1-2-3', connectionId: '2-3-4' })
+    const result = await createUseCase().execute({
+      userUuid: '00000000-0000-0000-0000-000000000000',
+      sessionUuid: '00000000-0000-0000-0000-000000000000',
+      connectionId: '2-3-4',
+    })
 
-    expect(webSocketsConnectionRepository.saveConnection).toHaveBeenCalledWith('1-2-3', '2-3-4')
     expect(result.isFailed()).toBe(false)
   })
 
@@ -31,7 +38,31 @@ describe('AddWebSocketsConnection', () => {
       .fn()
       .mockRejectedValueOnce(new Error('Could not save connection'))
 
-    const result = await createUseCase().execute({ userUuid: '1-2-3', connectionId: '2-3-4' })
+    const result = await createUseCase().execute({
+      userUuid: '00000000-0000-0000-0000-000000000000',
+      sessionUuid: '00000000-0000-0000-0000-000000000000',
+      connectionId: '2-3-4',
+    })
+
+    expect(result.isFailed()).toBe(true)
+  })
+
+  it('should return failure if the user uuid is invalid', async () => {
+    const result = await createUseCase().execute({
+      userUuid: 'invalid',
+      sessionUuid: '00000000-0000-0000-0000-000000000000',
+      connectionId: '2-3-4',
+    })
+
+    expect(result.isFailed()).toBe(true)
+  })
+
+  it('should return error if the session uuid is invalid', async () => {
+    const result = await createUseCase().execute({
+      userUuid: '00000000-0000-0000-0000-000000000000',
+      sessionUuid: 'invalid',
+      connectionId: '2-3-4',
+    })
 
     expect(result.isFailed()).toBe(true)
   })

+ 33 - 7
packages/websockets/src/Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnection.ts

@@ -1,24 +1,50 @@
-import { inject, injectable } from 'inversify'
 import { Logger } from 'winston'
-import { Result, UseCaseInterface } from '@standardnotes/domain-core'
+import { Result, Timestamps, UseCaseInterface, Uuid } from '@standardnotes/domain-core'
+import { TimerInterface } from '@standardnotes/time'
 
-import TYPES from '../../../Bootstrap/Types'
 import { WebSocketsConnectionRepositoryInterface } from '../../WebSockets/WebSocketsConnectionRepositoryInterface'
 import { AddWebSocketsConnectionDTO } from './AddWebSocketsConnectionDTO'
+import { Connection } from '../../Connection/Connection'
 
-@injectable()
 export class AddWebSocketsConnection implements UseCaseInterface<void> {
   constructor(
-    @inject(TYPES.WebSocketsConnectionRepository)
     private webSocketsConnectionRepository: WebSocketsConnectionRepositoryInterface,
-    @inject(TYPES.Logger) private logger: Logger,
+    private timer: TimerInterface,
+    private logger: Logger,
   ) {}
 
   async execute(dto: AddWebSocketsConnectionDTO): Promise<Result<void>> {
     try {
       this.logger.debug(`Persisting connection ${dto.connectionId} for user ${dto.userUuid}`)
 
-      await this.webSocketsConnectionRepository.saveConnection(dto.userUuid, dto.connectionId)
+      const userUuidOrError = Uuid.create(dto.userUuid)
+      if (userUuidOrError.isFailed()) {
+        return Result.fail(userUuidOrError.getError())
+      }
+      const userUuid = userUuidOrError.getValue()
+
+      const sessionUuidOrError = Uuid.create(dto.sessionUuid)
+      if (sessionUuidOrError.isFailed()) {
+        return Result.fail(sessionUuidOrError.getError())
+      }
+      const sessionUuid = sessionUuidOrError.getValue()
+
+      const connectionOrError = Connection.create({
+        userUuid,
+        sessionUuid,
+        connectionId: dto.connectionId,
+        timestamps: Timestamps.create(
+          this.timer.getTimestampInMicroseconds(),
+          this.timer.getTimestampInMicroseconds(),
+        ).getValue(),
+      })
+      /* istanbul ignore next */
+      if (connectionOrError.isFailed()) {
+        return Result.fail(connectionOrError.getError())
+      }
+      const connection = connectionOrError.getValue()
+
+      await this.webSocketsConnectionRepository.saveConnection(connection)
 
       return Result.ok()
     } catch (error) {

+ 1 - 0
packages/websockets/src/Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnectionDTO.ts

@@ -1,4 +1,5 @@
 export type AddWebSocketsConnectionDTO = {
   userUuid: string
+  sessionUuid: string
   connectionId: string
 }

+ 99 - 0
packages/websockets/src/Domain/UseCase/SendMessageToClient/SendMessageToClient.spec.ts

@@ -0,0 +1,99 @@
+import { ApiGatewayManagementApiClient } from '@aws-sdk/client-apigatewaymanagementapi'
+import { WebSocketsConnectionRepositoryInterface } from '../../WebSockets/WebSocketsConnectionRepositoryInterface'
+import { SendMessageToClient } from './SendMessageToClient'
+import { Logger } from 'winston'
+import { Connection } from '../../Connection/Connection'
+import { Timestamps, Uuid } from '@standardnotes/domain-core'
+
+describe('SendMessageToClient', () => {
+  let webSocketsConnectionRepository: WebSocketsConnectionRepositoryInterface
+  let apiGatewayManagementClient: ApiGatewayManagementApiClient
+  let logger: Logger
+
+  const createUseCase = () =>
+    new SendMessageToClient(webSocketsConnectionRepository, apiGatewayManagementClient, logger)
+
+  beforeEach(() => {
+    const connection = Connection.create({
+      userUuid: Uuid.create('00000000-0000-0000-0000-000000000000').getValue(),
+      connectionId: 'connection-id',
+      sessionUuid: Uuid.create('00000000-0000-0000-0000-000000000000').getValue(),
+      timestamps: Timestamps.create(123, 123).getValue(),
+    }).getValue()
+
+    webSocketsConnectionRepository = {} as jest.Mocked<WebSocketsConnectionRepositoryInterface>
+    webSocketsConnectionRepository.findAllByUserUuid = jest.fn().mockResolvedValue([connection])
+
+    apiGatewayManagementClient = {} as jest.Mocked<ApiGatewayManagementApiClient>
+    apiGatewayManagementClient.send = jest.fn().mockResolvedValue({ $metadata: { httpStatusCode: 200 } })
+
+    logger = {} as jest.Mocked<Logger>
+    logger.debug = jest.fn()
+    logger.error = jest.fn()
+  })
+
+  it('sends message to all connections for a user', async () => {
+    const useCase = createUseCase()
+
+    const result = await useCase.execute({
+      userUuid: '00000000-0000-0000-0000-000000000000',
+      message: 'message',
+    })
+
+    expect(result.isFailed()).toBe(false)
+    expect(apiGatewayManagementClient.send).toHaveBeenCalledTimes(1)
+  })
+
+  it('does not send message to originating session', async () => {
+    const useCase = createUseCase()
+
+    const result = await useCase.execute({
+      userUuid: '00000000-0000-0000-0000-000000000000',
+      message: 'message',
+      originatingSessionUuid: '00000000-0000-0000-0000-000000000000',
+    })
+
+    expect(result.isFailed()).toBe(false)
+    expect(apiGatewayManagementClient.send).toHaveBeenCalledTimes(0)
+  })
+
+  it('returns error if sending message fails', async () => {
+    apiGatewayManagementClient.send = jest.fn().mockRejectedValue(new Error('error'))
+
+    const useCase = createUseCase()
+
+    const result = await useCase.execute({
+      userUuid: '00000000-0000-0000-0000-000000000000',
+      message: 'message',
+    })
+
+    expect(result.isFailed()).toBe(true)
+    expect(result.getError()).toBe(
+      'Could not send message to connection connection-id for user 00000000-0000-0000-0000-000000000000. Error: error',
+    )
+  })
+
+  it('returns error if the user uuid is invalid', async () => {
+    const useCase = createUseCase()
+
+    const result = await useCase.execute({
+      userUuid: 'invalid',
+      message: 'message',
+    })
+
+    expect(result.isFailed()).toBe(true)
+  })
+
+  it('return error if sending the message does not return a 200 status code', async () => {
+    apiGatewayManagementClient.send = jest.fn().mockResolvedValue({ $metadata: { httpStatusCode: 500 } })
+
+    const useCase = createUseCase()
+
+    const result = await useCase.execute({
+      userUuid: '00000000-0000-0000-0000-000000000000',
+      message: 'message',
+    })
+
+    expect(result.isFailed()).toBe(true)
+  })
+})

+ 57 - 0
packages/websockets/src/Domain/UseCase/SendMessageToClient/SendMessageToClient.ts

@@ -0,0 +1,57 @@
+import { Result, UseCaseInterface, Uuid } from '@standardnotes/domain-core'
+import { ApiGatewayManagementApiClient, PostToConnectionCommand } from '@aws-sdk/client-apigatewaymanagementapi'
+import { Logger } from 'winston'
+
+import { SendMessageToClientDTO } from './SendMessageToClientDTO'
+import { WebSocketsConnectionRepositoryInterface } from '../../WebSockets/WebSocketsConnectionRepositoryInterface'
+
+export class SendMessageToClient implements UseCaseInterface<void> {
+  constructor(
+    private webSocketsConnectionRepository: WebSocketsConnectionRepositoryInterface,
+    private apiGatewayManagementClient: ApiGatewayManagementApiClient,
+    private logger: Logger,
+  ) {}
+
+  async execute(dto: SendMessageToClientDTO): Promise<Result<void>> {
+    const userUuidOrError = Uuid.create(dto.userUuid)
+    if (userUuidOrError.isFailed()) {
+      return Result.fail(userUuidOrError.getError())
+    }
+    const userUuid = userUuidOrError.getValue()
+
+    const userConnections = await this.webSocketsConnectionRepository.findAllByUserUuid(userUuid)
+
+    for (const connection of userConnections) {
+      if (dto.originatingSessionUuid && connection.props.sessionUuid.value === dto.originatingSessionUuid) {
+        continue
+      }
+
+      this.logger.debug(`Sending message to connection ${connection.props.connectionId} for user ${userUuid.value}`)
+
+      const requestParams = {
+        ConnectionId: connection.props.connectionId,
+        Data: dto.message,
+      }
+
+      const command = new PostToConnectionCommand(requestParams)
+
+      try {
+        const response = await this.apiGatewayManagementClient.send(command)
+
+        if (response.$metadata.httpStatusCode !== 200) {
+          return Result.fail(
+            `Could not send message to connection ${connection.props.connectionId} for user ${userUuid.value}. Response status code: ${response.$metadata.httpStatusCode}`,
+          )
+        }
+      } catch (error) {
+        return Result.fail(
+          `Could not send message to connection ${connection.props.connectionId} for user ${userUuid.value}. Error: ${
+            (error as Error).message
+          }`,
+        )
+      }
+    }
+
+    return Result.ok()
+  }
+}

+ 5 - 0
packages/websockets/src/Domain/UseCase/SendMessageToClient/SendMessageToClientDTO.ts

@@ -0,0 +1,5 @@
+export interface SendMessageToClientDTO {
+  userUuid: string
+  message: string
+  originatingSessionUuid?: string
+}

+ 5 - 2
packages/websockets/src/Domain/WebSockets/WebSocketsConnectionRepositoryInterface.ts

@@ -1,5 +1,8 @@
+import { Uuid } from '@standardnotes/domain-core'
+import { Connection } from '../Connection/Connection'
+
 export interface WebSocketsConnectionRepositoryInterface {
-  findAllByUserUuid(userUuid: string): Promise<string[]>
-  saveConnection(userUuid: string, connectionId: string): Promise<void>
+  findAllByUserUuid(userUuid: Uuid): Promise<Connection[]>
+  saveConnection(connection: Connection): Promise<void>
   removeConnection(connectionId: string): Promise<void>
 }

+ 6 - 4
packages/websockets/src/Infra/InversifyExpressUtils/AnnotatedWebSocketsController.ts

@@ -1,4 +1,3 @@
-import { WebSocketServerInterface } from '@standardnotes/api'
 import { Request, Response } from 'express'
 import { inject } from 'inversify'
 import {
@@ -12,24 +11,26 @@ import {
 import TYPES from '../../Bootstrap/Types'
 import { AddWebSocketsConnection } from '../../Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnection'
 import { RemoveWebSocketsConnection } from '../../Domain/UseCase/RemoveWebSocketsConnection/RemoveWebSocketsConnection'
+import { CreateWebSocketConnectionToken } from '../../Domain/UseCase/CreateWebSocketConnectionToken/CreateWebSocketConnectionToken'
 
 @controller('/sockets')
 export class AnnotatedWebSocketsController extends BaseHttpController {
   constructor(
     @inject(TYPES.AddWebSocketsConnection) private addWebSocketsConnection: AddWebSocketsConnection,
     @inject(TYPES.RemoveWebSocketsConnection) private removeWebSocketsConnection: RemoveWebSocketsConnection,
-    @inject(TYPES.WebSocketsController) private webSocketsController: WebSocketServerInterface,
+    @inject(TYPES.CreateWebSocketConnectionToken)
+    private createWebSocketConnectionToken: CreateWebSocketConnectionToken,
   ) {
     super()
   }
 
   @httpPost('/tokens', TYPES.ApiGatewayAuthMiddleware)
   async createConnectionToken(_request: Request, response: Response): Promise<results.JsonResult> {
-    const result = await this.webSocketsController.createConnectionToken({
+    const result = await this.createWebSocketConnectionToken.execute({
       userUuid: response.locals.user.uuid,
     })
 
-    return this.json(result.data, result.status)
+    return this.json(result)
   }
 
   @httpPost('/connections/:connectionId', TYPES.ApiGatewayAuthMiddleware)
@@ -39,6 +40,7 @@ export class AnnotatedWebSocketsController extends BaseHttpController {
   ): Promise<results.OkResult | results.BadRequestResult> {
     const result = await this.addWebSocketsConnection.execute({
       userUuid: response.locals.user.uuid,
+      sessionUuid: response.locals.session.uuid,
       connectionId: request.params.connectionId,
     })
 

+ 2 - 2
packages/websockets/src/Controller/ApiGatewayAuthMiddleware.ts → packages/websockets/src/Infra/InversifyExpressUtils/Middleware/ApiGatewayAuthMiddleware.ts

@@ -3,7 +3,7 @@ import { NextFunction, Request, Response } from 'express'
 import { inject, injectable } from 'inversify'
 import { BaseMiddleware } from 'inversify-express-utils'
 import { Logger } from 'winston'
-import TYPES from '../Bootstrap/Types'
+import TYPES from '../../../Bootstrap/Types'
 
 @injectable()
 export class ApiGatewayAuthMiddleware extends BaseMiddleware {
@@ -33,7 +33,7 @@ export class ApiGatewayAuthMiddleware extends BaseMiddleware {
         request.headers['x-auth-token'] as string,
       )
 
-      if (token === undefined) {
+      if (token === undefined || token.session === undefined) {
         this.logger.debug('ApiGatewayAuthMiddleware authentication failure.')
 
         response.status(401).send({

+ 0 - 44
packages/websockets/src/Infra/Redis/RedisWebSocketsConnectionRepository.spec.ts

@@ -1,44 +0,0 @@
-import 'reflect-metadata'
-
-import * as IORedis from 'ioredis'
-
-import { RedisWebSocketsConnectionRepository } from './RedisWebSocketsConnectionRepository'
-
-describe('RedisWebSocketsConnectionRepository', () => {
-  let redisClient: IORedis.Redis
-
-  const createRepository = () => new RedisWebSocketsConnectionRepository(redisClient)
-
-  beforeEach(() => {
-    redisClient = {} as jest.Mocked<IORedis.Redis>
-    redisClient.sadd = jest.fn()
-    redisClient.set = jest.fn()
-    redisClient.get = jest.fn()
-    redisClient.srem = jest.fn()
-    redisClient.del = jest.fn()
-    redisClient.smembers = jest.fn()
-  })
-
-  it('should save a connection to set of user connections', async () => {
-    await createRepository().saveConnection('1-2-3', '2-3-4')
-
-    expect(redisClient.sadd).toHaveBeenCalledWith('ws_user_connections:1-2-3', '2-3-4')
-    expect(redisClient.set).toHaveBeenCalledWith('ws_connection:2-3-4', '1-2-3')
-  })
-
-  it('should remove a connection from the set of user connections', async () => {
-    redisClient.get = jest.fn().mockReturnValue('1-2-3')
-
-    await createRepository().removeConnection('2-3-4')
-
-    expect(redisClient.srem).toHaveBeenCalledWith('ws_user_connections:1-2-3', '2-3-4')
-    expect(redisClient.del).toHaveBeenCalledWith('ws_connection:2-3-4')
-  })
-
-  it('should return all connections for a user uuid', async () => {
-    const userUuid = '1-2-3'
-
-    await createRepository().findAllByUserUuid(userUuid)
-    expect(redisClient.smembers).toHaveBeenCalledWith(`ws_user_connections:${userUuid}`)
-  })
-})

+ 0 - 28
packages/websockets/src/Infra/Redis/RedisWebSocketsConnectionRepository.ts

@@ -1,28 +0,0 @@
-import * as IORedis from 'ioredis'
-import { inject, injectable } from 'inversify'
-import TYPES from '../../Bootstrap/Types'
-import { WebSocketsConnectionRepositoryInterface } from '../../Domain/WebSockets/WebSocketsConnectionRepositoryInterface'
-
-@injectable()
-export class RedisWebSocketsConnectionRepository implements WebSocketsConnectionRepositoryInterface {
-  private readonly WEB_SOCKETS_USER_CONNECTIONS_PREFIX = 'ws_user_connections'
-  private readonly WEB_SOCKETS_CONNETION_PREFIX = 'ws_connection'
-
-  constructor(@inject(TYPES.Redis) private redisClient: IORedis.Redis) {}
-
-  async findAllByUserUuid(userUuid: string): Promise<string[]> {
-    return await this.redisClient.smembers(`${this.WEB_SOCKETS_USER_CONNECTIONS_PREFIX}:${userUuid}`)
-  }
-
-  async removeConnection(connectionId: string): Promise<void> {
-    const userUuid = await this.redisClient.get(`${this.WEB_SOCKETS_CONNETION_PREFIX}:${connectionId}`)
-
-    await this.redisClient.srem(`${this.WEB_SOCKETS_USER_CONNECTIONS_PREFIX}:${userUuid}`, connectionId)
-    await this.redisClient.del(`${this.WEB_SOCKETS_CONNETION_PREFIX}:${connectionId}`)
-  }
-
-  async saveConnection(userUuid: string, connectionId: string): Promise<void> {
-    await this.redisClient.set(`${this.WEB_SOCKETS_CONNETION_PREFIX}:${connectionId}`, userUuid)
-    await this.redisClient.sadd(`${this.WEB_SOCKETS_USER_CONNECTIONS_PREFIX}:${userUuid}`, connectionId)
-  }
-}

+ 42 - 0
packages/websockets/src/Infra/TypeORM/SQLConnection.ts

@@ -0,0 +1,42 @@
+import { Column, Entity, Index, PrimaryGeneratedColumn } from 'typeorm'
+
+@Entity({ name: 'connections' })
+export class SQLConnection {
+  @PrimaryGeneratedColumn('uuid')
+  declare uuid: string
+
+  @Column({
+    name: 'user_uuid',
+    type: 'varchar',
+    length: 36,
+  })
+  @Index('index_connections_on_user_uuid')
+  declare userUuid: string
+
+  @Column({
+    name: 'session_uuid',
+    type: 'varchar',
+    length: 36,
+  })
+  declare sessionUuid: string
+
+  @Column({
+    name: 'connection_id',
+    type: 'varchar',
+    length: 255,
+  })
+  @Index('index_connections_on_connection_id', { unique: true })
+  declare connectionId: string
+
+  @Column({
+    name: 'created_at_timestamp',
+    type: 'bigint',
+  })
+  declare createdAtTimestamp: number
+
+  @Column({
+    name: 'updated_at_timestamp',
+    type: 'bigint',
+  })
+  declare updatedAtTimestamp: number
+}

+ 38 - 0
packages/websockets/src/Infra/TypeORM/SQLConnectionRepository.ts

@@ -0,0 +1,38 @@
+import { Repository } from 'typeorm'
+import { WebSocketsConnectionRepositoryInterface } from '../../Domain/WebSockets/WebSocketsConnectionRepositoryInterface'
+import { SQLConnection } from './SQLConnection'
+import { MapperInterface, Uuid } from '@standardnotes/domain-core'
+import { Connection } from '../../Domain/Connection/Connection'
+import { Logger } from 'winston'
+
+export class SQLConnectionRepository implements WebSocketsConnectionRepositoryInterface {
+  constructor(
+    protected ormRepository: Repository<SQLConnection>,
+    protected mapper: MapperInterface<Connection, SQLConnection>,
+    protected logger: Logger,
+  ) {}
+
+  async findAllByUserUuid(userUuid: Uuid): Promise<Connection[]> {
+    const persistence = await this.ormRepository
+      .createQueryBuilder()
+      .where('user_uuid = :userUuid', { userUuid: userUuid.value })
+      .getMany()
+
+    return persistence.map((p) => this.mapper.toDomain(p))
+  }
+
+  async saveConnection(connection: Connection): Promise<void> {
+    const persistence = this.mapper.toProjection(connection)
+
+    await this.ormRepository.save(persistence)
+  }
+
+  async removeConnection(connectionId: string): Promise<void> {
+    await this.ormRepository
+      .createQueryBuilder()
+      .delete()
+      .from(SQLConnection)
+      .where('connection_id = :connectionId', { connectionId })
+      .execute()
+  }
+}

+ 0 - 49
packages/websockets/src/Infra/WebSockets/WebSocketsClientMessenger.ts

@@ -1,49 +0,0 @@
-import { ApiGatewayManagementApiClient, PostToConnectionCommand } from '@aws-sdk/client-apigatewaymanagementapi'
-import { Logger } from 'winston'
-import { inject, injectable } from 'inversify'
-
-import TYPES from '../../Bootstrap/Types'
-import { WebSocketsConnectionRepositoryInterface } from '../../Domain/WebSockets/WebSocketsConnectionRepositoryInterface'
-import { ClientMessengerInterface } from '../../Client/ClientMessengerInterface'
-
-@injectable()
-export class WebSocketsClientMessenger implements ClientMessengerInterface {
-  constructor(
-    @inject(TYPES.WebSocketsConnectionRepository)
-    private webSocketsConnectionRepository: WebSocketsConnectionRepositoryInterface,
-    @inject(TYPES.WebSockets_ApiGatewayManagementApiClient)
-    private apiGatewayManagementClient: ApiGatewayManagementApiClient,
-    @inject(TYPES.Logger) private logger: Logger,
-  ) {}
-
-  async send(userUuid: string, message: string): Promise<void> {
-    const userConnections = await this.webSocketsConnectionRepository.findAllByUserUuid(userUuid)
-
-    for (const connectionUuid of userConnections) {
-      this.logger.debug(`Sending message to connection ${connectionUuid} for user ${userUuid}`)
-
-      const requestParams = {
-        ConnectionId: connectionUuid,
-        Data: message,
-      }
-
-      const command = new PostToConnectionCommand(requestParams)
-
-      try {
-        const response = await this.apiGatewayManagementClient.send(command)
-
-        if (response.$metadata.httpStatusCode !== 200) {
-          this.logger.error(
-            `Could not send message to connection ${connectionUuid} for user ${userUuid}. Response status code: ${response.$metadata.httpStatusCode}`,
-          )
-        }
-      } catch (error) {
-        this.logger.error(
-          `Could not send message to connection ${connectionUuid} for user ${userUuid}. Error: ${
-            (error as Error).message
-          }`,
-        )
-      }
-    }
-  }
-}

+ 0 - 54
packages/websockets/src/Infra/WebSockets/WebSocketsClientService.spec.ts

@@ -1,54 +0,0 @@
-import 'reflect-metadata'
-
-import { ApiGatewayManagementApiClient } from '@aws-sdk/client-apigatewaymanagementapi'
-
-import { WebSocketsConnectionRepositoryInterface } from '../../Domain/WebSockets/WebSocketsConnectionRepositoryInterface'
-import { Logger } from 'winston'
-
-import { WebSocketsClientMessenger } from './WebSocketsClientMessenger'
-
-describe('WebSocketsClientMessenger', () => {
-  let connectionIds: string[]
-  let webSocketsConnectionRepository: WebSocketsConnectionRepositoryInterface
-  let apiGatewayManagementClient: ApiGatewayManagementApiClient
-  let logger: Logger
-
-  const createService = () =>
-    new WebSocketsClientMessenger(webSocketsConnectionRepository, apiGatewayManagementClient, logger)
-
-  beforeEach(() => {
-    connectionIds = ['1', '2']
-
-    webSocketsConnectionRepository = {} as jest.Mocked<WebSocketsConnectionRepositoryInterface>
-    webSocketsConnectionRepository.findAllByUserUuid = jest.fn().mockReturnValue(connectionIds)
-
-    apiGatewayManagementClient = {} as jest.Mocked<ApiGatewayManagementApiClient>
-    apiGatewayManagementClient.send = jest.fn().mockReturnValue({ $metadata: { httpStatusCode: 200 } })
-
-    logger = {} as jest.Mocked<Logger>
-    logger.debug = jest.fn()
-    logger.error = jest.fn()
-  })
-
-  it('should send a message to all user connections', async () => {
-    await createService().send('1-2-3', 'message')
-
-    expect(apiGatewayManagementClient.send).toHaveBeenCalledTimes(connectionIds.length)
-  })
-
-  it('should log an error if message could not be sent', async () => {
-    apiGatewayManagementClient.send = jest.fn().mockReturnValue({ $metadata: { httpStatusCode: 500 } })
-
-    await createService().send('1-2-3', 'message')
-
-    expect(logger.error).toHaveBeenCalledTimes(connectionIds.length)
-  })
-
-  it('should log an error if message sending throws error', async () => {
-    apiGatewayManagementClient.send = jest.fn().mockRejectedValue(new Error('error'))
-
-    await createService().send('1-2-3', 'message')
-
-    expect(logger.error).toHaveBeenCalledTimes(connectionIds.length)
-  })
-})

+ 50 - 0
packages/websockets/src/Mapping/SQL/ConnectionPersistenceMapper.ts

@@ -0,0 +1,50 @@
+import { MapperInterface, Timestamps, Uuid } from '@standardnotes/domain-core'
+
+import { Connection } from '../../Domain/Connection/Connection'
+import { SQLConnection } from '../../Infra/TypeORM/SQLConnection'
+
+export class ConnectionPersistenceMapper implements MapperInterface<Connection, SQLConnection> {
+  toDomain(projection: SQLConnection): Connection {
+    const userUuidOrError = Uuid.create(projection.userUuid)
+    if (userUuidOrError.isFailed()) {
+      throw new Error(`Failed to create connection from projection: ${userUuidOrError.getError()}`)
+    }
+    const userUuid = userUuidOrError.getValue()
+
+    const sessionUuidOrError = Uuid.create(projection.sessionUuid)
+    if (sessionUuidOrError.isFailed()) {
+      throw new Error(`Failed to create connection from projection: ${sessionUuidOrError.getError()}`)
+    }
+    const sessionUuid = sessionUuidOrError.getValue()
+
+    const timestampsOrError = Timestamps.create(projection.createdAtTimestamp, projection.updatedAtTimestamp)
+    if (timestampsOrError.isFailed()) {
+      throw new Error(`Failed to create connection from projection: ${timestampsOrError.getError()}`)
+    }
+    const timestamps = timestampsOrError.getValue()
+
+    const connectionOrError = Connection.create({
+      userUuid,
+      sessionUuid,
+      connectionId: projection.connectionId,
+      timestamps,
+    })
+    if (connectionOrError.isFailed()) {
+      throw new Error(`Failed to create connection from projection: ${connectionOrError.getError()}`)
+    }
+
+    return connectionOrError.getValue()
+  }
+
+  toProjection(domain: Connection): SQLConnection {
+    const projection = new SQLConnection()
+
+    projection.userUuid = domain.props.userUuid.value
+    projection.sessionUuid = domain.props.sessionUuid.value
+    projection.connectionId = domain.props.connectionId
+    projection.createdAtTimestamp = domain.props.timestamps.createdAt
+    projection.updatedAtTimestamp = domain.props.timestamps.updatedAt
+
+    return projection
+  }
+}

+ 1 - 1
yarn.lock

@@ -5887,13 +5887,13 @@ __metadata:
   dependencies:
     "@aws-sdk/client-apigatewaymanagementapi": "npm:^3.427.0"
     "@aws-sdk/client-sqs": "npm:^3.427.0"
-    "@standardnotes/api": "npm:^1.26.26"
     "@standardnotes/common": "workspace:^"
     "@standardnotes/domain-core": "workspace:^"
     "@standardnotes/domain-events": "workspace:^"
     "@standardnotes/domain-events-infra": "workspace:^"
     "@standardnotes/responses": "npm:^1.13.27"
     "@standardnotes/security": "workspace:^"
+    "@standardnotes/time": "workspace:^"
     "@types/cors": "npm:^2.8.9"
     "@types/express": "npm:^4.17.14"
     "@types/ioredis": "npm:^5.0.0"