AuthMiddleware.ts 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import { CrossServiceTokenData } from '@standardnotes/security'
  2. import { RoleName } from '@standardnotes/domain-core'
  3. import { TimerInterface } from '@standardnotes/time'
  4. import { NextFunction, Request, Response } from 'express'
  5. import { inject, injectable } from 'inversify'
  6. import { BaseMiddleware } from 'inversify-express-utils'
  7. import { verify } from 'jsonwebtoken'
  8. import { AxiosError } from 'axios'
  9. import { Logger } from 'winston'
  10. import { TYPES } from '../Bootstrap/Types'
  11. import { CrossServiceTokenCacheInterface } from '../Service/Cache/CrossServiceTokenCacheInterface'
  12. import { ServiceProxyInterface } from '../Service/Http/ServiceProxyInterface'
  13. @injectable()
  14. export class AuthMiddleware extends BaseMiddleware {
  15. constructor(
  16. @inject(TYPES.ServiceProxy) private serviceProxy: ServiceProxyInterface,
  17. @inject(TYPES.AUTH_JWT_SECRET) private jwtSecret: string,
  18. @inject(TYPES.CROSS_SERVICE_TOKEN_CACHE_TTL) private crossServiceTokenCacheTTL: number,
  19. @inject(TYPES.CrossServiceTokenCache) private crossServiceTokenCache: CrossServiceTokenCacheInterface,
  20. @inject(TYPES.Timer) private timer: TimerInterface,
  21. @inject(TYPES.Logger) private logger: Logger,
  22. ) {
  23. super()
  24. }
  25. async handler(request: Request, response: Response, next: NextFunction): Promise<void> {
  26. const authHeaderValue = request.headers.authorization as string
  27. if (!authHeaderValue) {
  28. response.status(401).send({
  29. error: {
  30. tag: 'invalid-auth',
  31. message: 'Invalid login credentials.',
  32. },
  33. })
  34. return
  35. }
  36. try {
  37. let crossServiceTokenFetchedFromCache = true
  38. let crossServiceToken = null
  39. if (this.crossServiceTokenCacheTTL) {
  40. crossServiceToken = await this.crossServiceTokenCache.get(authHeaderValue)
  41. }
  42. if (crossServiceToken === null) {
  43. const authResponse = await this.serviceProxy.validateSession(authHeaderValue)
  44. if (authResponse.status > 200) {
  45. response.setHeader('content-type', authResponse.headers.contentType)
  46. response.status(authResponse.status).send(authResponse.data)
  47. return
  48. }
  49. crossServiceToken = (authResponse.data as { authToken: string }).authToken
  50. crossServiceTokenFetchedFromCache = false
  51. }
  52. response.locals.authToken = crossServiceToken
  53. const decodedToken = <CrossServiceTokenData>verify(crossServiceToken, this.jwtSecret, { algorithms: ['HS256'] })
  54. response.locals.freeUser =
  55. decodedToken.roles.length === 1 &&
  56. decodedToken.roles.find((role) => role.name === RoleName.NAMES.CoreUser) !== undefined
  57. if (this.crossServiceTokenCacheTTL && !crossServiceTokenFetchedFromCache) {
  58. await this.crossServiceTokenCache.set({
  59. authorizationHeaderValue: authHeaderValue,
  60. encodedCrossServiceToken: crossServiceToken,
  61. expiresAtInSeconds: this.getCrossServiceTokenCacheExpireTimestamp(decodedToken),
  62. userUuid: decodedToken.user.uuid,
  63. })
  64. }
  65. response.locals.user = decodedToken.user
  66. response.locals.roles = decodedToken.roles
  67. } catch (error) {
  68. const errorMessage = (error as AxiosError).isAxiosError
  69. ? JSON.stringify((error as AxiosError).response?.data)
  70. : (error as Error).message
  71. this.logger.error(`Could not pass the request to sessions/validate on underlying service: ${errorMessage}`)
  72. this.logger.debug('Response error: %O', (error as AxiosError).response ?? error)
  73. if ((error as AxiosError).response?.headers['content-type']) {
  74. response.setHeader('content-type', (error as AxiosError).response?.headers['content-type'] as string)
  75. }
  76. const errorCode =
  77. (error as AxiosError).isAxiosError && !isNaN(+((error as AxiosError).code as string))
  78. ? +((error as AxiosError).code as string)
  79. : 500
  80. response.status(errorCode).send(errorMessage)
  81. return
  82. }
  83. return next()
  84. }
  85. private getCrossServiceTokenCacheExpireTimestamp(token: CrossServiceTokenData): number {
  86. const crossServiceTokenDefaultCacheExpiration = this.timer.getTimestampInSeconds() + this.crossServiceTokenCacheTTL
  87. if (token.session === undefined) {
  88. return crossServiceTokenDefaultCacheExpiration
  89. }
  90. const sessionAccessExpiration = this.timer.convertStringDateToSeconds(token.session.access_expiration)
  91. const sessionRefreshExpiration = this.timer.convertStringDateToSeconds(token.session.refresh_expiration)
  92. return Math.min(crossServiceTokenDefaultCacheExpiration, sessionAccessExpiration, sessionRefreshExpiration)
  93. }
  94. }