Selaa lähdekoodia

Move things around; rename to olm

Owen 4 kuukautta sitten
vanhempi
commit
e112fcba29

+ 72 - 0
server/auth/sessions/olm.ts

@@ -0,0 +1,72 @@
+import {
+    encodeHexLowerCase,
+} from "@oslojs/encoding";
+import { sha256 } from "@oslojs/crypto/sha2";
+import { Olm, olms, olmSessions, OlmSession } from "@server/db/schema";
+import db from "@server/db";
+import { eq } from "drizzle-orm";
+
+export const EXPIRES = 1000 * 60 * 60 * 24 * 30;
+
+export async function createOlmSession(
+    token: string,
+    olmId: string,
+): Promise<OlmSession> {
+    const sessionId = encodeHexLowerCase(
+        sha256(new TextEncoder().encode(token)),
+    );
+    const session: OlmSession = {
+        sessionId: sessionId,
+        olmId,
+        expiresAt: new Date(Date.now() + EXPIRES).getTime(),
+    };
+    await db.insert(olmSessions).values(session);
+    return session;
+}
+
+export async function validateOlmSessionToken(
+    token: string,
+): Promise<SessionValidationResult> {
+    const sessionId = encodeHexLowerCase(
+        sha256(new TextEncoder().encode(token)),
+    );
+    const result = await db
+        .select({ olm: olms, session: olmSessions })
+        .from(olmSessions)
+        .innerJoin(olms, eq(olmSessions.olmId, olms.olmId))
+        .where(eq(olmSessions.sessionId, sessionId));
+    if (result.length < 1) {
+        return { session: null, olm: null };
+    }
+    const { olm, session } = result[0];
+    if (Date.now() >= session.expiresAt) {
+        await db
+            .delete(olmSessions)
+            .where(eq(olmSessions.sessionId, session.sessionId));
+        return { session: null, olm: null };
+    }
+    if (Date.now() >= session.expiresAt - (EXPIRES / 2)) {
+        session.expiresAt = new Date(
+            Date.now() + EXPIRES,
+        ).getTime();
+        await db
+            .update(olmSessions)
+            .set({
+                expiresAt: session.expiresAt,
+            })
+            .where(eq(olmSessions.sessionId, session.sessionId));
+    }
+    return { session, olm };
+}
+
+export async function invalidateOlmSession(sessionId: string): Promise<void> {
+    await db.delete(olmSessions).where(eq(olmSessions.sessionId, sessionId));
+}
+
+export async function invalidateAllOlmSessions(olmId: string): Promise<void> {
+    await db.delete(olmSessions).where(eq(olmSessions.olmId, olmId));
+}
+
+export type SessionValidationResult =
+    | { session: OlmSession; olm: Olm }
+    | { session: null; olm: null };

+ 6 - 6
server/db/schema.ts

@@ -114,8 +114,8 @@ export const newts = sqliteTable("newt", {
     })
     })
 });
 });
 
 
-export const clients = sqliteTable("clients", {
-    clientId: text("id").primaryKey(),
+export const olms = sqliteTable("olms", {
+    olmId: text("id").primaryKey(),
     secretHash: text("secretHash").notNull(),
     secretHash: text("secretHash").notNull(),
     dateCreated: text("dateCreated").notNull(),
     dateCreated: text("dateCreated").notNull(),
     siteId: integer("siteId").references(() => sites.siteId, {
     siteId: integer("siteId").references(() => sites.siteId, {
@@ -156,9 +156,9 @@ export const newtSessions = sqliteTable("newtSession", {
     expiresAt: integer("expiresAt").notNull()
     expiresAt: integer("expiresAt").notNull()
 });
 });
 
 
-export const clientSessions = sqliteTable("clientSession", {
+export const olmSessions = sqliteTable("clientSession", {
     sessionId: text("id").primaryKey(),
     sessionId: text("id").primaryKey(),
-    clientId: text("clientId")
+    olmId: text("olmId")
         .notNull()
         .notNull()
         .references(() => newts.newtId, { onDelete: "cascade" }),
         .references(() => newts.newtId, { onDelete: "cascade" }),
     expiresAt: integer("expiresAt").notNull()
     expiresAt: integer("expiresAt").notNull()
@@ -425,8 +425,8 @@ export type Target = InferSelectModel<typeof targets>;
 export type Session = InferSelectModel<typeof sessions>;
 export type Session = InferSelectModel<typeof sessions>;
 export type Newt = InferSelectModel<typeof newts>;
 export type Newt = InferSelectModel<typeof newts>;
 export type NewtSession = InferSelectModel<typeof newtSessions>;
 export type NewtSession = InferSelectModel<typeof newtSessions>;
-export type Client = InferSelectModel<typeof clients>;
-export type ClientSession = InferSelectModel<typeof clientSessions>;
+export type Olm = InferSelectModel<typeof olms>;
+export type OlmSession = InferSelectModel<typeof olmSessions>;
 export type EmailVerificationCode = InferSelectModel<
 export type EmailVerificationCode = InferSelectModel<
     typeof emailVerificationCodes
     typeof emailVerificationCodes
 >;
 >;

+ 0 - 1
server/routers/client/index.ts

@@ -1 +0,0 @@
-export * from "./pickClientDefaults";

+ 2 - 2
server/routers/external.ts

@@ -7,7 +7,7 @@ import * as target from "./target";
 import * as user from "./user";
 import * as user from "./user";
 import * as auth from "./auth";
 import * as auth from "./auth";
 import * as role from "./role";
 import * as role from "./role";
-import * as client from "./client";
+import * as olm from "./olm";
 import * as accessToken from "./accessToken";
 import * as accessToken from "./accessToken";
 import HttpCode from "@server/types/HttpCode";
 import HttpCode from "@server/types/HttpCode";
 import {
 import {
@@ -100,7 +100,7 @@ authenticated.get(
     "/site/:siteId/pick-client-defaults",
     "/site/:siteId/pick-client-defaults",
     verifyOrgAccess,
     verifyOrgAccess,
     verifyUserHasAction(ActionsEnum.createClient),
     verifyUserHasAction(ActionsEnum.createClient),
-    client.pickClientDefaults
+    olm.pickOlmDefaults
 );
 );
 
 
 // authenticated.get(
 // authenticated.get(

+ 4 - 2
server/routers/messageHandlers.ts

@@ -1,8 +1,10 @@
-import { handleRegisterMessage } from "./newt";
+import { handleNewtRegisterMessage } from "./newt";
+import { handleOlmRegisterMessage } from "./olm";
 import { handleGetConfigMessage } from "./newt/handleGetConfigMessage";
 import { handleGetConfigMessage } from "./newt/handleGetConfigMessage";
 import { MessageHandler } from "./ws";
 import { MessageHandler } from "./ws";
 
 
 export const messageHandlers: Record<string, MessageHandler> = {
 export const messageHandlers: Record<string, MessageHandler> = {
-    "newt/wg/register": handleRegisterMessage,
+    "newt/wg/register": handleNewtRegisterMessage,
+    "olm/wg/register": handleOlmRegisterMessage,
     "newt/wg/get-config": handleGetConfigMessage,
     "newt/wg/get-config": handleGetConfigMessage,
 };
 };

+ 4 - 2
server/routers/newt/handleRegisterMessage.ts → server/routers/newt/handleNewtRegisterMessage.ts

@@ -11,8 +11,10 @@ import { eq, and, sql } from "drizzle-orm";
 import { addPeer, deletePeer } from "../gerbil/peers";
 import { addPeer, deletePeer } from "../gerbil/peers";
 import logger from "@server/logger";
 import logger from "@server/logger";
 
 
-export const handleRegisterMessage: MessageHandler = async (context) => {
-    const { message, newt, sendToClient } = context;
+export const handleNewtRegisterMessage: MessageHandler = async (context) => {
+    const { message, client, sendToClient } = context;
+
+    const newt = client;
 
 
     logger.info("Handling register message!");
     logger.info("Handling register message!");
 
 

+ 106 - 0
server/routers/olm/createOlm.ts

@@ -0,0 +1,106 @@
+import { NextFunction, Request, Response } from "express";
+import db from "@server/db";
+import { hash } from "@node-rs/argon2";
+import HttpCode from "@server/types/HttpCode";
+import { z } from "zod";
+import { newts } from "@server/db/schema";
+import createHttpError from "http-errors";
+import response from "@server/lib/response";
+import { SqliteError } from "better-sqlite3";
+import moment from "moment";
+import { generateSessionToken } from "@server/auth/sessions/app";
+import { createNewtSession } from "@server/auth/sessions/newt";
+import { fromError } from "zod-validation-error";
+import { hashPassword } from "@server/auth/password";
+
+export const createNewtBodySchema = z.object({});
+
+export type CreateNewtBody = z.infer<typeof createNewtBodySchema>;
+
+export type CreateNewtResponse = {
+    token: string;
+    newtId: string;
+    secret: string;
+};
+
+const createNewtSchema = z
+    .object({
+        newtId: z.string(),
+        secret: z.string()
+    })
+    .strict();
+
+export async function createNewt(
+    req: Request,
+    res: Response,
+    next: NextFunction
+): Promise<any> {
+    try {
+
+        const parsedBody = createNewtSchema.safeParse(req.body);
+        if (!parsedBody.success) {
+            return next(
+                createHttpError(
+                    HttpCode.BAD_REQUEST,
+                    fromError(parsedBody.error).toString()
+                )
+            );
+        }
+
+        const { newtId, secret } = parsedBody.data;
+
+        if (!req.userOrgRoleId) {
+            return next(
+                createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
+            );
+        }
+
+        const secretHash = await hashPassword(secret);
+
+        await db.insert(newts).values({
+            newtId: newtId,
+            secretHash,
+            dateCreated: moment().toISOString(),
+        });
+
+        // give the newt their default permissions:
+        // await db.insert(newtActions).values({
+        //     newtId: newtId,
+        //     actionId: ActionsEnum.createOrg,
+        //     orgId: null,
+        // });
+
+        const token = generateSessionToken();
+        await createNewtSession(token, newtId);
+
+        return response<CreateNewtResponse>(res, {
+            data: {
+                newtId,
+                secret,
+                token,
+            },
+            success: true,
+            error: false,
+            message: "Newt created successfully",
+            status: HttpCode.OK,
+        });
+    } catch (e) {
+        if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") {
+            return next(
+                createHttpError(
+                    HttpCode.BAD_REQUEST,
+                    "A newt with that email address already exists"
+                )
+            );
+        } else {
+            console.error(e);
+
+            return next(
+                createHttpError(
+                    HttpCode.INTERNAL_SERVER_ERROR,
+                    "Failed to create newt"
+                )
+            );
+        }
+    }
+}

+ 115 - 0
server/routers/olm/getToken.ts

@@ -0,0 +1,115 @@
+import { generateSessionToken } from "@server/auth/sessions/app";
+import db from "@server/db";
+import { newts } from "@server/db/schema";
+import HttpCode from "@server/types/HttpCode";
+import response from "@server/lib/response";
+import { eq } from "drizzle-orm";
+import { NextFunction, Request, Response } from "express";
+import createHttpError from "http-errors";
+import { z } from "zod";
+import { fromError } from "zod-validation-error";
+import {
+    createNewtSession,
+    validateNewtSessionToken
+} from "@server/auth/sessions/newt";
+import { verifyPassword } from "@server/auth/password";
+import logger from "@server/logger";
+import config from "@server/lib/config";
+
+export const newtGetTokenBodySchema = z.object({
+    newtId: z.string(),
+    secret: z.string(),
+    token: z.string().optional()
+});
+
+export type NewtGetTokenBody = z.infer<typeof newtGetTokenBodySchema>;
+
+export async function getToken(
+    req: Request,
+    res: Response,
+    next: NextFunction
+): Promise<any> {
+    const parsedBody = newtGetTokenBodySchema.safeParse(req.body);
+
+    if (!parsedBody.success) {
+        return next(
+            createHttpError(
+                HttpCode.BAD_REQUEST,
+                fromError(parsedBody.error).toString()
+            )
+        );
+    }
+
+    const { newtId, secret, token } = parsedBody.data;
+
+    try {
+        if (token) {
+            const { session, newt } = await validateNewtSessionToken(token);
+            if (session) {
+                if (config.getRawConfig().app.log_failed_attempts) {
+                    logger.info(
+                        `Newt session already valid. Newt ID: ${newtId}. IP: ${req.ip}.`
+                    );
+                }
+                return response<null>(res, {
+                    data: null,
+                    success: true,
+                    error: false,
+                    message: "Token session already valid",
+                    status: HttpCode.OK
+                });
+            }
+        }
+
+        const existingNewtRes = await db
+            .select()
+            .from(newts)
+            .where(eq(newts.newtId, newtId));
+        if (!existingNewtRes || !existingNewtRes.length) {
+            return next(
+                createHttpError(
+                    HttpCode.BAD_REQUEST,
+                    "No newt found with that newtId"
+                )
+            );
+        }
+
+        const existingNewt = existingNewtRes[0];
+
+        const validSecret = await verifyPassword(
+            secret,
+            existingNewt.secretHash
+        );
+        if (!validSecret) {
+            if (config.getRawConfig().app.log_failed_attempts) {
+                logger.info(
+                    `Newt id or secret is incorrect. Newt: ID ${newtId}. IP: ${req.ip}.`
+                );
+            }
+            return next(
+                createHttpError(HttpCode.BAD_REQUEST, "Secret is incorrect")
+            );
+        }
+
+        const resToken = generateSessionToken();
+        await createNewtSession(resToken, existingNewt.newtId);
+
+        return response<{ token: string }>(res, {
+            data: {
+                token: resToken
+            },
+            success: true,
+            error: false,
+            message: "Token created successfully",
+            status: HttpCode.OK
+        });
+    } catch (e) {
+        console.error(e);
+        return next(
+            createHttpError(
+                HttpCode.INTERNAL_SERVER_ERROR,
+                "Failed to authenticate newt"
+            )
+        );
+    }
+}

+ 147 - 0
server/routers/olm/handleGetConfigMessage.ts

@@ -0,0 +1,147 @@
+import { z } from "zod";
+import { MessageHandler } from "../ws";
+import logger from "@server/logger";
+import { fromError } from "zod-validation-error";
+import db from "@server/db";
+import { olms, Site, sites } from "@server/db/schema";
+import { eq, isNotNull } from "drizzle-orm";
+import { findNextAvailableCidr } from "@server/lib/ip";
+import config from "@server/lib/config";
+
+const inputSchema = z.object({
+    publicKey: z.string(),
+    endpoint: z.string(),
+    listenPort: z.number()
+});
+
+type Input = z.infer<typeof inputSchema>;
+
+export const handleGetConfigMessage: MessageHandler = async (context) => {
+    const { message, newt, sendToClient } = context;
+
+    logger.debug("Handling Newt get config message!");
+
+    if (!newt) {
+        logger.warn("Newt not found");
+        return;
+    }
+
+    if (!newt.siteId) {
+        logger.warn("Newt has no site!"); // TODO: Maybe we create the site here?
+        return;
+    }
+
+    const parsed = inputSchema.safeParse(message.data);
+    if (!parsed.success) {
+        logger.error(
+            "handleGetConfigMessage: Invalid input: " +
+                fromError(parsed.error).toString()
+        );
+        return;
+    }
+
+    const { publicKey, endpoint, listenPort } = message.data as Input;
+
+    const siteId = newt.siteId;
+
+    const [siteRes] = await db
+        .select()
+        .from(sites)
+        .where(eq(sites.siteId, siteId));
+
+    if (!siteRes) {
+        logger.warn("handleGetConfigMessage: Site not found");
+        return;
+    }
+
+    let site: Site | undefined;
+    if (!site) {
+        const address = await getNextAvailableSubnet();
+
+        // create a new exit node
+        const [updateRes] = await db
+            .update(sites)
+            .set({
+                publicKey,
+                endpoint,
+                address,
+                listenPort
+            })
+            .where(eq(sites.siteId, siteId))
+            .returning();
+
+        site = updateRes;
+
+        logger.info(`Updated site ${siteId} with new WG Newt info`);
+    } else {
+        site = siteRes;
+    }
+
+    if (!site) {
+        logger.error("handleGetConfigMessage: Failed to update site");
+        return;
+    }
+
+    const clientsRes = await db
+        .select()
+        .from(olms)
+        .where(eq(olms.siteId, siteId));
+
+    const peers = await Promise.all(
+        clientsRes.map(async (client) => {
+            return {
+                publicKey: client.pubKey,
+                allowedIps: "0.0.0.0/0"
+            };
+        })
+    );
+
+    const configResponse = {
+        listenPort: site.listenPort, // ?????
+        // ipAddress: exitNode[0].address,
+        peers
+    };
+
+    logger.debug("Sending config: ", configResponse);
+
+    return {
+        message: {
+            type: "olm/wg/connect", // what to make the response type?
+            data: {
+                config: configResponse
+            }
+        },
+        broadcast: false, // Send to all clients
+        excludeSender: false // Include sender in broadcast
+    };
+};
+
+async function getNextAvailableSubnet(): Promise<string> {
+    const existingAddresses = await db
+        .select({
+            address: sites.address
+        })
+        .from(sites)
+        .where(isNotNull(sites.address));
+
+    const addresses = existingAddresses
+        .map((a) => a.address)
+        .filter((a) => a) as string[];
+
+    let subnet = findNextAvailableCidr(
+        addresses,
+        config.getRawConfig().wg_site.block_size,
+        config.getRawConfig().wg_site.subnet_group
+    );
+    if (!subnet) {
+        throw new Error("No available subnets remaining in space");
+    }
+
+    // replace the last octet with 1
+    subnet =
+        subnet.split(".").slice(0, 3).join(".") +
+        ".1" +
+        "/" +
+        subnet.split("/")[1];
+    return subnet;
+}

+ 93 - 0
server/routers/olm/handleOlmRegisterMessage.ts

@@ -0,0 +1,93 @@
+import db from "@server/db";
+import { MessageHandler } from "../ws";
+import {
+    exitNodes,
+    resources,
+    sites,
+    Target,
+    targets
+} from "@server/db/schema";
+import { eq, and, sql } from "drizzle-orm";
+import { addPeer, deletePeer } from "../gerbil/peers";
+import logger from "@server/logger";
+
+export const handleOlmRegisterMessage: MessageHandler = async (context) => {
+    const { message, client, sendToClient } = context;
+
+    const olm = client;
+
+    logger.info("Handling register message!");
+
+    if (!olm) {
+        logger.warn("Olm not found");
+        return;
+    }
+
+    if (!olm.siteId) {
+        logger.warn("Olm has no site!"); // TODO: Maybe we create the site here?
+        return;
+    }
+
+    const siteId = olm.siteId;
+
+    const { publicKey } = message.data;
+    if (!publicKey) {
+        logger.warn("Public key not provided");
+        return;
+    }
+
+    const [site] = await db
+        .select()
+        .from(sites)
+        .where(eq(sites.siteId, siteId))
+        .limit(1);
+
+    if (!site || !site.exitNodeId) {
+        logger.warn("Site not found or does not have exit node");
+        return;
+    }
+
+    await db
+        .update(sites)
+        .set({
+            pubKey: publicKey
+        })
+        .where(eq(sites.siteId, siteId))
+        .returning();
+
+    const [exitNode] = await db
+        .select()
+        .from(exitNodes)
+        .where(eq(exitNodes.exitNodeId, site.exitNodeId))
+        .limit(1);
+
+    if (site.pubKey && site.pubKey !== publicKey) {
+        logger.info("Public key mismatch. Deleting old peer...");
+        await deletePeer(site.exitNodeId, site.pubKey);
+    }
+
+    if (!site.subnet) {
+        logger.warn("Site has no subnet");
+        return;
+    }
+
+    // add the peer to the exit node
+    await addPeer(site.exitNodeId, {
+        publicKey: publicKey,
+        allowedIps: [site.subnet]
+    });
+
+    return {
+        message: {
+            type: "olm/wg/connect",
+            data: {
+                endpoint: `${exitNode.endpoint}:${exitNode.listenPort}`,
+                publicKey: exitNode.publicKey,
+                serverIP: exitNode.address.split("/")[0],
+                tunnelIP: site.subnet.split("/")[0]
+            }
+        },
+        broadcast: false, // Send to all olms
+        excludeSender: false // Include sender in broadcast
+    };
+};

+ 1 - 0
server/routers/olm/index.ts

@@ -0,0 +1 @@
+export * from "./pickOlmDefaults";

+ 5 - 5
server/routers/client/pickClientDefaults.ts → server/routers/olm/pickOlmDefaults.ts

@@ -1,6 +1,6 @@
 import { Request, Response, NextFunction } from "express";
 import { Request, Response, NextFunction } from "express";
 import { db } from "@server/db";
 import { db } from "@server/db";
-import { clients, sites } from "@server/db/schema";
+import { olms, sites } from "@server/db/schema";
 import { eq } from "drizzle-orm";
 import { eq } from "drizzle-orm";
 import response from "@server/lib/response";
 import response from "@server/lib/response";
 import HttpCode from "@server/types/HttpCode";
 import HttpCode from "@server/types/HttpCode";
@@ -30,7 +30,7 @@ export type PickClientDefaultsResponse = {
     clientSecret: string;
     clientSecret: string;
 };
 };
 
 
-export async function pickClientDefaults(
+export async function pickOlmDefaults(
     req: Request,
     req: Request,
     res: Response,
     res: Response,
     next: NextFunction
     next: NextFunction
@@ -71,10 +71,10 @@ export async function pickClientDefaults(
 
 
         const clientsQuery = await db
         const clientsQuery = await db
             .select({
             .select({
-                subnet: clients.subnet
+                subnet: olms.subnet
             })
             })
-            .from(clients)
-            .where(eq(clients.siteId, site.siteId));
+            .from(olms)
+            .where(eq(olms.siteId, site.siteId));
 
 
         let subnets = clientsQuery.map((client) => client.subnet);
         let subnets = clientsQuery.map((client) => client.subnet);
 
 

+ 87 - 75
server/routers/ws.ts

@@ -3,10 +3,11 @@ import { Server as HttpServer } from "http";
 import { WebSocket, WebSocketServer } from "ws";
 import { WebSocket, WebSocketServer } from "ws";
 import { IncomingMessage } from "http";
 import { IncomingMessage } from "http";
 import { Socket } from "net";
 import { Socket } from "net";
-import { Newt, newts, NewtSession } from "@server/db/schema";
+import { Newt, newts, NewtSession, Olm, olms, OlmSession } from "@server/db/schema";
 import { eq } from "drizzle-orm";
 import { eq } from "drizzle-orm";
 import db from "@server/db";
 import db from "@server/db";
 import { validateNewtSessionToken } from "@server/auth/sessions/newt";
 import { validateNewtSessionToken } from "@server/auth/sessions/newt";
+import { validateOlmSessionToken } from "@server/auth/sessions/olm";
 import { messageHandlers } from "./messageHandlers";
 import { messageHandlers } from "./messageHandlers";
 import logger from "@server/logger";
 import logger from "@server/logger";
 
 
@@ -15,13 +16,17 @@ interface WebSocketRequest extends IncomingMessage {
     token?: string;
     token?: string;
 }
 }
 
 
+type ClientType = 'newt' | 'olm';
+
 interface AuthenticatedWebSocket extends WebSocket {
 interface AuthenticatedWebSocket extends WebSocket {
-    newt?: Newt;
+    client?: Newt | Olm;
+    clientType?: ClientType;
 }
 }
 
 
 interface TokenPayload {
 interface TokenPayload {
-    newt: Newt;
-    session: NewtSession;
+    client: Newt | Olm;
+    session: NewtSession | OlmSession;
+    clientType: ClientType;
 }
 }
 
 
 interface WSMessage {
 interface WSMessage {
@@ -33,15 +38,16 @@ interface HandlerResponse {
     message: WSMessage;
     message: WSMessage;
     broadcast?: boolean;
     broadcast?: boolean;
     excludeSender?: boolean;
     excludeSender?: boolean;
-    targetNewtId?: string;
+    targetClientId?: string;
 }
 }
 
 
 interface HandlerContext {
 interface HandlerContext {
     message: WSMessage;
     message: WSMessage;
     senderWs: WebSocket;
     senderWs: WebSocket;
-    newt: Newt | undefined;
-    sendToClient: (newtId: string, message: WSMessage) => boolean;
-    broadcastToAllExcept: (message: WSMessage, excludeNewtId?: string) => void;
+    client: Newt | Olm | undefined;
+    clientType: ClientType;
+    sendToClient: (clientId: string, message: WSMessage) => boolean;
+    broadcastToAllExcept: (message: WSMessage, excludeClientId?: string) => void;
     connectedClients: Map<string, WebSocket[]>;
     connectedClients: Map<string, WebSocket[]>;
 }
 }
 
 
@@ -54,34 +60,32 @@ const wss: WebSocketServer = new WebSocketServer({ noServer: true });
 let connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
 let connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
 
 
 // Helper functions for client management
 // Helper functions for client management
-const addClient = (newtId: string, ws: AuthenticatedWebSocket): void => {
-    const existingClients = connectedClients.get(newtId) || [];
+const addClient = (clientId: string, ws: AuthenticatedWebSocket, clientType: ClientType): void => {
+    const existingClients = connectedClients.get(clientId) || [];
     existingClients.push(ws);
     existingClients.push(ws);
-    connectedClients.set(newtId, existingClients);
-    logger.info(`Client added to tracking - Newt ID: ${newtId}, Total connections: ${existingClients.length}`);
+    connectedClients.set(clientId, existingClients);
+    logger.info(`Client added to tracking - ${clientType.toUpperCase()} ID: ${clientId}, Total connections: ${existingClients.length}`);
 };
 };
 
 
-const removeClient = (newtId: string, ws: AuthenticatedWebSocket): void => {
-    const existingClients = connectedClients.get(newtId) || [];
+const removeClient = (clientId: string, ws: AuthenticatedWebSocket, clientType: ClientType): void => {
+    const existingClients = connectedClients.get(clientId) || [];
     const updatedClients = existingClients.filter(client => client !== ws);
     const updatedClients = existingClients.filter(client => client !== ws);
-
     if (updatedClients.length === 0) {
     if (updatedClients.length === 0) {
-        connectedClients.delete(newtId);
-        logger.info(`All connections removed for Newt ID: ${newtId}`);
+        connectedClients.delete(clientId);
+        logger.info(`All connections removed for ${clientType.toUpperCase()} ID: ${clientId}`);
     } else {
     } else {
-        connectedClients.set(newtId, updatedClients);
-        logger.info(`Connection removed - Newt ID: ${newtId}, Remaining connections: ${updatedClients.length}`);
+        connectedClients.set(clientId, updatedClients);
+        logger.info(`Connection removed - ${clientType.toUpperCase()} ID: ${clientId}, Remaining connections: ${updatedClients.length}`);
     }
     }
 };
 };
 
 
 // Helper functions for sending messages
 // Helper functions for sending messages
-const sendToClient = (newtId: string, message: WSMessage): boolean => {
-    const clients = connectedClients.get(newtId);
+const sendToClient = (clientId: string, message: WSMessage): boolean => {
+    const clients = connectedClients.get(clientId);
     if (!clients || clients.length === 0) {
     if (!clients || clients.length === 0) {
-        logger.info(`No active connections found for Newt ID: ${newtId}`);
+        logger.info(`No active connections found for Client ID: ${clientId}`);
         return false;
         return false;
     }
     }
-
     const messageString = JSON.stringify(message);
     const messageString = JSON.stringify(message);
     clients.forEach(client => {
     clients.forEach(client => {
         if (client.readyState === WebSocket.OPEN) {
         if (client.readyState === WebSocket.OPEN) {
@@ -91,9 +95,9 @@ const sendToClient = (newtId: string, message: WSMessage): boolean => {
     return true;
     return true;
 };
 };
 
 
-const broadcastToAllExcept = (message: WSMessage, excludeNewtId?: string): void => {
-    connectedClients.forEach((clients, newtId) => {
-        if (newtId !== excludeNewtId) {
+const broadcastToAllExcept = (message: WSMessage, excludeClientId?: string): void => {
+    connectedClients.forEach((clients, clientId) => {
+        if (clientId !== excludeClientId) {
             clients.forEach(client => {
             clients.forEach(client => {
                 if (client.readyState === WebSocket.OPEN) {
                 if (client.readyState === WebSocket.OPEN) {
                     client.send(JSON.stringify(message));
                     client.send(JSON.stringify(message));
@@ -103,84 +107,88 @@ const broadcastToAllExcept = (message: WSMessage, excludeNewtId?: string): void
     });
     });
 };
 };
 
 
-// Token verification middleware (unchanged)
-const verifyToken = async (token: string): Promise<TokenPayload | null> => {
+// Token verification middleware
+const verifyToken = async (token: string, clientType: ClientType): Promise<TokenPayload | null> => {
     try {
     try {
-        const { session, newt } = await validateNewtSessionToken(token);
-
-        if (!session || !newt) {
-            return null;
-        }
-
-        const existingNewt = await db
-            .select()
-            .from(newts)
-            .where(eq(newts.newtId, newt.newtId));
-
-        if (!existingNewt || !existingNewt[0]) {
-            return null;
+        if (clientType === 'newt') {
+            const { session, newt } = await validateNewtSessionToken(token);
+            if (!session || !newt) {
+                return null;
+            }
+            const existingNewt = await db
+                .select()
+                .from(newts)
+                .where(eq(newts.newtId, newt.newtId));
+            if (!existingNewt || !existingNewt[0]) {
+                return null;
+            }
+            return { client: existingNewt[0], session, clientType };
+        } else {
+            const { session, olm } = await validateOlmSessionToken(token);
+            if (!session || !olm) {
+                return null;
+            }
+            const existingOlm = await db
+                .select()
+                .from(olms)
+                .where(eq(olms.olmId, olm.olmId));
+            if (!existingOlm || !existingOlm[0]) {
+                return null;
+            }
+            return { client: existingOlm[0], session, clientType };
         }
         }
-
-        return { newt: existingNewt[0], session };
     } catch (error) {
     } catch (error) {
         logger.error("Token verification failed:", error);
         logger.error("Token verification failed:", error);
         return null;
         return null;
     }
     }
 };
 };
 
 
-const setupConnection = (ws: AuthenticatedWebSocket, newt: Newt): void => {
+const setupConnection = (ws: AuthenticatedWebSocket, client: Newt | Olm, clientType: ClientType): void => {
     logger.info("Establishing websocket connection");
     logger.info("Establishing websocket connection");
-
-    if (!newt) {
-        logger.error("Connection attempt without newt");
+    if (!client) {
+        logger.error("Connection attempt without client");
         return ws.terminate();
         return ws.terminate();
     }
     }
 
 
-    ws.newt = newt;
+    ws.client = client;
+    ws.clientType = clientType;
 
 
     // Add client to tracking
     // Add client to tracking
-    addClient(newt.newtId, ws);
+    const clientId = clientType === 'newt' ? (client as Newt).newtId : (client as Olm).olmId;
+    addClient(clientId, ws, clientType);
 
 
     ws.on("message", async (data) => {
     ws.on("message", async (data) => {
         try {
         try {
             const message: WSMessage = JSON.parse(data.toString());
             const message: WSMessage = JSON.parse(data.toString());
-            // logger.info(`Message received from Newt ID ${newtId}:`, message);
 
 
-            // Validate message format
             if (!message.type || typeof message.type !== "string") {
             if (!message.type || typeof message.type !== "string") {
                 throw new Error("Invalid message format: missing or invalid type");
                 throw new Error("Invalid message format: missing or invalid type");
             }
             }
 
 
-            // Get the appropriate handler for the message type
             const handler = messageHandlers[message.type];
             const handler = messageHandlers[message.type];
             if (!handler) {
             if (!handler) {
                 throw new Error(`Unsupported message type: ${message.type}`);
                 throw new Error(`Unsupported message type: ${message.type}`);
             }
             }
 
 
-            // Process the message and get response
             const response = await handler({
             const response = await handler({
                 message,
                 message,
                 senderWs: ws,
                 senderWs: ws,
-                newt: ws.newt,
+                client: ws.client,
+                clientType: ws.clientType!,
                 sendToClient,
                 sendToClient,
                 broadcastToAllExcept,
                 broadcastToAllExcept,
                 connectedClients
                 connectedClients
             });
             });
 
 
-            // Send response if one was returned
             if (response) {
             if (response) {
                 if (response.broadcast) {
                 if (response.broadcast) {
-                    // Broadcast to all clients except sender if specified
-                    broadcastToAllExcept(response.message, response.excludeSender ? newt.newtId : undefined);
-                } else if (response.targetNewtId) {
-                    // Send to specific client if targetNewtId is provided
-                    sendToClient(response.targetNewtId, response.message);
+                    broadcastToAllExcept(response.message, response.excludeSender ? clientId : undefined);
+                } else if (response.targetClientId) {
+                    sendToClient(response.targetClientId, response.message);
                 } else {
                 } else {
-                    // Send back to sender
                     ws.send(JSON.stringify(response.message));
                     ws.send(JSON.stringify(response.message));
                 }
                 }
             }
             }
-
         } catch (error) {
         } catch (error) {
             logger.error("Message handling error:", error);
             logger.error("Message handling error:", error);
             ws.send(JSON.stringify({
             ws.send(JSON.stringify({
@@ -194,18 +202,18 @@ const setupConnection = (ws: AuthenticatedWebSocket, newt: Newt): void => {
     });
     });
 
 
     ws.on("close", () => {
     ws.on("close", () => {
-        removeClient(newt.newtId, ws);
-        logger.info(`Client disconnected - Newt ID: ${newt.newtId}`);
+        removeClient(clientId, ws, clientType);
+        logger.info(`Client disconnected - ${clientType.toUpperCase()} ID: ${clientId}`);
     });
     });
 
 
     ws.on("error", (error: Error) => {
     ws.on("error", (error: Error) => {
-        logger.error(`WebSocket error for Newt ID ${newt.newtId}:`, error);
+        logger.error(`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`, error);
     });
     });
 
 
-    logger.info(`WebSocket connection established - Newt ID: ${newt.newtId}`);
+    logger.info(`WebSocket connection established - ${clientType.toUpperCase()} ID: ${clientId}`);
 };
 };
 
 
-// Router endpoint (unchanged)
+// Router endpoint
 router.get("/ws", (req: Request, res: Response) => {
 router.get("/ws", (req: Request, res: Response) => {
     res.status(200).send("WebSocket endpoint");
     res.status(200).send("WebSocket endpoint");
 });
 });
@@ -214,18 +222,22 @@ router.get("/ws", (req: Request, res: Response) => {
 const handleWSUpgrade = (server: HttpServer): void => {
 const handleWSUpgrade = (server: HttpServer): void => {
     server.on("upgrade", async (request: WebSocketRequest, socket: Socket, head: Buffer) => {
     server.on("upgrade", async (request: WebSocketRequest, socket: Socket, head: Buffer) => {
         try {
         try {
-            const token = request.url?.includes("?")
-                ? new URLSearchParams(request.url.split("?")[1]).get("token") || ""
-                : request.headers["sec-websocket-protocol"];
+            const url = new URL(request.url || '', `http://${request.headers.host}`);
+            const token = url.searchParams.get('token') || request.headers["sec-websocket-protocol"] || '';
+            let clientType = url.searchParams.get('clientType') as ClientType;
+
+            if (!clientType) {
+                clientType = "newt";
+            }
 
 
-            if (!token) {
-                logger.warn("Unauthorized connection attempt: no token...");
+            if (!token || !clientType || !['newt', 'olm'].includes(clientType)) {
+                logger.warn("Unauthorized connection attempt: invalid token or client type...");
                 socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
                 socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
                 socket.destroy();
                 socket.destroy();
                 return;
                 return;
             }
             }
 
 
-            const tokenPayload = await verifyToken(token);
+            const tokenPayload = await verifyToken(token, clientType);
             if (!tokenPayload) {
             if (!tokenPayload) {
                 logger.warn("Unauthorized connection attempt: invalid token...");
                 logger.warn("Unauthorized connection attempt: invalid token...");
                 socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
                 socket.write("HTTP/1.1 401 Unauthorized\r\n\r\n");
@@ -234,7 +246,7 @@ const handleWSUpgrade = (server: HttpServer): void => {
             }
             }
 
 
             wss.handleUpgrade(request, socket, head, (ws: AuthenticatedWebSocket) => {
             wss.handleUpgrade(request, socket, head, (ws: AuthenticatedWebSocket) => {
-                setupConnection(ws, tokenPayload.newt);
+                setupConnection(ws, tokenPayload.client, tokenPayload.clientType);
             });
             });
         } catch (error) {
         } catch (error) {
             logger.error("WebSocket upgrade error:", error);
             logger.error("WebSocket upgrade error:", error);
@@ -250,4 +262,4 @@ export {
     sendToClient,
     sendToClient,
     broadcastToAllExcept,
     broadcastToAllExcept,
     connectedClients
     connectedClients
-};
+};