瀏覽代碼

Merge branch 'main' of https://github.com/fosrl/pangolin

Milo Schwartz 7 月之前
父節點
當前提交
4cdaa9b588

+ 17 - 17
server/auth/resourceOtp.ts

@@ -3,11 +3,10 @@ import { resourceOtp } from "@server/db/schema";
 import { and, eq } from "drizzle-orm";
 import { createDate, isWithinExpirationDate, TimeSpan } from "oslo";
 import { alphabet, generateRandomString, sha256 } from "oslo/crypto";
-import { encodeHex } from "oslo/encoding";
 import { sendEmail } from "@server/emails";
 import ResourceOTPCode from "@server/emails/templates/ResourceOTPCode";
 import config from "@server/config";
-import { hash, verify } from "@node-rs/argon2";
+import { verifyPassword } from "./password";
 import { hashPassword } from "./password";
 
 export async function sendResourceOtpEmail(
@@ -37,24 +36,25 @@ export async function generateResourceOtpCode(
     resourceId: number,
     email: string
 ): Promise<string> {
-    await db
-        .delete(resourceOtp)
-        .where(
-            and(
-                eq(resourceOtp.email, email),
-                eq(resourceOtp.resourceId, resourceId)
-            )
-        );
-
     const otp = generateRandomString(8, alphabet("0-9", "A-Z", "a-z"));
+    await db.transaction(async (trx) => {
+        await trx
+            .delete(resourceOtp)
+            .where(
+                and(
+                    eq(resourceOtp.email, email),
+                    eq(resourceOtp.resourceId, resourceId)
+                )
+            );
 
-    const otpHash = await hashPassword(otp);
+        const otpHash = await hashPassword(otp);
 
-    await db.insert(resourceOtp).values({
-        resourceId,
-        email,
-        otpHash,
-        expiresAt: createDate(new TimeSpan(15, "m")).getTime()
+        await trx.insert(resourceOtp).values({
+            resourceId,
+            email,
+            otpHash,
+            expiresAt: createDate(new TimeSpan(15, "m")).getTime()
+        });
     });
 
     return otp;

+ 10 - 10
server/auth/sendEmailVerificationCode.ts

@@ -31,18 +31,18 @@ async function generateEmailVerificationCode(
     userId: string,
     email: string
 ): Promise<string> {
-    await db
-        .delete(emailVerificationCodes)
-        .where(eq(emailVerificationCodes.userId, userId));
-
     const code = generateRandomString(8, alphabet("0-9"));
+    await db.transaction(async (trx) => {
+        await trx
+            .delete(emailVerificationCodes)
+            .where(eq(emailVerificationCodes.userId, userId));
 
-    await db.insert(emailVerificationCodes).values({
-        userId,
-        email,
-        code,
-        expiresAt: createDate(new TimeSpan(15, "m")).getTime()
+        await trx.insert(emailVerificationCodes).values({
+            userId,
+            email,
+            code,
+            expiresAt: createDate(new TimeSpan(15, "m")).getTime()
+        });
     });
-
     return code;
 }

+ 0 - 4
server/routers/auth/disable2fa.ts

@@ -92,10 +92,6 @@ export async function disable2fa(
             .set({ twoFactorEnabled: false })
             .where(eq(users.userId, user.userId));
 
-        await db
-            .delete(twoFactorBackupCodes)
-            .where(eq(twoFactorBackupCodes.userId, user.userId));
-
         sendEmail(
             TwoFactorAuthNotification({
                 email: user.email,

+ 15 - 10
server/routers/auth/requestPasswordReset.ts

@@ -63,18 +63,23 @@ export async function requestPasswordReset(
             );
         }
 
-        await db
-            .delete(passwordResetTokens)
-            .where(eq(passwordResetTokens.userId, existingUser[0].userId));
+        const token = generateRandomString(
+            8,
+            alphabet("0-9", "A-Z", "a-z")
+        );
+        await db.transaction(async (trx) => {
+            await trx
+                .delete(passwordResetTokens)
+                .where(eq(passwordResetTokens.userId, existingUser[0].userId));
 
-        const token = generateRandomString(8, alphabet("0-9", "A-Z", "a-z"));
-        const tokenHash = await hashPassword(token);
+            const tokenHash = await hashPassword(token);
 
-        await db.insert(passwordResetTokens).values({
-            userId: existingUser[0].userId,
-            email: existingUser[0].email,
-            tokenHash,
-            expiresAt: createDate(new TimeSpan(2, "h")).getTime()
+            await trx.insert(passwordResetTokens).values({
+                userId: existingUser[0].userId,
+                email: existingUser[0].email,
+                tokenHash,
+                expiresAt: createDate(new TimeSpan(2, "h")).getTime()
+            });
         });
 
         const url = `${config.app.base_url}/auth/reset-password?email=${email}&token=${token}`;

+ 11 - 9
server/routers/auth/resetPassword.ts

@@ -135,20 +135,22 @@ export async function resetPassword(
 
         await invalidateAllSessions(resetRequest[0].userId);
 
-        await db
-            .update(users)
-            .set({ passwordHash })
-            .where(eq(users.userId, resetRequest[0].userId));
-
-        await db
-            .delete(passwordResetTokens)
-            .where(eq(passwordResetTokens.email, email));
+        await db.transaction(async (trx) => {
+            await trx
+                .update(users)
+                .set({ passwordHash })
+                .where(eq(users.userId, resetRequest[0].userId));
+
+            await trx
+                .delete(passwordResetTokens)
+                .where(eq(passwordResetTokens.email, email));
+        });
 
         await sendEmail(ConfirmPasswordReset({ email }), {
             from: config.email?.no_reply,
             to: email,
             subject: "Password Reset Confirmation"
-        })
+        });
 
         return response<ResetPasswordResponse>(res, {
             data: null,

+ 12 - 10
server/routers/auth/verifyEmail.ts

@@ -62,16 +62,18 @@ export async function verifyEmail(
         const valid = await isValidCode(user, code);
 
         if (valid) {
-            await db
-                .delete(emailVerificationCodes)
-                .where(eq(emailVerificationCodes.userId, user.userId));
-
-            await db
-                .update(users)
-                .set({
-                    emailVerified: true
-                })
-                .where(eq(users.userId, user.userId));
+            await db.transaction(async (trx) => {
+                await trx
+                    .delete(emailVerificationCodes)
+                    .where(eq(emailVerificationCodes.userId, user.userId));
+
+                await trx
+                    .update(users)
+                    .set({
+                        emailVerified: true
+                    })
+                    .where(eq(users.userId, user.userId));
+            });
         } else {
             return next(
                 createHttpError(

+ 17 - 15
server/routers/auth/verifyTotp.ts

@@ -76,21 +76,23 @@ export async function verifyTotp(
         let codes;
         if (valid) {
             // if valid, enable two-factor authentication; the totp secret is no longer temporary
-            await db
-                .update(users)
-                .set({ twoFactorEnabled: true })
-                .where(eq(users.userId, user.userId));
-
-            const backupCodes = await generateBackupCodes();
-            codes = backupCodes;
-            for (const code of backupCodes) {
-                const hash = await hashPassword(code);
-
-                await db.insert(twoFactorBackupCodes).values({
-                    userId: user.userId,
-                    codeHash: hash
-                });
-            }
+            await db.transaction(async (trx) => {
+                await trx
+                    .update(users)
+                    .set({ twoFactorEnabled: true })
+                    .where(eq(users.userId, user.userId));
+
+                const backupCodes = await generateBackupCodes();
+                codes = backupCodes;
+                for (const code of backupCodes) {
+                    const hash = await hashPassword(code);
+
+                    await trx.insert(twoFactorBackupCodes).values({
+                        userId: user.userId,
+                        codeHash: hash
+                    });
+                }
+            });
         }
 
         if (!valid) {

+ 58 - 44
server/routers/gerbil/receiveBandwidth.ts

@@ -1,10 +1,10 @@
-import { Request, Response, NextFunction } from 'express';
-import { DrizzleError, eq } from 'drizzle-orm';
-import { sites, resources, targets, exitNodes } from '@server/db/schema';
-import db from '@server/db';
-import logger from '@server/logger';
-import createHttpError from 'http-errors';
-import HttpCode from '@server/types/HttpCode';
+import { Request, Response, NextFunction } from "express";
+import { DrizzleError, eq } from "drizzle-orm";
+import { sites, resources, targets, exitNodes } from "@server/db/schema";
+import db from "@server/db";
+import logger from "@server/logger";
+import createHttpError from "http-errors";
+import HttpCode from "@server/types/HttpCode";
 import response from "@server/utils/response";
 
 interface PeerBandwidth {
@@ -13,62 +13,76 @@ interface PeerBandwidth {
     bytesOut: number;
 }
 
-export const receiveBandwidth = async (req: Request, res: Response, next: NextFunction): Promise<any> => {
+export const receiveBandwidth = async (
+    req: Request,
+    res: Response,
+    next: NextFunction
+): Promise<any> => {
     try {
         const bandwidthData: PeerBandwidth[] = req.body;
 
         if (!Array.isArray(bandwidthData)) {
-            throw new Error('Invalid bandwidth data');
+            throw new Error("Invalid bandwidth data");
         }
 
-        for (const peer of bandwidthData) {
-            const { publicKey, bytesIn, bytesOut } = peer;
+        await db.transaction(async (trx) => {
+            for (const peer of bandwidthData) {
+                const { publicKey, bytesIn, bytesOut } = peer;
 
-            // Find the site by public key
-            const site = await db.query.sites.findFirst({
-                where: eq(sites.pubKey, publicKey),
-            });
+                // Find the site by public key
+                const site = await trx.query.sites.findFirst({
+                    where: eq(sites.pubKey, publicKey)
+                });
 
-            if (!site) {
-                logger.warn(`Site not found for public key: ${publicKey}`);
-                continue;
-            }
-            let online = site.online;
-
-            // if the bandwidth for the site is > 0 then set it to online. if it has been less than 0 (no update) for 5 minutes then set it to offline
-            if (bytesIn > 0 || bytesOut > 0) {
-                online = true;
-            } else if (site.lastBandwidthUpdate) {
-                const lastBandwidthUpdate = new Date(site.lastBandwidthUpdate);
-                const currentTime = new Date();
-                const diff = currentTime.getTime() - lastBandwidthUpdate.getTime();
-                if (diff < 300000) {
-                    online = false;
+                if (!site) {
+                    logger.warn(`Site not found for public key: ${publicKey}`);
+                    continue;
                 }
-            }
+                let online = site.online;
 
-            // Update the site's bandwidth usage
-            await db.update(sites)
-                .set({
-                    megabytesIn: (site.megabytesIn || 0) + bytesIn,
-                    megabytesOut: (site.megabytesOut || 0) + bytesOut,
-                    lastBandwidthUpdate: new Date().toISOString(),
-                    online,
-                })
-                .where(eq(sites.siteId, site.siteId));
+                // if the bandwidth for the site is > 0 then set it to online. if it has been less than 0 (no update) for 5 minutes then set it to offline
+                if (bytesIn > 0 || bytesOut > 0) {
+                    online = true;
+                } else if (site.lastBandwidthUpdate) {
+                    const lastBandwidthUpdate = new Date(
+                        site.lastBandwidthUpdate
+                    );
+                    const currentTime = new Date();
+                    const diff =
+                        currentTime.getTime() - lastBandwidthUpdate.getTime();
+                    if (diff < 300000) {
+                        online = false;
+                    }
+                }
 
-        }
+                // Update the site's bandwidth usage
+                await trx
+                    .update(sites)
+                    .set({
+                        megabytesIn: (site.megabytesIn || 0) + bytesIn,
+                        megabytesOut: (site.megabytesOut || 0) + bytesOut,
+                        lastBandwidthUpdate: new Date().toISOString(),
+                        online
+                    })
+                    .where(eq(sites.siteId, site.siteId));
+            }
+        });
 
         return response(res, {
             data: {},
             success: true,
             error: false,
             message: "Organization retrieved successfully",
-            status: HttpCode.OK,
+            status: HttpCode.OK
         });
     } catch (error) {
-        logger.error('Error updating bandwidth data:', error);
-        return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred..."));
+        logger.error("Error updating bandwidth data:", error);
+        return next(
+            createHttpError(
+                HttpCode.INTERNAL_SERVER_ERROR,
+                "An error occurred..."
+            )
+        );
     }
 };
 

+ 24 - 19
server/routers/newt/handleRegisterMessage.ts

@@ -39,15 +39,14 @@ export const handleRegisterMessage: MessageHandler = async (context) => {
         return;
     }
 
-    const [updatedSite] = await db
-    .update(sites)
+    await db
+        .update(sites)
         .set({
             pubKey: publicKey
         })
         .where(eq(sites.siteId, siteId))
         .returning();
 
-
     const [exitNode] = await db
         .select()
         .from(exitNodes)
@@ -67,35 +66,41 @@ export const handleRegisterMessage: MessageHandler = async (context) => {
     // add the peer to the exit node
     await addPeer(site.exitNodeId, {
         publicKey: publicKey,
-        allowedIps: [site.subnet],
+        allowedIps: [site.subnet]
     });
 
-    const siteResources = await db.select().from(resources).where(eq(resources.siteId, siteId));
+    const siteResources = await db
+        .select()
+        .from(resources)
+        .where(eq(resources.siteId, siteId));
 
     // get the targets from the resourceIds
     const siteTargets = await db
-    .select()
-    .from(targets)
-    .where(
-        inArray(
-            targets.resourceId,
-            siteResources.map(resource => resource.resourceId)
-        )
-    );
+        .select()
+        .from(targets)
+        .where(
+            inArray(
+                targets.resourceId,
+                siteResources.map((resource) => resource.resourceId)
+            )
+        );
 
     const udpTargets = siteTargets
         .filter((target) => target.protocol === "udp")
         .map((target) => {
-            return `${target.internalPort ? target.internalPort + ":" : ""}${target.ip}:${target.port}`;
+            return `${target.internalPort ? target.internalPort + ":" : ""}${
+                target.ip
+            }:${target.port}`;
         });
 
     const tcpTargets = siteTargets
         .filter((target) => target.protocol === "tcp")
         .map((target) => {
-            return `${target.internalPort ? target.internalPort + ":" : ""}${target.ip}:${target.port}`;
+            return `${target.internalPort ? target.internalPort + ":" : ""}${
+                target.ip
+            }:${target.port}`;
         });
 
-
     return {
         message: {
             type: "newt/wg/connect",
@@ -106,11 +111,11 @@ export const handleRegisterMessage: MessageHandler = async (context) => {
                 tunnelIP: site.subnet.split("/")[0],
                 targets: {
                     udp: udpTargets,
-                    tcp: tcpTargets,
+                    tcp: tcpTargets
                 }
-            },
+            }
         },
         broadcast: false, // Send to all clients
-        excludeSender: false, // Include sender in broadcast
+        excludeSender: false // Include sender in broadcast
     };
 };

+ 36 - 30
server/routers/org/deleteOrg.ts

@@ -24,9 +24,7 @@ const deleteOrgSchema = z
     })
     .strict();
 
-export type DeleteOrgResponse = {
-
-}
+export type DeleteOrgResponse = {};
 
 export async function deleteOrg(
     req: Request,
@@ -79,39 +77,47 @@ export async function deleteOrg(
             .where(eq(sites.orgId, orgId))
             .limit(1);
 
-        if (sites) {
-            for (const site of orgSites) {
-                if (site.pubKey) {
-                    if (site.type == "wireguard") {
-                        await deletePeer(site.exitNodeId!, site.pubKey);
-                    } else if (site.type == "newt") {
-                        // get the newt on the site by querying the newt table for siteId
-                        const [deletedNewt] = await db
-                            .delete(newts)
-                            .where(eq(newts.siteId, site.siteId))
-                            .returning();
-                        if (deletedNewt) {
-                            const payload = {
-                                type: `newt/terminate`,
-                                data: {}
-                            };
-                            sendToClient(deletedNewt.newtId, payload);
+        await db.transaction(async (trx) => {
+            if (sites) {
+                for (const site of orgSites) {
+                    if (site.pubKey) {
+                        if (site.type == "wireguard") {
+                            await deletePeer(site.exitNodeId!, site.pubKey);
+                        } else if (site.type == "newt") {
+                            // get the newt on the site by querying the newt table for siteId
+                            const [deletedNewt] = await trx
+                                .delete(newts)
+                                .where(eq(newts.siteId, site.siteId))
+                                .returning();
+                            if (deletedNewt) {
+                                const payload = {
+                                    type: `newt/terminate`,
+                                    data: {}
+                                };
+                                sendToClient(deletedNewt.newtId, payload);
 
-                            // delete all of the sessions for the newt
-                            await db.delete(newtSessions)
-                                .where(
-                                    eq(newtSessions.newtId, deletedNewt.newtId)
-                                );
+                                // delete all of the sessions for the newt
+                                await trx
+                                    .delete(newtSessions)
+                                    .where(
+                                        eq(
+                                            newtSessions.newtId,
+                                            deletedNewt.newtId
+                                        )
+                                    );
+                            }
                         }
                     }
-                }
 
-                logger.info(`Deleting site ${site.siteId}`);
-                await db.delete(sites).where(eq(sites.siteId, site.siteId))
+                    logger.info(`Deleting site ${site.siteId}`);
+                    await trx
+                        .delete(sites)
+                        .where(eq(sites.siteId, site.siteId));
+                }
             }
-        }
 
-        await db.delete(orgs).where(eq(orgs.orgId, orgId));
+            await trx.delete(orgs).where(eq(orgs.orgId, orgId));
+        });
 
         return response(res, {
             data: null,

+ 41 - 41
server/routers/resource/createResource.ts

@@ -89,50 +89,50 @@ export async function createResource(
         }
 
         const fullDomain = `${subdomain}.${org[0].domain}`;
-
-        const newResource = await db
-            .insert(resources)
-            .values({
-                siteId,
-                fullDomain,
-                orgId,
-                name,
-                subdomain,
-                ssl: true
-            })
-            .returning();
-
-        const adminRole = await db
-            .select()
-            .from(roles)
-            .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
-            .limit(1);
-
-        if (adminRole.length === 0) {
-            return next(
-                createHttpError(HttpCode.NOT_FOUND, `Admin role not found`)
-            );
-        }
-
-        await db.insert(roleResources).values({
-            roleId: adminRole[0].roleId,
-            resourceId: newResource[0].resourceId
-        });
-
-        if (req.userOrgRoleId != adminRole[0].roleId) {
-            // make sure the user can access the resource
-            await db.insert(userResources).values({
-                userId: req.user?.userId!,
+        await db.transaction(async (trx) => {
+            const newResource = await trx
+                .insert(resources)
+                .values({
+                    siteId,
+                    fullDomain,
+                    orgId,
+                    name,
+                    subdomain,
+                    ssl: true
+                })
+                .returning();
+
+            const adminRole = await db
+                .select()
+                .from(roles)
+                .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
+                .limit(1);
+
+            if (adminRole.length === 0) {
+                return next(
+                    createHttpError(HttpCode.NOT_FOUND, `Admin role not found`)
+                );
+            }
+
+            await trx.insert(roleResources).values({
+                roleId: adminRole[0].roleId,
                 resourceId: newResource[0].resourceId
             });
-        }
 
-        response<CreateResourceResponse>(res, {
-            data: newResource[0],
-            success: true,
-            error: false,
-            message: "Resource created successfully",
-            status: HttpCode.CREATED
+            if (req.userOrgRoleId != adminRole[0].roleId) {
+                // make sure the user can access the resource
+                await trx.insert(userResources).values({
+                    userId: req.user?.userId!,
+                    resourceId: newResource[0].resourceId
+                });
+            }
+            response<CreateResourceResponse>(res, {
+                data: newResource[0],
+                success: true,
+                error: false,
+                message: "Resource created successfully",
+                status: HttpCode.CREATED
+            });
         });
     } catch (error) {
         if (

+ 25 - 23
server/routers/role/addRoleSite.ts

@@ -51,32 +51,34 @@ export async function addRoleSite(
 
         const { roleId } = parsedParams.data;
 
-        const newRoleSite = await db
-            .insert(roleSites)
-            .values({
-                roleId,
-                siteId
-            })
-            .returning();
+        await db.transaction(async (trx) => {
+            const newRoleSite = await trx
+                .insert(roleSites)
+                .values({
+                    roleId,
+                    siteId
+                })
+                .returning();
 
-        const siteResources = await db
-            .select()
-            .from(resources)
-            .where(eq(resources.siteId, siteId));
+            const siteResources = await db
+                .select()
+                .from(resources)
+                .where(eq(resources.siteId, siteId));
 
-        for (const resource of siteResources) {
-            await db.insert(roleResources).values({
-                roleId,
-                resourceId: resource.resourceId
-            });
-        }
+            for (const resource of siteResources) {
+                await trx.insert(roleResources).values({
+                    roleId,
+                    resourceId: resource.resourceId
+                });
+            }
 
-        return response(res, {
-            data: newRoleSite[0],
-            success: true,
-            error: false,
-            message: "Site added to role successfully",
-            status: HttpCode.CREATED
+            return response(res, {
+                data: newRoleSite[0],
+                success: true,
+                error: false,
+                message: "Site added to role successfully",
+                status: HttpCode.CREATED
+            });
         });
     } catch (error) {
         logger.error(error);

+ 25 - 23
server/routers/role/createRole.ts

@@ -82,31 +82,33 @@ export async function createRole(
             );
         }
 
-        const newRole = await db
-            .insert(roles)
-            .values({
-                ...roleData,
-                orgId
-            })
-            .returning();
-
-        await db
-            .insert(roleActions)
-            .values(
-                defaultRoleAllowedActions.map((action) => ({
-                    roleId: newRole[0].roleId,
-                    actionId: action,
+        await db.transaction(async (trx) => {
+            const newRole = await trx
+                .insert(roles)
+                .values({
+                    ...roleData,
                     orgId
-                }))
-            )
-            .execute();
+                })
+                .returning();
+
+            await trx
+                .insert(roleActions)
+                .values(
+                    defaultRoleAllowedActions.map((action) => ({
+                        roleId: newRole[0].roleId,
+                        actionId: action,
+                        orgId
+                    }))
+                )
+                .execute();
 
-        return response<Role>(res, {
-            data: newRole[0],
-            success: true,
-            error: false,
-            message: "Role created successfully",
-            status: HttpCode.CREATED
+            return response<Role>(res, {
+                data: newRole[0],
+                success: true,
+                error: false,
+                message: "Role created successfully",
+                status: HttpCode.CREATED
+            });
         });
     } catch (error) {
         logger.error(error);

+ 10 - 8
server/routers/role/deleteRole.ts

@@ -98,15 +98,17 @@ export async function deleteRole(
             );
         }
 
-        // move all users from the userOrgs table with roleId to newRoleId
-        await db
-            .update(userOrgs)
-            .set({ roleId: newRoleId })
-            .where(eq(userOrgs.roleId, roleId));
-
-        // delete the old role
-        await db.delete(roles).where(eq(roles.roleId, roleId));
+        await db.transaction(async (trx) => {
+            // move all users from the userOrgs table with roleId to newRoleId
+            await trx
+                .update(userOrgs)
+                .set({ roleId: newRoleId })
+                .where(eq(userOrgs.roleId, roleId));
 
+            // delete the old role
+            await trx.delete(roles).where(eq(roles.roleId, roleId));
+        });
+        
         return response(res, {
             data: null,
             success: true,

+ 32 - 27
server/routers/role/removeRoleSite.ts

@@ -51,38 +51,43 @@ export async function removeRoleSite(
 
         const { roleId } = parsedBody.data;
 
-        const deletedRoleSite = await db
-            .delete(roleSites)
-            .where(
-                and(eq(roleSites.roleId, roleId), eq(roleSites.siteId, siteId))
-            )
-            .returning();
-
-        if (deletedRoleSite.length === 0) {
-            return next(
-                createHttpError(
-                    HttpCode.NOT_FOUND,
-                    `Site with ID ${siteId} not found for role with ID ${roleId}`
-                )
-            );
-        }
-
-        const siteResources = await db
-            .select()
-            .from(resources)
-            .where(eq(resources.siteId, siteId));
-
-        for (const resource of siteResources) {
-            await db
-                .delete(roleResources)
+        await db.transaction(async (trx) => {
+            const deletedRoleSite = await trx
+                .delete(roleSites)
                 .where(
                     and(
-                        eq(roleResources.roleId, roleId),
-                        eq(roleResources.resourceId, resource.resourceId)
+                        eq(roleSites.roleId, roleId),
+                        eq(roleSites.siteId, siteId)
                     )
                 )
                 .returning();
-        }
+
+            if (deletedRoleSite.length === 0) {
+                return next(
+                    createHttpError(
+                        HttpCode.NOT_FOUND,
+                        `Site with ID ${siteId} not found for role with ID ${roleId}`
+                    )
+                );
+            }
+
+            const siteResources = await db
+                .select()
+                .from(resources)
+                .where(eq(resources.siteId, siteId));
+
+            for (const resource of siteResources) {
+                await trx
+                    .delete(roleResources)
+                    .where(
+                        and(
+                            eq(roleResources.roleId, roleId),
+                            eq(roleResources.resourceId, resource.resourceId)
+                        )
+                    )
+                    .returning();
+            }
+        });
 
         return response(res, {
             data: null,

+ 56 - 51
server/routers/site/createSite.ts

@@ -94,64 +94,69 @@ export async function createSite(
             };
         }
 
-        const [newSite] = await db.insert(sites).values(payload).returning();
-
-        const adminRole = await db
-            .select()
-            .from(roles)
-            .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
-            .limit(1);
-
-        if (adminRole.length === 0) {
-            return next(
-                createHttpError(HttpCode.NOT_FOUND, `Admin role not found`)
-            );
-        }
-
-        await db.insert(roleSites).values({
-            roleId: adminRole[0].roleId,
-            siteId: newSite.siteId
-        });
+        await db.transaction(async (trx) => {
+            const [newSite] = await trx
+                .insert(sites)
+                .values(payload)
+                .returning();
+
+            const adminRole = await trx
+                .select()
+                .from(roles)
+                .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
+                .limit(1);
+
+            if (adminRole.length === 0) {
+                return next(
+                    createHttpError(HttpCode.NOT_FOUND, `Admin role not found`)
+                );
+            }
 
-        if (req.userOrgRoleId != adminRole[0].roleId) {
-            // make sure the user can access the site
-            db.insert(userSites).values({
-                userId: req.user?.userId!,
+            await trx.insert(roleSites).values({
+                roleId: adminRole[0].roleId,
                 siteId: newSite.siteId
             });
-        }
 
-        // add the peer to the exit node
-        if (type == "newt") {
-            const secretHash = await hashPassword(secret!);
+            if (req.userOrgRoleId != adminRole[0].roleId) {
+                // make sure the user can access the site
+                trx.insert(userSites).values({
+                    userId: req.user?.userId!,
+                    siteId: newSite.siteId
+                });
+            }
 
-            await db.insert(newts).values({
-                newtId: newtId!,
-                secretHash,
-                siteId: newSite.siteId,
-                dateCreated: moment().toISOString()
-            });
-        } else if (type == "wireguard") {
-            if (!pubKey) {
-                return next(
-                    createHttpError(
-                        HttpCode.BAD_REQUEST,
-                        "Public key is required for wireguard sites"
-                    )
-                );
+            // add the peer to the exit node
+            if (type == "newt") {
+                const secretHash = await hashPassword(secret!);
+
+                await trx.insert(newts).values({
+                    newtId: newtId!,
+                    secretHash,
+                    siteId: newSite.siteId,
+                    dateCreated: moment().toISOString()
+                });
+            } else if (type == "wireguard") {
+                if (!pubKey) {
+                    return next(
+                        createHttpError(
+                            HttpCode.BAD_REQUEST,
+                            "Public key is required for wireguard sites"
+                        )
+                    );
+                }
+                await addPeer(exitNodeId, {
+                    publicKey: pubKey,
+                    allowedIps: []
+                });
             }
-            await addPeer(exitNodeId, {
-                publicKey: pubKey,
-                allowedIps: []
-            });
-        }
 
-        return response<CreateSiteResponse>(res, {
-            data: newSite,
-            success: true,
-            error: false,
-            message: "Site created successfully",
-            status: HttpCode.CREATED
+            return response<CreateSiteResponse>(res, {
+                data: newSite,
+                success: true,
+                error: false,
+                message: "Site created successfully",
+                status: HttpCode.CREATED
+            });
         });
     } catch (error) {
         logger.error(error);

+ 24 - 22
server/routers/site/deleteSite.ts

@@ -50,32 +50,34 @@ export async function deleteSite(
             );
         }
 
-        if (site.pubKey) {
-            if (site.type == "wireguard") {
-                await deletePeer(site.exitNodeId!, site.pubKey);
-            } else if (site.type == "newt") {
-                // get the newt on the site by querying the newt table for siteId
-                const [deletedNewt] = await db
-                    .delete(newts)
-                    .where(eq(newts.siteId, siteId))
-                    .returning();
-                if (deletedNewt) {
-                    const payload = {
-                        type: `newt/terminate`,
-                        data: {}
-                    };
-                    sendToClient(deletedNewt.newtId, payload);
+        await db.transaction(async (trx) => {
+            if (site.pubKey) {
+                if (site.type == "wireguard") {
+                    await deletePeer(site.exitNodeId!, site.pubKey);
+                } else if (site.type == "newt") {
+                    // get the newt on the site by querying the newt table for siteId
+                    const [deletedNewt] = await trx
+                        .delete(newts)
+                        .where(eq(newts.siteId, siteId))
+                        .returning();
+                    if (deletedNewt) {
+                        const payload = {
+                            type: `newt/terminate`,
+                            data: {}
+                        };
+                        sendToClient(deletedNewt.newtId, payload);
 
-                    // delete all of the sessions for the newt
-                    db.delete(newtSessions)
-                        .where(eq(newtSessions.newtId, deletedNewt.newtId))
-                        .run();
+                        // delete all of the sessions for the newt
+                        await trx
+                            .delete(newtSessions)
+                            .where(eq(newtSessions.newtId, deletedNewt.newtId));
+                    }
                 }
             }
-        }
-
-        db.delete(sites).where(eq(sites.siteId, siteId)).run();
 
+            await trx.delete(sites).where(eq(sites.siteId, siteId));
+        });
+        
         return response(res, {
             data: null,
             success: true,

+ 12 - 8
server/routers/user/acceptInvite.ts

@@ -100,15 +100,19 @@ export async function acceptInvite(
             );
         }
 
-        // add the user to the org
-        await db.insert(userOrgs).values({
-            userId: existingUser[0].userId,
-            orgId: existingInvite.orgId,
-            roleId: existingInvite.roleId
-        });
+        await db.transaction(async (trx) => {
+            // add the user to the org
+            await trx.insert(userOrgs).values({
+                userId: existingUser[0].userId,
+                orgId: existingInvite.orgId,
+                roleId: existingInvite.roleId
+            });
 
-        // delete the invite
-        await db.delete(userInvites).where(eq(userInvites.inviteId, inviteId));
+            // delete the invite
+            await trx
+                .delete(userInvites)
+                .where(eq(userInvites.inviteId, inviteId));
+        });
 
         return response<AcceptInviteResponse>(res, {
             data: { accepted: true, orgId: existingInvite.orgId },

+ 26 - 23
server/routers/user/addUserSite.ts

@@ -34,33 +34,36 @@ export async function addUserSite(
 
         const { userId, siteId } = parsedBody.data;
 
-        const newUserSite = await db
-            .insert(userSites)
-            .values({
-                userId,
-                siteId
-            })
-            .returning();
+        await db.transaction(async (trx) => {
+            const newUserSite = await trx
+                .insert(userSites)
+                .values({
+                    userId,
+                    siteId
+                })
+                .returning();
 
-        const siteResources = await db
-            .select()
-            .from(resources)
-            .where(eq(resources.siteId, siteId));
+            const siteResources = await trx
+                .select()
+                .from(resources)
+                .where(eq(resources.siteId, siteId));
 
-        for (const resource of siteResources) {
-            await db.insert(userResources).values({
-                userId,
-                resourceId: resource.resourceId
-            });
-        }
+            for (const resource of siteResources) {
+                await trx.insert(userResources).values({
+                    userId,
+                    resourceId: resource.resourceId
+                });
+            }
 
-        return response(res, {
-            data: newUserSite[0],
-            success: true,
-            error: false,
-            message: "Site added to user successfully",
-            status: HttpCode.CREATED
+            return response(res, {
+                data: newUserSite[0],
+                success: true,
+                error: false,
+                message: "Site added to user successfully",
+                status: HttpCode.CREATED
+            });
         });
+        
     } catch (error) {
         logger.error(error);
         return next(

+ 20 - 15
server/routers/user/inviteUser.ts

@@ -130,21 +130,26 @@ export async function inviteUser(
 
         const tokenHash = await hashPassword(token);
 
-        // delete any existing invites for this email
-        await db
-            .delete(userInvites)
-            .where(
-                and(eq(userInvites.email, email), eq(userInvites.orgId, orgId))
-            )
-            .execute();
-
-        await db.insert(userInvites).values({
-            inviteId,
-            orgId,
-            email,
-            expiresAt,
-            tokenHash,
-            roleId
+        await db.transaction(async (trx) => {
+            // delete any existing invites for this email
+            await trx
+                .delete(userInvites)
+                .where(
+                    and(
+                        eq(userInvites.email, email),
+                        eq(userInvites.orgId, orgId)
+                    )
+                )
+                .execute();
+
+            await trx.insert(userInvites).values({
+                inviteId,
+                orgId,
+                email,
+                expiresAt,
+                tokenHash,
+                roleId
+            });
         });
 
         const inviteLink = `${config.app.base_url}/invite?token=${inviteId}-${token}`;

+ 32 - 27
server/routers/user/removeUserSite.ts

@@ -51,38 +51,43 @@ export async function removeUserSite(
 
         const { siteId } = parsedBody.data;
 
-        const deletedUserSite = await db
-            .delete(userSites)
-            .where(
-                and(eq(userSites.userId, userId), eq(userSites.siteId, siteId))
-            )
-            .returning();
-
-        if (deletedUserSite.length === 0) {
-            return next(
-                createHttpError(
-                    HttpCode.NOT_FOUND,
-                    `Site with ID ${siteId} not found for user with ID ${userId}`
-                )
-            );
-        }
-
-        const siteResources = await db
-            .select()
-            .from(resources)
-            .where(eq(resources.siteId, siteId));
-
-        for (const resource of siteResources) {
-            await db
-                .delete(userResources)
+        await db.transaction(async (trx) => {
+            const deletedUserSite = await trx
+                .delete(userSites)
                 .where(
                     and(
-                        eq(userResources.userId, userId),
-                        eq(userResources.resourceId, resource.resourceId)
+                        eq(userSites.userId, userId),
+                        eq(userSites.siteId, siteId)
                     )
                 )
                 .returning();
-        }
+
+            if (deletedUserSite.length === 0) {
+                return next(
+                    createHttpError(
+                        HttpCode.NOT_FOUND,
+                        `Site with ID ${siteId} not found for user with ID ${userId}`
+                    )
+                );
+            }
+
+            const siteResources = await trx
+                .select()
+                .from(resources)
+                .where(eq(resources.siteId, siteId));
+
+            for (const resource of siteResources) {
+                await trx
+                    .delete(userResources)
+                    .where(
+                        and(
+                            eq(userResources.userId, userId),
+                            eq(userResources.resourceId, resource.resourceId)
+                        )
+                    )
+                    .returning();
+            }
+        });
 
         return response(res, {
             data: null,

+ 23 - 8
server/setup/ensureActions.ts

@@ -22,13 +22,15 @@ export async function ensureActions() {
         .where(eq(roles.isAdmin, true))
         .execute();
 
+        await db.transaction(async (trx) => {
+
     // Add new actions
     for (const actionId of actionsToAdd) {
         logger.debug(`Adding action: ${actionId}`);
-        await db.insert(actions).values({ actionId }).execute();
+        await trx.insert(actions).values({ actionId }).execute();
         // Add new actions to the Default role
         if (defaultRoles.length != 0) {
-            await db
+            await trx
                 .insert(roleActions)
                 .values(
                     defaultRoles.map((role) => ({
@@ -44,19 +46,23 @@ export async function ensureActions() {
     // Remove deprecated actions
     if (actionsToRemove.length > 0) {
         logger.debug(`Removing actions: ${actionsToRemove.join(", ")}`);
-        await db
+        await trx
             .delete(actions)
             .where(inArray(actions.actionId, actionsToRemove))
             .execute();
-        await db
+        await trx
             .delete(roleActions)
             .where(inArray(roleActions.actionId, actionsToRemove))
             .execute();
     }
+});
 }
 
 export async function createAdminRole(orgId: string) {
-    const [insertedRole] = await db
+    let roleId: any;
+    await db.transaction(async (trx) => {
+
+    const [insertedRole] = await trx
         .insert(roles)
         .values({
             orgId,
@@ -67,16 +73,20 @@ export async function createAdminRole(orgId: string) {
         .returning({ roleId: roles.roleId })
         .execute();
 
-    const roleId = insertedRole.roleId;
+    if (!insertedRole || !insertedRole.roleId) {
+        throw new Error("Failed to create Admin role");
+    }
+
+    roleId = insertedRole.roleId;
 
-    const actionIds = await db.select().from(actions).execute();
+    const actionIds = await trx.select().from(actions).execute();
 
     if (actionIds.length === 0) {
         logger.info("No actions to assign to the Admin role");
         return;
     }
 
-    await db
+    await trx
         .insert(roleActions)
         .values(
             actionIds.map((action) => ({
@@ -86,6 +96,11 @@ export async function createAdminRole(orgId: string) {
             }))
         )
         .execute();
+    });
+
+    if (!roleId) {
+        throw new Error("Failed to create Admin role");
+    }
 
     return roleId;
 }