Container.ts 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. import * as winston from 'winston'
  2. import Redis from 'ioredis'
  3. import { captureAWSv3Client } from 'aws-xray-sdk'
  4. import { SNSClient, SNSClientConfig } from '@aws-sdk/client-sns'
  5. import { SQSClient, SQSClientConfig } from '@aws-sdk/client-sqs'
  6. import { S3Client, S3ClientConfig } from '@aws-sdk/client-s3'
  7. import { Container } from 'inversify'
  8. import { Env } from './Env'
  9. import TYPES from './Types'
  10. import { UploadFileChunk } from '../Domain/UseCase/UploadFileChunk/UploadFileChunk'
  11. import { ValetTokenAuthMiddleware } from '../Infra/InversifyExpress/Middleware/ValetTokenAuthMiddleware'
  12. import { TokenDecoder, TokenDecoderInterface, ValetTokenData } from '@standardnotes/security'
  13. import { Timer, TimerInterface } from '@standardnotes/time'
  14. import { DomainEventFactoryInterface } from '../Domain/Event/DomainEventFactoryInterface'
  15. import { DomainEventFactory } from '../Domain/Event/DomainEventFactory'
  16. import {
  17. DirectCallDomainEventPublisher,
  18. DirectCallEventMessageHandler,
  19. SNSDomainEventPublisher,
  20. SQSDomainEventSubscriberFactory,
  21. SQSEventMessageHandler,
  22. SQSXRayEventMessageHandler,
  23. } from '@standardnotes/domain-events-infra'
  24. import { StreamDownloadFile } from '../Domain/UseCase/StreamDownloadFile/StreamDownloadFile'
  25. import { FileDownloaderInterface } from '../Domain/Services/FileDownloaderInterface'
  26. import { S3FileDownloader } from '../Infra/S3/S3FileDownloader'
  27. import { FileUploaderInterface } from '../Domain/Services/FileUploaderInterface'
  28. import { S3FileUploader } from '../Infra/S3/S3FileUploader'
  29. import { FSFileDownloader } from '../Infra/FS/FSFileDownloader'
  30. import { FSFileUploader } from '../Infra/FS/FSFileUploader'
  31. import { CreateUploadSession } from '../Domain/UseCase/CreateUploadSession/CreateUploadSession'
  32. import { FinishUploadSession } from '../Domain/UseCase/FinishUploadSession/FinishUploadSession'
  33. import { UploadRepositoryInterface } from '../Domain/Upload/UploadRepositoryInterface'
  34. import { RedisUploadRepository } from '../Infra/Redis/RedisUploadRepository'
  35. import { GetFileMetadata } from '../Domain/UseCase/GetFileMetadata/GetFileMetadata'
  36. import { FileRemoverInterface } from '../Domain/Services/FileRemoverInterface'
  37. import { S3FileRemover } from '../Infra/S3/S3FileRemover'
  38. import { FSFileRemover } from '../Infra/FS/FSFileRemover'
  39. import { RemoveFile } from '../Domain/UseCase/RemoveFile/RemoveFile'
  40. import {
  41. DomainEventHandlerInterface,
  42. DomainEventMessageHandlerInterface,
  43. DomainEventPublisherInterface,
  44. DomainEventSubscriberFactoryInterface,
  45. } from '@standardnotes/domain-events'
  46. import { MarkFilesToBeRemoved } from '../Domain/UseCase/MarkFilesToBeRemoved/MarkFilesToBeRemoved'
  47. import { AccountDeletionRequestedEventHandler } from '../Domain/Handler/AccountDeletionRequestedEventHandler'
  48. import { SharedSubscriptionInvitationCanceledEventHandler } from '../Domain/Handler/SharedSubscriptionInvitationCanceledEventHandler'
  49. import { InMemoryUploadRepository } from '../Infra/InMemory/InMemoryUploadRepository'
  50. import { Transform } from 'stream'
  51. import { FileMoverInterface } from '../Domain/Services/FileMoverInterface'
  52. import { S3FileMover } from '../Infra/S3/S3FileMover'
  53. import { FSFileMover } from '../Infra/FS/FSFileMover'
  54. import { MoveFile } from '../Domain/UseCase/MoveFile/MoveFile'
  55. import { SharedVaultValetTokenAuthMiddleware } from '../Infra/InversifyExpress/Middleware/SharedVaultValetTokenAuthMiddleware'
  56. import { ServiceIdentifier } from '@standardnotes/domain-core'
  57. export class ContainerConfigLoader {
  58. async load(configuration?: {
  59. directCallDomainEventPublisher?: DirectCallDomainEventPublisher
  60. logger?: Transform
  61. environmentOverrides?: { [name: string]: string }
  62. }): Promise<Container> {
  63. const directCallDomainEventPublisher =
  64. configuration?.directCallDomainEventPublisher ?? new DirectCallDomainEventPublisher()
  65. const env: Env = new Env(configuration?.environmentOverrides)
  66. env.load()
  67. const container = new Container()
  68. if (env.get('NEW_RELIC_ENABLED', true) === 'true') {
  69. await import('newrelic')
  70. }
  71. // env vars
  72. container.bind(TYPES.Files_VALET_TOKEN_SECRET).toConstantValue(env.get('VALET_TOKEN_SECRET'))
  73. container
  74. .bind(TYPES.Files_MAX_CHUNK_BYTES)
  75. .toConstantValue(env.get('MAX_CHUNK_BYTES', true) ? +env.get('MAX_CHUNK_BYTES', true) : 100000000)
  76. container.bind(TYPES.Files_VERSION).toConstantValue(env.get('VERSION', true) ?? 'development')
  77. container
  78. .bind(TYPES.Files_FILE_UPLOAD_PATH)
  79. .toConstantValue(env.get('FILE_UPLOAD_PATH', true) ?? `${__dirname}/../../uploads`)
  80. const isConfiguredForHomeServer = env.get('MODE', true) === 'home-server'
  81. const isConfiguredForSelfHosting = env.get('MODE', true) === 'self-hosted'
  82. const isConfiguredForInMemoryCache = env.get('CACHE_TYPE', true) === 'memory'
  83. const isConfiguredForAWSProduction = !isConfiguredForHomeServer && !isConfiguredForSelfHosting
  84. let logger: winston.Logger
  85. if (configuration?.logger) {
  86. logger = configuration.logger as winston.Logger
  87. } else {
  88. logger = this.createLogger({ env })
  89. }
  90. container.bind<winston.Logger>(TYPES.Files_Logger).toConstantValue(logger)
  91. container.bind<TimerInterface>(TYPES.Files_Timer).toConstantValue(new Timer())
  92. // services
  93. container
  94. .bind<TokenDecoderInterface<ValetTokenData>>(TYPES.Files_ValetTokenDecoder)
  95. .toConstantValue(new TokenDecoder<ValetTokenData>(container.get(TYPES.Files_VALET_TOKEN_SECRET)))
  96. container
  97. .bind<DomainEventFactoryInterface>(TYPES.Files_DomainEventFactory)
  98. .toConstantValue(new DomainEventFactory(container.get<TimerInterface>(TYPES.Files_Timer)))
  99. if (isConfiguredForInMemoryCache) {
  100. container
  101. .bind<UploadRepositoryInterface>(TYPES.Files_UploadRepository)
  102. .toConstantValue(new InMemoryUploadRepository(container.get(TYPES.Files_Timer)))
  103. } else {
  104. container.bind(TYPES.Files_REDIS_URL).toConstantValue(env.get('REDIS_URL'))
  105. const redisUrl = container.get(TYPES.Files_REDIS_URL) as string
  106. const isRedisInClusterMode = redisUrl.indexOf(',') > 0
  107. let redis
  108. if (isRedisInClusterMode) {
  109. redis = new Redis.Cluster(redisUrl.split(','))
  110. } else {
  111. redis = new Redis(redisUrl)
  112. }
  113. container.bind(TYPES.Files_Redis).toConstantValue(redis)
  114. container.bind<UploadRepositoryInterface>(TYPES.Files_UploadRepository).to(RedisUploadRepository)
  115. }
  116. if (isConfiguredForHomeServer) {
  117. container
  118. .bind<DomainEventPublisherInterface>(TYPES.Files_DomainEventPublisher)
  119. .toConstantValue(directCallDomainEventPublisher)
  120. } else {
  121. container.bind(TYPES.Files_S3_BUCKET_NAME).toConstantValue(env.get('S3_BUCKET_NAME', true))
  122. container.bind(TYPES.Files_S3_AWS_REGION).toConstantValue(env.get('S3_AWS_REGION', true))
  123. container.bind(TYPES.Files_SNS_TOPIC_ARN).toConstantValue(env.get('SNS_TOPIC_ARN'))
  124. container.bind(TYPES.Files_SNS_AWS_REGION).toConstantValue(env.get('SNS_AWS_REGION', true))
  125. container.bind(TYPES.Files_SQS_QUEUE_URL).toConstantValue(env.get('SQS_QUEUE_URL'))
  126. if (env.get('SNS_TOPIC_ARN', true)) {
  127. const snsConfig: SNSClientConfig = {
  128. apiVersion: 'latest',
  129. region: env.get('SNS_AWS_REGION', true),
  130. }
  131. if (env.get('SNS_ENDPOINT', true)) {
  132. snsConfig.endpoint = env.get('SNS_ENDPOINT', true)
  133. }
  134. if (env.get('SNS_ACCESS_KEY_ID', true) && env.get('SNS_SECRET_ACCESS_KEY', true)) {
  135. snsConfig.credentials = {
  136. accessKeyId: env.get('SNS_ACCESS_KEY_ID', true),
  137. secretAccessKey: env.get('SNS_SECRET_ACCESS_KEY', true),
  138. }
  139. }
  140. let snsClient = new SNSClient(snsConfig)
  141. if (isConfiguredForAWSProduction) {
  142. snsClient = captureAWSv3Client(snsClient)
  143. }
  144. container.bind<SNSClient>(TYPES.Files_SNS).toConstantValue(snsClient)
  145. }
  146. if (env.get('SQS_QUEUE_URL', true)) {
  147. const sqsConfig: SQSClientConfig = {
  148. region: env.get('SQS_AWS_REGION', true),
  149. }
  150. if (env.get('SQS_ENDPOINT', true)) {
  151. sqsConfig.endpoint = env.get('SQS_ENDPOINT', true)
  152. }
  153. if (env.get('SQS_ACCESS_KEY_ID', true) && env.get('SQS_SECRET_ACCESS_KEY', true)) {
  154. sqsConfig.credentials = {
  155. accessKeyId: env.get('SQS_ACCESS_KEY_ID', true),
  156. secretAccessKey: env.get('SQS_SECRET_ACCESS_KEY', true),
  157. }
  158. }
  159. let sqsClient = new SQSClient(sqsConfig)
  160. if (isConfiguredForAWSProduction) {
  161. sqsClient = captureAWSv3Client(sqsClient)
  162. }
  163. container.bind<SQSClient>(TYPES.Files_SQS).toConstantValue(sqsClient)
  164. }
  165. container
  166. .bind<DomainEventPublisherInterface>(TYPES.Files_DomainEventPublisher)
  167. .toConstantValue(
  168. new SNSDomainEventPublisher(container.get(TYPES.Files_SNS), container.get(TYPES.Files_SNS_TOPIC_ARN)),
  169. )
  170. }
  171. if (!isConfiguredForHomeServer && (env.get('S3_AWS_REGION', true) || env.get('S3_ENDPOINT', true))) {
  172. const s3Opts: S3ClientConfig = {
  173. apiVersion: 'latest',
  174. }
  175. if (env.get('S3_AWS_REGION', true)) {
  176. s3Opts.region = env.get('S3_AWS_REGION', true)
  177. }
  178. if (env.get('S3_ENDPOINT', true)) {
  179. s3Opts.endpoint = env.get('S3_ENDPOINT', true)
  180. }
  181. let s3Client = new S3Client(s3Opts)
  182. if (isConfiguredForAWSProduction) {
  183. s3Client = captureAWSv3Client(s3Client)
  184. }
  185. container.bind<S3Client>(TYPES.Files_S3).toConstantValue(s3Client)
  186. container.bind<FileDownloaderInterface>(TYPES.Files_FileDownloader).to(S3FileDownloader)
  187. container.bind<FileUploaderInterface>(TYPES.Files_FileUploader).to(S3FileUploader)
  188. container.bind<FileRemoverInterface>(TYPES.Files_FileRemover).to(S3FileRemover)
  189. container.bind<FileMoverInterface>(TYPES.Files_FileMover).to(S3FileMover)
  190. } else {
  191. container.bind<FileDownloaderInterface>(TYPES.Files_FileDownloader).to(FSFileDownloader)
  192. container
  193. .bind<FileUploaderInterface>(TYPES.Files_FileUploader)
  194. .toConstantValue(
  195. new FSFileUploader(container.get(TYPES.Files_FILE_UPLOAD_PATH), container.get(TYPES.Files_Logger)),
  196. )
  197. container
  198. .bind<FileRemoverInterface>(TYPES.Files_FileRemover)
  199. .toConstantValue(new FSFileRemover(container.get<string>(TYPES.Files_FILE_UPLOAD_PATH)))
  200. container.bind<FileMoverInterface>(TYPES.Files_FileMover).to(FSFileMover)
  201. }
  202. // use cases
  203. container.bind<UploadFileChunk>(TYPES.Files_UploadFileChunk).to(UploadFileChunk)
  204. container.bind<StreamDownloadFile>(TYPES.Files_StreamDownloadFile).to(StreamDownloadFile)
  205. container.bind<CreateUploadSession>(TYPES.Files_CreateUploadSession).to(CreateUploadSession)
  206. container
  207. .bind<FinishUploadSession>(TYPES.Files_FinishUploadSession)
  208. .toConstantValue(
  209. new FinishUploadSession(
  210. container.get(TYPES.Files_FileUploader),
  211. container.get(TYPES.Files_UploadRepository),
  212. container.get(TYPES.Files_DomainEventPublisher),
  213. container.get(TYPES.Files_DomainEventFactory),
  214. ),
  215. )
  216. container
  217. .bind<GetFileMetadata>(TYPES.Files_GetFileMetadata)
  218. .toConstantValue(
  219. new GetFileMetadata(
  220. container.get<FileDownloaderInterface>(TYPES.Files_FileDownloader),
  221. container.get<winston.Logger>(TYPES.Files_Logger),
  222. ),
  223. )
  224. container.bind<RemoveFile>(TYPES.Files_RemoveFile).to(RemoveFile)
  225. container
  226. .bind<MoveFile>(TYPES.Files_MoveFile)
  227. .toConstantValue(
  228. new MoveFile(
  229. container.get<GetFileMetadata>(TYPES.Files_GetFileMetadata),
  230. container.get<FileMoverInterface>(TYPES.Files_FileMover),
  231. container.get<DomainEventPublisherInterface>(TYPES.Files_DomainEventPublisher),
  232. container.get<DomainEventFactoryInterface>(TYPES.Files_DomainEventFactory),
  233. container.get<winston.Logger>(TYPES.Files_Logger),
  234. ),
  235. )
  236. container.bind<MarkFilesToBeRemoved>(TYPES.Files_MarkFilesToBeRemoved).to(MarkFilesToBeRemoved)
  237. // middleware
  238. container.bind<ValetTokenAuthMiddleware>(TYPES.Files_ValetTokenAuthMiddleware).to(ValetTokenAuthMiddleware)
  239. container
  240. .bind<SharedVaultValetTokenAuthMiddleware>(TYPES.Files_SharedVaultValetTokenAuthMiddleware)
  241. .to(SharedVaultValetTokenAuthMiddleware)
  242. // Handlers
  243. container
  244. .bind<AccountDeletionRequestedEventHandler>(TYPES.Files_AccountDeletionRequestedEventHandler)
  245. .toConstantValue(
  246. new AccountDeletionRequestedEventHandler(
  247. container.get<MarkFilesToBeRemoved>(TYPES.Files_MarkFilesToBeRemoved),
  248. container.get<DomainEventPublisherInterface>(TYPES.Files_DomainEventPublisher),
  249. container.get<DomainEventFactoryInterface>(TYPES.Files_DomainEventFactory),
  250. container.get<winston.Logger>(TYPES.Files_Logger),
  251. ),
  252. )
  253. container
  254. .bind<SharedSubscriptionInvitationCanceledEventHandler>(
  255. TYPES.Files_SharedSubscriptionInvitationCanceledEventHandler,
  256. )
  257. .toConstantValue(
  258. new SharedSubscriptionInvitationCanceledEventHandler(
  259. container.get<MarkFilesToBeRemoved>(TYPES.Files_MarkFilesToBeRemoved),
  260. container.get<DomainEventPublisherInterface>(TYPES.Files_DomainEventPublisher),
  261. container.get<DomainEventFactoryInterface>(TYPES.Files_DomainEventFactory),
  262. container.get<winston.Logger>(TYPES.Files_Logger),
  263. ),
  264. )
  265. const eventHandlers: Map<string, DomainEventHandlerInterface> = new Map([
  266. ['ACCOUNT_DELETION_REQUESTED', container.get(TYPES.Files_AccountDeletionRequestedEventHandler)],
  267. [
  268. 'SHARED_SUBSCRIPTION_INVITATION_CANCELED',
  269. container.get(TYPES.Files_SharedSubscriptionInvitationCanceledEventHandler),
  270. ],
  271. ])
  272. if (isConfiguredForHomeServer) {
  273. const directCallEventMessageHandler = new DirectCallEventMessageHandler(
  274. eventHandlers,
  275. container.get(TYPES.Files_Logger),
  276. )
  277. directCallDomainEventPublisher.register(directCallEventMessageHandler)
  278. container
  279. .bind<DomainEventMessageHandlerInterface>(TYPES.Files_DomainEventMessageHandler)
  280. .toConstantValue(directCallEventMessageHandler)
  281. } else {
  282. container
  283. .bind<DomainEventMessageHandlerInterface>(TYPES.Files_DomainEventMessageHandler)
  284. .toConstantValue(
  285. env.get('NEW_RELIC_ENABLED', true) === 'true'
  286. ? new SQSXRayEventMessageHandler(
  287. ServiceIdentifier.NAMES.FilesWorker,
  288. eventHandlers,
  289. container.get(TYPES.Files_Logger),
  290. )
  291. : new SQSEventMessageHandler(eventHandlers, container.get(TYPES.Files_Logger)),
  292. )
  293. container
  294. .bind<DomainEventSubscriberFactoryInterface>(TYPES.Files_DomainEventSubscriberFactory)
  295. .toConstantValue(
  296. new SQSDomainEventSubscriberFactory(
  297. container.get(TYPES.Files_SQS),
  298. container.get(TYPES.Files_SQS_QUEUE_URL),
  299. container.get(TYPES.Files_DomainEventMessageHandler),
  300. ),
  301. )
  302. }
  303. logger.debug('Configuration complete')
  304. return container
  305. }
  306. createLogger({ env }: { env: Env }): winston.Logger {
  307. return winston.createLogger({
  308. level: env.get('LOG_LEVEL', true) || 'info',
  309. format: winston.format.combine(winston.format.splat(), winston.format.json()),
  310. transports: [new winston.transports.Console({ level: env.get('LOG_LEVEL', true) || 'info' })],
  311. defaultMeta: { service: 'files' },
  312. })
  313. }
  314. }