Container.ts 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. import * as winston from 'winston'
  2. import Redis from 'ioredis'
  3. import { captureAWSv3Client } from 'aws-xray-sdk'
  4. import { SNSClient, SNSClientConfig } from '@aws-sdk/client-sns'
  5. import { SQSClient, SQSClientConfig } from '@aws-sdk/client-sqs'
  6. import { S3Client, S3ClientConfig } from '@aws-sdk/client-s3'
  7. import { Container } from 'inversify'
  8. import { Env } from './Env'
  9. import TYPES from './Types'
  10. import { UploadFileChunk } from '../Domain/UseCase/UploadFileChunk/UploadFileChunk'
  11. import { ValetTokenAuthMiddleware } from '../Infra/InversifyExpress/Middleware/ValetTokenAuthMiddleware'
  12. import { TokenDecoder, TokenDecoderInterface, ValetTokenData } from '@standardnotes/security'
  13. import { Timer, TimerInterface } from '@standardnotes/time'
  14. import { DomainEventFactoryInterface } from '../Domain/Event/DomainEventFactoryInterface'
  15. import { DomainEventFactory } from '../Domain/Event/DomainEventFactory'
  16. import {
  17. DirectCallDomainEventPublisher,
  18. DirectCallEventMessageHandler,
  19. SNSDomainEventPublisher,
  20. SQSDomainEventSubscriberFactory,
  21. SQSEventMessageHandler,
  22. SQSXRayEventMessageHandler,
  23. } from '@standardnotes/domain-events-infra'
  24. import { StreamDownloadFile } from '../Domain/UseCase/StreamDownloadFile/StreamDownloadFile'
  25. import { FileDownloaderInterface } from '../Domain/Services/FileDownloaderInterface'
  26. import { S3FileDownloader } from '../Infra/S3/S3FileDownloader'
  27. import { FileUploaderInterface } from '../Domain/Services/FileUploaderInterface'
  28. import { S3FileUploader } from '../Infra/S3/S3FileUploader'
  29. import { FSFileDownloader } from '../Infra/FS/FSFileDownloader'
  30. import { FSFileUploader } from '../Infra/FS/FSFileUploader'
  31. import { CreateUploadSession } from '../Domain/UseCase/CreateUploadSession/CreateUploadSession'
  32. import { FinishUploadSession } from '../Domain/UseCase/FinishUploadSession/FinishUploadSession'
  33. import { UploadRepositoryInterface } from '../Domain/Upload/UploadRepositoryInterface'
  34. import { RedisUploadRepository } from '../Infra/Redis/RedisUploadRepository'
  35. import { GetFileMetadata } from '../Domain/UseCase/GetFileMetadata/GetFileMetadata'
  36. import { FileRemoverInterface } from '../Domain/Services/FileRemoverInterface'
  37. import { S3FileRemover } from '../Infra/S3/S3FileRemover'
  38. import { FSFileRemover } from '../Infra/FS/FSFileRemover'
  39. import { RemoveFile } from '../Domain/UseCase/RemoveFile/RemoveFile'
  40. import {
  41. DomainEventHandlerInterface,
  42. DomainEventMessageHandlerInterface,
  43. DomainEventPublisherInterface,
  44. DomainEventSubscriberFactoryInterface,
  45. } from '@standardnotes/domain-events'
  46. import { MarkFilesToBeRemoved } from '../Domain/UseCase/MarkFilesToBeRemoved/MarkFilesToBeRemoved'
  47. import { AccountDeletionRequestedEventHandler } from '../Domain/Handler/AccountDeletionRequestedEventHandler'
  48. import { SharedSubscriptionInvitationCanceledEventHandler } from '../Domain/Handler/SharedSubscriptionInvitationCanceledEventHandler'
  49. import { InMemoryUploadRepository } from '../Infra/InMemory/InMemoryUploadRepository'
  50. import { Transform } from 'stream'
  51. import { FileMoverInterface } from '../Domain/Services/FileMoverInterface'
  52. import { S3FileMover } from '../Infra/S3/S3FileMover'
  53. import { FSFileMover } from '../Infra/FS/FSFileMover'
  54. import { MoveFile } from '../Domain/UseCase/MoveFile/MoveFile'
  55. import { SharedVaultValetTokenAuthMiddleware } from '../Infra/InversifyExpress/Middleware/SharedVaultValetTokenAuthMiddleware'
  56. export class ContainerConfigLoader {
  57. async load(configuration?: {
  58. directCallDomainEventPublisher?: DirectCallDomainEventPublisher
  59. logger?: Transform
  60. environmentOverrides?: { [name: string]: string }
  61. }): Promise<Container> {
  62. const directCallDomainEventPublisher =
  63. configuration?.directCallDomainEventPublisher ?? new DirectCallDomainEventPublisher()
  64. const env: Env = new Env(configuration?.environmentOverrides)
  65. env.load()
  66. const container = new Container()
  67. if (env.get('NEW_RELIC_ENABLED', true) === 'true') {
  68. await import('newrelic')
  69. }
  70. // env vars
  71. container.bind(TYPES.Files_VALET_TOKEN_SECRET).toConstantValue(env.get('VALET_TOKEN_SECRET'))
  72. container
  73. .bind(TYPES.Files_MAX_CHUNK_BYTES)
  74. .toConstantValue(env.get('MAX_CHUNK_BYTES', true) ? +env.get('MAX_CHUNK_BYTES', true) : 100000000)
  75. container.bind(TYPES.Files_VERSION).toConstantValue(env.get('VERSION', true) ?? 'development')
  76. container
  77. .bind(TYPES.Files_FILE_UPLOAD_PATH)
  78. .toConstantValue(env.get('FILE_UPLOAD_PATH', true) ?? `${__dirname}/../../uploads`)
  79. const isConfiguredForHomeServer = env.get('MODE', true) === 'home-server'
  80. const isConfiguredForSelfHosting = env.get('MODE', true) === 'self-hosted'
  81. const isConfiguredForInMemoryCache = env.get('CACHE_TYPE', true) === 'memory'
  82. const isConfiguredForAWSProduction = !isConfiguredForHomeServer && !isConfiguredForSelfHosting
  83. let logger: winston.Logger
  84. if (configuration?.logger) {
  85. logger = configuration.logger as winston.Logger
  86. } else {
  87. logger = this.createLogger({ env })
  88. }
  89. container.bind<winston.Logger>(TYPES.Files_Logger).toConstantValue(logger)
  90. container.bind<TimerInterface>(TYPES.Files_Timer).toConstantValue(new Timer())
  91. // services
  92. container
  93. .bind<TokenDecoderInterface<ValetTokenData>>(TYPES.Files_ValetTokenDecoder)
  94. .toConstantValue(new TokenDecoder<ValetTokenData>(container.get(TYPES.Files_VALET_TOKEN_SECRET)))
  95. container
  96. .bind<DomainEventFactoryInterface>(TYPES.Files_DomainEventFactory)
  97. .toConstantValue(new DomainEventFactory(container.get<TimerInterface>(TYPES.Files_Timer)))
  98. if (isConfiguredForInMemoryCache) {
  99. container
  100. .bind<UploadRepositoryInterface>(TYPES.Files_UploadRepository)
  101. .toConstantValue(new InMemoryUploadRepository(container.get(TYPES.Files_Timer)))
  102. } else {
  103. container.bind(TYPES.Files_REDIS_URL).toConstantValue(env.get('REDIS_URL'))
  104. const redisUrl = container.get(TYPES.Files_REDIS_URL) as string
  105. const isRedisInClusterMode = redisUrl.indexOf(',') > 0
  106. let redis
  107. if (isRedisInClusterMode) {
  108. redis = new Redis.Cluster(redisUrl.split(','))
  109. } else {
  110. redis = new Redis(redisUrl)
  111. }
  112. container.bind(TYPES.Files_Redis).toConstantValue(redis)
  113. container.bind<UploadRepositoryInterface>(TYPES.Files_UploadRepository).to(RedisUploadRepository)
  114. }
  115. if (isConfiguredForHomeServer) {
  116. container
  117. .bind<DomainEventPublisherInterface>(TYPES.Files_DomainEventPublisher)
  118. .toConstantValue(directCallDomainEventPublisher)
  119. } else {
  120. container.bind(TYPES.Files_S3_BUCKET_NAME).toConstantValue(env.get('S3_BUCKET_NAME', true))
  121. container.bind(TYPES.Files_S3_AWS_REGION).toConstantValue(env.get('S3_AWS_REGION', true))
  122. container.bind(TYPES.Files_SNS_TOPIC_ARN).toConstantValue(env.get('SNS_TOPIC_ARN'))
  123. container.bind(TYPES.Files_SNS_AWS_REGION).toConstantValue(env.get('SNS_AWS_REGION', true))
  124. container.bind(TYPES.Files_SQS_QUEUE_URL).toConstantValue(env.get('SQS_QUEUE_URL'))
  125. if (env.get('SNS_TOPIC_ARN', true)) {
  126. const snsConfig: SNSClientConfig = {
  127. apiVersion: 'latest',
  128. region: env.get('SNS_AWS_REGION', true),
  129. }
  130. if (env.get('SNS_ENDPOINT', true)) {
  131. snsConfig.endpoint = env.get('SNS_ENDPOINT', true)
  132. }
  133. if (env.get('SNS_ACCESS_KEY_ID', true) && env.get('SNS_SECRET_ACCESS_KEY', true)) {
  134. snsConfig.credentials = {
  135. accessKeyId: env.get('SNS_ACCESS_KEY_ID', true),
  136. secretAccessKey: env.get('SNS_SECRET_ACCESS_KEY', true),
  137. }
  138. }
  139. let snsClient = new SNSClient(snsConfig)
  140. if (isConfiguredForAWSProduction) {
  141. snsClient = captureAWSv3Client(snsClient)
  142. }
  143. container.bind<SNSClient>(TYPES.Files_SNS).toConstantValue(snsClient)
  144. }
  145. if (env.get('SQS_QUEUE_URL', true)) {
  146. const sqsConfig: SQSClientConfig = {
  147. region: env.get('SQS_AWS_REGION', true),
  148. }
  149. if (env.get('SQS_ENDPOINT', true)) {
  150. sqsConfig.endpoint = env.get('SQS_ENDPOINT', true)
  151. }
  152. if (env.get('SQS_ACCESS_KEY_ID', true) && env.get('SQS_SECRET_ACCESS_KEY', true)) {
  153. sqsConfig.credentials = {
  154. accessKeyId: env.get('SQS_ACCESS_KEY_ID', true),
  155. secretAccessKey: env.get('SQS_SECRET_ACCESS_KEY', true),
  156. }
  157. }
  158. let sqsClient = new SQSClient(sqsConfig)
  159. if (isConfiguredForAWSProduction) {
  160. sqsClient = captureAWSv3Client(sqsClient)
  161. }
  162. container.bind<SQSClient>(TYPES.Files_SQS).toConstantValue(sqsClient)
  163. }
  164. container
  165. .bind<DomainEventPublisherInterface>(TYPES.Files_DomainEventPublisher)
  166. .toConstantValue(
  167. new SNSDomainEventPublisher(container.get(TYPES.Files_SNS), container.get(TYPES.Files_SNS_TOPIC_ARN)),
  168. )
  169. }
  170. if (!isConfiguredForHomeServer && (env.get('S3_AWS_REGION', true) || env.get('S3_ENDPOINT', true))) {
  171. const s3Opts: S3ClientConfig = {
  172. apiVersion: 'latest',
  173. }
  174. if (env.get('S3_AWS_REGION', true)) {
  175. s3Opts.region = env.get('S3_AWS_REGION', true)
  176. }
  177. if (env.get('S3_ENDPOINT', true)) {
  178. s3Opts.endpoint = env.get('S3_ENDPOINT', true)
  179. }
  180. let s3Client = new S3Client(s3Opts)
  181. if (isConfiguredForAWSProduction) {
  182. s3Client = captureAWSv3Client(s3Client)
  183. }
  184. container.bind<S3Client>(TYPES.Files_S3).toConstantValue(s3Client)
  185. container.bind<FileDownloaderInterface>(TYPES.Files_FileDownloader).to(S3FileDownloader)
  186. container.bind<FileUploaderInterface>(TYPES.Files_FileUploader).to(S3FileUploader)
  187. container.bind<FileRemoverInterface>(TYPES.Files_FileRemover).to(S3FileRemover)
  188. container.bind<FileMoverInterface>(TYPES.Files_FileMover).to(S3FileMover)
  189. } else {
  190. container.bind<FileDownloaderInterface>(TYPES.Files_FileDownloader).to(FSFileDownloader)
  191. container
  192. .bind<FileUploaderInterface>(TYPES.Files_FileUploader)
  193. .toConstantValue(
  194. new FSFileUploader(container.get(TYPES.Files_FILE_UPLOAD_PATH), container.get(TYPES.Files_Logger)),
  195. )
  196. container
  197. .bind<FileRemoverInterface>(TYPES.Files_FileRemover)
  198. .toConstantValue(new FSFileRemover(container.get<string>(TYPES.Files_FILE_UPLOAD_PATH)))
  199. container.bind<FileMoverInterface>(TYPES.Files_FileMover).to(FSFileMover)
  200. }
  201. // use cases
  202. container.bind<UploadFileChunk>(TYPES.Files_UploadFileChunk).to(UploadFileChunk)
  203. container.bind<StreamDownloadFile>(TYPES.Files_StreamDownloadFile).to(StreamDownloadFile)
  204. container.bind<CreateUploadSession>(TYPES.Files_CreateUploadSession).to(CreateUploadSession)
  205. container
  206. .bind<FinishUploadSession>(TYPES.Files_FinishUploadSession)
  207. .toConstantValue(
  208. new FinishUploadSession(
  209. container.get(TYPES.Files_FileUploader),
  210. container.get(TYPES.Files_UploadRepository),
  211. container.get(TYPES.Files_DomainEventPublisher),
  212. container.get(TYPES.Files_DomainEventFactory),
  213. ),
  214. )
  215. container
  216. .bind<GetFileMetadata>(TYPES.Files_GetFileMetadata)
  217. .toConstantValue(
  218. new GetFileMetadata(
  219. container.get<FileDownloaderInterface>(TYPES.Files_FileDownloader),
  220. container.get<winston.Logger>(TYPES.Files_Logger),
  221. ),
  222. )
  223. container.bind<RemoveFile>(TYPES.Files_RemoveFile).to(RemoveFile)
  224. container
  225. .bind<MoveFile>(TYPES.Files_MoveFile)
  226. .toConstantValue(
  227. new MoveFile(
  228. container.get<GetFileMetadata>(TYPES.Files_GetFileMetadata),
  229. container.get<FileMoverInterface>(TYPES.Files_FileMover),
  230. container.get<DomainEventPublisherInterface>(TYPES.Files_DomainEventPublisher),
  231. container.get<DomainEventFactoryInterface>(TYPES.Files_DomainEventFactory),
  232. container.get<winston.Logger>(TYPES.Files_Logger),
  233. ),
  234. )
  235. container.bind<MarkFilesToBeRemoved>(TYPES.Files_MarkFilesToBeRemoved).to(MarkFilesToBeRemoved)
  236. // middleware
  237. container.bind<ValetTokenAuthMiddleware>(TYPES.Files_ValetTokenAuthMiddleware).to(ValetTokenAuthMiddleware)
  238. container
  239. .bind<SharedVaultValetTokenAuthMiddleware>(TYPES.Files_SharedVaultValetTokenAuthMiddleware)
  240. .to(SharedVaultValetTokenAuthMiddleware)
  241. // Handlers
  242. container
  243. .bind<AccountDeletionRequestedEventHandler>(TYPES.Files_AccountDeletionRequestedEventHandler)
  244. .toConstantValue(
  245. new AccountDeletionRequestedEventHandler(
  246. container.get<MarkFilesToBeRemoved>(TYPES.Files_MarkFilesToBeRemoved),
  247. container.get<DomainEventPublisherInterface>(TYPES.Files_DomainEventPublisher),
  248. container.get<DomainEventFactoryInterface>(TYPES.Files_DomainEventFactory),
  249. container.get<winston.Logger>(TYPES.Files_Logger),
  250. ),
  251. )
  252. container
  253. .bind<SharedSubscriptionInvitationCanceledEventHandler>(
  254. TYPES.Files_SharedSubscriptionInvitationCanceledEventHandler,
  255. )
  256. .toConstantValue(
  257. new SharedSubscriptionInvitationCanceledEventHandler(
  258. container.get<MarkFilesToBeRemoved>(TYPES.Files_MarkFilesToBeRemoved),
  259. container.get<DomainEventPublisherInterface>(TYPES.Files_DomainEventPublisher),
  260. container.get<DomainEventFactoryInterface>(TYPES.Files_DomainEventFactory),
  261. container.get<winston.Logger>(TYPES.Files_Logger),
  262. ),
  263. )
  264. const eventHandlers: Map<string, DomainEventHandlerInterface> = new Map([
  265. ['ACCOUNT_DELETION_REQUESTED', container.get(TYPES.Files_AccountDeletionRequestedEventHandler)],
  266. [
  267. 'SHARED_SUBSCRIPTION_INVITATION_CANCELED',
  268. container.get(TYPES.Files_SharedSubscriptionInvitationCanceledEventHandler),
  269. ],
  270. ])
  271. if (isConfiguredForHomeServer) {
  272. const directCallEventMessageHandler = new DirectCallEventMessageHandler(
  273. eventHandlers,
  274. container.get(TYPES.Files_Logger),
  275. )
  276. directCallDomainEventPublisher.register(directCallEventMessageHandler)
  277. container
  278. .bind<DomainEventMessageHandlerInterface>(TYPES.Files_DomainEventMessageHandler)
  279. .toConstantValue(directCallEventMessageHandler)
  280. } else {
  281. container
  282. .bind<DomainEventMessageHandlerInterface>(TYPES.Files_DomainEventMessageHandler)
  283. .toConstantValue(
  284. env.get('NEW_RELIC_ENABLED', true) === 'true'
  285. ? new SQSXRayEventMessageHandler(eventHandlers, container.get(TYPES.Files_Logger))
  286. : new SQSEventMessageHandler(eventHandlers, container.get(TYPES.Files_Logger)),
  287. )
  288. container
  289. .bind<DomainEventSubscriberFactoryInterface>(TYPES.Files_DomainEventSubscriberFactory)
  290. .toConstantValue(
  291. new SQSDomainEventSubscriberFactory(
  292. container.get(TYPES.Files_SQS),
  293. container.get(TYPES.Files_SQS_QUEUE_URL),
  294. container.get(TYPES.Files_DomainEventMessageHandler),
  295. ),
  296. )
  297. }
  298. logger.debug('Configuration complete')
  299. return container
  300. }
  301. createLogger({ env }: { env: Env }): winston.Logger {
  302. return winston.createLogger({
  303. level: env.get('LOG_LEVEL', true) || 'info',
  304. format: winston.format.combine(winston.format.splat(), winston.format.json()),
  305. transports: [new winston.transports.Console({ level: env.get('LOG_LEVEL', true) || 'info' })],
  306. defaultMeta: { service: 'files' },
  307. })
  308. }
  309. }