Browse Source

feat(auth): add creating cross service token in exchange for web socket connection token

Karol Sójko 2 years ago
parent
commit
965ae79414

+ 2 - 0
packages/api-gateway/src/Bootstrap/Container.ts

@@ -23,6 +23,7 @@ import { SubscriptionTokenAuthMiddleware } from '../Controller/SubscriptionToken
 import { StatisticsMiddleware } from '../Controller/StatisticsMiddleware'
 import { StatisticsMiddleware } from '../Controller/StatisticsMiddleware'
 import { CrossServiceTokenCacheInterface } from '../Service/Cache/CrossServiceTokenCacheInterface'
 import { CrossServiceTokenCacheInterface } from '../Service/Cache/CrossServiceTokenCacheInterface'
 import { RedisCrossServiceTokenCache } from '../Infra/Redis/RedisCrossServiceTokenCache'
 import { RedisCrossServiceTokenCache } from '../Infra/Redis/RedisCrossServiceTokenCache'
+import { WebSocketAuthMiddleware } from '../Controller/WebSocketAuthMiddleware'
 
 
 // eslint-disable-next-line @typescript-eslint/no-var-requires
 // eslint-disable-next-line @typescript-eslint/no-var-requires
 const newrelicFormatter = require('@newrelic/winston-enricher')
 const newrelicFormatter = require('@newrelic/winston-enricher')
@@ -85,6 +86,7 @@ export class ContainerConfigLoader {
 
 
     // Middleware
     // Middleware
     container.bind<AuthMiddleware>(TYPES.AuthMiddleware).to(AuthMiddleware)
     container.bind<AuthMiddleware>(TYPES.AuthMiddleware).to(AuthMiddleware)
+    container.bind<WebSocketAuthMiddleware>(TYPES.WebSocketAuthMiddleware).to(WebSocketAuthMiddleware)
     container
     container
       .bind<SubscriptionTokenAuthMiddleware>(TYPES.SubscriptionTokenAuthMiddleware)
       .bind<SubscriptionTokenAuthMiddleware>(TYPES.SubscriptionTokenAuthMiddleware)
       .to(SubscriptionTokenAuthMiddleware)
       .to(SubscriptionTokenAuthMiddleware)

+ 1 - 0
packages/api-gateway/src/Bootstrap/Types.ts

@@ -18,6 +18,7 @@ const TYPES = {
   // Middleware
   // Middleware
   StatisticsMiddleware: Symbol.for('StatisticsMiddleware'),
   StatisticsMiddleware: Symbol.for('StatisticsMiddleware'),
   AuthMiddleware: Symbol.for('AuthMiddleware'),
   AuthMiddleware: Symbol.for('AuthMiddleware'),
+  WebSocketAuthMiddleware: Symbol.for('WebSocketAuthMiddleware'),
   SubscriptionTokenAuthMiddleware: Symbol.for('SubscriptionTokenAuthMiddleware'),
   SubscriptionTokenAuthMiddleware: Symbol.for('SubscriptionTokenAuthMiddleware'),
   // Services
   // Services
   HTTPService: Symbol.for('HTTPService'),
   HTTPService: Symbol.for('HTTPService'),

+ 95 - 0
packages/api-gateway/src/Controller/WebSocketAuthMiddleware.ts

@@ -0,0 +1,95 @@
+import { CrossServiceTokenData } from '@standardnotes/security'
+import { RoleName } from '@standardnotes/common'
+import { NextFunction, Request, Response } from 'express'
+import { inject, injectable } from 'inversify'
+import { BaseMiddleware } from 'inversify-express-utils'
+import { verify } from 'jsonwebtoken'
+import { AxiosError, AxiosInstance } from 'axios'
+import { Logger } from 'winston'
+
+import TYPES from '../Bootstrap/Types'
+
+@injectable()
+export class WebSocketAuthMiddleware extends BaseMiddleware {
+  constructor(
+    @inject(TYPES.HTTPClient) private httpClient: AxiosInstance,
+    @inject(TYPES.AUTH_SERVER_URL) private authServerUrl: string,
+    @inject(TYPES.AUTH_JWT_SECRET) private jwtSecret: string,
+    @inject(TYPES.Logger) private logger: Logger,
+  ) {
+    super()
+  }
+
+  async handler(request: Request, response: Response, next: NextFunction): Promise<void> {
+    const authHeaderValue = request.headers.authorization as string
+
+    if (!authHeaderValue) {
+      response.status(401).send({
+        error: {
+          tag: 'invalid-auth',
+          message: 'Invalid login credentials.',
+        },
+      })
+
+      return
+    }
+
+    try {
+      const authResponse = await this.httpClient.request({
+        method: 'POST',
+        headers: {
+          Authorization: authHeaderValue,
+          Accept: 'application/json',
+        },
+        validateStatus: (status: number) => {
+          return status >= 200 && status < 500
+        },
+        url: `${this.authServerUrl}/sockets/tokens/validate`,
+      })
+
+      if (authResponse.status > 200) {
+        response.setHeader('content-type', authResponse.headers['content-type'])
+        response.status(authResponse.status).send(authResponse.data)
+
+        return
+      }
+
+      const crossServiceToken = authResponse.data.authToken
+
+      response.locals.authToken = crossServiceToken
+
+      const decodedToken = <CrossServiceTokenData>verify(crossServiceToken, this.jwtSecret, { algorithms: ['HS256'] })
+
+      response.locals.freeUser =
+        decodedToken.roles.length === 1 &&
+        decodedToken.roles.find((role) => role.name === RoleName.CoreUser) !== undefined
+      response.locals.userUuid = decodedToken.user.uuid
+      response.locals.roles = decodedToken.roles
+    } catch (error) {
+      const errorMessage = (error as AxiosError).isAxiosError
+        ? JSON.stringify((error as AxiosError).response?.data)
+        : (error as Error).message
+
+      this.logger.error(
+        `Could not pass the request to ${this.authServerUrl}/sockets/tokens/validate on underlying service: ${errorMessage}`,
+      )
+
+      this.logger.debug('Response error: %O', (error as AxiosError).response ?? error)
+
+      if ((error as AxiosError).response?.headers['content-type']) {
+        response.setHeader('content-type', (error as AxiosError).response?.headers['content-type'] as string)
+      }
+
+      const errorCode =
+        (error as AxiosError).isAxiosError && !isNaN(+((error as AxiosError).code as string))
+          ? +((error as AxiosError).code as string)
+          : 500
+
+      response.status(errorCode).send(errorMessage)
+
+      return
+    }
+
+    return next()
+  }
+}

+ 1 - 1
packages/api-gateway/src/Controller/v1/WebSocketsController.ts

@@ -20,7 +20,7 @@ export class WebSocketsController extends BaseHttpController {
     await this.httpService.callAuthServer(request, response, 'sockets/tokens', request.body)
     await this.httpService.callAuthServer(request, response, 'sockets/tokens', request.body)
   }
   }
 
 
-  @httpPost('/', TYPES.AuthMiddleware)
+  @httpPost('/', TYPES.WebSocketAuthMiddleware)
   async createWebSocketConnection(request: Request, response: Response): Promise<void> {
   async createWebSocketConnection(request: Request, response: Response): Promise<void> {
     if (!request.headers.connectionid) {
     if (!request.headers.connectionid) {
       this.logger.error('Could not create a websocket connection. Missing connection id header.')
       this.logger.error('Could not create a websocket connection. Missing connection id header.')

+ 7 - 0
packages/auth/src/Bootstrap/Container.ts

@@ -212,6 +212,7 @@ import { SubscriptionInvitesController } from '../Controller/SubscriptionInvites
 import { CreateWebSocketConnectionToken } from '../Domain/UseCase/CreateWebSocketConnectionToken/CreateWebSocketConnectionToken'
 import { CreateWebSocketConnectionToken } from '../Domain/UseCase/CreateWebSocketConnectionToken/CreateWebSocketConnectionToken'
 import { WebSocketsController } from '../Controller/WebSocketsController'
 import { WebSocketsController } from '../Controller/WebSocketsController'
 import { WebSocketServerInterface } from '@standardnotes/api'
 import { WebSocketServerInterface } from '@standardnotes/api'
+import { CreateCrossServiceToken } from '../Domain/UseCase/CreateCrossServiceToken/CreateCrossServiceToken'
 
 
 // eslint-disable-next-line @typescript-eslint/no-var-requires
 // eslint-disable-next-line @typescript-eslint/no-var-requires
 const newrelicFormatter = require('@newrelic/winston-enricher')
 const newrelicFormatter = require('@newrelic/winston-enricher')
@@ -462,6 +463,7 @@ export class ContainerConfigLoader {
     container
     container
       .bind<CreateWebSocketConnectionToken>(TYPES.CreateWebSocketConnectionToken)
       .bind<CreateWebSocketConnectionToken>(TYPES.CreateWebSocketConnectionToken)
       .to(CreateWebSocketConnectionToken)
       .to(CreateWebSocketConnectionToken)
+    container.bind<CreateCrossServiceToken>(TYPES.CreateCrossServiceToken).to(CreateCrossServiceToken)
 
 
     // Handlers
     // Handlers
     container.bind<UserRegisteredEventHandler>(TYPES.UserRegisteredEventHandler).to(UserRegisteredEventHandler)
     container.bind<UserRegisteredEventHandler>(TYPES.UserRegisteredEventHandler).to(UserRegisteredEventHandler)
@@ -534,6 +536,11 @@ export class ContainerConfigLoader {
     container
     container
       .bind<TokenDecoderInterface<OfflineUserTokenData>>(TYPES.OfflineUserTokenDecoder)
       .bind<TokenDecoderInterface<OfflineUserTokenData>>(TYPES.OfflineUserTokenDecoder)
       .toConstantValue(new TokenDecoder<OfflineUserTokenData>(container.get(TYPES.AUTH_JWT_SECRET)))
       .toConstantValue(new TokenDecoder<OfflineUserTokenData>(container.get(TYPES.AUTH_JWT_SECRET)))
+    container
+      .bind<TokenDecoderInterface<WebSocketConnectionTokenData>>(TYPES.WebSocketConnectionTokenDecoder)
+      .toConstantValue(
+        new TokenDecoder<WebSocketConnectionTokenData>(container.get(TYPES.WEB_SOCKET_CONNECTION_TOKEN_SECRET)),
+      )
     container
     container
       .bind<TokenEncoderInterface<OfflineUserTokenData>>(TYPES.OfflineUserTokenEncoder)
       .bind<TokenEncoderInterface<OfflineUserTokenData>>(TYPES.OfflineUserTokenEncoder)
       .toConstantValue(new TokenEncoder<OfflineUserTokenData>(container.get(TYPES.AUTH_JWT_SECRET)))
       .toConstantValue(new TokenEncoder<OfflineUserTokenData>(container.get(TYPES.AUTH_JWT_SECRET)))

+ 2 - 0
packages/auth/src/Bootstrap/Types.ts

@@ -129,6 +129,7 @@ const TYPES = {
   GetUserAnalyticsId: Symbol.for('GetUserAnalyticsId'),
   GetUserAnalyticsId: Symbol.for('GetUserAnalyticsId'),
   VerifyPredicate: Symbol.for('VerifyPredicate'),
   VerifyPredicate: Symbol.for('VerifyPredicate'),
   CreateWebSocketConnectionToken: Symbol.for('CreateWebSocketConnectionToken'),
   CreateWebSocketConnectionToken: Symbol.for('CreateWebSocketConnectionToken'),
+  CreateCrossServiceToken: Symbol.for('CreateCrossServiceToken'),
   // Handlers
   // Handlers
   UserRegisteredEventHandler: Symbol.for('UserRegisteredEventHandler'),
   UserRegisteredEventHandler: Symbol.for('UserRegisteredEventHandler'),
   AccountDeletionRequestedEventHandler: Symbol.for('AccountDeletionRequestedEventHandler'),
   AccountDeletionRequestedEventHandler: Symbol.for('AccountDeletionRequestedEventHandler'),
@@ -171,6 +172,7 @@ const TYPES = {
   SessionTokenEncoder: Symbol.for('SessionTokenEncoder'),
   SessionTokenEncoder: Symbol.for('SessionTokenEncoder'),
   ValetTokenEncoder: Symbol.for('ValetTokenEncoder'),
   ValetTokenEncoder: Symbol.for('ValetTokenEncoder'),
   WebSocketConnectionTokenEncoder: Symbol.for('WebSocketConnectionTokenEncoder'),
   WebSocketConnectionTokenEncoder: Symbol.for('WebSocketConnectionTokenEncoder'),
+  WebSocketConnectionTokenDecoder: Symbol.for('WebSocketConnectionTokenDecoder'),
   AuthenticationMethodResolver: Symbol.for('AuthenticationMethodResolver'),
   AuthenticationMethodResolver: Symbol.for('AuthenticationMethodResolver'),
   DomainEventPublisher: Symbol.for('DomainEventPublisher'),
   DomainEventPublisher: Symbol.for('DomainEventPublisher'),
   DomainEventSubscriberFactory: Symbol.for('DomainEventSubscriberFactory'),
   DomainEventSubscriberFactory: Symbol.for('DomainEventSubscriberFactory'),

+ 5 - 102
packages/auth/src/Controller/SessionsController.spec.ts

@@ -9,43 +9,25 @@ import { ProjectorInterface } from '../Projection/ProjectorInterface'
 import { GetActiveSessionsForUser } from '../Domain/UseCase/GetActiveSessionsForUser'
 import { GetActiveSessionsForUser } from '../Domain/UseCase/GetActiveSessionsForUser'
 import { AuthenticateRequest } from '../Domain/UseCase/AuthenticateRequest'
 import { AuthenticateRequest } from '../Domain/UseCase/AuthenticateRequest'
 import { User } from '../Domain/User/User'
 import { User } from '../Domain/User/User'
-import { Role } from '../Domain/Role/Role'
-import { CrossServiceTokenData, TokenEncoderInterface } from '@standardnotes/security'
-import { GetUserAnalyticsId } from '../Domain/UseCase/GetUserAnalyticsId/GetUserAnalyticsId'
+import { CreateCrossServiceToken } from '../Domain/UseCase/CreateCrossServiceToken/CreateCrossServiceToken'
 
 
 describe('SessionsController', () => {
 describe('SessionsController', () => {
   let getActiveSessionsForUser: GetActiveSessionsForUser
   let getActiveSessionsForUser: GetActiveSessionsForUser
   let authenticateRequest: AuthenticateRequest
   let authenticateRequest: AuthenticateRequest
-  let userProjector: ProjectorInterface<User>
-  let tokenEncoder: TokenEncoderInterface<CrossServiceTokenData>
-  const jwtTTL = 60
   let sessionProjector: ProjectorInterface<Session>
   let sessionProjector: ProjectorInterface<Session>
-  let roleProjector: ProjectorInterface<Role>
   let session: Session
   let session: Session
   let request: express.Request
   let request: express.Request
   let response: express.Response
   let response: express.Response
   let user: User
   let user: User
-  let role: Role
-  let getUserAnalyticsId: GetUserAnalyticsId
+  let createCrossServiceToken: CreateCrossServiceToken
 
 
   const createController = () =>
   const createController = () =>
-    new SessionsController(
-      getActiveSessionsForUser,
-      authenticateRequest,
-      userProjector,
-      sessionProjector,
-      roleProjector,
-      tokenEncoder,
-      getUserAnalyticsId,
-      true,
-      jwtTTL,
-    )
+    new SessionsController(getActiveSessionsForUser, authenticateRequest, sessionProjector, createCrossServiceToken)
 
 
   beforeEach(() => {
   beforeEach(() => {
     session = {} as jest.Mocked<Session>
     session = {} as jest.Mocked<Session>
 
 
     user = {} as jest.Mocked<User>
     user = {} as jest.Mocked<User>
-    user.roles = Promise.resolve([role])
 
 
     getActiveSessionsForUser = {} as jest.Mocked<GetActiveSessionsForUser>
     getActiveSessionsForUser = {} as jest.Mocked<GetActiveSessionsForUser>
     getActiveSessionsForUser.execute = jest.fn().mockReturnValue({ sessions: [session] })
     getActiveSessionsForUser.execute = jest.fn().mockReturnValue({ sessions: [session] })
@@ -53,21 +35,11 @@ describe('SessionsController', () => {
     authenticateRequest = {} as jest.Mocked<AuthenticateRequest>
     authenticateRequest = {} as jest.Mocked<AuthenticateRequest>
     authenticateRequest.execute = jest.fn()
     authenticateRequest.execute = jest.fn()
 
 
-    userProjector = {} as jest.Mocked<ProjectorInterface<User>>
-    userProjector.projectSimple = jest.fn().mockReturnValue({ bar: 'baz' })
-
-    roleProjector = {} as jest.Mocked<ProjectorInterface<Role>>
-    roleProjector.projectSimple = jest.fn().mockReturnValue({ name: 'role1', uuid: '1-3-4' })
-
     sessionProjector = {} as jest.Mocked<ProjectorInterface<Session>>
     sessionProjector = {} as jest.Mocked<ProjectorInterface<Session>>
     sessionProjector.projectCustom = jest.fn().mockReturnValue({ foo: 'bar' })
     sessionProjector.projectCustom = jest.fn().mockReturnValue({ foo: 'bar' })
-    sessionProjector.projectSimple = jest.fn().mockReturnValue({ test: 'test' })
-
-    tokenEncoder = {} as jest.Mocked<TokenEncoderInterface<CrossServiceTokenData>>
-    tokenEncoder.encodeExpirableToken = jest.fn().mockReturnValue('foobar')
 
 
-    getUserAnalyticsId = {} as jest.Mocked<GetUserAnalyticsId>
-    getUserAnalyticsId.execute = jest.fn().mockReturnValue({ analyticsId: 123 })
+    createCrossServiceToken = {} as jest.Mocked<CreateCrossServiceToken>
+    createCrossServiceToken.execute = jest.fn().mockReturnValue({ token: 'foobar' })
 
 
     request = {
     request = {
       params: {},
       params: {},
@@ -114,75 +86,6 @@ describe('SessionsController', () => {
     const httpResponseContent = await result.content.readAsStringAsync()
     const httpResponseContent = await result.content.readAsStringAsync()
     const httpResponseJSON = JSON.parse(httpResponseContent)
     const httpResponseJSON = JSON.parse(httpResponseContent)
 
 
-    expect(tokenEncoder.encodeExpirableToken).toHaveBeenCalledWith(
-      {
-        analyticsId: 123,
-        roles: [
-          {
-            name: 'role1',
-            uuid: '1-3-4',
-          },
-        ],
-        session: {
-          test: 'test',
-        },
-        user: {
-          bar: 'baz',
-        },
-      },
-      60,
-    )
-
-    expect(httpResponseJSON.authToken).toEqual('foobar')
-  })
-
-  it('should validate a session from an incoming request - disabled analytics', async () => {
-    authenticateRequest.execute = jest.fn().mockReturnValue({
-      success: true,
-      user,
-      session,
-    })
-
-    request.headers.authorization = 'test'
-
-    const controller = new SessionsController(
-      getActiveSessionsForUser,
-      authenticateRequest,
-      userProjector,
-      sessionProjector,
-      roleProjector,
-      tokenEncoder,
-      getUserAnalyticsId,
-      false,
-      jwtTTL,
-    )
-
-    const httpResponse = await controller.validate(request)
-
-    expect(httpResponse).toBeInstanceOf(results.JsonResult)
-
-    const result = await httpResponse.executeAsync()
-    const httpResponseContent = await result.content.readAsStringAsync()
-    const httpResponseJSON = JSON.parse(httpResponseContent)
-
-    expect(tokenEncoder.encodeExpirableToken).toHaveBeenCalledWith(
-      {
-        roles: [
-          {
-            name: 'role1',
-            uuid: '1-3-4',
-          },
-        ],
-        session: {
-          test: 'test',
-        },
-        user: {
-          bar: 'baz',
-        },
-      },
-      60,
-    )
-
     expect(httpResponseJSON.authToken).toEqual('foobar')
     expect(httpResponseJSON.authToken).toEqual('foobar')
   })
   })
 
 

+ 7 - 60
packages/auth/src/Controller/SessionsController.ts

@@ -12,26 +12,18 @@ import TYPES from '../Bootstrap/Types'
 import { Session } from '../Domain/Session/Session'
 import { Session } from '../Domain/Session/Session'
 import { AuthenticateRequest } from '../Domain/UseCase/AuthenticateRequest'
 import { AuthenticateRequest } from '../Domain/UseCase/AuthenticateRequest'
 import { GetActiveSessionsForUser } from '../Domain/UseCase/GetActiveSessionsForUser'
 import { GetActiveSessionsForUser } from '../Domain/UseCase/GetActiveSessionsForUser'
-import { Role } from '../Domain/Role/Role'
 import { User } from '../Domain/User/User'
 import { User } from '../Domain/User/User'
 import { ProjectorInterface } from '../Projection/ProjectorInterface'
 import { ProjectorInterface } from '../Projection/ProjectorInterface'
 import { SessionProjector } from '../Projection/SessionProjector'
 import { SessionProjector } from '../Projection/SessionProjector'
-import { CrossServiceTokenData, TokenEncoderInterface } from '@standardnotes/security'
-import { RoleName } from '@standardnotes/common'
-import { GetUserAnalyticsId } from '../Domain/UseCase/GetUserAnalyticsId/GetUserAnalyticsId'
+import { CreateCrossServiceToken } from '../Domain/UseCase/CreateCrossServiceToken/CreateCrossServiceToken'
 
 
 @controller('/sessions')
 @controller('/sessions')
 export class SessionsController extends BaseHttpController {
 export class SessionsController extends BaseHttpController {
   constructor(
   constructor(
     @inject(TYPES.GetActiveSessionsForUser) private getActiveSessionsForUser: GetActiveSessionsForUser,
     @inject(TYPES.GetActiveSessionsForUser) private getActiveSessionsForUser: GetActiveSessionsForUser,
     @inject(TYPES.AuthenticateRequest) private authenticateRequest: AuthenticateRequest,
     @inject(TYPES.AuthenticateRequest) private authenticateRequest: AuthenticateRequest,
-    @inject(TYPES.UserProjector) private userProjector: ProjectorInterface<User>,
     @inject(TYPES.SessionProjector) private sessionProjector: ProjectorInterface<Session>,
     @inject(TYPES.SessionProjector) private sessionProjector: ProjectorInterface<Session>,
-    @inject(TYPES.RoleProjector) private roleProjector: ProjectorInterface<Role>,
-    @inject(TYPES.CrossServiceTokenEncoder) private tokenEncoder: TokenEncoderInterface<CrossServiceTokenData>,
-    @inject(TYPES.GetUserAnalyticsId) private getUserAnalyticsId: GetUserAnalyticsId,
-    @inject(TYPES.ANALYTICS_ENABLED) private analyticsEnabled: boolean,
-    @inject(TYPES.AUTH_JWT_TTL) private jwtTTL: number,
+    @inject(TYPES.CreateCrossServiceToken) private createCrossServiceToken: CreateCrossServiceToken,
   ) {
   ) {
     super()
     super()
   }
   }
@@ -56,25 +48,12 @@ export class SessionsController extends BaseHttpController {
 
 
     const user = authenticateRequestResponse.user as User
     const user = authenticateRequestResponse.user as User
 
 
-    const roles = await user.roles
-
-    const authTokenData: CrossServiceTokenData = {
-      user: this.projectUser(user),
-      roles: this.projectRoles(roles),
-    }
-
-    if (this.analyticsEnabled) {
-      const { analyticsId } = await this.getUserAnalyticsId.execute({ userUuid: user.uuid })
-      authTokenData.analyticsId = analyticsId
-    }
-
-    if (authenticateRequestResponse.session !== undefined) {
-      authTokenData.session = this.projectSession(authenticateRequestResponse.session)
-    }
-
-    const authToken = this.tokenEncoder.encodeExpirableToken(authTokenData, this.jwtTTL)
+    const result = await this.createCrossServiceToken.execute({
+      user,
+      session: authenticateRequestResponse.session,
+    })
 
 
-    return this.json({ authToken })
+    return this.json({ authToken: result.token })
   }
   }
 
 
   @httpGet('/', TYPES.AuthMiddleware, TYPES.SessionMiddleware)
   @httpGet('/', TYPES.AuthMiddleware, TYPES.SessionMiddleware)
@@ -93,36 +72,4 @@ export class SessionsController extends BaseHttpController {
       ),
       ),
     )
     )
   }
   }
-
-  private projectUser(user: User): { uuid: string; email: string } {
-    return <{ uuid: string; email: string }>this.userProjector.projectSimple(user)
-  }
-
-  private projectSession(session: Session): {
-    uuid: string
-    api_version: string
-    created_at: string
-    updated_at: string
-    device_info: string
-    readonly_access: boolean
-    access_expiration: string
-    refresh_expiration: string
-  } {
-    return <
-      {
-        uuid: string
-        api_version: string
-        created_at: string
-        updated_at: string
-        device_info: string
-        readonly_access: boolean
-        access_expiration: string
-        refresh_expiration: string
-      }
-    >this.sessionProjector.projectSimple(session)
-  }
-
-  private projectRoles(roles: Array<Role>): Array<{ uuid: string; name: RoleName }> {
-    return roles.map((role) => <{ uuid: string; name: RoleName }>this.roleProjector.projectSimple(role))
-  }
 }
 }

+ 173 - 0
packages/auth/src/Domain/UseCase/CreateCrossServiceToken/CreateCrossServiceToken.spec.ts

@@ -0,0 +1,173 @@
+import 'reflect-metadata'
+
+import { TokenEncoderInterface, CrossServiceTokenData } from '@standardnotes/security'
+import { ProjectorInterface } from '../../../Projection/ProjectorInterface'
+import { Session } from '../../Session/Session'
+import { User } from '../../User/User'
+import { Role } from '../../Role/Role'
+import { UserRepositoryInterface } from '../../User/UserRepositoryInterface'
+import { GetUserAnalyticsId } from '../GetUserAnalyticsId/GetUserAnalyticsId'
+
+import { CreateCrossServiceToken } from './CreateCrossServiceToken'
+
+describe('CreateCrossServiceToken', () => {
+  let userProjector: ProjectorInterface<User>
+  let sessionProjector: ProjectorInterface<Session>
+  let roleProjector: ProjectorInterface<Role>
+  let tokenEncoder: TokenEncoderInterface<CrossServiceTokenData>
+  let getUserAnalyticsId: GetUserAnalyticsId
+  let userRepository: UserRepositoryInterface
+  const jwtTTL = 60
+
+  let session: Session
+  let user: User
+  let role: Role
+
+  const createUseCase = (analyticsEnabled = true) =>
+    new CreateCrossServiceToken(
+      userProjector,
+      sessionProjector,
+      roleProjector,
+      tokenEncoder,
+      getUserAnalyticsId,
+      userRepository,
+      analyticsEnabled,
+      jwtTTL,
+    )
+
+  beforeEach(() => {
+    session = {} as jest.Mocked<Session>
+
+    user = {} as jest.Mocked<User>
+    user.roles = Promise.resolve([role])
+
+    userProjector = {} as jest.Mocked<ProjectorInterface<User>>
+    userProjector.projectSimple = jest.fn().mockReturnValue({ bar: 'baz' })
+
+    roleProjector = {} as jest.Mocked<ProjectorInterface<Role>>
+    roleProjector.projectSimple = jest.fn().mockReturnValue({ name: 'role1', uuid: '1-3-4' })
+
+    sessionProjector = {} as jest.Mocked<ProjectorInterface<Session>>
+    sessionProjector.projectCustom = jest.fn().mockReturnValue({ foo: 'bar' })
+    sessionProjector.projectSimple = jest.fn().mockReturnValue({ test: 'test' })
+
+    tokenEncoder = {} as jest.Mocked<TokenEncoderInterface<CrossServiceTokenData>>
+    tokenEncoder.encodeExpirableToken = jest.fn().mockReturnValue('foobar')
+
+    getUserAnalyticsId = {} as jest.Mocked<GetUserAnalyticsId>
+    getUserAnalyticsId.execute = jest.fn().mockReturnValue({ analyticsId: 123 })
+
+    userRepository = {} as jest.Mocked<UserRepositoryInterface>
+    userRepository.findOneByUuid = jest.fn().mockReturnValue(user)
+  })
+
+  it('should create a cross service token for user', async () => {
+    await createUseCase().execute({
+      user,
+      session,
+    })
+
+    expect(tokenEncoder.encodeExpirableToken).toHaveBeenCalledWith(
+      {
+        analyticsId: 123,
+        roles: [
+          {
+            name: 'role1',
+            uuid: '1-3-4',
+          },
+        ],
+        session: {
+          test: 'test',
+        },
+        user: {
+          bar: 'baz',
+        },
+      },
+      60,
+    )
+  })
+
+  it('should create a cross service token for user - analytics disabled', async () => {
+    await createUseCase(false).execute({
+      user,
+      session,
+    })
+
+    expect(tokenEncoder.encodeExpirableToken).toHaveBeenCalledWith(
+      {
+        roles: [
+          {
+            name: 'role1',
+            uuid: '1-3-4',
+          },
+        ],
+        session: {
+          test: 'test',
+        },
+        user: {
+          bar: 'baz',
+        },
+      },
+      60,
+    )
+  })
+
+  it('should create a cross service token for user without a session', async () => {
+    await createUseCase().execute({
+      user,
+    })
+
+    expect(tokenEncoder.encodeExpirableToken).toHaveBeenCalledWith(
+      {
+        analyticsId: 123,
+        roles: [
+          {
+            name: 'role1',
+            uuid: '1-3-4',
+          },
+        ],
+        user: {
+          bar: 'baz',
+        },
+      },
+      60,
+    )
+  })
+
+  it('should create a cross service token for user by user uuid', async () => {
+    await createUseCase().execute({
+      userUuid: '1-2-3',
+    })
+
+    expect(tokenEncoder.encodeExpirableToken).toHaveBeenCalledWith(
+      {
+        analyticsId: 123,
+        roles: [
+          {
+            name: 'role1',
+            uuid: '1-3-4',
+          },
+        ],
+        user: {
+          bar: 'baz',
+        },
+      },
+      60,
+    )
+  })
+
+  it('should throw an error if user does not exist', async () => {
+    userRepository.findOneByUuid = jest.fn().mockReturnValue(null)
+
+    let caughtError = null
+    try {
+      await createUseCase().execute({
+        userUuid: '1-2-3',
+      })
+    } catch (error) {
+      caughtError = error
+    }
+
+    expect(caughtError).not.toBeNull()
+  })
+})

+ 91 - 0
packages/auth/src/Domain/UseCase/CreateCrossServiceToken/CreateCrossServiceToken.ts

@@ -0,0 +1,91 @@
+import { RoleName } from '@standardnotes/common'
+import { TokenEncoderInterface, CrossServiceTokenData } from '@standardnotes/security'
+
+import { inject, injectable } from 'inversify'
+import TYPES from '../../../Bootstrap/Types'
+import { ProjectorInterface } from '../../../Projection/ProjectorInterface'
+import { Role } from '../../Role/Role'
+import { Session } from '../../Session/Session'
+import { User } from '../../User/User'
+import { UserRepositoryInterface } from '../../User/UserRepositoryInterface'
+import { GetUserAnalyticsId } from '../GetUserAnalyticsId/GetUserAnalyticsId'
+import { UseCaseInterface } from '../UseCaseInterface'
+import { CreateCrossServiceTokenDTO } from './CreateCrossServiceTokenDTO'
+import { CreateCrossServiceTokenResponse } from './CreateCrossServiceTokenResponse'
+
+@injectable()
+export class CreateCrossServiceToken implements UseCaseInterface {
+  constructor(
+    @inject(TYPES.UserProjector) private userProjector: ProjectorInterface<User>,
+    @inject(TYPES.SessionProjector) private sessionProjector: ProjectorInterface<Session>,
+    @inject(TYPES.RoleProjector) private roleProjector: ProjectorInterface<Role>,
+    @inject(TYPES.CrossServiceTokenEncoder) private tokenEncoder: TokenEncoderInterface<CrossServiceTokenData>,
+    @inject(TYPES.GetUserAnalyticsId) private getUserAnalyticsId: GetUserAnalyticsId,
+    @inject(TYPES.UserRepository) private userRepository: UserRepositoryInterface,
+    @inject(TYPES.ANALYTICS_ENABLED) private analyticsEnabled: boolean,
+    @inject(TYPES.AUTH_JWT_TTL) private jwtTTL: number,
+  ) {}
+
+  async execute(dto: CreateCrossServiceTokenDTO): Promise<CreateCrossServiceTokenResponse> {
+    let user: User | undefined | null = dto.user
+    if (user === undefined && dto.userUuid !== undefined) {
+      user = await this.userRepository.findOneByUuid(dto.userUuid)
+    }
+
+    if (!user) {
+      throw new Error(`Could not find user with uuid ${dto.userUuid}`)
+    }
+
+    const roles = await user.roles
+
+    const authTokenData: CrossServiceTokenData = {
+      user: this.projectUser(user),
+      roles: this.projectRoles(roles),
+    }
+
+    if (this.analyticsEnabled) {
+      const { analyticsId } = await this.getUserAnalyticsId.execute({ userUuid: user.uuid })
+      authTokenData.analyticsId = analyticsId
+    }
+
+    if (dto.session !== undefined) {
+      authTokenData.session = this.projectSession(dto.session)
+    }
+
+    return {
+      token: this.tokenEncoder.encodeExpirableToken(authTokenData, this.jwtTTL),
+    }
+  }
+
+  private projectUser(user: User): { uuid: string; email: string } {
+    return <{ uuid: string; email: string }>this.userProjector.projectSimple(user)
+  }
+
+  private projectSession(session: Session): {
+    uuid: string
+    api_version: string
+    created_at: string
+    updated_at: string
+    device_info: string
+    readonly_access: boolean
+    access_expiration: string
+    refresh_expiration: string
+  } {
+    return <
+      {
+        uuid: string
+        api_version: string
+        created_at: string
+        updated_at: string
+        device_info: string
+        readonly_access: boolean
+        access_expiration: string
+        refresh_expiration: string
+      }
+    >this.sessionProjector.projectSimple(session)
+  }
+
+  private projectRoles(roles: Array<Role>): Array<{ uuid: string; name: RoleName }> {
+    return roles.map((role) => <{ uuid: string; name: RoleName }>this.roleProjector.projectSimple(role))
+  }
+}

+ 13 - 0
packages/auth/src/Domain/UseCase/CreateCrossServiceToken/CreateCrossServiceTokenDTO.ts

@@ -0,0 +1,13 @@
+import { Either, Uuid } from '@standardnotes/common'
+import { Session } from '../../Session/Session'
+import { User } from '../../User/User'
+
+export type CreateCrossServiceTokenDTO = Either<
+  {
+    user: User
+    session?: Session
+  },
+  {
+    userUuid: Uuid
+  }
+>

+ 3 - 0
packages/auth/src/Domain/UseCase/CreateCrossServiceToken/CreateCrossServiceTokenResponse.ts

@@ -0,0 +1,3 @@
+export type CreateCrossServiceTokenResponse = {
+  token: string
+}

+ 41 - 0
packages/auth/src/Infra/InversifyExpressUtils/InversifyExpressWebSocketsController.ts

@@ -1,4 +1,6 @@
 import { WebSocketServerInterface } from '@standardnotes/api'
 import { WebSocketServerInterface } from '@standardnotes/api'
+import { ErrorTag } from '@standardnotes/common'
+import { TokenDecoderInterface, WebSocketConnectionTokenData } from '@standardnotes/security'
 import { Request, Response } from 'express'
 import { Request, Response } from 'express'
 import { inject } from 'inversify'
 import { inject } from 'inversify'
 import {
 import {
@@ -11,6 +13,7 @@ import {
 } from 'inversify-express-utils'
 } from 'inversify-express-utils'
 import TYPES from '../../Bootstrap/Types'
 import TYPES from '../../Bootstrap/Types'
 import { AddWebSocketsConnection } from '../../Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnection'
 import { AddWebSocketsConnection } from '../../Domain/UseCase/AddWebSocketsConnection/AddWebSocketsConnection'
+import { CreateCrossServiceToken } from '../../Domain/UseCase/CreateCrossServiceToken/CreateCrossServiceToken'
 import { RemoveWebSocketsConnection } from '../../Domain/UseCase/RemoveWebSocketsConnection/RemoveWebSocketsConnection'
 import { RemoveWebSocketsConnection } from '../../Domain/UseCase/RemoveWebSocketsConnection/RemoveWebSocketsConnection'
 
 
 @controller('/sockets')
 @controller('/sockets')
@@ -18,7 +21,10 @@ export class InversifyExpressWebSocketsController extends BaseHttpController {
   constructor(
   constructor(
     @inject(TYPES.AddWebSocketsConnection) private addWebSocketsConnection: AddWebSocketsConnection,
     @inject(TYPES.AddWebSocketsConnection) private addWebSocketsConnection: AddWebSocketsConnection,
     @inject(TYPES.RemoveWebSocketsConnection) private removeWebSocketsConnection: RemoveWebSocketsConnection,
     @inject(TYPES.RemoveWebSocketsConnection) private removeWebSocketsConnection: RemoveWebSocketsConnection,
+    @inject(TYPES.CreateCrossServiceToken) private createCrossServiceToken: CreateCrossServiceToken,
     @inject(TYPES.WebSocketsController) private webSocketsController: WebSocketServerInterface,
     @inject(TYPES.WebSocketsController) private webSocketsController: WebSocketServerInterface,
+    @inject(TYPES.WebSocketConnectionTokenDecoder)
+    private tokenDecoder: TokenDecoderInterface<WebSocketConnectionTokenData>,
   ) {
   ) {
     super()
     super()
   }
   }
@@ -53,4 +59,39 @@ export class InversifyExpressWebSocketsController extends BaseHttpController {
 
 
     return this.json(result)
     return this.json(result)
   }
   }
+
+  @httpPost('/tokens/validate')
+  async validateToken(request: Request): Promise<results.JsonResult> {
+    if (!request.headers.authorization) {
+      return this.json(
+        {
+          error: {
+            tag: ErrorTag.AuthInvalid,
+            message: 'Invalid authorization token.',
+          },
+        },
+        401,
+      )
+    }
+
+    const token: WebSocketConnectionTokenData | undefined = this.tokenDecoder.decodeToken(request.headers.authorization)
+
+    if (token === undefined) {
+      return this.json(
+        {
+          error: {
+            tag: ErrorTag.AuthInvalid,
+            message: 'Invalid authorization token.',
+          },
+        },
+        401,
+      )
+    }
+
+    const result = await this.createCrossServiceToken.execute({
+      userUuid: token.userUuid,
+    })
+
+    return this.json({ authToken: result.token })
+  }
 }
 }