Selaa lähdekoodia

sql providers: remove unnecessary []byte to string conversion

always check affected rows for updates

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
Nicola Murino 2 vuotta sitten
vanhempi
commit
2c1319985d

+ 1 - 1
go.mod

@@ -121,7 +121,7 @@ require (
 	github.com/jackc/pgpassfile v1.0.0 // indirect
 	github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
 	github.com/jmespath/go-jmespath v0.4.0 // indirect
-	github.com/klauspost/cpuid/v2 v2.2.3 // indirect
+	github.com/klauspost/cpuid/v2 v2.2.4 // indirect
 	github.com/kr/fs v0.1.0 // indirect
 	github.com/lestrrat-go/blackmagic v1.0.1 // indirect
 	github.com/lestrrat-go/httpcc v1.0.1 // indirect

+ 2 - 1
go.sum

@@ -1401,8 +1401,9 @@ github.com/klauspost/compress v1.15.1/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47e
 github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7yIO+lw=
 github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4=
 github.com/klauspost/cpuid/v2 v2.0.4/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
-github.com/klauspost/cpuid/v2 v2.2.3 h1:sxCkb+qR91z4vsqw4vGGZlDgPz3G7gjaLyK3V8y70BU=
 github.com/klauspost/cpuid/v2 v2.2.3/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
+github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
+github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
 github.com/kolo/xmlrpc v0.0.0-20201022064351-38db28db192b/go.mod h1:pcaDhQK0/NJZEvtCO0qQPPropqV0sJOJ6YW7X+9kRwM=
 github.com/kolo/xmlrpc v0.0.0-20220921171641-a4b6fa1dd06b/go.mod h1:pcaDhQK0/NJZEvtCO0qQPPropqV0sJOJ6YW7X+9kRwM=
 github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=

+ 3 - 3
internal/dataprovider/mysql.go

@@ -230,7 +230,7 @@ func initializeMySQLProvider() error {
 	}
 	dbHandle, err := sql.Open("mysql", connString)
 	if err == nil {
-		providerLog(logger.LevelDebug, "mysql database handle created, connection string: %#v, pool size: %v",
+		providerLog(logger.LevelDebug, "mysql database handle created, connection string: %q, pool size: %v",
 			redactedConnString, config.PoolSize)
 		dbHandle.SetMaxOpenConns(config.PoolSize)
 		if config.PoolSize > 0 {
@@ -242,7 +242,7 @@ func initializeMySQLProvider() error {
 		dbHandle.SetConnMaxIdleTime(120 * time.Second)
 		provider = &MySQLProvider{dbHandle: dbHandle}
 	} else {
-		providerLog(logger.LevelError, "error creating mysql database handler, connection string: %#v, error: %v",
+		providerLog(logger.LevelError, "error creating mysql database handler, connection string: %q, error: %v",
 			redactedConnString, err)
 	}
 	return err
@@ -260,7 +260,7 @@ func getMySQLConnectionString(redactedPwd bool) (string, error) {
 				return "", err
 			}
 		}
-		connectionString = fmt.Sprintf("%v:%v@tcp([%v]:%v)/%v?charset=utf8mb4&interpolateParams=true&timeout=10s&parseTime=true&tls=%v&writeTimeout=60s&readTimeout=60s",
+		connectionString = fmt.Sprintf("%s:%s@tcp([%s]:%d)/%s?collation=utf8mb4_unicode_ci&interpolateParams=true&timeout=10s&parseTime=true&clientFoundRows=true&tls=%s&writeTimeout=60s&readTimeout=60s",
 			config.Username, password, config.Host, config.Port, config.Name, sslMode)
 	} else {
 		connectionString = config.ConnectionString

+ 5 - 5
internal/dataprovider/pgsql.go

@@ -236,7 +236,7 @@ func initializePGSQLProvider() error {
 	var err error
 	dbHandle, err := sql.Open("pgx", getPGSQLConnectionString(false))
 	if err == nil {
-		providerLog(logger.LevelDebug, "postgres database handle created, connection string: %#v, pool size: %v",
+		providerLog(logger.LevelDebug, "postgres database handle created, connection string: %q, pool size: %d",
 			getPGSQLConnectionString(true), config.PoolSize)
 		dbHandle.SetMaxOpenConns(config.PoolSize)
 		if config.PoolSize > 0 {
@@ -248,7 +248,7 @@ func initializePGSQLProvider() error {
 		dbHandle.SetConnMaxIdleTime(120 * time.Second)
 		provider = &PGSQLProvider{dbHandle: dbHandle}
 	} else {
-		providerLog(logger.LevelError, "error creating postgres database handler, connection string: %#v, error: %v",
+		providerLog(logger.LevelError, "error creating postgres database handler, connection string: %q, error: %v",
 			getPGSQLConnectionString(true), err)
 	}
 	return err
@@ -261,13 +261,13 @@ func getPGSQLConnectionString(redactedPwd bool) string {
 		if redactedPwd && password != "" {
 			password = "[redacted]"
 		}
-		connectionString = fmt.Sprintf("host='%v' port=%v dbname='%v' user='%v' password='%v' sslmode=%v connect_timeout=10",
+		connectionString = fmt.Sprintf("host='%s' port=%d dbname='%s' user='%s' password='%s' sslmode=%s connect_timeout=10",
 			config.Host, config.Port, config.Name, config.Username, password, getSSLMode())
 		if config.RootCert != "" {
-			connectionString += fmt.Sprintf(" sslrootcert='%v'", config.RootCert)
+			connectionString += fmt.Sprintf(" sslrootcert='%s'", config.RootCert)
 		}
 		if config.ClientCert != "" && config.ClientKey != "" {
-			connectionString += fmt.Sprintf(" sslcert='%v' sslkey='%v'", config.ClientCert, config.ClientKey)
+			connectionString += fmt.Sprintf(" sslcert='%s' sslkey='%s'", config.ClientCert, config.ClientKey)
 		}
 		if config.DisableSNI {
 			connectionString += " sslsni=0"

+ 81 - 61
internal/dataprovider/sqlcommon.go

@@ -118,11 +118,11 @@ func sqlCommonAddShare(share *Share, dbHandle *sql.DB) error {
 	if err != nil {
 		return err
 	}
-	allowFrom := ""
+	var allowFrom []byte
 	if len(share.AllowFrom) > 0 {
 		res, err := json.Marshal(share.AllowFrom)
 		if err == nil {
-			allowFrom = string(res)
+			allowFrom = res
 		}
 	}
 
@@ -145,7 +145,7 @@ func sqlCommonAddShare(share *Share, dbHandle *sql.DB) error {
 		lastUseAt = share.LastUseAt
 	}
 	_, err = dbHandle.ExecContext(ctx, q, share.ShareID, share.Name, share.Description, share.Scope,
-		string(paths), createdAt, updatedAt, lastUseAt, share.ExpiresAt, share.Password,
+		paths, createdAt, updatedAt, lastUseAt, share.ExpiresAt, share.Password,
 		share.MaxTokens, usedTokens, allowFrom, user.ID)
 	return err
 }
@@ -161,11 +161,11 @@ func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error {
 		return err
 	}
 
-	allowFrom := ""
+	var allowFrom []byte
 	if len(share.AllowFrom) > 0 {
 		res, err := json.Marshal(share.AllowFrom)
 		if err == nil {
-			allowFrom = string(res)
+			allowFrom = res
 		}
 	}
 
@@ -184,6 +184,7 @@ func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error {
 		q = getUpdateShareQuery()
 	}
 
+	var res sql.Result
 	if share.IsRestore {
 		if share.CreatedAt == 0 {
 			share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
@@ -191,15 +192,18 @@ func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error {
 		if share.UpdatedAt == 0 {
 			share.UpdatedAt = share.CreatedAt
 		}
-		_, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, string(paths),
+		res, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, paths,
 			share.CreatedAt, share.UpdatedAt, share.LastUseAt, share.ExpiresAt, share.Password, share.MaxTokens,
 			share.UsedTokens, allowFrom, user.ID, share.ShareID)
 	} else {
-		_, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, string(paths),
+		res, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, paths,
 			util.GetTimeAsMsSinceEpoch(time.Now()), share.ExpiresAt, share.Password, share.MaxTokens,
 			allowFrom, user.ID, share.ShareID)
 	}
-	return err
+	if err != nil {
+		return err
+	}
+	return sqlCommonRequireRowAffected(res)
 }
 
 func sqlCommonDeleteShare(share Share, dbHandle *sql.DB) error {
@@ -311,9 +315,12 @@ func sqlCommonUpdateAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
 	defer cancel()
 
 	q := getUpdateAPIKeyQuery()
-	_, err = dbHandle.ExecContext(ctx, q, apiKey.Name, apiKey.Scope, apiKey.ExpiresAt, userID, adminID,
+	res, err := dbHandle.ExecContext(ctx, q, apiKey.Name, apiKey.Scope, apiKey.ExpiresAt, userID, adminID,
 		apiKey.Description, util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.KeyID)
-	return err
+	if err != nil {
+		return err
+	}
+	return sqlCommonRequireRowAffected(res)
 }
 
 func sqlCommonDeleteAPIKey(apiKey APIKey, dbHandle *sql.DB) error {
@@ -436,8 +443,8 @@ func sqlCommonAddAdmin(admin *Admin, dbHandle *sql.DB) error {
 
 	return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
 		q := getAddAdminQuery(admin.Role)
-		_, err = tx.ExecContext(ctx, q, admin.Username, admin.Password, admin.Status, admin.Email, string(perms),
-			string(filters), admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
+		_, err = tx.ExecContext(ctx, q, admin.Username, admin.Password, admin.Status, admin.Email, perms,
+			filters, admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
 			util.GetTimeAsMsSinceEpoch(time.Now()), admin.Role)
 		if err != nil {
 			return err
@@ -467,7 +474,7 @@ func sqlCommonUpdateAdmin(admin *Admin, dbHandle *sql.DB) error {
 
 	return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
 		q := getUpdateAdminQuery(admin.Role)
-		_, err = tx.ExecContext(ctx, q, admin.Password, admin.Status, admin.Email, string(perms), string(filters),
+		_, err = tx.ExecContext(ctx, q, admin.Password, admin.Status, admin.Email, perms, filters,
 			admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()), admin.Role, admin.Username)
 		if err != nil {
 			return err
@@ -753,9 +760,12 @@ func sqlCommonUpdateIPListEntry(entry *IPListEntry, dbHandle *sql.DB) error {
 	defer cancel()
 
 	q := getUpdateIPListEntryQuery()
-	_, err := dbHandle.ExecContext(ctx, q, entry.Mode, entry.Protocols, entry.Description,
+	res, err := dbHandle.ExecContext(ctx, q, entry.Mode, entry.Protocols, entry.Description,
 		util.GetTimeAsMsSinceEpoch(time.Now()), entry.Type, entry.IPOrNet)
-	return err
+	if err != nil {
+		return err
+	}
+	return sqlCommonRequireRowAffected(res)
 }
 
 func sqlCommonDeleteIPListEntry(entry IPListEntry, softDelete bool, dbHandle *sql.DB) error {
@@ -876,8 +886,11 @@ func sqlCommonUpdateRole(role *Role, dbHandle *sql.DB) error {
 	defer cancel()
 
 	q := getUpdateRoleQuery()
-	_, err := dbHandle.ExecContext(ctx, q, role.Description, util.GetTimeAsMsSinceEpoch(time.Now()), role.Name)
-	return err
+	res, err := dbHandle.ExecContext(ctx, q, role.Description, util.GetTimeAsMsSinceEpoch(time.Now()), role.Name)
+	if err != nil {
+		return err
+	}
+	return sqlCommonRequireRowAffected(res)
 }
 
 func sqlCommonDeleteRole(role Role, dbHandle *sql.DB) error {
@@ -1069,7 +1082,7 @@ func sqlCommonAddGroup(group *Group, dbHandle *sql.DB) error {
 	return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
 		q := getAddGroupQuery()
 		_, err := tx.ExecContext(ctx, q, group.Name, group.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
-			util.GetTimeAsMsSinceEpoch(time.Now()), string(settings))
+			util.GetTimeAsMsSinceEpoch(time.Now()), settings)
 		if err != nil {
 			return err
 		}
@@ -1170,7 +1183,7 @@ func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, isSSHCert bo
 func sqlCommonCheckAvailability(dbHandle *sql.DB) (err error) {
 	defer func() {
 		if r := recover(); r != nil {
-			providerLog(logger.LevelError, "panic in check provider availability, stack trace: %v", string(debug.Stack()))
+			providerLog(logger.LevelError, "panic in check provider availability, stack trace: %s", string(debug.Stack()))
 			err = errors.New("unable to check provider status")
 		}
 	}()
@@ -1189,10 +1202,10 @@ func sqlCommonUpdateTransferQuota(username string, uploadSize, downloadSize int6
 	q := getUpdateTransferQuotaQuery(reset)
 	_, err := dbHandle.ExecContext(ctx, q, uploadSize, downloadSize, util.GetTimeAsMsSinceEpoch(time.Now()), username)
 	if err == nil {
-		providerLog(logger.LevelDebug, "transfer quota updated for user %#v, ul increment: %v dl increment: %v is reset? %v",
+		providerLog(logger.LevelDebug, "transfer quota updated for user %q, ul increment: %d dl increment: %d is reset? %t",
 			username, uploadSize, downloadSize, reset)
 	} else {
-		providerLog(logger.LevelError, "error updating quota for user %#v: %v", username, err)
+		providerLog(logger.LevelError, "error updating quota for user %q: %v", username, err)
 	}
 	return err
 }
@@ -1204,10 +1217,10 @@ func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bo
 	q := getUpdateQuotaQuery(reset)
 	_, err := dbHandle.ExecContext(ctx, q, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), username)
 	if err == nil {
-		providerLog(logger.LevelDebug, "quota updated for user %#v, files increment: %v size increment: %v is reset? %v",
+		providerLog(logger.LevelDebug, "quota updated for user %q, files increment: %d size increment: %d is reset? %t",
 			username, filesAdd, sizeAdd, reset)
 	} else {
-		providerLog(logger.LevelError, "error updating quota for user %#v: %v", username, err)
+		providerLog(logger.LevelError, "error updating quota for user %q: %v", username, err)
 	}
 	return err
 }
@@ -1234,9 +1247,9 @@ func sqlCommonUpdateShareLastUse(shareID string, numTokens int, dbHandle *sql.DB
 	q := getUpdateShareLastUseQuery()
 	_, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), numTokens, shareID)
 	if err == nil {
-		providerLog(logger.LevelDebug, "last use updated for shared object %#v", shareID)
+		providerLog(logger.LevelDebug, "last use updated for shared object %q", shareID)
 	} else {
-		providerLog(logger.LevelWarn, "error updating last use for shared object %#v: %v", shareID, err)
+		providerLog(logger.LevelWarn, "error updating last use for shared object %q: %v", shareID, err)
 	}
 	return err
 }
@@ -1248,9 +1261,9 @@ func sqlCommonUpdateAPIKeyLastUse(keyID string, dbHandle *sql.DB) error {
 	q := getUpdateAPIKeyLastUseQuery()
 	_, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), keyID)
 	if err == nil {
-		providerLog(logger.LevelDebug, "last use updated for key %#v", keyID)
+		providerLog(logger.LevelDebug, "last use updated for key %q", keyID)
 	} else {
-		providerLog(logger.LevelWarn, "error updating last use for key %#v: %v", keyID, err)
+		providerLog(logger.LevelWarn, "error updating last use for key %q: %v", keyID, err)
 	}
 	return err
 }
@@ -1262,9 +1275,9 @@ func sqlCommonUpdateAdminLastLogin(username string, dbHandle *sql.DB) error {
 	q := getUpdateAdminLastLoginQuery()
 	_, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
 	if err == nil {
-		providerLog(logger.LevelDebug, "last login updated for admin %#v", username)
+		providerLog(logger.LevelDebug, "last login updated for admin %q", username)
 	} else {
-		providerLog(logger.LevelWarn, "error updating last login for admin %#v: %v", username, err)
+		providerLog(logger.LevelWarn, "error updating last login for admin %q: %v", username, err)
 	}
 	return err
 }
@@ -1276,9 +1289,9 @@ func sqlCommonSetUpdatedAt(username string, dbHandle *sql.DB) {
 	q := getSetUpdateAtQuery()
 	_, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
 	if err == nil {
-		providerLog(logger.LevelDebug, "updated_at set for user %#v", username)
+		providerLog(logger.LevelDebug, "updated_at set for user %q", username)
 	} else {
-		providerLog(logger.LevelWarn, "error setting updated_at for user %#v: %v", username, err)
+		providerLog(logger.LevelWarn, "error setting updated_at for user %q: %v", username, err)
 	}
 }
 
@@ -1313,9 +1326,9 @@ func sqlCommonUpdateLastLogin(username string, dbHandle *sql.DB) error {
 	q := getUpdateLastLoginQuery()
 	_, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
 	if err == nil {
-		providerLog(logger.LevelDebug, "last login updated for user %#v", username)
+		providerLog(logger.LevelDebug, "last login updated for user %q", username)
 	} else {
-		providerLog(logger.LevelWarn, "error updating last login for user %#v: %v", username, err)
+		providerLog(logger.LevelWarn, "error updating last login for user %q: %v", username, err)
 	}
 	return err
 }
@@ -1353,9 +1366,9 @@ func sqlCommonAddUser(user *User, dbHandle *sql.DB) error {
 			}
 		}
 		q := getAddUserQuery(user.Role)
-		_, err := tx.ExecContext(ctx, q, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID,
-			user.MaxSessions, user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth,
-			user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo,
+		_, err := tx.ExecContext(ctx, q, user.Username, user.Password, publicKeys, user.HomeDir, user.UID, user.GID,
+			user.MaxSessions, user.QuotaSize, user.QuotaFiles, permissions, user.UploadBandwidth,
+			user.DownloadBandwidth, user.Status, user.ExpirationDate, filters, fsConfig, user.AdditionalInfo,
 			user.Description, user.Email, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()),
 			user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer, user.Role, user.LastPasswordChange)
 		if err != nil {
@@ -1373,8 +1386,11 @@ func sqlCommonUpdateUserPassword(username, password string, dbHandle *sql.DB) er
 	defer cancel()
 
 	q := getUpdateUserPasswordQuery()
-	_, err := dbHandle.ExecContext(ctx, q, password, username)
-	return err
+	res, err := dbHandle.ExecContext(ctx, q, password, username)
+	if err != nil {
+		return err
+	}
+	return sqlCommonRequireRowAffected(res)
 }
 
 func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
@@ -1404,14 +1420,17 @@ func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
 
 	return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
 		q := getUpdateUserQuery(user.Role)
-		_, err := tx.ExecContext(ctx, q, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions,
-			user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status,
-			user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo, user.Description, user.Email,
+		res, err := tx.ExecContext(ctx, q, user.Password, publicKeys, user.HomeDir, user.UID, user.GID, user.MaxSessions,
+			user.QuotaSize, user.QuotaFiles, permissions, user.UploadBandwidth, user.DownloadBandwidth, user.Status,
+			user.ExpirationDate, filters, fsConfig, user.AdditionalInfo, user.Description, user.Email,
 			util.GetTimeAsMsSinceEpoch(time.Now()), user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer,
-			user.Role, user.LastPasswordChange, user.ID)
+			user.Role, user.LastPasswordChange, user.Username)
 		if err != nil {
 			return err
 		}
+		if err := sqlCommonRequireRowAffected(res); err != nil {
+			return err
+		}
 		if err := generateUserVirtualFoldersMapping(ctx, user, tx); err != nil {
 			return err
 		}
@@ -1440,7 +1459,7 @@ func sqlCommonDeleteUser(user User, softDelete bool, dbHandle *sql.DB) error {
 			return sqlCommonRequireRowAffected(res)
 		})
 	}
-	res, err := dbHandle.ExecContext(ctx, q, user.ID)
+	res, err := dbHandle.ExecContext(ctx, q, user.Username)
 	if err != nil {
 		return err
 	}
@@ -2292,7 +2311,7 @@ func sqlCommonAddOrUpdateFolder(ctx context.Context, baseFolder *vfs.BaseVirtual
 	}
 	q := getUpsertFolderQuery()
 	_, err = dbHandle.ExecContext(ctx, q, baseFolder.MappedPath, usedQuotaSize, usedQuotaFiles,
-		lastQuotaUpdate, baseFolder.Name, baseFolder.Description, string(fsConfig))
+		lastQuotaUpdate, baseFolder.Name, baseFolder.Description, fsConfig)
 	return err
 }
 
@@ -2310,7 +2329,7 @@ func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) erro
 
 	q := getAddFolderQuery()
 	_, err = dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.UsedQuotaSize, folder.UsedQuotaFiles,
-		folder.LastQuotaUpdate, folder.Name, folder.Description, string(fsConfig))
+		folder.LastQuotaUpdate, folder.Name, folder.Description, fsConfig)
 	return err
 }
 
@@ -2327,8 +2346,11 @@ func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) e
 	defer cancel()
 
 	q := getUpdateFolderQuery()
-	_, err = dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.Description, string(fsConfig), folder.Name)
-	return err
+	res, err := dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.Description, fsConfig, folder.Name)
+	if err != nil {
+		return err
+	}
+	return sqlCommonRequireRowAffected(res)
 }
 
 func sqlCommonDeleteFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
@@ -2336,7 +2358,7 @@ func sqlCommonDeleteFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) er
 	defer cancel()
 
 	q := getDeleteFolderQuery()
-	res, err := dbHandle.ExecContext(ctx, q, folder.ID)
+	res, err := dbHandle.ExecContext(ctx, q, folder.Name)
 	if err != nil {
 		return err
 	}
@@ -2485,7 +2507,7 @@ func sqlCommonAddAdminGroupMapping(ctx context.Context, username, groupName stri
 		return err
 	}
 	q := getAddAdminGroupMappingQuery()
-	_, err = dbHandle.ExecContext(ctx, q, username, groupName, string(options))
+	_, err = dbHandle.ExecContext(ctx, q, username, groupName, options)
 	return err
 }
 
@@ -3082,7 +3104,7 @@ func sqlCommonUpdateFolderQuota(name string, filesAdd int, sizeAdd int64, reset
 	q := getUpdateFolderQuotaQuery(reset)
 	_, err := dbHandle.ExecContext(ctx, q, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), name)
 	if err == nil {
-		providerLog(logger.LevelDebug, "quota updated for folder %#v, files increment: %v size increment: %v is reset? %v",
+		providerLog(logger.LevelDebug, "quota updated for folder %q, files increment: %d size increment: %d is reset? %t",
 			name, filesAdd, sizeAdd, reset)
 	} else {
 		providerLog(logger.LevelWarn, "error updating quota for folder %#v: %v", name, err)
@@ -3337,7 +3359,7 @@ func generateEventRuleActionsMapping(ctx context.Context, rule *EventRule, dbHan
 			return err
 		}
 		q = getAddRuleActionMappingQuery()
-		_, err = dbHandle.ExecContext(ctx, q, rule.Name, action.Name, action.Order, string(options))
+		_, err = dbHandle.ExecContext(ctx, q, rule.Name, action.Name, action.Order, options)
 		if err != nil {
 			return err
 		}
@@ -3444,7 +3466,7 @@ func sqlCommonAddEventAction(action *BaseEventAction, dbHandle *sql.DB) error {
 	if err != nil {
 		return err
 	}
-	_, err = dbHandle.ExecContext(ctx, q, action.Name, action.Description, action.Type, string(options))
+	_, err = dbHandle.ExecContext(ctx, q, action.Name, action.Description, action.Type, options)
 	return err
 }
 
@@ -3461,10 +3483,13 @@ func sqlCommonUpdateEventAction(action *BaseEventAction, dbHandle *sql.DB) error
 
 	return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
 		q := getUpdateEventActionQuery()
-		_, err = tx.ExecContext(ctx, q, action.Description, action.Type, string(options), action.Name)
+		res, err := tx.ExecContext(ctx, q, action.Description, action.Type, options, action.Name)
 		if err != nil {
 			return err
 		}
+		if err := sqlCommonRequireRowAffected(res); err != nil {
+			return err
+		}
 		q = getUpdateRulesTimestampQuery()
 		_, err = tx.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), action.ID)
 		return err
@@ -3609,7 +3634,7 @@ func sqlCommonAddEventRule(rule *EventRule, dbHandle *sql.DB) error {
 		}
 		q := getAddEventRuleQuery()
 		_, err := tx.ExecContext(ctx, q, rule.Name, rule.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
-			util.GetTimeAsMsSinceEpoch(time.Now()), rule.Trigger, string(conditions), rule.Status)
+			util.GetTimeAsMsSinceEpoch(time.Now()), rule.Trigger, conditions, rule.Status)
 		if err != nil {
 			return err
 		}
@@ -3631,7 +3656,7 @@ func sqlCommonUpdateEventRule(rule *EventRule, dbHandle *sql.DB) error {
 	return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
 		q := getUpdateEventRuleQuery()
 		_, err := tx.ExecContext(ctx, q, rule.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
-			rule.Trigger, string(conditions), rule.Status, rule.Name)
+			rule.Trigger, conditions, rule.Status, rule.Name)
 		if err != nil {
 			return err
 		}
@@ -3744,7 +3769,7 @@ func sqlCommonAddNode(dbHandle *sql.DB) error {
 	defer cancel()
 
 	q := getAddNodeQuery()
-	_, err = dbHandle.ExecContext(ctx, q, currentNode.Name, string(data), util.GetTimeAsMsSinceEpoch(time.Now()),
+	_, err = dbHandle.ExecContext(ctx, q, currentNode.Name, data, util.GetTimeAsMsSinceEpoch(time.Now()),
 		util.GetTimeAsMsSinceEpoch(time.Now()))
 	if err != nil {
 		return fmt.Errorf("unable to register cluster node: %w", err)
@@ -3857,9 +3882,6 @@ func sqlCommonSetConfigs(configs *Configs, dbHandle *sql.DB) error {
 	if err != nil {
 		return err
 	}
-	if config.Driver == MySQLDataProviderName {
-		return nil
-	}
 	return sqlCommonRequireRowAffected(res)
 }
 
@@ -3884,8 +3906,6 @@ func sqlCommonGetDatabaseVersion(dbHandle sqlQuerier, showInitWarn bool) (schema
 }
 
 func sqlCommonRequireRowAffected(res sql.Result) error {
-	// MariaDB/MySQL returns 0 rows affected for updates that don't change anything
-	// so we don't check rows affected for updates
 	affected, err := res.RowsAffected()
 	if err == nil && affected == 0 {
 		return util.NewRecordNotFoundError(sql.ErrNoRows.Error())

+ 3 - 3
internal/dataprovider/sqlite.go

@@ -221,17 +221,17 @@ func initializeSQLiteProvider(basePath string) error {
 		if !filepath.IsAbs(dbPath) {
 			dbPath = filepath.Join(basePath, dbPath)
 		}
-		connectionString = fmt.Sprintf("file:%v?cache=shared&_foreign_keys=1", dbPath)
+		connectionString = fmt.Sprintf("file:%s?cache=shared&_foreign_keys=1", dbPath)
 	} else {
 		connectionString = config.ConnectionString
 	}
 	dbHandle, err := sql.Open("sqlite3", connectionString)
 	if err == nil {
-		providerLog(logger.LevelDebug, "sqlite database handle created, connection string: %#v", connectionString)
+		providerLog(logger.LevelDebug, "sqlite database handle created, connection string: %q", connectionString)
 		dbHandle.SetMaxOpenConns(1)
 		provider = &SQLiteProvider{dbHandle: dbHandle}
 	} else {
-		providerLog(logger.LevelError, "error creating sqlite database handler, connection string: %#v, error: %v",
+		providerLog(logger.LevelError, "error creating sqlite database handler, connection string: %q, error: %v",
 			connectionString, err)
 	}
 	return err

+ 3 - 3
internal/dataprovider/sqlqueries.go

@@ -706,7 +706,7 @@ func getUpdateUserQuery(role string) string {
 	return fmt.Sprintf(`UPDATE %s SET password=%s,public_keys=%s,home_dir=%s,uid=%s,gid=%s,max_sessions=%s,quota_size=%s,
 		quota_files=%s,permissions=%s,upload_bandwidth=%s,download_bandwidth=%s,status=%s,expiration_date=%s,filters=%s,filesystem=%s,
 		additional_info=%s,description=%s,email=%s,updated_at=%s,upload_data_transfer=%s,download_data_transfer=%s,
-		total_data_transfer=%s,role_id=COALESCE((SELECT id from %s WHERE name=%s),%s),last_password_change=%s WHERE id = %s`,
+		total_data_transfer=%s,role_id=COALESCE((SELECT id from %s WHERE name=%s),%s),last_password_change=%s WHERE username = %s`,
 		sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4],
 		sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9],
 		sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13], sqlPlaceholders[14],
@@ -725,7 +725,7 @@ func getDeleteUserQuery(softDelete bool) string {
 		return fmt.Sprintf(`UPDATE %s SET updated_at=%s,deleted_at=%s WHERE username = %s`,
 			sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2])
 	}
-	return fmt.Sprintf(`DELETE FROM %s WHERE id = %s`, sqlTableUsers, sqlPlaceholders[0])
+	return fmt.Sprintf(`DELETE FROM %s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0])
 }
 
 func getRemoveSoftDeletedUserQuery() string {
@@ -748,7 +748,7 @@ func getUpdateFolderQuery() string {
 }
 
 func getDeleteFolderQuery() string {
-	return fmt.Sprintf(`DELETE FROM %s WHERE id = %s`, sqlTableFolders, sqlPlaceholders[0])
+	return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, sqlTableFolders, sqlPlaceholders[0])
 }
 
 func getUpsertFolderQuery() string {