From 69b404f5d45f32530ebadbdbbec01d4e335dbbe9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=20S=C3=B3jko?= Date: Tue, 28 Nov 2023 09:31:42 +0100 Subject: [PATCH] feat: send event to client upon items change on server (#941) * feat(websockets): persist connections in mysql * fix: add sending event to client upon items changed on server * fix payload * fix: add cathcing errors * fix: send changed items event only on a 10% dice roll --- .pnp.cjs | 2 +- .../Domain/Event/ItemsChangedOnServerEvent.ts | 7 ++ .../Event/ItemsChangedOnServerEventPayload.ts | 5 + .../WebSocketMessageRequestedEventPayload.ts | 1 + packages/domain-events/src/Domain/index.ts | 2 + .../syncing-server/src/Bootstrap/Container.ts | 27 +++-- .../syncing-server/src/Bootstrap/Types.ts | 1 + .../src/Domain/Event/DomainEventFactory.ts | 26 ++++- .../Event/DomainEventFactoryInterface.ts | 12 ++- .../Syncing/SaveItems/SaveItems.spec.ts | 35 ++++++- .../UseCase/Syncing/SaveItems/SaveItems.ts | 39 ++++++++ .../SendEventToClient.spec.ts | 19 ++++ .../SendEventToClient/SendEventToClient.ts | 35 ++++--- .../SendEventToClient/SendEventToClientDTO.ts | 1 + packages/websockets/.env.sample | 9 ++ packages/websockets/bin/server.ts | 2 +- packages/websockets/bin/worker.ts | 2 +- .../mysql/1701087671322-initial-database.ts | 17 ++++ packages/websockets/package.json | 2 +- .../websockets/src/Bootstrap/Container.ts | 84 +++++++++++----- .../websockets/src/Bootstrap/DataSource.ts | 84 ++++++++++++++++ .../src/Bootstrap/MigrationsDataSource.ts | 7 ++ packages/websockets/src/Bootstrap/Types.ts | 9 +- .../src/Client/ClientMessengerInterface.ts | 3 - .../ApiGatewayAuthMiddleware.spec.ts | 99 ------------------- .../Controller/WebSocketsController.spec.ts | 28 ------ .../src/Controller/WebSocketsController.ts | 29 ------ .../src/Domain/Connection/Connection.ts | 13 +++ .../src/Domain/Connection/ConnectionProps.ts | 8 ++ .../WebSocketMessageRequestedEventHandler.ts | 20 ++-- .../AddWebSocketsConnection.spec.ts | 41 +++++++- .../AddWebSocketsConnection.ts | 40 ++++++-- .../AddWebSocketsConnectionDTO.ts | 1 + .../SendMessageToClient.spec.ts | 99 +++++++++++++++++++ .../SendMessageToClient.ts | 57 +++++++++++ .../SendMessageToClientDTO.ts | 5 + ...WebSocketsConnectionRepositoryInterface.ts | 7 +- .../AnnotatedWebSocketsController.ts | 10 +- .../Middleware}/ApiGatewayAuthMiddleware.ts | 4 +- ...edisWebSocketsConnectionRepository.spec.ts | 44 --------- .../RedisWebSocketsConnectionRepository.ts | 28 ------ .../src/Infra/TypeORM/SQLConnection.ts | 42 ++++++++ .../Infra/TypeORM/SQLConnectionRepository.ts | 38 +++++++ .../WebSockets/WebSocketsClientMessenger.ts | 49 --------- .../WebSocketsClientService.spec.ts | 54 ---------- .../SQL/ConnectionPersistenceMapper.ts | 50 ++++++++++ yarn.lock | 2 +- 47 files changed, 777 insertions(+), 422 deletions(-) create mode 100644 packages/domain-events/src/Domain/Event/ItemsChangedOnServerEvent.ts create mode 100644 packages/domain-events/src/Domain/Event/ItemsChangedOnServerEventPayload.ts create mode 100644 packages/websockets/migrations/mysql/1701087671322-initial-database.ts create mode 100644 packages/websockets/src/Bootstrap/DataSource.ts create mode 100644 packages/websockets/src/Bootstrap/MigrationsDataSource.ts delete mode 100644 packages/websockets/src/Client/ClientMessengerInterface.ts delete mode 100644 packages/websockets/src/Controller/ApiGatewayAuthMiddleware.spec.ts delete mode 100644 packages/websockets/src/Controller/WebSocketsController.spec.ts delete mode 100644 packages/websockets/src/Controller/WebSocketsController.ts create mode 100644 packages/websockets/src/Domain/Connection/Connection.ts create mode 100644 packages/websockets/src/Domain/Connection/ConnectionProps.ts create mode 100644 packages/websockets/src/Domain/UseCase/SendMessageToClient/SendMessageToClient.spec.ts create mode 100644 packages/websockets/src/Domain/UseCase/SendMessageToClient/SendMessageToClient.ts create mode 100644 packages/websockets/src/Domain/UseCase/SendMessageToClient/SendMessageToClientDTO.ts rename packages/websockets/src/{Controller => Infra/InversifyExpressUtils/Middleware}/ApiGatewayAuthMiddleware.ts (93%) delete mode 100644 packages/websockets/src/Infra/Redis/RedisWebSocketsConnectionRepository.spec.ts delete mode 100644 packages/websockets/src/Infra/Redis/RedisWebSocketsConnectionRepository.ts create mode 100644 packages/websockets/src/Infra/TypeORM/SQLConnection.ts create mode 100644 packages/websockets/src/Infra/TypeORM/SQLConnectionRepository.ts delete mode 100644 packages/websockets/src/Infra/WebSockets/WebSocketsClientMessenger.ts delete mode 100644 packages/websockets/src/Infra/WebSockets/WebSocketsClientService.spec.ts create mode 100644 packages/websockets/src/Mapping/SQL/ConnectionPersistenceMapper.ts diff --git a/.pnp.cjs b/.pnp.cjs index df7e5b169..bcdf47165 100755 --- a/.pnp.cjs +++ b/.pnp.cjs @@ -7045,13 +7045,13 @@ const RAW_RUNTIME_STATE = ["@standardnotes/websockets-server", "workspace:packages/websockets"],\ ["@aws-sdk/client-apigatewaymanagementapi", "npm:3.427.0"],\ ["@aws-sdk/client-sqs", "npm:3.427.0"],\ - ["@standardnotes/api", "npm:1.26.26"],\ ["@standardnotes/common", "workspace:packages/common"],\ ["@standardnotes/domain-core", "workspace:packages/domain-core"],\ ["@standardnotes/domain-events", "workspace:packages/domain-events"],\ ["@standardnotes/domain-events-infra", "workspace:packages/domain-events-infra"],\ ["@standardnotes/responses", "npm:1.13.27"],\ ["@standardnotes/security", "workspace:packages/security"],\ + ["@standardnotes/time", "workspace:packages/time"],\ ["@types/cors", "npm:2.8.13"],\ ["@types/express", "npm:4.17.17"],\ ["@types/ioredis", "npm:5.0.0"],\ diff --git a/packages/domain-events/src/Domain/Event/ItemsChangedOnServerEvent.ts b/packages/domain-events/src/Domain/Event/ItemsChangedOnServerEvent.ts new file mode 100644 index 000000000..81f6147e1 --- /dev/null +++ b/packages/domain-events/src/Domain/Event/ItemsChangedOnServerEvent.ts @@ -0,0 +1,7 @@ +import { DomainEventInterface } from './DomainEventInterface' +import { ItemsChangedOnServerEventPayload } from './ItemsChangedOnServerEventPayload' + +export interface ItemsChangedOnServerEvent extends DomainEventInterface { + type: 'ITEMS_CHANGED_ON_SERVER' + payload: ItemsChangedOnServerEventPayload +} diff --git a/packages/domain-events/src/Domain/Event/ItemsChangedOnServerEventPayload.ts b/packages/domain-events/src/Domain/Event/ItemsChangedOnServerEventPayload.ts new file mode 100644 index 000000000..e45a7b75d --- /dev/null +++ b/packages/domain-events/src/Domain/Event/ItemsChangedOnServerEventPayload.ts @@ -0,0 +1,5 @@ +export interface ItemsChangedOnServerEventPayload { + userUuid: string + sessionUuid: string + timestamp: number +} diff --git a/packages/domain-events/src/Domain/Event/WebSocketMessageRequestedEventPayload.ts b/packages/domain-events/src/Domain/Event/WebSocketMessageRequestedEventPayload.ts index 028696733..460c51c93 100644 --- a/packages/domain-events/src/Domain/Event/WebSocketMessageRequestedEventPayload.ts +++ b/packages/domain-events/src/Domain/Event/WebSocketMessageRequestedEventPayload.ts @@ -1,4 +1,5 @@ export interface WebSocketMessageRequestedEventPayload { userUuid: string message: string + originatingSessionUuid?: string } diff --git a/packages/domain-events/src/Domain/index.ts b/packages/domain-events/src/Domain/index.ts index dd83f60f1..efb6755b7 100644 --- a/packages/domain-events/src/Domain/index.ts +++ b/packages/domain-events/src/Domain/index.ts @@ -40,6 +40,8 @@ export * from './Event/ItemRemovedFromSharedVaultEvent' export * from './Event/ItemRemovedFromSharedVaultEventPayload' export * from './Event/ItemRevisionCreationRequestedEvent' export * from './Event/ItemRevisionCreationRequestedEventPayload' +export * from './Event/ItemsChangedOnServerEvent' +export * from './Event/ItemsChangedOnServerEventPayload' export * from './Event/ListedAccountCreatedEvent' export * from './Event/ListedAccountCreatedEventPayload' export * from './Event/ListedAccountDeletedEvent' diff --git a/packages/syncing-server/src/Bootstrap/Container.ts b/packages/syncing-server/src/Bootstrap/Container.ts index c14abe88b..ae80fb67c 100644 --- a/packages/syncing-server/src/Bootstrap/Container.ts +++ b/packages/syncing-server/src/Bootstrap/Container.ts @@ -114,7 +114,13 @@ import { GetSharedVaults } from '../Domain/UseCase/SharedVaults/GetSharedVaults/ import { CreateSharedVault } from '../Domain/UseCase/SharedVaults/CreateSharedVault/CreateSharedVault' import { DeleteSharedVault } from '../Domain/UseCase/SharedVaults/DeleteSharedVault/DeleteSharedVault' import { CreateSharedVaultFileValetToken } from '../Domain/UseCase/SharedVaults/CreateSharedVaultFileValetToken/CreateSharedVaultFileValetToken' -import { SharedVaultValetTokenData, TokenEncoder, TokenEncoderInterface } from '@standardnotes/security' +import { + DeterministicSelector, + SelectorInterface, + SharedVaultValetTokenData, + TokenEncoder, + TokenEncoderInterface, +} from '@standardnotes/security' import { SharedVaultHttpRepresentation } from '../Mapping/Http/SharedVaultHttpRepresentation' import { SharedVaultHttpMapper } from '../Mapping/Http/SharedVaultHttpMapper' import { SharedVaultInviteHttpRepresentation } from '../Mapping/Http/SharedVaultInviteHttpRepresentation' @@ -200,6 +206,10 @@ export class ContainerConfigLoader { } container.bind(TYPES.Sync_Logger).toConstantValue(logger) + container + .bind>(TYPES.Sync_NumberSelector) + .toConstantValue(new DeterministicSelector()) + const appDataSource = new AppDataSource({ env, runMigrations: this.mode === 'server' }) await appDataSource.initialize() @@ -601,12 +611,15 @@ export class ContainerConfigLoader { .bind(TYPES.Sync_SaveItems) .toConstantValue( new SaveItems( - container.get(TYPES.Sync_ItemSaveValidator), - container.get(TYPES.Sync_SQLItemRepository), - container.get(TYPES.Sync_Timer), - container.get(TYPES.Sync_SaveNewItem), - container.get(TYPES.Sync_UpdateExistingItem), - container.get(TYPES.Sync_Logger), + container.get(TYPES.Sync_ItemSaveValidator), + container.get(TYPES.Sync_SQLItemRepository), + container.get(TYPES.Sync_Timer), + container.get(TYPES.Sync_SaveNewItem), + container.get(TYPES.Sync_UpdateExistingItem), + container.get(TYPES.Sync_SendEventToClient), + container.get(TYPES.Sync_DomainEventFactory), + container.get>(TYPES.Sync_NumberSelector), + container.get(TYPES.Sync_Logger), ), ) container diff --git a/packages/syncing-server/src/Bootstrap/Types.ts b/packages/syncing-server/src/Bootstrap/Types.ts index cef5463be..7fcb9f056 100644 --- a/packages/syncing-server/src/Bootstrap/Types.ts +++ b/packages/syncing-server/src/Bootstrap/Types.ts @@ -6,6 +6,7 @@ const TYPES = { Sync_SQS: Symbol.for('Sync_SQS'), Sync_S3: Symbol.for('Sync_S3'), Sync_Env: Symbol.for('Sync_Env'), + Sync_NumberSelector: Symbol.for('Sync_NumberSelector'), // Repositories Sync_SQLItemRepository: Symbol.for('Sync_SQLItemRepository'), Sync_SharedVaultRepository: Symbol.for('Sync_SharedVaultRepository'), diff --git a/packages/syncing-server/src/Domain/Event/DomainEventFactory.ts b/packages/syncing-server/src/Domain/Event/DomainEventFactory.ts index 599423dcf..cbdd1b36f 100644 --- a/packages/syncing-server/src/Domain/Event/DomainEventFactory.ts +++ b/packages/syncing-server/src/Domain/Event/DomainEventFactory.ts @@ -7,6 +7,7 @@ import { ItemDumpedEvent, ItemRemovedFromSharedVaultEvent, ItemRevisionCreationRequestedEvent, + ItemsChangedOnServerEvent, MessageSentToUserEvent, NotificationAddedForUserEvent, RevisionsCopyRequestedEvent, @@ -23,6 +24,25 @@ import { DomainEventFactoryInterface } from './DomainEventFactoryInterface' export class DomainEventFactory implements DomainEventFactoryInterface { constructor(private timer: TimerInterface) {} + createItemsChangedOnServerEvent(dto: { + userUuid: string + sessionUuid: string + timestamp: number + }): ItemsChangedOnServerEvent { + return { + type: 'ITEMS_CHANGED_ON_SERVER', + createdAt: this.timer.getUTCDate(), + meta: { + correlation: { + userIdentifier: dto.userUuid, + userIdentifierType: 'uuid', + }, + origin: DomainEventService.SyncingServer, + }, + payload: dto, + } + } + createAccountDeletionVerificationPassedEvent(dto: { userUuid: string email: string @@ -207,7 +227,11 @@ export class DomainEventFactory implements DomainEventFactoryInterface { } } - createWebSocketMessageRequestedEvent(dto: { userUuid: string; message: string }): WebSocketMessageRequestedEvent { + createWebSocketMessageRequestedEvent(dto: { + userUuid: string + message: string + originatingSessionUuid?: string + }): WebSocketMessageRequestedEvent { return { type: 'WEB_SOCKET_MESSAGE_REQUESTED', createdAt: this.timer.getUTCDate(), diff --git a/packages/syncing-server/src/Domain/Event/DomainEventFactoryInterface.ts b/packages/syncing-server/src/Domain/Event/DomainEventFactoryInterface.ts index dfed376a3..83e4d0979 100644 --- a/packages/syncing-server/src/Domain/Event/DomainEventFactoryInterface.ts +++ b/packages/syncing-server/src/Domain/Event/DomainEventFactoryInterface.ts @@ -5,6 +5,7 @@ import { ItemDumpedEvent, ItemRemovedFromSharedVaultEvent, ItemRevisionCreationRequestedEvent, + ItemsChangedOnServerEvent, MessageSentToUserEvent, NotificationAddedForUserEvent, RevisionsCopyRequestedEvent, @@ -17,7 +18,16 @@ import { } from '@standardnotes/domain-events' export interface DomainEventFactoryInterface { - createWebSocketMessageRequestedEvent(dto: { userUuid: string; message: string }): WebSocketMessageRequestedEvent + createWebSocketMessageRequestedEvent(dto: { + userUuid: string + message: string + originatingSessionUuid?: string + }): WebSocketMessageRequestedEvent + createItemsChangedOnServerEvent(dto: { + userUuid: string + sessionUuid: string + timestamp: number + }): ItemsChangedOnServerEvent createUserInvitedToSharedVaultEvent(dto: { invite: { uuid: string diff --git a/packages/syncing-server/src/Domain/UseCase/Syncing/SaveItems/SaveItems.spec.ts b/packages/syncing-server/src/Domain/UseCase/Syncing/SaveItems/SaveItems.spec.ts index 6abd0d56d..973ca213c 100644 --- a/packages/syncing-server/src/Domain/UseCase/Syncing/SaveItems/SaveItems.spec.ts +++ b/packages/syncing-server/src/Domain/UseCase/Syncing/SaveItems/SaveItems.spec.ts @@ -8,6 +8,10 @@ import { Logger } from 'winston' import { ContentType, Dates, Result, Timestamps, Uuid } from '@standardnotes/domain-core' import { ItemHash } from '../../../Item/ItemHash' import { Item } from '../../../Item/Item' +import { SendEventToClient } from '../SendEventToClient/SendEventToClient' +import { DomainEventFactoryInterface } from '../../../Event/DomainEventFactoryInterface' +import { ItemsChangedOnServerEvent } from '@standardnotes/domain-events' +import { SelectorInterface } from '@standardnotes/security' describe('SaveItems', () => { let itemSaveValidator: ItemSaveValidatorInterface @@ -18,11 +22,35 @@ describe('SaveItems', () => { let logger: Logger let itemHash1: ItemHash let savedItem: Item + let sendEventToClient: SendEventToClient + let domainEventFactory: DomainEventFactoryInterface + let deterministicSelector: SelectorInterface const createUseCase = () => - new SaveItems(itemSaveValidator, itemRepository, timer, saveNewItem, updateExistingItem, logger) + new SaveItems( + itemSaveValidator, + itemRepository, + timer, + saveNewItem, + updateExistingItem, + sendEventToClient, + domainEventFactory, + deterministicSelector, + logger, + ) beforeEach(() => { + deterministicSelector = {} as jest.Mocked> + deterministicSelector.select = jest.fn().mockReturnValue(1) + + sendEventToClient = {} as jest.Mocked + sendEventToClient.execute = jest.fn().mockReturnValue(Result.ok()) + + domainEventFactory = {} as jest.Mocked + domainEventFactory.createItemsChangedOnServerEvent = jest + .fn() + .mockReturnValue({} as jest.Mocked) + itemSaveValidator = {} as jest.Mocked itemSaveValidator.validate = jest.fn().mockResolvedValue({ passed: true }) @@ -92,6 +120,7 @@ describe('SaveItems', () => { userUuid: 'user-uuid', sessionUuid: 'session-uuid', }) + expect(sendEventToClient.execute).toHaveBeenCalled() }) it('should mark items as conflicts if saving new item fails', async () => { @@ -115,6 +144,7 @@ describe('SaveItems', () => { type: 'uuid_conflict', }, ]) + expect(sendEventToClient.execute).not.toHaveBeenCalled() }) it('should mark items as conflicts if saving new item throws an error', async () => { @@ -197,6 +227,8 @@ describe('SaveItems', () => { }) it('should update existing items', async () => { + deterministicSelector.select = jest.fn().mockReturnValue(0) + const useCase = createUseCase() itemRepository.findByUuid = jest.fn().mockResolvedValue(savedItem) @@ -217,6 +249,7 @@ describe('SaveItems', () => { sessionUuid: 'session-uuid', performingUserUuid: '00000000-0000-0000-0000-000000000000', }) + expect(sendEventToClient.execute).not.toHaveBeenCalled() }) it('should mark items as conflicts if updating existing item fails', async () => { diff --git a/packages/syncing-server/src/Domain/UseCase/Syncing/SaveItems/SaveItems.ts b/packages/syncing-server/src/Domain/UseCase/Syncing/SaveItems/SaveItems.ts index 6703baeb8..96b66e3bb 100644 --- a/packages/syncing-server/src/Domain/UseCase/Syncing/SaveItems/SaveItems.ts +++ b/packages/syncing-server/src/Domain/UseCase/Syncing/SaveItems/SaveItems.ts @@ -11,6 +11,9 @@ import { ItemSaveValidatorInterface } from '../../../Item/SaveValidator/ItemSave import { SaveNewItem } from '../SaveNewItem/SaveNewItem' import { UpdateExistingItem } from '../UpdateExistingItem/UpdateExistingItem' import { ItemRepositoryInterface } from '../../../Item/ItemRepositoryInterface' +import { SendEventToClient } from '../SendEventToClient/SendEventToClient' +import { DomainEventFactoryInterface } from '../../../Event/DomainEventFactoryInterface' +import { SelectorInterface } from '@standardnotes/security' export class SaveItems implements UseCaseInterface { private readonly SYNC_TOKEN_VERSION = 2 @@ -21,6 +24,9 @@ export class SaveItems implements UseCaseInterface { private timer: TimerInterface, private saveNewItem: SaveNewItem, private updateExistingItem: UpdateExistingItem, + private sendEventToClient: SendEventToClient, + private domainEventFactory: DomainEventFactoryInterface, + private deterministicSelector: SelectorInterface, private logger: Logger, ) {} @@ -133,6 +139,8 @@ export class SaveItems implements UseCaseInterface { const syncToken = this.calculateSyncToken(lastUpdatedTimestamp, savedItems) + await this.notifyOtherClientsOfTheUserThatItemsChanged(dto, savedItems, lastUpdatedTimestamp) + return Result.ok({ savedItems, conflicts, @@ -140,6 +148,37 @@ export class SaveItems implements UseCaseInterface { }) } + private async notifyOtherClientsOfTheUserThatItemsChanged( + dto: SaveItemsDTO, + savedItems: Item[], + lastUpdatedTimestamp: number, + ): Promise { + if (savedItems.length === 0 || !dto.sessionUuid) { + return + } + + const tenPercentSpreadArray = Array.from(Array(10).keys()) + const diceRoll = this.deterministicSelector.select(dto.userUuid, tenPercentSpreadArray) + if (diceRoll !== 1) { + return + } + + const itemsChangedEvent = this.domainEventFactory.createItemsChangedOnServerEvent({ + userUuid: dto.userUuid, + sessionUuid: dto.sessionUuid, + timestamp: lastUpdatedTimestamp, + }) + const result = await this.sendEventToClient.execute({ + userUuid: dto.userUuid, + originatingSessionUuid: dto.sessionUuid, + event: itemsChangedEvent, + }) + /* istanbul ignore next */ + if (result.isFailed()) { + this.logger.error(`[${dto.userUuid}] Sending items changed event to client failed. Error: ${result.getError()}`) + } + } + private calculateSyncToken(lastUpdatedTimestamp: number, savedItems: Array): string { if (savedItems.length) { const sortedItems = savedItems.sort((itemA: Item, itemB: Item) => { diff --git a/packages/syncing-server/src/Domain/UseCase/Syncing/SendEventToClient/SendEventToClient.spec.ts b/packages/syncing-server/src/Domain/UseCase/Syncing/SendEventToClient/SendEventToClient.spec.ts index 345139c6c..611daaa16 100644 --- a/packages/syncing-server/src/Domain/UseCase/Syncing/SendEventToClient/SendEventToClient.spec.ts +++ b/packages/syncing-server/src/Domain/UseCase/Syncing/SendEventToClient/SendEventToClient.spec.ts @@ -17,6 +17,8 @@ describe('SendEventToClient', () => { beforeEach(() => { logger = {} as jest.Mocked logger.info = jest.fn() + logger.debug = jest.fn() + logger.error = jest.fn() domainEventFactory = {} as jest.Mocked domainEventFactory.createWebSocketMessageRequestedEvent = jest @@ -58,4 +60,21 @@ describe('SendEventToClient', () => { expect(result.isFailed()).toBe(true) }) + + it('should return a failed result if error is thrown', async () => { + const useCase = createUseCase() + + domainEventFactory.createWebSocketMessageRequestedEvent = jest.fn().mockImplementation(() => { + throw new Error('test') + }) + + const result = await useCase.execute({ + userUuid: '00000000-0000-0000-0000-000000000000', + event: { + type: 'test', + } as jest.Mocked, + }) + + expect(result.isFailed()).toBe(true) + }) }) diff --git a/packages/syncing-server/src/Domain/UseCase/Syncing/SendEventToClient/SendEventToClient.ts b/packages/syncing-server/src/Domain/UseCase/Syncing/SendEventToClient/SendEventToClient.ts index e52394db8..3d6d30ed2 100644 --- a/packages/syncing-server/src/Domain/UseCase/Syncing/SendEventToClient/SendEventToClient.ts +++ b/packages/syncing-server/src/Domain/UseCase/Syncing/SendEventToClient/SendEventToClient.ts @@ -13,21 +13,26 @@ export class SendEventToClient implements UseCaseInterface { ) {} async execute(dto: SendEventToClientDTO): Promise> { - const userUuidOrError = Uuid.create(dto.userUuid) - if (userUuidOrError.isFailed()) { - return Result.fail(userUuidOrError.getError()) + try { + const userUuidOrError = Uuid.create(dto.userUuid) + if (userUuidOrError.isFailed()) { + return Result.fail(userUuidOrError.getError()) + } + const userUuid = userUuidOrError.getValue() + + this.logger.debug(`[WebSockets] Requesting message ${dto.event.type} to user ${dto.userUuid}`) + + const event = this.domainEventFactory.createWebSocketMessageRequestedEvent({ + userUuid: userUuid.value, + message: JSON.stringify(dto.event), + originatingSessionUuid: dto.originatingSessionUuid, + }) + + await this.domainEventPublisher.publish(event) + + return Result.ok() + } catch (error) { + return Result.fail(`Failed to send event to client: ${(error as Error).message}`) } - const userUuid = userUuidOrError.getValue() - - this.logger.info(`[WebSockets] Requesting message ${dto.event.type} to user ${dto.userUuid}`) - - const event = this.domainEventFactory.createWebSocketMessageRequestedEvent({ - userUuid: userUuid.value, - message: JSON.stringify(dto.event), - }) - - await this.domainEventPublisher.publish(event) - - return Result.ok() } } diff --git a/packages/syncing-server/src/Domain/UseCase/Syncing/SendEventToClient/SendEventToClientDTO.ts b/packages/syncing-server/src/Domain/UseCase/Syncing/SendEventToClient/SendEventToClientDTO.ts index afab531a5..db3ace4a6 100644 --- a/packages/syncing-server/src/Domain/UseCase/Syncing/SendEventToClient/SendEventToClientDTO.ts +++ b/packages/syncing-server/src/Domain/UseCase/Syncing/SendEventToClient/SendEventToClientDTO.ts @@ -3,4 +3,5 @@ import { DomainEventInterface } from '@standardnotes/domain-events' export interface SendEventToClientDTO { userUuid: string event: DomainEventInterface + originatingSessionUuid?: string } diff --git a/packages/websockets/.env.sample b/packages/websockets/.env.sample index d6665a946..eccd6d3f5 100644 --- a/packages/websockets/.env.sample +++ b/packages/websockets/.env.sample @@ -8,6 +8,15 @@ AUTH_JWT_SECRET=auth_jwt_secret REDIS_URL=redis://cache +DB_HOST=127.0.0.1 +DB_REPLICA_HOST=127.0.0.1 +DB_PORT=3306 +DB_USERNAME=websockets +DB_PASSWORD=changeme123 +DB_DATABASE=websockets +DB_DEBUG_LEVEL=all # "all" | "query" | "schema" | "error" | "warn" | "info" | "log" | "migration" +DB_TYPE=mysql + SNS_TOPIC_ARN= SNS_AWS_REGION= SQS_QUEUE_URL= diff --git a/packages/websockets/bin/server.ts b/packages/websockets/bin/server.ts index f9cf6af5c..eca13d614 100644 --- a/packages/websockets/bin/server.ts +++ b/packages/websockets/bin/server.ts @@ -12,7 +12,7 @@ import { ContainerConfigLoader } from '../src/Bootstrap/Container' import TYPES from '../src/Bootstrap/Types' import { Env } from '../src/Bootstrap/Env' -const container = new ContainerConfigLoader() +const container = new ContainerConfigLoader('server') void container.load().then((container) => { const env: Env = new Env() env.load() diff --git a/packages/websockets/bin/worker.ts b/packages/websockets/bin/worker.ts index 389dcbaf8..11a82393b 100644 --- a/packages/websockets/bin/worker.ts +++ b/packages/websockets/bin/worker.ts @@ -7,7 +7,7 @@ import TYPES from '../src/Bootstrap/Types' import { Env } from '../src/Bootstrap/Env' import { DomainEventSubscriberInterface } from '@standardnotes/domain-events' -const container = new ContainerConfigLoader() +const container = new ContainerConfigLoader('worker') void container.load().then((container) => { const env: Env = new Env() env.load() diff --git a/packages/websockets/migrations/mysql/1701087671322-initial-database.ts b/packages/websockets/migrations/mysql/1701087671322-initial-database.ts new file mode 100644 index 000000000..a1e114b27 --- /dev/null +++ b/packages/websockets/migrations/mysql/1701087671322-initial-database.ts @@ -0,0 +1,17 @@ +import { MigrationInterface, QueryRunner } from 'typeorm' + +export class InitialDatabase1701087671322 implements MigrationInterface { + name = 'InitialDatabase1701087671322' + + public async up(queryRunner: QueryRunner): Promise { + await queryRunner.query( + 'CREATE TABLE `connections` (`uuid` varchar(36) NOT NULL, `user_uuid` varchar(36) NOT NULL, `session_uuid` varchar(36) NOT NULL, `connection_id` varchar(255) NOT NULL, `created_at_timestamp` bigint NOT NULL, `updated_at_timestamp` bigint NOT NULL, INDEX `index_connections_on_user_uuid` (`user_uuid`), UNIQUE INDEX `index_connections_on_connection_id` (`connection_id`), PRIMARY KEY (`uuid`)) ENGINE=InnoDB', + ) + } + + public async down(queryRunner: QueryRunner): Promise { + await queryRunner.query('DROP INDEX `index_connections_on_connection_id` ON `connections`') + await queryRunner.query('DROP INDEX `index_connections_on_user_uuid` ON `connections`') + await queryRunner.query('DROP TABLE `connections`') + } +} diff --git a/packages/websockets/package.json b/packages/websockets/package.json index 8ee61b602..983be44a8 100644 --- a/packages/websockets/package.json +++ b/packages/websockets/package.json @@ -29,13 +29,13 @@ "dependencies": { "@aws-sdk/client-apigatewaymanagementapi": "^3.427.0", "@aws-sdk/client-sqs": "^3.427.0", - "@standardnotes/api": "^1.26.26", "@standardnotes/common": "workspace:^", "@standardnotes/domain-core": "workspace:^", "@standardnotes/domain-events": "workspace:^", "@standardnotes/domain-events-infra": "workspace:^", "@standardnotes/responses": "^1.13.27", "@standardnotes/security": "workspace:^", + "@standardnotes/time": "workspace:^", "cors": "2.8.5", "dotenv": "^16.0.1", "express": "^4.18.2", diff --git a/packages/websockets/src/Bootstrap/Container.ts b/packages/websockets/src/Bootstrap/Container.ts index 54edd5393..777f2487b 100644 --- a/packages/websockets/src/Bootstrap/Container.ts +++ b/packages/websockets/src/Bootstrap/Container.ts @@ -1,5 +1,4 @@ import * as winston from 'winston' -import Redis from 'ioredis' import { SQSClient, SQSClientConfig } from '@aws-sdk/client-sqs' import { ApiGatewayManagementApiClient } from '@aws-sdk/client-apigatewaymanagementapi' import { Container } from 'inversify' @@ -8,16 +7,14 @@ import { DomainEventMessageHandlerInterface, DomainEventSubscriberInterface, } from '@standardnotes/domain-events' +import { TimerInterface, Timer } from '@standardnotes/time' import { Env } from './Env' import TYPES from './Types' import { WebSocketsConnectionRepositoryInterface } from '../Domain/WebSockets/WebSocketsConnectionRepositoryInterface' -import { RedisWebSocketsConnectionRepository } from '../Infra/Redis/RedisWebSocketsConnectionRepository' import { AddWebSocketsConnection } from '../Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnection' import { RemoveWebSocketsConnection } from '../Domain/UseCase/RemoveWebSocketsConnection/RemoveWebSocketsConnection' -import { WebSocketsClientMessenger } from '../Infra/WebSockets/WebSocketsClientMessenger' import { SQSDomainEventSubscriber, SQSEventMessageHandler } from '@standardnotes/domain-events-infra' -import { ApiGatewayAuthMiddleware } from '../Controller/ApiGatewayAuthMiddleware' - +import { ApiGatewayAuthMiddleware } from '../Infra/InversifyExpressUtils/Middleware/ApiGatewayAuthMiddleware' import { CrossServiceTokenData, TokenDecoder, @@ -27,29 +24,25 @@ import { WebSocketConnectionTokenData, } from '@standardnotes/security' import { CreateWebSocketConnectionToken } from '../Domain/UseCase/CreateWebSocketConnectionToken/CreateWebSocketConnectionToken' -import { WebSocketsController } from '../Controller/WebSocketsController' -import { WebSocketServerInterface } from '@standardnotes/api' -import { ClientMessengerInterface } from '../Client/ClientMessengerInterface' import { WebSocketMessageRequestedEventHandler } from '../Domain/Handler/WebSocketMessageRequestedEventHandler' +import { SQLConnectionRepository } from '../Infra/TypeORM/SQLConnectionRepository' +import { Connection } from '../Domain/Connection/Connection' +import { SQLConnection } from '../Infra/TypeORM/SQLConnection' +import { MapperInterface } from '@standardnotes/domain-core' +import { Repository } from 'typeorm' +import { ConnectionPersistenceMapper } from '../Mapping/SQL/ConnectionPersistenceMapper' +import { AppDataSource } from './DataSource' +import { SendMessageToClient } from '../Domain/UseCase/SendMessageToClient/SendMessageToClient' export class ContainerConfigLoader { + constructor(private mode: 'server' | 'worker' = 'server') {} + async load(): Promise { const env: Env = new Env() env.load() const container = new Container() - const redisUrl = env.get('REDIS_URL') - const isRedisInClusterMode = redisUrl.indexOf(',') > 0 - let redis - if (isRedisInClusterMode) { - redis = new Redis.Cluster(redisUrl.split(',')) - } else { - redis = new Redis(redisUrl) - } - - container.bind(TYPES.Redis).toConstantValue(redis) - const winstonFormatters = [winston.format.splat(), winston.format.json()] const logger = winston.createLogger({ @@ -59,6 +52,13 @@ export class ContainerConfigLoader { }) container.bind(TYPES.Logger).toConstantValue(logger) + const appDataSource = new AppDataSource({ env, runMigrations: this.mode === 'server' }) + await appDataSource.initialize() + + logger.debug('Database initialized') + + container.bind(TYPES.Timer).toConstantValue(new Timer()) + if (env.get('SQS_QUEUE_URL', true)) { const sqsConfig: SQSClientConfig = { region: env.get('SQS_AWS_REGION', true), @@ -83,14 +83,26 @@ export class ContainerConfigLoader { region: env.get('API_GATEWAY_AWS_REGION', true) ?? 'us-east-1', }), ) + // Mappers + container + .bind>(TYPES.ConnectionPersistenceMapper) + .toConstantValue(new ConnectionPersistenceMapper()) - // Controller - container.bind(TYPES.WebSocketsController).to(WebSocketsController) + // ORM + container + .bind>(TYPES.ORMConnectionRepository) + .toConstantValue(appDataSource.getRepository(SQLConnection)) // Repositories container .bind(TYPES.WebSocketsConnectionRepository) - .to(RedisWebSocketsConnectionRepository) + .toConstantValue( + new SQLConnectionRepository( + container.get>(TYPES.ORMConnectionRepository), + container.get>(TYPES.ConnectionPersistenceMapper), + container.get(TYPES.Logger), + ), + ) // Middleware container.bind(TYPES.ApiGatewayAuthMiddleware).to(ApiGatewayAuthMiddleware) @@ -103,21 +115,42 @@ export class ContainerConfigLoader { container .bind(TYPES.WEB_SOCKET_CONNECTION_TOKEN_TTL) .toConstantValue(+env.get('WEB_SOCKET_CONNECTION_TOKEN_TTL', true)) - container.bind(TYPES.REDIS_URL).toConstantValue(env.get('REDIS_URL')) container.bind(TYPES.SQS_QUEUE_URL).toConstantValue(env.get('SQS_QUEUE_URL')) container.bind(TYPES.VERSION).toConstantValue(env.get('VERSION')) // use cases - container.bind(TYPES.AddWebSocketsConnection).to(AddWebSocketsConnection) + container + .bind(TYPES.AddWebSocketsConnection) + .toConstantValue( + new AddWebSocketsConnection( + container.get(TYPES.WebSocketsConnectionRepository), + container.get(TYPES.Timer), + container.get(TYPES.Logger), + ), + ) container.bind(TYPES.RemoveWebSocketsConnection).to(RemoveWebSocketsConnection) container .bind(TYPES.CreateWebSocketConnectionToken) .to(CreateWebSocketConnectionToken) + container + .bind(TYPES.SendMessageToClient) + .toConstantValue( + new SendMessageToClient( + container.get(TYPES.WebSocketsConnectionRepository), + container.get(TYPES.WebSockets_ApiGatewayManagementApiClient), + container.get(TYPES.Logger), + ), + ) // Handlers container .bind(TYPES.WebSocketMessageRequestedEventHandler) - .to(WebSocketMessageRequestedEventHandler) + .toConstantValue( + new WebSocketMessageRequestedEventHandler( + container.get(TYPES.SendMessageToClient), + container.get(TYPES.Logger), + ), + ) // Services container @@ -128,7 +161,6 @@ export class ContainerConfigLoader { .toConstantValue( new TokenEncoder(container.get(TYPES.WEB_SOCKET_CONNECTION_TOKEN_SECRET)), ) - container.bind(TYPES.WebSocketsClientMessenger).to(WebSocketsClientMessenger) const eventHandlers: Map = new Map([ ['WEB_SOCKET_MESSAGE_REQUESTED', container.get(TYPES.WebSocketMessageRequestedEventHandler)], diff --git a/packages/websockets/src/Bootstrap/DataSource.ts b/packages/websockets/src/Bootstrap/DataSource.ts new file mode 100644 index 000000000..735471894 --- /dev/null +++ b/packages/websockets/src/Bootstrap/DataSource.ts @@ -0,0 +1,84 @@ +import { DataSource, EntityTarget, LoggerOptions, ObjectLiteral, Repository } from 'typeorm' +import { MysqlConnectionOptions } from 'typeorm/driver/mysql/MysqlConnectionOptions' +import { Env } from './Env' +import { SQLConnection } from '../Infra/TypeORM/SQLConnection' + +export class AppDataSource { + private _dataSource: DataSource | undefined + + constructor( + private configuration: { + env: Env + runMigrations: boolean + }, + ) {} + + getRepository(target: EntityTarget): Repository { + if (!this._dataSource) { + throw new Error('DataSource not initialized') + } + + return this._dataSource.getRepository(target) + } + + async initialize(): Promise { + await this.dataSource.initialize() + } + + get dataSource(): DataSource { + this.configuration.env.load() + + const maxQueryExecutionTime = this.configuration.env.get('DB_MAX_QUERY_EXECUTION_TIME', true) + ? +this.configuration.env.get('DB_MAX_QUERY_EXECUTION_TIME', true) + : 45_000 + + const commonDataSourceOptions = { + maxQueryExecutionTime, + entities: [SQLConnection], + migrations: [`${__dirname}/../../migrations/mysql/*.js`], + migrationsRun: this.configuration.runMigrations, + logging: this.configuration.env.get('DB_DEBUG_LEVEL', true) ?? 'info', + } + + const inReplicaMode = this.configuration.env.get('DB_REPLICA_HOST', true) ? true : false + + const replicationConfig = { + master: { + host: this.configuration.env.get('DB_HOST'), + port: parseInt(this.configuration.env.get('DB_PORT')), + username: this.configuration.env.get('DB_USERNAME'), + password: this.configuration.env.get('DB_PASSWORD'), + database: this.configuration.env.get('DB_DATABASE'), + }, + slaves: [ + { + host: this.configuration.env.get('DB_REPLICA_HOST', true), + port: parseInt(this.configuration.env.get('DB_PORT')), + username: this.configuration.env.get('DB_USERNAME'), + password: this.configuration.env.get('DB_PASSWORD'), + database: this.configuration.env.get('DB_DATABASE'), + }, + ], + removeNodeErrorCount: 10, + restoreNodeTimeout: 5, + } + + const mySQLDataSourceOptions: MysqlConnectionOptions = { + ...commonDataSourceOptions, + type: 'mysql', + charset: 'utf8mb4', + supportBigNumbers: true, + bigNumberStrings: false, + replication: inReplicaMode ? replicationConfig : undefined, + host: inReplicaMode ? undefined : this.configuration.env.get('DB_HOST'), + port: inReplicaMode ? undefined : parseInt(this.configuration.env.get('DB_PORT')), + username: inReplicaMode ? undefined : this.configuration.env.get('DB_USERNAME'), + password: inReplicaMode ? undefined : this.configuration.env.get('DB_PASSWORD'), + database: inReplicaMode ? undefined : this.configuration.env.get('DB_DATABASE'), + } + + this._dataSource = new DataSource(mySQLDataSourceOptions) + + return this._dataSource + } +} diff --git a/packages/websockets/src/Bootstrap/MigrationsDataSource.ts b/packages/websockets/src/Bootstrap/MigrationsDataSource.ts new file mode 100644 index 000000000..60a1bb2ea --- /dev/null +++ b/packages/websockets/src/Bootstrap/MigrationsDataSource.ts @@ -0,0 +1,7 @@ +import { AppDataSource } from './DataSource' +import { Env } from './Env' + +const env: Env = new Env() +env.load() + +export const MigrationsDataSource = new AppDataSource({ env, runMigrations: true }).dataSource diff --git a/packages/websockets/src/Bootstrap/Types.ts b/packages/websockets/src/Bootstrap/Types.ts index 3d815ac2c..78207f3e4 100644 --- a/packages/websockets/src/Bootstrap/Types.ts +++ b/packages/websockets/src/Bootstrap/Types.ts @@ -1,10 +1,12 @@ const TYPES = { Logger: Symbol.for('Logger'), - Redis: Symbol.for('Redis'), + Timer: Symbol.for('Timer'), SQS: Symbol.for('SQS'), WebSockets_ApiGatewayManagementApiClient: Symbol.for('WebSockets_ApiGatewayManagementApiClient'), - // Controller - WebSocketsController: Symbol.for('WebSocketsController'), + // Mappers + ConnectionPersistenceMapper: Symbol.for('ConnectionPersistenceMapper'), + // ORM + ORMConnectionRepository: Symbol.for('ORMConnectionRepository'), // Repositories WebSocketsConnectionRepository: Symbol.for('WebSocketsConnectionRepository'), // Middleware @@ -22,6 +24,7 @@ const TYPES = { AddWebSocketsConnection: Symbol.for('AddWebSocketsConnection'), RemoveWebSocketsConnection: Symbol.for('RemoveWebSocketsConnection'), CreateWebSocketConnectionToken: Symbol.for('CreateWebSocketConnectionToken'), + SendMessageToClient: Symbol.for('SendMessageToClient'), // Handlers WebSocketMessageRequestedEventHandler: Symbol.for('WebSocketMessageRequestedEventHandler'), // Services diff --git a/packages/websockets/src/Client/ClientMessengerInterface.ts b/packages/websockets/src/Client/ClientMessengerInterface.ts deleted file mode 100644 index d952ec6cc..000000000 --- a/packages/websockets/src/Client/ClientMessengerInterface.ts +++ /dev/null @@ -1,3 +0,0 @@ -export interface ClientMessengerInterface { - send(userUuid: string, message: string): Promise -} diff --git a/packages/websockets/src/Controller/ApiGatewayAuthMiddleware.spec.ts b/packages/websockets/src/Controller/ApiGatewayAuthMiddleware.spec.ts deleted file mode 100644 index 0b50305bf..000000000 --- a/packages/websockets/src/Controller/ApiGatewayAuthMiddleware.spec.ts +++ /dev/null @@ -1,99 +0,0 @@ -import 'reflect-metadata' - -import { ApiGatewayAuthMiddleware } from './ApiGatewayAuthMiddleware' -import { NextFunction, Request, Response } from 'express' -import { Logger } from 'winston' -import { CrossServiceTokenData, TokenDecoderInterface } from '@standardnotes/security' -import { RoleName } from '@standardnotes/domain-core' - -describe('ApiGatewayAuthMiddleware', () => { - let tokenDecoder: TokenDecoderInterface - let request: Request - let response: Response - let next: NextFunction - - const logger = { - debug: jest.fn(), - } as unknown as jest.Mocked - - const createMiddleware = () => new ApiGatewayAuthMiddleware(tokenDecoder, logger) - - beforeEach(() => { - tokenDecoder = {} as jest.Mocked> - tokenDecoder.decodeToken = jest.fn().mockReturnValue({ - user: { - uuid: '1-2-3', - email: 'test@test.te', - }, - roles: [ - { - uuid: 'a-b-c', - name: RoleName.NAMES.CoreUser, - }, - ], - }) - - request = { - headers: {}, - } as jest.Mocked - response = { - locals: {}, - } as jest.Mocked - response.status = jest.fn().mockReturnThis() - response.send = jest.fn() - next = jest.fn() - }) - - it('should authorize user', async () => { - request.headers['x-auth-token'] = 'auth-jwt-token' - - await createMiddleware().handler(request, response, next) - - expect(response.locals.user).toEqual({ - uuid: '1-2-3', - email: 'test@test.te', - }) - expect(response.locals.roles).toEqual([ - { - uuid: 'a-b-c', - name: RoleName.NAMES.CoreUser, - }, - ]) - - expect(next).toHaveBeenCalled() - }) - - it('should not authorize if request is missing auth jwt token in headers', async () => { - await createMiddleware().handler(request, response, next) - - expect(response.status).toHaveBeenCalledWith(401) - expect(next).not.toHaveBeenCalled() - }) - - it('should not authorize if auth jwt token is malformed', async () => { - request.headers['x-auth-token'] = 'auth-jwt-token' - - tokenDecoder.decodeToken = jest.fn().mockReturnValue(undefined) - - await createMiddleware().handler(request, response, next) - - expect(response.status).toHaveBeenCalledWith(401) - expect(next).not.toHaveBeenCalled() - }) - - it('should pass the error to next middleware if one occurres', async () => { - request.headers['x-auth-token'] = 'auth-jwt-token' - - const error = new Error('Ooops') - - tokenDecoder.decodeToken = jest.fn().mockImplementation(() => { - throw error - }) - - await createMiddleware().handler(request, response, next) - - expect(response.status).not.toHaveBeenCalled() - - expect(next).toHaveBeenCalledWith(error) - }) -}) diff --git a/packages/websockets/src/Controller/WebSocketsController.spec.ts b/packages/websockets/src/Controller/WebSocketsController.spec.ts deleted file mode 100644 index 26ff9b223..000000000 --- a/packages/websockets/src/Controller/WebSocketsController.spec.ts +++ /dev/null @@ -1,28 +0,0 @@ -import 'reflect-metadata' - -import { WebSocketsController } from './WebSocketsController' -import { CreateWebSocketConnectionToken } from '../Domain/UseCase/CreateWebSocketConnectionToken/CreateWebSocketConnectionToken' - -describe('WebSocketsController', () => { - let createWebSocketConnectionToken: CreateWebSocketConnectionToken - - const createController = () => new WebSocketsController(createWebSocketConnectionToken) - - beforeEach(() => { - createWebSocketConnectionToken = {} as jest.Mocked - createWebSocketConnectionToken.execute = jest.fn().mockReturnValue({ token: 'foobar' }) - }) - - it('should create a web sockets connection token', async () => { - const response = await createController().createConnectionToken({ userUuid: '1-2-3' }) - - expect(response).toEqual({ - status: 200, - data: { token: 'foobar' }, - }) - - expect(createWebSocketConnectionToken.execute).toHaveBeenCalledWith({ - userUuid: '1-2-3', - }) - }) -}) diff --git a/packages/websockets/src/Controller/WebSocketsController.ts b/packages/websockets/src/Controller/WebSocketsController.ts deleted file mode 100644 index e3b3509de..000000000 --- a/packages/websockets/src/Controller/WebSocketsController.ts +++ /dev/null @@ -1,29 +0,0 @@ -import { HttpStatusCode, HttpResponse } from '@standardnotes/responses' -import { - WebSocketConnectionTokenRequestParams, - WebSocketConnectionTokenResponseBody, - WebSocketServerInterface, -} from '@standardnotes/api' -import { inject, injectable } from 'inversify' - -import TYPES from '../Bootstrap/Types' -import { CreateWebSocketConnectionToken } from '../Domain/UseCase/CreateWebSocketConnectionToken/CreateWebSocketConnectionToken' - -@injectable() -export class WebSocketsController implements WebSocketServerInterface { - constructor( - @inject(TYPES.CreateWebSocketConnectionToken) - private createWebSocketConnectionToken: CreateWebSocketConnectionToken, - ) {} - - async createConnectionToken( - params: WebSocketConnectionTokenRequestParams, - ): Promise> { - const result = await this.createWebSocketConnectionToken.execute({ userUuid: params.userUuid as string }) - - return { - status: HttpStatusCode.Success, - data: result, - } - } -} diff --git a/packages/websockets/src/Domain/Connection/Connection.ts b/packages/websockets/src/Domain/Connection/Connection.ts new file mode 100644 index 000000000..e0aac6243 --- /dev/null +++ b/packages/websockets/src/Domain/Connection/Connection.ts @@ -0,0 +1,13 @@ +import { Entity, Result, UniqueEntityId } from '@standardnotes/domain-core' + +import { ConnectionProps } from './ConnectionProps' + +export class Connection extends Entity { + private constructor(props: ConnectionProps, id?: UniqueEntityId) { + super(props, id) + } + + static create(props: ConnectionProps, id?: UniqueEntityId): Result { + return Result.ok(new Connection(props, id)) + } +} diff --git a/packages/websockets/src/Domain/Connection/ConnectionProps.ts b/packages/websockets/src/Domain/Connection/ConnectionProps.ts new file mode 100644 index 000000000..60aae1bc2 --- /dev/null +++ b/packages/websockets/src/Domain/Connection/ConnectionProps.ts @@ -0,0 +1,8 @@ +import { Timestamps, Uuid } from '@standardnotes/domain-core' + +export interface ConnectionProps { + userUuid: Uuid + sessionUuid: Uuid + connectionId: string + timestamps: Timestamps +} diff --git a/packages/websockets/src/Domain/Handler/WebSocketMessageRequestedEventHandler.ts b/packages/websockets/src/Domain/Handler/WebSocketMessageRequestedEventHandler.ts index a72245c02..849be62a2 100644 --- a/packages/websockets/src/Domain/Handler/WebSocketMessageRequestedEventHandler.ts +++ b/packages/websockets/src/Domain/Handler/WebSocketMessageRequestedEventHandler.ts @@ -1,20 +1,22 @@ import { DomainEventHandlerInterface, WebSocketMessageRequestedEvent } from '@standardnotes/domain-events' -import { inject, injectable } from 'inversify' import { Logger } from 'winston' +import { SendMessageToClient } from '../UseCase/SendMessageToClient/SendMessageToClient' -import TYPES from '../../Bootstrap/Types' -import { ClientMessengerInterface } from '../../Client/ClientMessengerInterface' - -@injectable() export class WebSocketMessageRequestedEventHandler implements DomainEventHandlerInterface { constructor( - @inject(TYPES.WebSocketsClientMessenger) private webSocketsClientMessenger: ClientMessengerInterface, - @inject(TYPES.Logger) private logger: Logger, + private sendMessageToClient: SendMessageToClient, + private logger: Logger, ) {} async handle(event: WebSocketMessageRequestedEvent): Promise { - this.logger.debug(`Sending message to user ${event.payload.userUuid}`) + const result = await this.sendMessageToClient.execute({ + userUuid: event.payload.userUuid, + message: event.payload.message, + originatingSessionUuid: event.payload.originatingSessionUuid, + }) - await this.webSocketsClientMessenger.send(event.payload.userUuid, event.payload.message) + if (result.isFailed()) { + this.logger.error(`Could not send message to user ${event.payload.userUuid}. Error: ${result.getError()}`) + } } } diff --git a/packages/websockets/src/Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnection.spec.ts b/packages/websockets/src/Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnection.spec.ts index 3ad7fd93f..cae791d19 100644 --- a/packages/websockets/src/Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnection.spec.ts +++ b/packages/websockets/src/Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnection.spec.ts @@ -1,14 +1,15 @@ -import 'reflect-metadata' import { Logger } from 'winston' +import { TimerInterface } from '@standardnotes/time' import { WebSocketsConnectionRepositoryInterface } from '../../WebSockets/WebSocketsConnectionRepositoryInterface' import { AddWebSocketsConnection } from './AddWebSocketsConnection' describe('AddWebSocketsConnection', () => { let webSocketsConnectionRepository: WebSocketsConnectionRepositoryInterface + let timer: TimerInterface let logger: Logger - const createUseCase = () => new AddWebSocketsConnection(webSocketsConnectionRepository, logger) + const createUseCase = () => new AddWebSocketsConnection(webSocketsConnectionRepository, timer, logger) beforeEach(() => { webSocketsConnectionRepository = {} as jest.Mocked @@ -17,12 +18,18 @@ describe('AddWebSocketsConnection', () => { logger = {} as jest.Mocked logger.debug = jest.fn() logger.error = jest.fn() + + timer = {} as jest.Mocked + timer.getTimestampInMicroseconds = jest.fn().mockReturnValue(123) }) it('should save a web sockets connection for a user for further communication', async () => { - const result = await createUseCase().execute({ userUuid: '1-2-3', connectionId: '2-3-4' }) + const result = await createUseCase().execute({ + userUuid: '00000000-0000-0000-0000-000000000000', + sessionUuid: '00000000-0000-0000-0000-000000000000', + connectionId: '2-3-4', + }) - expect(webSocketsConnectionRepository.saveConnection).toHaveBeenCalledWith('1-2-3', '2-3-4') expect(result.isFailed()).toBe(false) }) @@ -31,7 +38,31 @@ describe('AddWebSocketsConnection', () => { .fn() .mockRejectedValueOnce(new Error('Could not save connection')) - const result = await createUseCase().execute({ userUuid: '1-2-3', connectionId: '2-3-4' }) + const result = await createUseCase().execute({ + userUuid: '00000000-0000-0000-0000-000000000000', + sessionUuid: '00000000-0000-0000-0000-000000000000', + connectionId: '2-3-4', + }) + + expect(result.isFailed()).toBe(true) + }) + + it('should return failure if the user uuid is invalid', async () => { + const result = await createUseCase().execute({ + userUuid: 'invalid', + sessionUuid: '00000000-0000-0000-0000-000000000000', + connectionId: '2-3-4', + }) + + expect(result.isFailed()).toBe(true) + }) + + it('should return error if the session uuid is invalid', async () => { + const result = await createUseCase().execute({ + userUuid: '00000000-0000-0000-0000-000000000000', + sessionUuid: 'invalid', + connectionId: '2-3-4', + }) expect(result.isFailed()).toBe(true) }) diff --git a/packages/websockets/src/Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnection.ts b/packages/websockets/src/Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnection.ts index 60a0d704a..85e48aaf0 100644 --- a/packages/websockets/src/Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnection.ts +++ b/packages/websockets/src/Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnection.ts @@ -1,24 +1,50 @@ -import { inject, injectable } from 'inversify' import { Logger } from 'winston' -import { Result, UseCaseInterface } from '@standardnotes/domain-core' +import { Result, Timestamps, UseCaseInterface, Uuid } from '@standardnotes/domain-core' +import { TimerInterface } from '@standardnotes/time' -import TYPES from '../../../Bootstrap/Types' import { WebSocketsConnectionRepositoryInterface } from '../../WebSockets/WebSocketsConnectionRepositoryInterface' import { AddWebSocketsConnectionDTO } from './AddWebSocketsConnectionDTO' +import { Connection } from '../../Connection/Connection' -@injectable() export class AddWebSocketsConnection implements UseCaseInterface { constructor( - @inject(TYPES.WebSocketsConnectionRepository) private webSocketsConnectionRepository: WebSocketsConnectionRepositoryInterface, - @inject(TYPES.Logger) private logger: Logger, + private timer: TimerInterface, + private logger: Logger, ) {} async execute(dto: AddWebSocketsConnectionDTO): Promise> { try { this.logger.debug(`Persisting connection ${dto.connectionId} for user ${dto.userUuid}`) - await this.webSocketsConnectionRepository.saveConnection(dto.userUuid, dto.connectionId) + const userUuidOrError = Uuid.create(dto.userUuid) + if (userUuidOrError.isFailed()) { + return Result.fail(userUuidOrError.getError()) + } + const userUuid = userUuidOrError.getValue() + + const sessionUuidOrError = Uuid.create(dto.sessionUuid) + if (sessionUuidOrError.isFailed()) { + return Result.fail(sessionUuidOrError.getError()) + } + const sessionUuid = sessionUuidOrError.getValue() + + const connectionOrError = Connection.create({ + userUuid, + sessionUuid, + connectionId: dto.connectionId, + timestamps: Timestamps.create( + this.timer.getTimestampInMicroseconds(), + this.timer.getTimestampInMicroseconds(), + ).getValue(), + }) + /* istanbul ignore next */ + if (connectionOrError.isFailed()) { + return Result.fail(connectionOrError.getError()) + } + const connection = connectionOrError.getValue() + + await this.webSocketsConnectionRepository.saveConnection(connection) return Result.ok() } catch (error) { diff --git a/packages/websockets/src/Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnectionDTO.ts b/packages/websockets/src/Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnectionDTO.ts index 6a9208d31..d8dd342e1 100644 --- a/packages/websockets/src/Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnectionDTO.ts +++ b/packages/websockets/src/Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnectionDTO.ts @@ -1,4 +1,5 @@ export type AddWebSocketsConnectionDTO = { userUuid: string + sessionUuid: string connectionId: string } diff --git a/packages/websockets/src/Domain/UseCase/SendMessageToClient/SendMessageToClient.spec.ts b/packages/websockets/src/Domain/UseCase/SendMessageToClient/SendMessageToClient.spec.ts new file mode 100644 index 000000000..67e681403 --- /dev/null +++ b/packages/websockets/src/Domain/UseCase/SendMessageToClient/SendMessageToClient.spec.ts @@ -0,0 +1,99 @@ +import { ApiGatewayManagementApiClient } from '@aws-sdk/client-apigatewaymanagementapi' +import { WebSocketsConnectionRepositoryInterface } from '../../WebSockets/WebSocketsConnectionRepositoryInterface' +import { SendMessageToClient } from './SendMessageToClient' +import { Logger } from 'winston' +import { Connection } from '../../Connection/Connection' +import { Timestamps, Uuid } from '@standardnotes/domain-core' + +describe('SendMessageToClient', () => { + let webSocketsConnectionRepository: WebSocketsConnectionRepositoryInterface + let apiGatewayManagementClient: ApiGatewayManagementApiClient + let logger: Logger + + const createUseCase = () => + new SendMessageToClient(webSocketsConnectionRepository, apiGatewayManagementClient, logger) + + beforeEach(() => { + const connection = Connection.create({ + userUuid: Uuid.create('00000000-0000-0000-0000-000000000000').getValue(), + connectionId: 'connection-id', + sessionUuid: Uuid.create('00000000-0000-0000-0000-000000000000').getValue(), + timestamps: Timestamps.create(123, 123).getValue(), + }).getValue() + + webSocketsConnectionRepository = {} as jest.Mocked + webSocketsConnectionRepository.findAllByUserUuid = jest.fn().mockResolvedValue([connection]) + + apiGatewayManagementClient = {} as jest.Mocked + apiGatewayManagementClient.send = jest.fn().mockResolvedValue({ $metadata: { httpStatusCode: 200 } }) + + logger = {} as jest.Mocked + logger.debug = jest.fn() + logger.error = jest.fn() + }) + + it('sends message to all connections for a user', async () => { + const useCase = createUseCase() + + const result = await useCase.execute({ + userUuid: '00000000-0000-0000-0000-000000000000', + message: 'message', + }) + + expect(result.isFailed()).toBe(false) + expect(apiGatewayManagementClient.send).toHaveBeenCalledTimes(1) + }) + + it('does not send message to originating session', async () => { + const useCase = createUseCase() + + const result = await useCase.execute({ + userUuid: '00000000-0000-0000-0000-000000000000', + message: 'message', + originatingSessionUuid: '00000000-0000-0000-0000-000000000000', + }) + + expect(result.isFailed()).toBe(false) + expect(apiGatewayManagementClient.send).toHaveBeenCalledTimes(0) + }) + + it('returns error if sending message fails', async () => { + apiGatewayManagementClient.send = jest.fn().mockRejectedValue(new Error('error')) + + const useCase = createUseCase() + + const result = await useCase.execute({ + userUuid: '00000000-0000-0000-0000-000000000000', + message: 'message', + }) + + expect(result.isFailed()).toBe(true) + expect(result.getError()).toBe( + 'Could not send message to connection connection-id for user 00000000-0000-0000-0000-000000000000. Error: error', + ) + }) + + it('returns error if the user uuid is invalid', async () => { + const useCase = createUseCase() + + const result = await useCase.execute({ + userUuid: 'invalid', + message: 'message', + }) + + expect(result.isFailed()).toBe(true) + }) + + it('return error if sending the message does not return a 200 status code', async () => { + apiGatewayManagementClient.send = jest.fn().mockResolvedValue({ $metadata: { httpStatusCode: 500 } }) + + const useCase = createUseCase() + + const result = await useCase.execute({ + userUuid: '00000000-0000-0000-0000-000000000000', + message: 'message', + }) + + expect(result.isFailed()).toBe(true) + }) +}) diff --git a/packages/websockets/src/Domain/UseCase/SendMessageToClient/SendMessageToClient.ts b/packages/websockets/src/Domain/UseCase/SendMessageToClient/SendMessageToClient.ts new file mode 100644 index 000000000..ce27d3548 --- /dev/null +++ b/packages/websockets/src/Domain/UseCase/SendMessageToClient/SendMessageToClient.ts @@ -0,0 +1,57 @@ +import { Result, UseCaseInterface, Uuid } from '@standardnotes/domain-core' +import { ApiGatewayManagementApiClient, PostToConnectionCommand } from '@aws-sdk/client-apigatewaymanagementapi' +import { Logger } from 'winston' + +import { SendMessageToClientDTO } from './SendMessageToClientDTO' +import { WebSocketsConnectionRepositoryInterface } from '../../WebSockets/WebSocketsConnectionRepositoryInterface' + +export class SendMessageToClient implements UseCaseInterface { + constructor( + private webSocketsConnectionRepository: WebSocketsConnectionRepositoryInterface, + private apiGatewayManagementClient: ApiGatewayManagementApiClient, + private logger: Logger, + ) {} + + async execute(dto: SendMessageToClientDTO): Promise> { + const userUuidOrError = Uuid.create(dto.userUuid) + if (userUuidOrError.isFailed()) { + return Result.fail(userUuidOrError.getError()) + } + const userUuid = userUuidOrError.getValue() + + const userConnections = await this.webSocketsConnectionRepository.findAllByUserUuid(userUuid) + + for (const connection of userConnections) { + if (dto.originatingSessionUuid && connection.props.sessionUuid.value === dto.originatingSessionUuid) { + continue + } + + this.logger.debug(`Sending message to connection ${connection.props.connectionId} for user ${userUuid.value}`) + + const requestParams = { + ConnectionId: connection.props.connectionId, + Data: dto.message, + } + + const command = new PostToConnectionCommand(requestParams) + + try { + const response = await this.apiGatewayManagementClient.send(command) + + if (response.$metadata.httpStatusCode !== 200) { + return Result.fail( + `Could not send message to connection ${connection.props.connectionId} for user ${userUuid.value}. Response status code: ${response.$metadata.httpStatusCode}`, + ) + } + } catch (error) { + return Result.fail( + `Could not send message to connection ${connection.props.connectionId} for user ${userUuid.value}. Error: ${ + (error as Error).message + }`, + ) + } + } + + return Result.ok() + } +} diff --git a/packages/websockets/src/Domain/UseCase/SendMessageToClient/SendMessageToClientDTO.ts b/packages/websockets/src/Domain/UseCase/SendMessageToClient/SendMessageToClientDTO.ts new file mode 100644 index 000000000..fa05d5f3e --- /dev/null +++ b/packages/websockets/src/Domain/UseCase/SendMessageToClient/SendMessageToClientDTO.ts @@ -0,0 +1,5 @@ +export interface SendMessageToClientDTO { + userUuid: string + message: string + originatingSessionUuid?: string +} diff --git a/packages/websockets/src/Domain/WebSockets/WebSocketsConnectionRepositoryInterface.ts b/packages/websockets/src/Domain/WebSockets/WebSocketsConnectionRepositoryInterface.ts index 2f59ea765..74d528fab 100644 --- a/packages/websockets/src/Domain/WebSockets/WebSocketsConnectionRepositoryInterface.ts +++ b/packages/websockets/src/Domain/WebSockets/WebSocketsConnectionRepositoryInterface.ts @@ -1,5 +1,8 @@ +import { Uuid } from '@standardnotes/domain-core' +import { Connection } from '../Connection/Connection' + export interface WebSocketsConnectionRepositoryInterface { - findAllByUserUuid(userUuid: string): Promise - saveConnection(userUuid: string, connectionId: string): Promise + findAllByUserUuid(userUuid: Uuid): Promise + saveConnection(connection: Connection): Promise removeConnection(connectionId: string): Promise } diff --git a/packages/websockets/src/Infra/InversifyExpressUtils/AnnotatedWebSocketsController.ts b/packages/websockets/src/Infra/InversifyExpressUtils/AnnotatedWebSocketsController.ts index 7fec54062..a7e7131dd 100644 --- a/packages/websockets/src/Infra/InversifyExpressUtils/AnnotatedWebSocketsController.ts +++ b/packages/websockets/src/Infra/InversifyExpressUtils/AnnotatedWebSocketsController.ts @@ -1,4 +1,3 @@ -import { WebSocketServerInterface } from '@standardnotes/api' import { Request, Response } from 'express' import { inject } from 'inversify' import { @@ -12,24 +11,26 @@ import { import TYPES from '../../Bootstrap/Types' import { AddWebSocketsConnection } from '../../Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnection' import { RemoveWebSocketsConnection } from '../../Domain/UseCase/RemoveWebSocketsConnection/RemoveWebSocketsConnection' +import { CreateWebSocketConnectionToken } from '../../Domain/UseCase/CreateWebSocketConnectionToken/CreateWebSocketConnectionToken' @controller('/sockets') export class AnnotatedWebSocketsController extends BaseHttpController { constructor( @inject(TYPES.AddWebSocketsConnection) private addWebSocketsConnection: AddWebSocketsConnection, @inject(TYPES.RemoveWebSocketsConnection) private removeWebSocketsConnection: RemoveWebSocketsConnection, - @inject(TYPES.WebSocketsController) private webSocketsController: WebSocketServerInterface, + @inject(TYPES.CreateWebSocketConnectionToken) + private createWebSocketConnectionToken: CreateWebSocketConnectionToken, ) { super() } @httpPost('/tokens', TYPES.ApiGatewayAuthMiddleware) async createConnectionToken(_request: Request, response: Response): Promise { - const result = await this.webSocketsController.createConnectionToken({ + const result = await this.createWebSocketConnectionToken.execute({ userUuid: response.locals.user.uuid, }) - return this.json(result.data, result.status) + return this.json(result) } @httpPost('/connections/:connectionId', TYPES.ApiGatewayAuthMiddleware) @@ -39,6 +40,7 @@ export class AnnotatedWebSocketsController extends BaseHttpController { ): Promise { const result = await this.addWebSocketsConnection.execute({ userUuid: response.locals.user.uuid, + sessionUuid: response.locals.session.uuid, connectionId: request.params.connectionId, }) diff --git a/packages/websockets/src/Controller/ApiGatewayAuthMiddleware.ts b/packages/websockets/src/Infra/InversifyExpressUtils/Middleware/ApiGatewayAuthMiddleware.ts similarity index 93% rename from packages/websockets/src/Controller/ApiGatewayAuthMiddleware.ts rename to packages/websockets/src/Infra/InversifyExpressUtils/Middleware/ApiGatewayAuthMiddleware.ts index 2665df1c3..c8c4cecff 100644 --- a/packages/websockets/src/Controller/ApiGatewayAuthMiddleware.ts +++ b/packages/websockets/src/Infra/InversifyExpressUtils/Middleware/ApiGatewayAuthMiddleware.ts @@ -3,7 +3,7 @@ import { NextFunction, Request, Response } from 'express' import { inject, injectable } from 'inversify' import { BaseMiddleware } from 'inversify-express-utils' import { Logger } from 'winston' -import TYPES from '../Bootstrap/Types' +import TYPES from '../../../Bootstrap/Types' @injectable() export class ApiGatewayAuthMiddleware extends BaseMiddleware { @@ -33,7 +33,7 @@ export class ApiGatewayAuthMiddleware extends BaseMiddleware { request.headers['x-auth-token'] as string, ) - if (token === undefined) { + if (token === undefined || token.session === undefined) { this.logger.debug('ApiGatewayAuthMiddleware authentication failure.') response.status(401).send({ diff --git a/packages/websockets/src/Infra/Redis/RedisWebSocketsConnectionRepository.spec.ts b/packages/websockets/src/Infra/Redis/RedisWebSocketsConnectionRepository.spec.ts deleted file mode 100644 index 4db514670..000000000 --- a/packages/websockets/src/Infra/Redis/RedisWebSocketsConnectionRepository.spec.ts +++ /dev/null @@ -1,44 +0,0 @@ -import 'reflect-metadata' - -import * as IORedis from 'ioredis' - -import { RedisWebSocketsConnectionRepository } from './RedisWebSocketsConnectionRepository' - -describe('RedisWebSocketsConnectionRepository', () => { - let redisClient: IORedis.Redis - - const createRepository = () => new RedisWebSocketsConnectionRepository(redisClient) - - beforeEach(() => { - redisClient = {} as jest.Mocked - redisClient.sadd = jest.fn() - redisClient.set = jest.fn() - redisClient.get = jest.fn() - redisClient.srem = jest.fn() - redisClient.del = jest.fn() - redisClient.smembers = jest.fn() - }) - - it('should save a connection to set of user connections', async () => { - await createRepository().saveConnection('1-2-3', '2-3-4') - - expect(redisClient.sadd).toHaveBeenCalledWith('ws_user_connections:1-2-3', '2-3-4') - expect(redisClient.set).toHaveBeenCalledWith('ws_connection:2-3-4', '1-2-3') - }) - - it('should remove a connection from the set of user connections', async () => { - redisClient.get = jest.fn().mockReturnValue('1-2-3') - - await createRepository().removeConnection('2-3-4') - - expect(redisClient.srem).toHaveBeenCalledWith('ws_user_connections:1-2-3', '2-3-4') - expect(redisClient.del).toHaveBeenCalledWith('ws_connection:2-3-4') - }) - - it('should return all connections for a user uuid', async () => { - const userUuid = '1-2-3' - - await createRepository().findAllByUserUuid(userUuid) - expect(redisClient.smembers).toHaveBeenCalledWith(`ws_user_connections:${userUuid}`) - }) -}) diff --git a/packages/websockets/src/Infra/Redis/RedisWebSocketsConnectionRepository.ts b/packages/websockets/src/Infra/Redis/RedisWebSocketsConnectionRepository.ts deleted file mode 100644 index 5e080475d..000000000 --- a/packages/websockets/src/Infra/Redis/RedisWebSocketsConnectionRepository.ts +++ /dev/null @@ -1,28 +0,0 @@ -import * as IORedis from 'ioredis' -import { inject, injectable } from 'inversify' -import TYPES from '../../Bootstrap/Types' -import { WebSocketsConnectionRepositoryInterface } from '../../Domain/WebSockets/WebSocketsConnectionRepositoryInterface' - -@injectable() -export class RedisWebSocketsConnectionRepository implements WebSocketsConnectionRepositoryInterface { - private readonly WEB_SOCKETS_USER_CONNECTIONS_PREFIX = 'ws_user_connections' - private readonly WEB_SOCKETS_CONNETION_PREFIX = 'ws_connection' - - constructor(@inject(TYPES.Redis) private redisClient: IORedis.Redis) {} - - async findAllByUserUuid(userUuid: string): Promise { - return await this.redisClient.smembers(`${this.WEB_SOCKETS_USER_CONNECTIONS_PREFIX}:${userUuid}`) - } - - async removeConnection(connectionId: string): Promise { - const userUuid = await this.redisClient.get(`${this.WEB_SOCKETS_CONNETION_PREFIX}:${connectionId}`) - - await this.redisClient.srem(`${this.WEB_SOCKETS_USER_CONNECTIONS_PREFIX}:${userUuid}`, connectionId) - await this.redisClient.del(`${this.WEB_SOCKETS_CONNETION_PREFIX}:${connectionId}`) - } - - async saveConnection(userUuid: string, connectionId: string): Promise { - await this.redisClient.set(`${this.WEB_SOCKETS_CONNETION_PREFIX}:${connectionId}`, userUuid) - await this.redisClient.sadd(`${this.WEB_SOCKETS_USER_CONNECTIONS_PREFIX}:${userUuid}`, connectionId) - } -} diff --git a/packages/websockets/src/Infra/TypeORM/SQLConnection.ts b/packages/websockets/src/Infra/TypeORM/SQLConnection.ts new file mode 100644 index 000000000..b128ff0d3 --- /dev/null +++ b/packages/websockets/src/Infra/TypeORM/SQLConnection.ts @@ -0,0 +1,42 @@ +import { Column, Entity, Index, PrimaryGeneratedColumn } from 'typeorm' + +@Entity({ name: 'connections' }) +export class SQLConnection { + @PrimaryGeneratedColumn('uuid') + declare uuid: string + + @Column({ + name: 'user_uuid', + type: 'varchar', + length: 36, + }) + @Index('index_connections_on_user_uuid') + declare userUuid: string + + @Column({ + name: 'session_uuid', + type: 'varchar', + length: 36, + }) + declare sessionUuid: string + + @Column({ + name: 'connection_id', + type: 'varchar', + length: 255, + }) + @Index('index_connections_on_connection_id', { unique: true }) + declare connectionId: string + + @Column({ + name: 'created_at_timestamp', + type: 'bigint', + }) + declare createdAtTimestamp: number + + @Column({ + name: 'updated_at_timestamp', + type: 'bigint', + }) + declare updatedAtTimestamp: number +} diff --git a/packages/websockets/src/Infra/TypeORM/SQLConnectionRepository.ts b/packages/websockets/src/Infra/TypeORM/SQLConnectionRepository.ts new file mode 100644 index 000000000..edb684a3a --- /dev/null +++ b/packages/websockets/src/Infra/TypeORM/SQLConnectionRepository.ts @@ -0,0 +1,38 @@ +import { Repository } from 'typeorm' +import { WebSocketsConnectionRepositoryInterface } from '../../Domain/WebSockets/WebSocketsConnectionRepositoryInterface' +import { SQLConnection } from './SQLConnection' +import { MapperInterface, Uuid } from '@standardnotes/domain-core' +import { Connection } from '../../Domain/Connection/Connection' +import { Logger } from 'winston' + +export class SQLConnectionRepository implements WebSocketsConnectionRepositoryInterface { + constructor( + protected ormRepository: Repository, + protected mapper: MapperInterface, + protected logger: Logger, + ) {} + + async findAllByUserUuid(userUuid: Uuid): Promise { + const persistence = await this.ormRepository + .createQueryBuilder() + .where('user_uuid = :userUuid', { userUuid: userUuid.value }) + .getMany() + + return persistence.map((p) => this.mapper.toDomain(p)) + } + + async saveConnection(connection: Connection): Promise { + const persistence = this.mapper.toProjection(connection) + + await this.ormRepository.save(persistence) + } + + async removeConnection(connectionId: string): Promise { + await this.ormRepository + .createQueryBuilder() + .delete() + .from(SQLConnection) + .where('connection_id = :connectionId', { connectionId }) + .execute() + } +} diff --git a/packages/websockets/src/Infra/WebSockets/WebSocketsClientMessenger.ts b/packages/websockets/src/Infra/WebSockets/WebSocketsClientMessenger.ts deleted file mode 100644 index c3019e301..000000000 --- a/packages/websockets/src/Infra/WebSockets/WebSocketsClientMessenger.ts +++ /dev/null @@ -1,49 +0,0 @@ -import { ApiGatewayManagementApiClient, PostToConnectionCommand } from '@aws-sdk/client-apigatewaymanagementapi' -import { Logger } from 'winston' -import { inject, injectable } from 'inversify' - -import TYPES from '../../Bootstrap/Types' -import { WebSocketsConnectionRepositoryInterface } from '../../Domain/WebSockets/WebSocketsConnectionRepositoryInterface' -import { ClientMessengerInterface } from '../../Client/ClientMessengerInterface' - -@injectable() -export class WebSocketsClientMessenger implements ClientMessengerInterface { - constructor( - @inject(TYPES.WebSocketsConnectionRepository) - private webSocketsConnectionRepository: WebSocketsConnectionRepositoryInterface, - @inject(TYPES.WebSockets_ApiGatewayManagementApiClient) - private apiGatewayManagementClient: ApiGatewayManagementApiClient, - @inject(TYPES.Logger) private logger: Logger, - ) {} - - async send(userUuid: string, message: string): Promise { - const userConnections = await this.webSocketsConnectionRepository.findAllByUserUuid(userUuid) - - for (const connectionUuid of userConnections) { - this.logger.debug(`Sending message to connection ${connectionUuid} for user ${userUuid}`) - - const requestParams = { - ConnectionId: connectionUuid, - Data: message, - } - - const command = new PostToConnectionCommand(requestParams) - - try { - const response = await this.apiGatewayManagementClient.send(command) - - if (response.$metadata.httpStatusCode !== 200) { - this.logger.error( - `Could not send message to connection ${connectionUuid} for user ${userUuid}. Response status code: ${response.$metadata.httpStatusCode}`, - ) - } - } catch (error) { - this.logger.error( - `Could not send message to connection ${connectionUuid} for user ${userUuid}. Error: ${ - (error as Error).message - }`, - ) - } - } - } -} diff --git a/packages/websockets/src/Infra/WebSockets/WebSocketsClientService.spec.ts b/packages/websockets/src/Infra/WebSockets/WebSocketsClientService.spec.ts deleted file mode 100644 index e2329825f..000000000 --- a/packages/websockets/src/Infra/WebSockets/WebSocketsClientService.spec.ts +++ /dev/null @@ -1,54 +0,0 @@ -import 'reflect-metadata' - -import { ApiGatewayManagementApiClient } from '@aws-sdk/client-apigatewaymanagementapi' - -import { WebSocketsConnectionRepositoryInterface } from '../../Domain/WebSockets/WebSocketsConnectionRepositoryInterface' -import { Logger } from 'winston' - -import { WebSocketsClientMessenger } from './WebSocketsClientMessenger' - -describe('WebSocketsClientMessenger', () => { - let connectionIds: string[] - let webSocketsConnectionRepository: WebSocketsConnectionRepositoryInterface - let apiGatewayManagementClient: ApiGatewayManagementApiClient - let logger: Logger - - const createService = () => - new WebSocketsClientMessenger(webSocketsConnectionRepository, apiGatewayManagementClient, logger) - - beforeEach(() => { - connectionIds = ['1', '2'] - - webSocketsConnectionRepository = {} as jest.Mocked - webSocketsConnectionRepository.findAllByUserUuid = jest.fn().mockReturnValue(connectionIds) - - apiGatewayManagementClient = {} as jest.Mocked - apiGatewayManagementClient.send = jest.fn().mockReturnValue({ $metadata: { httpStatusCode: 200 } }) - - logger = {} as jest.Mocked - logger.debug = jest.fn() - logger.error = jest.fn() - }) - - it('should send a message to all user connections', async () => { - await createService().send('1-2-3', 'message') - - expect(apiGatewayManagementClient.send).toHaveBeenCalledTimes(connectionIds.length) - }) - - it('should log an error if message could not be sent', async () => { - apiGatewayManagementClient.send = jest.fn().mockReturnValue({ $metadata: { httpStatusCode: 500 } }) - - await createService().send('1-2-3', 'message') - - expect(logger.error).toHaveBeenCalledTimes(connectionIds.length) - }) - - it('should log an error if message sending throws error', async () => { - apiGatewayManagementClient.send = jest.fn().mockRejectedValue(new Error('error')) - - await createService().send('1-2-3', 'message') - - expect(logger.error).toHaveBeenCalledTimes(connectionIds.length) - }) -}) diff --git a/packages/websockets/src/Mapping/SQL/ConnectionPersistenceMapper.ts b/packages/websockets/src/Mapping/SQL/ConnectionPersistenceMapper.ts new file mode 100644 index 000000000..a14cb763e --- /dev/null +++ b/packages/websockets/src/Mapping/SQL/ConnectionPersistenceMapper.ts @@ -0,0 +1,50 @@ +import { MapperInterface, Timestamps, Uuid } from '@standardnotes/domain-core' + +import { Connection } from '../../Domain/Connection/Connection' +import { SQLConnection } from '../../Infra/TypeORM/SQLConnection' + +export class ConnectionPersistenceMapper implements MapperInterface { + toDomain(projection: SQLConnection): Connection { + const userUuidOrError = Uuid.create(projection.userUuid) + if (userUuidOrError.isFailed()) { + throw new Error(`Failed to create connection from projection: ${userUuidOrError.getError()}`) + } + const userUuid = userUuidOrError.getValue() + + const sessionUuidOrError = Uuid.create(projection.sessionUuid) + if (sessionUuidOrError.isFailed()) { + throw new Error(`Failed to create connection from projection: ${sessionUuidOrError.getError()}`) + } + const sessionUuid = sessionUuidOrError.getValue() + + const timestampsOrError = Timestamps.create(projection.createdAtTimestamp, projection.updatedAtTimestamp) + if (timestampsOrError.isFailed()) { + throw new Error(`Failed to create connection from projection: ${timestampsOrError.getError()}`) + } + const timestamps = timestampsOrError.getValue() + + const connectionOrError = Connection.create({ + userUuid, + sessionUuid, + connectionId: projection.connectionId, + timestamps, + }) + if (connectionOrError.isFailed()) { + throw new Error(`Failed to create connection from projection: ${connectionOrError.getError()}`) + } + + return connectionOrError.getValue() + } + + toProjection(domain: Connection): SQLConnection { + const projection = new SQLConnection() + + projection.userUuid = domain.props.userUuid.value + projection.sessionUuid = domain.props.sessionUuid.value + projection.connectionId = domain.props.connectionId + projection.createdAtTimestamp = domain.props.timestamps.createdAt + projection.updatedAtTimestamp = domain.props.timestamps.updatedAt + + return projection + } +} diff --git a/yarn.lock b/yarn.lock index 34315c750..fc4b623b0 100644 --- a/yarn.lock +++ b/yarn.lock @@ -5887,13 +5887,13 @@ __metadata: dependencies: "@aws-sdk/client-apigatewaymanagementapi": "npm:^3.427.0" "@aws-sdk/client-sqs": "npm:^3.427.0" - "@standardnotes/api": "npm:^1.26.26" "@standardnotes/common": "workspace:^" "@standardnotes/domain-core": "workspace:^" "@standardnotes/domain-events": "workspace:^" "@standardnotes/domain-events-infra": "workspace:^" "@standardnotes/responses": "npm:^1.13.27" "@standardnotes/security": "workspace:^" + "@standardnotes/time": "workspace:^" "@types/cors": "npm:^2.8.9" "@types/express": "npm:^4.17.14" "@types/ioredis": "npm:^5.0.0"