WebSocketAuthMiddleware.ts 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import { CrossServiceTokenData } from '@standardnotes/security'
  2. import { RoleName } from '@standardnotes/domain-core'
  3. import { NextFunction, Request, Response } from 'express'
  4. import { inject, injectable } from 'inversify'
  5. import { BaseMiddleware } from 'inversify-express-utils'
  6. import { verify } from 'jsonwebtoken'
  7. import { AxiosError, AxiosInstance } from 'axios'
  8. import { Logger } from 'winston'
  9. import { TYPES } from '../Bootstrap/Types'
  10. @injectable()
  11. export class WebSocketAuthMiddleware extends BaseMiddleware {
  12. constructor(
  13. @inject(TYPES.HTTPClient) private httpClient: AxiosInstance,
  14. @inject(TYPES.AUTH_SERVER_URL) private authServerUrl: string,
  15. @inject(TYPES.AUTH_JWT_SECRET) private jwtSecret: string,
  16. @inject(TYPES.Logger) private logger: Logger,
  17. ) {
  18. super()
  19. }
  20. async handler(request: Request, response: Response, next: NextFunction): Promise<void> {
  21. const authHeaderValue = request.headers.authorization as string
  22. if (!authHeaderValue) {
  23. response.status(401).send({
  24. error: {
  25. tag: 'invalid-auth',
  26. message: 'Invalid login credentials.',
  27. },
  28. })
  29. return
  30. }
  31. try {
  32. const authResponse = await this.httpClient.request({
  33. method: 'POST',
  34. headers: {
  35. Authorization: authHeaderValue,
  36. Accept: 'application/json',
  37. },
  38. validateStatus: (status: number) => {
  39. return status >= 200 && status < 500
  40. },
  41. url: `${this.authServerUrl}/sockets/tokens/validate`,
  42. })
  43. if (authResponse.status > 200) {
  44. response.setHeader('content-type', authResponse.headers['content-type'] as string)
  45. response.status(authResponse.status).send(authResponse.data)
  46. return
  47. }
  48. const crossServiceToken = authResponse.data.authToken
  49. response.locals.authToken = crossServiceToken
  50. const decodedToken = <CrossServiceTokenData>verify(crossServiceToken, this.jwtSecret, { algorithms: ['HS256'] })
  51. response.locals.freeUser =
  52. decodedToken.roles.length === 1 &&
  53. decodedToken.roles.find((role) => role.name === RoleName.NAMES.CoreUser) !== undefined
  54. response.locals.user = decodedToken.user
  55. response.locals.roles = decodedToken.roles
  56. } catch (error) {
  57. const errorMessage = (error as AxiosError).isAxiosError
  58. ? JSON.stringify((error as AxiosError).response?.data)
  59. : (error as Error).message
  60. this.logger.error(
  61. `Could not pass the request to ${this.authServerUrl}/sockets/tokens/validate on underlying service: ${errorMessage}`,
  62. )
  63. this.logger.debug('Response error: %O', (error as AxiosError).response ?? error)
  64. if ((error as AxiosError).response?.headers['content-type']) {
  65. response.setHeader('content-type', (error as AxiosError).response?.headers['content-type'] as string)
  66. }
  67. const errorCode =
  68. (error as AxiosError).isAxiosError && !isNaN(+((error as AxiosError).code as string))
  69. ? +((error as AxiosError).code as string)
  70. : 500
  71. response.status(errorCode).send(errorMessage)
  72. return
  73. }
  74. return next()
  75. }
  76. }