Container.ts 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import * as winston from 'winston'
  2. // eslint-disable-next-line @typescript-eslint/no-var-requires
  3. const axios = require('axios')
  4. import { AxiosInstance } from 'axios'
  5. import Redis from 'ioredis'
  6. import { SQSClient, SQSClientConfig } from '@aws-sdk/client-sqs'
  7. import { ApiGatewayManagementApiClient } from '@aws-sdk/client-apigatewaymanagementapi'
  8. import { Container } from 'inversify'
  9. import {
  10. DomainEventHandlerInterface,
  11. DomainEventMessageHandlerInterface,
  12. DomainEventSubscriberFactoryInterface,
  13. } from '@standardnotes/domain-events'
  14. import { Env } from './Env'
  15. import TYPES from './Types'
  16. import { WebSocketsConnectionRepositoryInterface } from '../Domain/WebSockets/WebSocketsConnectionRepositoryInterface'
  17. import { RedisWebSocketsConnectionRepository } from '../Infra/Redis/RedisWebSocketsConnectionRepository'
  18. import { AddWebSocketsConnection } from '../Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnection'
  19. import { RemoveWebSocketsConnection } from '../Domain/UseCase/RemoveWebSocketsConnection/RemoveWebSocketsConnection'
  20. import { WebSocketsClientMessenger } from '../Infra/WebSockets/WebSocketsClientMessenger'
  21. import {
  22. OpenTelemetrySDK,
  23. OpenTelemetrySDKInterface,
  24. SQSDomainEventSubscriberFactory,
  25. SQSOpenTelemetryEventMessageHandler,
  26. } from '@standardnotes/domain-events-infra'
  27. import { ApiGatewayAuthMiddleware } from '../Controller/ApiGatewayAuthMiddleware'
  28. import {
  29. CrossServiceTokenData,
  30. TokenDecoder,
  31. TokenDecoderInterface,
  32. TokenEncoder,
  33. TokenEncoderInterface,
  34. WebSocketConnectionTokenData,
  35. } from '@standardnotes/security'
  36. import { CreateWebSocketConnectionToken } from '../Domain/UseCase/CreateWebSocketConnectionToken/CreateWebSocketConnectionToken'
  37. import { WebSocketsController } from '../Controller/WebSocketsController'
  38. import { WebSocketServerInterface } from '@standardnotes/api'
  39. import { ClientMessengerInterface } from '../Client/ClientMessengerInterface'
  40. import { WebSocketMessageRequestedEventHandler } from '../Domain/Handler/WebSocketMessageRequestedEventHandler'
  41. import { ServiceIdentifier } from '@standardnotes/domain-core'
  42. export class ContainerConfigLoader {
  43. constructor(private mode: 'server' | 'worker' = 'server') {}
  44. async load(): Promise<Container> {
  45. const env: Env = new Env()
  46. env.load()
  47. const container = new Container()
  48. container
  49. .bind<OpenTelemetrySDKInterface>(TYPES.WebSockets_OpenTelemetrySDK)
  50. .toConstantValue(
  51. new OpenTelemetrySDK(
  52. this.mode === 'server' ? ServiceIdentifier.NAMES.Websockets : ServiceIdentifier.NAMES.WebsocketsWorker,
  53. ),
  54. )
  55. const redisUrl = env.get('REDIS_URL')
  56. const isRedisInClusterMode = redisUrl.indexOf(',') > 0
  57. let redis
  58. if (isRedisInClusterMode) {
  59. redis = new Redis.Cluster(redisUrl.split(','))
  60. } else {
  61. redis = new Redis(redisUrl)
  62. }
  63. container.bind(TYPES.Redis).toConstantValue(redis)
  64. const winstonFormatters = [winston.format.splat(), winston.format.json()]
  65. if (env.get('NEW_RELIC_ENABLED', true) === 'true') {
  66. await import('newrelic')
  67. // eslint-disable-next-line @typescript-eslint/no-var-requires
  68. const newrelicFormatter = require('@newrelic/winston-enricher')
  69. const newrelicWinstonFormatter = newrelicFormatter(winston)
  70. winstonFormatters.push(newrelicWinstonFormatter())
  71. }
  72. const logger = winston.createLogger({
  73. level: env.get('LOG_LEVEL', true) || 'info',
  74. format: winston.format.combine(...winstonFormatters),
  75. transports: [new winston.transports.Console({ level: env.get('LOG_LEVEL', true) || 'info' })],
  76. })
  77. container.bind<winston.Logger>(TYPES.Logger).toConstantValue(logger)
  78. if (env.get('SQS_QUEUE_URL', true)) {
  79. const sqsConfig: SQSClientConfig = {
  80. region: env.get('SQS_AWS_REGION', true),
  81. }
  82. if (env.get('SQS_ENDPOINT', true)) {
  83. sqsConfig.endpoint = env.get('SQS_ENDPOINT', true)
  84. }
  85. if (env.get('SQS_ACCESS_KEY_ID', true) && env.get('SQS_SECRET_ACCESS_KEY', true)) {
  86. sqsConfig.credentials = {
  87. accessKeyId: env.get('SQS_ACCESS_KEY_ID', true),
  88. secretAccessKey: env.get('SQS_SECRET_ACCESS_KEY', true),
  89. }
  90. }
  91. container.bind<SQSClient>(TYPES.SQS).toConstantValue(new SQSClient(sqsConfig))
  92. }
  93. container.bind(TYPES.WEBSOCKETS_API_URL).toConstantValue(env.get('WEBSOCKETS_API_URL', true))
  94. container.bind<ApiGatewayManagementApiClient>(TYPES.WebSockets_ApiGatewayManagementApiClient).toConstantValue(
  95. new ApiGatewayManagementApiClient({
  96. endpoint: container.get(TYPES.WEBSOCKETS_API_URL),
  97. region: env.get('API_GATEWAY_AWS_REGION', true) ?? 'us-east-1',
  98. }),
  99. )
  100. // Controller
  101. container.bind<WebSocketServerInterface>(TYPES.WebSocketsController).to(WebSocketsController)
  102. // Repositories
  103. container
  104. .bind<WebSocketsConnectionRepositoryInterface>(TYPES.WebSocketsConnectionRepository)
  105. .to(RedisWebSocketsConnectionRepository)
  106. // Middleware
  107. container.bind<ApiGatewayAuthMiddleware>(TYPES.ApiGatewayAuthMiddleware).to(ApiGatewayAuthMiddleware)
  108. // env vars
  109. container.bind(TYPES.AUTH_JWT_SECRET).toConstantValue(env.get('AUTH_JWT_SECRET'))
  110. container
  111. .bind(TYPES.WEB_SOCKET_CONNECTION_TOKEN_SECRET)
  112. .toConstantValue(env.get('WEB_SOCKET_CONNECTION_TOKEN_SECRET', true))
  113. container
  114. .bind(TYPES.WEB_SOCKET_CONNECTION_TOKEN_TTL)
  115. .toConstantValue(+env.get('WEB_SOCKET_CONNECTION_TOKEN_TTL', true))
  116. container.bind(TYPES.REDIS_URL).toConstantValue(env.get('REDIS_URL'))
  117. container.bind(TYPES.SQS_QUEUE_URL).toConstantValue(env.get('SQS_QUEUE_URL'))
  118. container.bind(TYPES.NEW_RELIC_ENABLED).toConstantValue(env.get('NEW_RELIC_ENABLED', true))
  119. container.bind(TYPES.VERSION).toConstantValue(env.get('VERSION'))
  120. // use cases
  121. container.bind<AddWebSocketsConnection>(TYPES.AddWebSocketsConnection).to(AddWebSocketsConnection)
  122. container.bind<RemoveWebSocketsConnection>(TYPES.RemoveWebSocketsConnection).to(RemoveWebSocketsConnection)
  123. container
  124. .bind<CreateWebSocketConnectionToken>(TYPES.CreateWebSocketConnectionToken)
  125. .to(CreateWebSocketConnectionToken)
  126. // Handlers
  127. container
  128. .bind<WebSocketMessageRequestedEventHandler>(TYPES.WebSocketMessageRequestedEventHandler)
  129. .to(WebSocketMessageRequestedEventHandler)
  130. // Services
  131. container.bind<AxiosInstance>(TYPES.HTTPClient).toConstantValue(axios.create())
  132. container
  133. .bind<TokenDecoderInterface<CrossServiceTokenData>>(TYPES.CrossServiceTokenDecoder)
  134. .toConstantValue(new TokenDecoder<CrossServiceTokenData>(container.get(TYPES.AUTH_JWT_SECRET)))
  135. container
  136. .bind<TokenEncoderInterface<WebSocketConnectionTokenData>>(TYPES.WebSocketConnectionTokenEncoder)
  137. .toConstantValue(
  138. new TokenEncoder<WebSocketConnectionTokenData>(container.get(TYPES.WEB_SOCKET_CONNECTION_TOKEN_SECRET)),
  139. )
  140. container.bind<ClientMessengerInterface>(TYPES.WebSocketsClientMessenger).to(WebSocketsClientMessenger)
  141. const eventHandlers: Map<string, DomainEventHandlerInterface> = new Map([
  142. ['WEB_SOCKET_MESSAGE_REQUESTED', container.get(TYPES.WebSocketMessageRequestedEventHandler)],
  143. ])
  144. container
  145. .bind<DomainEventMessageHandlerInterface>(TYPES.DomainEventMessageHandler)
  146. .toConstantValue(new SQSOpenTelemetryEventMessageHandler(eventHandlers, container.get(TYPES.Logger)))
  147. container
  148. .bind<DomainEventSubscriberFactoryInterface>(TYPES.DomainEventSubscriberFactory)
  149. .toConstantValue(
  150. new SQSDomainEventSubscriberFactory(
  151. container.get(TYPES.SQS),
  152. container.get(TYPES.SQS_QUEUE_URL),
  153. container.get(TYPES.DomainEventMessageHandler),
  154. ),
  155. )
  156. return container
  157. }
  158. }