Browse Source

Basic websocket and auth for newt

Owen Schwartz 8 months ago
parent
commit
e5e78ff1bf

+ 80 - 0
server/auth/newt.ts

@@ -0,0 +1,80 @@
+export * from "./verifySession";
+export * from "./unauthorizedResponse";
+
+import {
+    encodeHexLowerCase,
+} from "@oslojs/encoding";
+import { sha256 } from "@oslojs/crypto/sha2";
+import { Newt, newts, newtSessions, NewtSession } from "@server/db/schema";
+import db from "@server/db";
+import { eq } from "drizzle-orm";
+import config from "@server/config";
+
+export const SESSION_COOKIE_NAME = "session";
+export const SESSION_COOKIE_EXPIRES = 1000 * 60 * 60 * 24 * 30;
+export const SECURE_COOKIES = config.server.secure_cookies;
+export const COOKIE_DOMAIN =
+    "." + new URL(config.app.base_url).hostname.split(".").slice(-2).join(".");
+
+export async function createNewtSession(
+    token: string,
+    newtId: string,
+): Promise<NewtSession> {
+    const sessionId = encodeHexLowerCase(
+        sha256(new TextEncoder().encode(token)),
+    );
+    const session: NewtSession = {
+        sessionId: sessionId,
+        newtId,
+        expiresAt: new Date(Date.now() + SESSION_COOKIE_EXPIRES).getTime(),
+    };
+    await db.insert(newtSessions).values(session);
+    return session;
+}
+
+export async function validateNewtSessionToken(
+    token: string,
+): Promise<SessionValidationResult> {
+    const sessionId = encodeHexLowerCase(
+        sha256(new TextEncoder().encode(token)),
+    );
+    const result = await db
+        .select({ newt: newts, session: newtSessions })
+        .from(newtSessions)
+        .innerJoin(newts, eq(newtSessions.newtId, newts.newtId))
+        .where(eq(newtSessions.sessionId, sessionId));
+    if (result.length < 1) {
+        return { session: null, newt: null };
+    }
+    const { newt, session } = result[0];
+    if (Date.now() >= session.expiresAt) {
+        await db
+            .delete(newtSessions)
+            .where(eq(newtSessions.sessionId, session.sessionId));
+        return { session: null, newt: null };
+    }
+    if (Date.now() >= session.expiresAt - (SESSION_COOKIE_EXPIRES / 2)) {
+        session.expiresAt = new Date(
+            Date.now() + SESSION_COOKIE_EXPIRES,
+        ).getTime();
+        await db
+            .update(newtSessions)
+            .set({
+                expiresAt: session.expiresAt,
+            })
+            .where(eq(newtSessions.sessionId, session.sessionId));
+    }
+    return { session, newt };
+}
+
+export async function invalidateNewtSession(sessionId: string): Promise<void> {
+    await db.delete(newtSessions).where(eq(newtSessions.sessionId, sessionId));
+}
+
+export async function invalidateAllNewtSessions(newtId: string): Promise<void> {
+    await db.delete(newtSessions).where(eq(newtSessions.newtId, newtId));
+}
+
+export type SessionValidationResult =
+    | { session: NewtSession; newt: Newt }
+    | { session: null; newt: null };

+ 16 - 0
server/db/schema.ts

@@ -73,6 +73,12 @@ export const users = sqliteTable("user", {
     dateCreated: text("dateCreated").notNull(),
 });
 
+export const newts = sqliteTable("newt", {
+    newtId: text("id").primaryKey(),
+    secretHash: text("secretHash").notNull(),
+    dateCreated: text("dateCreated").notNull(),
+});
+
 export const twoFactorBackupCodes = sqliteTable("twoFactorBackupCodes", {
     codeId: integer("id").primaryKey({ autoIncrement: true }),
     userId: text("userId")
@@ -89,6 +95,14 @@ export const sessions = sqliteTable("session", {
     expiresAt: integer("expiresAt").notNull(),
 });
 
+export const newtSessions = sqliteTable("newtSession", {
+    sessionId: text("id").primaryKey(),
+    newtId: text("newtId")
+        .notNull()
+        .references(() => newts.newtId, { onDelete: "cascade" }),
+    expiresAt: integer("expiresAt").notNull(),
+});
+
 export const userOrgs = sqliteTable("userOrgs", {
     userId: text("userId")
         .notNull()
@@ -227,6 +241,8 @@ export type Resource = InferSelectModel<typeof resources>;
 export type ExitNode = InferSelectModel<typeof exitNodes>;
 export type Target = InferSelectModel<typeof targets>;
 export type Session = InferSelectModel<typeof sessions>;
+export type Newt = InferSelectModel<typeof newts>;
+export type NewtSession = InferSelectModel<typeof newtSessions>;
 export type EmailVerificationCode = InferSelectModel<
     typeof emailVerificationCodes
 >;

+ 1 - 0
server/routers/auth/index.ts

@@ -12,6 +12,7 @@ export * from "./verifyTargetAccess";
 export * from "./verifyRoleAccess";
 export * from "./verifyUserAccess";
 export * from "./verifyAdmin";
+// export * from "./verifySuperUser";
 export * from "./verifyEmail";
 export * from "./requestEmailVerificationCode";
 export * from "./changePassword";

+ 115 - 0
server/routers/auth/newtGetToken.ts

@@ -0,0 +1,115 @@
+import { verify } from "@node-rs/argon2";
+import {
+    createSession,
+    generateSessionToken,
+    verifySession,
+} from "@server/auth";
+import db from "@server/db";
+import { newts } from "@server/db/schema";
+import HttpCode from "@server/types/HttpCode";
+import response from "@server/utils/response";
+import { eq } from "drizzle-orm";
+import { NextFunction, Request, Response } from "express";
+import createHttpError from "http-errors";
+import { z } from "zod";
+import { fromError } from "zod-validation-error";
+import config from "@server/config";
+import { validateNewtSessionToken } from "@server/auth/newt";
+
+export const newtGetTokenBodySchema = z.object({
+    newtId: z.string().email(),
+    secret: z.string(),
+    token: z.string().optional(),
+});
+
+export type NewtGetTokenBody = z.infer<typeof newtGetTokenBodySchema>;
+
+export async function newtGetToken(
+    req: Request,
+    res: Response,
+    next: NextFunction
+): Promise<any> {
+    const parsedBody = newtGetTokenBodySchema.safeParse(req.body);
+
+    if (!parsedBody.success) {
+        return next(
+            createHttpError(
+                HttpCode.BAD_REQUEST,
+                fromError(parsedBody.error).toString()
+            )
+        );
+    }
+
+    const { newtId, secret, token } = parsedBody.data;
+
+    try {
+        if (token) {
+            const { session, newt } = await validateNewtSessionToken(
+                token
+            );
+            if (session) {
+                return response<null>(res, {
+                    data: null,
+                    success: true,
+                    error: false,
+                    message: "Token session already valid",
+                    status: HttpCode.OK,
+                });
+            }
+        }
+
+        const existingNewtRes = await db
+            .select()
+            .from(newts)
+            .where(eq(newts.newtId, newtId));
+        if (!existingNewtRes || !existingNewtRes.length) {
+            return next(
+                createHttpError(
+                    HttpCode.BAD_REQUEST,
+                    "No newt found with that newtId"
+                )
+            );
+        }
+
+        const existingNewt = existingNewtRes[0];
+
+        const validSecret = await verify(
+            existingNewt.secretHash,
+            secret,
+            {
+                memoryCost: 19456,
+                timeCost: 2,
+                outputLen: 32,
+                parallelism: 1,
+            }
+        );
+        if (!validSecret) {
+            return next(
+                createHttpError(
+                    HttpCode.BAD_REQUEST,
+                    "Secret is incorrect"
+                )
+            );
+        }
+
+        const resToken = generateSessionToken();
+        await createSession(resToken, existingNewt.newtId);
+
+        return response<{ token: string }>(res, {
+            data: {
+                token: resToken
+            },
+            success: true,
+            error: false,
+            message: "Token created successfully",
+            status: HttpCode.OK,
+        });
+    } catch (e) {
+        return next(
+            createHttpError(
+                HttpCode.INTERNAL_SERVER_ERROR,
+                "Failed to authenticate newt"
+            )
+        );
+    }
+}

+ 1 - 1
server/routers/auth/sendEmailVerificationCode.ts

@@ -4,8 +4,8 @@ import db from "@server/db";
 import { users, emailVerificationCodes } from "@server/db/schema";
 import { eq } from "drizzle-orm";
 import { sendEmail } from "@server/emails";
-import VerifyEmail from "@server/emails/templates/VerifyEmailCode";
 import config from "@server/config";
+import VerifyEmail from "@server/emails/templates/verifyEmailCode";
 
 export async function sendEmailVerificationCode(
     email: string,

+ 0 - 1
server/routers/external.ts

@@ -19,7 +19,6 @@ import {
     verifyResourceAccess,
     verifyTargetAccess,
     verifyRoleAccess,
-    verifyUserInRole,
     verifyUserAccess,
 } from "./auth";
 import { verifyUserHasAction } from "./auth/verifyUserHasAction";

+ 125 - 103
server/routers/ws.ts

@@ -3,20 +3,25 @@ import { Server as HttpServer } from 'http';
 import { WebSocket, WebSocketServer } from 'ws';
 import { IncomingMessage } from 'http';
 import { Socket } from 'net';
+import { Newt, newts, NewtSession } from '@server/db/schema';
+import { eq } from 'drizzle-orm';
+import db from '@server/db';
+import { newtGetToken } from './auth';
+import { validateNewtSessionToken } from '@server/auth/newt';
 
 // Custom interfaces
 interface WebSocketRequest extends IncomingMessage {
-  token?: string;
+    token?: string;
 }
 
 interface AuthenticatedWebSocket extends WebSocket {
-  userId?: string;
-  isAlive?: boolean;
+    newt?: Newt;
+    isAlive?: boolean;
 }
 
 interface TokenPayload {
-  userId: string;
-  // Add other token payload properties as needed
+    newt: Newt;
+    session: NewtSession;
 }
 
 const router: Router = Router();
@@ -24,121 +29,138 @@ const wss: WebSocketServer = new WebSocketServer({ noServer: true });
 
 // Token verification middleware
 const verifyToken = async (token: string): Promise<TokenPayload | null> => {
-  try {
-    // This is where you'd implement your token verification logic
-    // For example, verify JWT, check against database, etc.
-    // Return the token payload if valid, null if invalid
-    return { userId: 'dummy-user-id' }; // Placeholder return
-  } catch (error) {
-    console.error('Token verification failed:', error);
-    return null;
-  }
+    try {
+
+        const { session, newt } = await validateNewtSessionToken(
+            token
+        );
+
+        if (!session || !newt) {
+            return null;
+        }
+
+        const existingNewt = await db
+            .select()
+            .from(newts)
+            .where(eq(newts.newtId, newt.newtId));
+
+        if (!existingNewt || !existingNewt[0]) {
+            return null;
+        }
+
+        return { newt: existingNewt[0], session };
+    } catch (error) {
+        console.error('Token verification failed:', error);
+        return null;
+    }
 };
 
 // Handle WebSocket upgrade requests
 router.get('/ws', (req: Request, res: Response) => {
-  // WebSocket upgrade will be handled by the server
-  res.status(200).send('WebSocket endpoint');
+    // WebSocket upgrade will be handled by the server
+    res.status(200).send('WebSocket endpoint');
 });
 
+router.get('/ws/auth/newtGetToken', newtGetToken);
+
 // Set up WebSocket server handling
 const handleWSUpgrade = (server: HttpServer): void => {
-  server.on('upgrade', async (request: WebSocketRequest, socket: Socket, head: Buffer) => {
-    try {
-      // Extract token from query parameters or headers
-      const token = request.url?.includes('?')
-        ? new URLSearchParams(request.url.split('?')[1]).get('token') || ''
-        : request.headers['sec-websocket-protocol'];
-
-      if (!token) {
-        socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n');
-        socket.destroy();
-        return;
-      }
-
-      // Verify the token
-      const tokenPayload = await verifyToken(token);
-      if (!tokenPayload) {
-        socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n');
-        socket.destroy();
-        return;
-      }
-
-      // Store token payload data in the request for later use
-      request.token = token;
-
-      wss.handleUpgrade(request, socket, head, (ws: AuthenticatedWebSocket) => {
-        // Attach user data to the WebSocket instance
-        ws.userId = tokenPayload.userId;
-        ws.isAlive = true;
-        wss.emit('connection', ws, request);
-      });
-    } catch (error) {
-      console.error('Upgrade error:', error);
-      socket.write('HTTP/1.1 500 Internal Server Error\r\n\r\n');
-      socket.destroy();
-    }
-  });
+    server.on('upgrade', async (request: WebSocketRequest, socket: Socket, head: Buffer) => {
+        try {
+            // Extract token from query parameters or headers
+            const token = request.url?.includes('?')
+                ? new URLSearchParams(request.url.split('?')[1]).get('token') || ''
+                : request.headers['sec-websocket-protocol'];
+
+            if (!token) {
+                socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n');
+                socket.destroy();
+                return;
+            }
+
+            // Verify the token
+            const tokenPayload = await verifyToken(token);
+            if (!tokenPayload) {
+                socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n');
+                socket.destroy();
+                return;
+            }
+
+            // Store token payload data in the request for later use
+            request.token = token;
+
+            wss.handleUpgrade(request, socket, head, (ws: AuthenticatedWebSocket) => {
+                // Attach newt data to the WebSocket instance
+                ws.newt = tokenPayload.newt;
+                ws.isAlive = true;
+                wss.emit('connection', ws, request);
+            });
+        } catch (error) {
+            console.error('Upgrade error:', error);
+            socket.write('HTTP/1.1 500 Internal Server Error\r\n\r\n');
+            socket.destroy();
+        }
+    });
 };
 
 // WebSocket message interface
 interface WSMessage {
-  type: string;
-  data: any;
+    type: string;
+    data: any;
 }
 
 // WebSocket connection handler
 wss.on('connection', (ws: AuthenticatedWebSocket, request: WebSocketRequest) => {
-  console.log(`Client connected - User ID: ${ws.userId}`);
-
-  // Set up ping-pong for connection health check
-  const pingInterval = setInterval(() => {
-    if (ws.isAlive === false) {
-      clearInterval(pingInterval);
-      return ws.terminate();
-    }
-    ws.isAlive = false;
-    ws.ping();
-  }, 30000);
-
-  // Handle pong response
-  ws.on('pong', () => {
-    ws.isAlive = true;
-  });
-
-  // Set up message handler
-  ws.on('message', (data) => {
-    try {
-      const message: WSMessage = JSON.parse(data.toString());
-      console.log('Received:', message);
-
-      // Echo the message back
-      ws.send(JSON.stringify({
-        type: 'echo',
-        data: message
-      }));
-    } catch (error) {
-      console.error('Message parsing error:', error);
-      ws.send(JSON.stringify({
-        type: 'error',
-        data: 'Invalid message format'
-      }));
-    }
-  });
-
-  // Handle client disconnect
-  ws.on('close', () => {
-    clearInterval(pingInterval);
-    console.log(`Client disconnected - User ID: ${ws.userId}`);
-  });
-
-  // Handle errors
-  ws.on('error', (error: Error) => {
-    console.error('WebSocket error:', error);
-  });
+    console.log(`Client connected - Newt ID: ${ws.newt?.newtId}`);
+
+    // Set up ping-pong for connection health check
+    const pingInterval = setInterval(() => {
+        if (ws.isAlive === false) {
+            clearInterval(pingInterval);
+            return ws.terminate();
+        }
+        ws.isAlive = false;
+        ws.ping();
+    }, 30000);
+
+    // Handle pong response
+    ws.on('pong', () => {
+        ws.isAlive = true;
+    });
+
+    // Set up message handler
+    ws.on('message', (data) => {
+        try {
+            const message: WSMessage = JSON.parse(data.toString());
+            console.log('Received:', message);
+
+            // Echo the message back
+            ws.send(JSON.stringify({
+                type: 'echo',
+                data: message
+            }));
+        } catch (error) {
+            console.error('Message parsing error:', error);
+            ws.send(JSON.stringify({
+                type: 'error',
+                data: 'Invalid message format'
+            }));
+        }
+    });
+
+    // Handle client disconnect
+    ws.on('close', () => {
+        clearInterval(pingInterval);
+        console.log(`Client disconnected - Newt ID: ${ws.newt?.newtId}`);
+    });
+
+    // Handle errors
+    ws.on('error', (error: Error) => {
+        console.error('WebSocket error:', error);
+    });
 });
 
 export {
-  router,
-  handleWSUpgrade
+    router,
+    handleWSUpgrade
 };