Owen 4 місяців тому
батько
коміт
bec303821b

+ 3 - 2
server/routers/newt/handleGetConfigMessage.ts

@@ -3,7 +3,7 @@ import { MessageHandler } from "../ws";
 import logger from "@server/logger";
 import { fromError } from "zod-validation-error";
 import db from "@server/db";
-import { clients, Site, sites } from "@server/db/schema";
+import { clients, Newt, Site, sites } from "@server/db/schema";
 import { eq, isNotNull } from "drizzle-orm";
 import { findNextAvailableCidr } from "@server/lib/ip";
 import config from "@server/lib/config";
@@ -17,7 +17,8 @@ const inputSchema = z.object({
 type Input = z.infer<typeof inputSchema>;
 
 export const handleGetConfigMessage: MessageHandler = async (context) => {
-    const { message, newt, sendToClient } = context;
+    const { message, client, sendToClient } = context;
+    const newt = client as Newt;
 
     logger.debug("Handling Newt get config message!");
 

+ 2 - 2
server/routers/newt/handleNewtRegisterMessage.ts

@@ -2,6 +2,7 @@ import db from "@server/db";
 import { MessageHandler } from "../ws";
 import {
     exitNodes,
+    Newt,
     resources,
     sites,
     Target,
@@ -13,8 +14,7 @@ import logger from "@server/logger";
 
 export const handleNewtRegisterMessage: MessageHandler = async (context) => {
     const { message, client, sendToClient } = context;
-
-    const newt = client;
+    const newt = client as Newt;
 
     logger.info("Handling register newt message!");
 

+ 23 - 11
server/routers/olm/handleOlmRegisterMessage.ts

@@ -1,6 +1,8 @@
 import db from "@server/db";
 import { MessageHandler } from "../ws";
 import {
+    clients,
+    Olm,
     olms,
     sites,
 } from "@server/db/schema";
@@ -9,9 +11,8 @@ import { addPeer, deletePeer } from "../newt/peers";
 import logger from "@server/logger";
 
 export const handleOlmRegisterMessage: MessageHandler = async (context) => {
-    const { message, client, sendToClient } = context;
-
-    const olm = client;
+    const { message, client: c, sendToClient } = context;
+    const olm = c as Olm;
 
     logger.info("Handling register olm message!");
 
@@ -20,12 +21,12 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
         return;
     }
 
-    if (!olm.siteId) {
+    if (!olm.clientId) {
         logger.warn("Olm has no site!"); // TODO: Maybe we create the site here?
         return;
     }
 
-    const siteId = olm.siteId;
+    const clientId = olm.clientId;
 
     const { publicKey } = message.data;
     if (!publicKey) {
@@ -33,28 +34,39 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
         return;
     }
 
+    const [client] = await db
+        .select()
+        .from(clients)
+        .where(eq(clients.clientId, clientId))
+        .limit(1);
+
+    if (!client || !client.siteId) {
+        logger.warn("Site not found or does not have exit node");
+        return;
+    }
+
     const [site] = await db
         .select()
         .from(sites)
-        .where(eq(sites.siteId, siteId))
+        .where(eq(sites.siteId, client.siteId))
         .limit(1);
 
-    if (!site) {
+    if (!client) {
         logger.warn("Site not found or does not have exit node");
         return;
     }
 
     await db
-        .update(olms)
+        .update(clients)
         .set({
             pubKey: publicKey
         })
-        .where(eq(olms.olmId, olm.olmId))
+        .where(eq(clients.clientId, olm.clientId))
         .returning();
 
-    if (olm.pubKey && olm.pubKey !== publicKey) {
+    if (client.pubKey && client.pubKey !== publicKey) {
         logger.info("Public key mismatch. Deleting old peer...");
-        await deletePeer(site.siteId, site.pubKey);
+        await deletePeer(site.siteId, client.pubKey);
     }
 
     if (!site.subnet) {