Jelajahi Sumber

add wg site get config and pick client defaults

miloschwartz 4 bulan lalu
induk
melakukan
41983ce356

+ 1 - 0
server/auth/actions.ts

@@ -62,6 +62,7 @@ export enum ActionsEnum {
     deleteResourceRule = "deleteResourceRule",
     listResourceRules = "listResourceRules",
     updateResourceRule = "updateResourceRule",
+    createClient = "createClient"
 }
 
 export async function checkUserActionPermission(

+ 11 - 3
server/db/schema.ts

@@ -31,8 +31,7 @@ export const sites = sqliteTable("sites", {
     address: text("address"), // this is the address of the wireguard interface in gerbil
     endpoint: text("endpoint"), // this is how to reach gerbil externally - gets put into the wireguard config
     publicKey: text("pubicKey"),
-    listenPort: integer("listenPort"),
-    reachableAt: text("reachableAt") // this is the internal address of the gerbil http server for command control
+    listenPort: integer("listenPort")
 });
 
 export const resources = sqliteTable("resources", {
@@ -121,7 +120,16 @@ export const clients = sqliteTable("clients", {
     dateCreated: text("dateCreated").notNull(),
     siteId: integer("siteId").references(() => sites.siteId, {
         onDelete: "cascade"
-    })
+    }),
+
+    // wgstuff
+    pubKey: text("pubKey"),
+    subnet: text("subnet").notNull(),
+    megabytesIn: integer("bytesIn"),
+    megabytesOut: integer("bytesOut"),
+    lastBandwidthUpdate: text("lastBandwidthUpdate"),
+    type: text("type").notNull(), // "newt" or "wireguard"
+    online: integer("online", { mode: "boolean" }).notNull().default(false),
 });
 
 export const twoFactorBackupCodes = sqliteTable("twoFactorBackupCodes", {

+ 4 - 0
server/lib/config.ts

@@ -109,6 +109,10 @@ const configSchema = z.object({
         block_size: z.number().positive().gt(0),
         site_block_size: z.number().positive().gt(0)
     }),
+    wg_site: z.object({
+        block_size: z.number().positive().gt(0),
+        subnet_group: z.string(),
+    }),
     rate_limits: z.object({
         global: z.object({
             window_minutes: z.number().positive().gt(0),

+ 1 - 0
server/routers/client/index.ts

@@ -0,0 +1 @@
+export * from "./pickClientDefaults";

+ 128 - 0
server/routers/client/pickClientDefaults.ts

@@ -0,0 +1,128 @@
+import { Request, Response, NextFunction } from "express";
+import { db } from "@server/db";
+import { clients, sites } from "@server/db/schema";
+import { eq } from "drizzle-orm";
+import response from "@server/lib/response";
+import HttpCode from "@server/types/HttpCode";
+import createHttpError from "http-errors";
+import logger from "@server/logger";
+import { findNextAvailableCidr } from "@server/lib/ip";
+import { generateId } from "@server/auth/sessions/app";
+import config from "@server/lib/config";
+import { z } from "zod";
+import { fromError } from "zod-validation-error";
+
+const getSiteSchema = z
+    .object({
+        siteId: z.number().int().positive()
+    })
+    .strict();
+
+export type PickClientDefaultsResponse = {
+    siteId: number;
+    address: string;
+    publicKey: string;
+    name: string;
+    listenPort: number;
+    endpoint: string;
+    subnet: string;
+    clientId: string;
+    clientSecret: string;
+};
+
+export async function pickClientDefaults(
+    req: Request,
+    res: Response,
+    next: NextFunction
+): Promise<any> {
+    try {
+        const parsedParams = getSiteSchema.safeParse(req.params);
+        if (!parsedParams.success) {
+            return next(
+                createHttpError(
+                    HttpCode.BAD_REQUEST,
+                    fromError(parsedParams.error).toString()
+                )
+            );
+        }
+
+        const { siteId } = parsedParams.data;
+
+        const [site] = await db
+            .select()
+            .from(sites)
+            .where(eq(sites.siteId, siteId));
+
+        if (!site) {
+            return next(createHttpError(HttpCode.NOT_FOUND, "Site not found"));
+        }
+
+        // make sure all the required fields are present
+        if (
+            !site.address ||
+            !site.publicKey ||
+            !site.listenPort ||
+            !site.endpoint
+        ) {
+            return next(
+                createHttpError(HttpCode.BAD_REQUEST, "Site has no address")
+            );
+        }
+
+        const clientsQuery = await db
+            .select({
+                subnet: clients.subnet
+            })
+            .from(clients)
+            .where(eq(clients.siteId, site.siteId));
+
+        let subnets = clientsQuery.map((client) => client.subnet);
+
+        // exclude the exit node address by replacing after the / with a site block size
+        subnets.push(
+            site.address.replace(
+                /\/\d+$/,
+                `/${config.getRawConfig().wg_site.block_size}`
+            )
+        );
+        const newSubnet = findNextAvailableCidr(
+            subnets,
+            config.getRawConfig().wg_site.block_size,
+            site.address
+        );
+        if (!newSubnet) {
+            return next(
+                createHttpError(
+                    HttpCode.INTERNAL_SERVER_ERROR,
+                    "No available subnets"
+                )
+            );
+        }
+
+        const clientId = generateId(15);
+        const secret = generateId(48);
+
+        return response<PickClientDefaultsResponse>(res, {
+            data: {
+                siteId: site.siteId,
+                address: site.address,
+                publicKey: site.publicKey,
+                name: site.name,
+                listenPort: site.listenPort,
+                endpoint: site.endpoint,
+                subnet: newSubnet,
+                clientId,
+                clientSecret: secret
+            },
+            success: true,
+            error: false,
+            message: "Organization retrieved successfully",
+            status: HttpCode.OK
+        });
+    } catch (error) {
+        logger.error(error);
+        return next(
+            createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
+        );
+    }
+}

+ 9 - 0
server/routers/external.ts

@@ -7,6 +7,7 @@ import * as target from "./target";
 import * as user from "./user";
 import * as auth from "./auth";
 import * as role from "./role";
+import * as client from "./client";
 import * as accessToken from "./accessToken";
 import HttpCode from "@server/types/HttpCode";
 import {
@@ -94,6 +95,14 @@ authenticated.get(
     verifyUserHasAction(ActionsEnum.getSite),
     site.getSite
 );
+
+authenticated.get(
+    "/site/:siteId/pick-client-defaults",
+    verifyOrgAccess,
+    verifyUserHasAction(ActionsEnum.createClient),
+    client.pickClientDefaults
+);
+
 // authenticated.get(
 //     "/site/:siteId/roles",
 //     verifySiteAccess,

+ 1 - 1
server/routers/gerbil/getConfig.ts

@@ -86,7 +86,7 @@ export async function getConfig(req: Request, res: Response, next: NextFunction)
         const peers = await Promise.all(sitesRes.map(async (site) => {
             return {
                 publicKey: site.pubKey,
-                allowedIps: await getAllowedIps(site.siteId)
+                allowedIps: await getAllowedIps(site.siteId) // put 0.0.0.0/0 for now
             };
         }));
 

+ 3 - 1
server/routers/messageHandlers.ts

@@ -1,6 +1,8 @@
 import { handleRegisterMessage } from "./newt";
+import { handleGetConfigMessage } from "./newt/handleGetConfigMessage";
 import { MessageHandler } from "./ws";
 
 export const messageHandlers: Record<string, MessageHandler> = {
     "newt/wg/register": handleRegisterMessage,
-};
+    "newt/wg/get-config": handleGetConfigMessage,
+};

+ 147 - 0
server/routers/newt/handleGetConfigMessage.ts

@@ -0,0 +1,147 @@
+import { z } from "zod";
+import { MessageHandler } from "../ws";
+import logger from "@server/logger";
+import { fromError } from "zod-validation-error";
+import db from "@server/db";
+import { clients, Site, sites } from "@server/db/schema";
+import { eq, isNotNull } from "drizzle-orm";
+import { findNextAvailableCidr } from "@server/lib/ip";
+import config from "@server/lib/config";
+
+const inputSchema = z.object({
+    publicKey: z.string(),
+    endpoint: z.string(),
+    listenPort: z.number()
+});
+
+type Input = z.infer<typeof inputSchema>;
+
+export const handleGetConfigMessage: MessageHandler = async (context) => {
+    const { message, newt, sendToClient } = context;
+
+    logger.debug("Handling Newt get config message!");
+
+    if (!newt) {
+        logger.warn("Newt not found");
+        return;
+    }
+
+    if (!newt.siteId) {
+        logger.warn("Newt has no site!"); // TODO: Maybe we create the site here?
+        return;
+    }
+
+    const parsed = inputSchema.safeParse(message.data);
+    if (!parsed.success) {
+        logger.error(
+            "handleGetConfigMessage: Invalid input: " +
+                fromError(parsed.error).toString()
+        );
+        return;
+    }
+
+    const { publicKey, endpoint, listenPort } = message.data as Input;
+
+    const siteId = newt.siteId;
+
+    const [siteRes] = await db
+        .select()
+        .from(sites)
+        .where(eq(sites.siteId, siteId));
+
+    if (!siteRes) {
+        logger.warn("handleGetConfigMessage: Site not found");
+        return;
+    }
+
+    let site: Site | undefined;
+    if (!site) {
+        const address = await getNextAvailableSubnet();
+
+        // create a new exit node
+        const [updateRes] = await db
+            .update(sites)
+            .set({
+                publicKey,
+                endpoint,
+                address,
+                listenPort
+            })
+            .where(eq(sites.siteId, siteId))
+            .returning();
+
+        site = updateRes;
+
+        logger.info(`Updated site ${siteId} with new WG Newt info`);
+    } else {
+        site = siteRes;
+    }
+
+    if (!site) {
+        logger.error("handleGetConfigMessage: Failed to update site");
+        return;
+    }
+
+    const clientsRes = await db
+        .select()
+        .from(clients)
+        .where(eq(clients.siteId, siteId));
+
+    const peers = await Promise.all(
+        clientsRes.map(async (client) => {
+            return {
+                publicKey: client.pubKey,
+                allowedIps: "0.0.0.0/0"
+            };
+        })
+    );
+
+    const configResponse = {
+        listenPort: site.listenPort, // ?????
+        // ipAddress: exitNode[0].address,
+        peers
+    };
+
+    logger.debug("Sending config: ", configResponse);
+
+    return {
+        message: {
+            type: "newt/wg/connect", // what to make the response type?
+            data: {
+                config: configResponse
+            }
+        },
+        broadcast: false, // Send to all clients
+        excludeSender: false // Include sender in broadcast
+    };
+};
+
+async function getNextAvailableSubnet(): Promise<string> {
+    const existingAddresses = await db
+        .select({
+            address: sites.address
+        })
+        .from(sites)
+        .where(isNotNull(sites.address));
+
+    const addresses = existingAddresses
+        .map((a) => a.address)
+        .filter((a) => a) as string[];
+
+    let subnet = findNextAvailableCidr(
+        addresses,
+        config.getRawConfig().wg_site.block_size,
+        config.getRawConfig().wg_site.subnet_group
+    );
+    if (!subnet) {
+        throw new Error("No available subnets remaining in space");
+    }
+
+    // replace the last octet with 1
+    subnet =
+        subnet.split(".").slice(0, 3).join(".") +
+        ".1" +
+        "/" +
+        subnet.split("/")[1];
+    return subnet;
+}