diff --git a/go.mod b/go.mod index d70f5900..9003aac8 100644 --- a/go.mod +++ b/go.mod @@ -121,7 +121,7 @@ require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect - github.com/klauspost/cpuid/v2 v2.2.3 // indirect + github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/kr/fs v0.1.0 // indirect github.com/lestrrat-go/blackmagic v1.0.1 // indirect github.com/lestrrat-go/httpcc v1.0.1 // indirect diff --git a/go.sum b/go.sum index 63dc7977..e22685b6 100644 --- a/go.sum +++ b/go.sum @@ -1401,8 +1401,9 @@ github.com/klauspost/compress v1.15.1/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47e github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7yIO+lw= github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4= github.com/klauspost/cpuid/v2 v2.0.4/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.3 h1:sxCkb+qR91z4vsqw4vGGZlDgPz3G7gjaLyK3V8y70BU= github.com/klauspost/cpuid/v2 v2.2.3/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= +github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= github.com/kolo/xmlrpc v0.0.0-20201022064351-38db28db192b/go.mod h1:pcaDhQK0/NJZEvtCO0qQPPropqV0sJOJ6YW7X+9kRwM= github.com/kolo/xmlrpc v0.0.0-20220921171641-a4b6fa1dd06b/go.mod h1:pcaDhQK0/NJZEvtCO0qQPPropqV0sJOJ6YW7X+9kRwM= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= diff --git a/internal/dataprovider/mysql.go b/internal/dataprovider/mysql.go index 2e3bcaa9..1d2c91c3 100644 --- a/internal/dataprovider/mysql.go +++ b/internal/dataprovider/mysql.go @@ -230,7 +230,7 @@ func initializeMySQLProvider() error { } dbHandle, err := sql.Open("mysql", connString) if err == nil { - providerLog(logger.LevelDebug, "mysql database handle created, connection string: %#v, pool size: %v", + providerLog(logger.LevelDebug, "mysql database handle created, connection string: %q, pool size: %v", redactedConnString, config.PoolSize) dbHandle.SetMaxOpenConns(config.PoolSize) if config.PoolSize > 0 { @@ -242,7 +242,7 @@ func initializeMySQLProvider() error { dbHandle.SetConnMaxIdleTime(120 * time.Second) provider = &MySQLProvider{dbHandle: dbHandle} } else { - providerLog(logger.LevelError, "error creating mysql database handler, connection string: %#v, error: %v", + providerLog(logger.LevelError, "error creating mysql database handler, connection string: %q, error: %v", redactedConnString, err) } return err @@ -260,7 +260,7 @@ func getMySQLConnectionString(redactedPwd bool) (string, error) { return "", err } } - connectionString = fmt.Sprintf("%v:%v@tcp([%v]:%v)/%v?charset=utf8mb4&interpolateParams=true&timeout=10s&parseTime=true&tls=%v&writeTimeout=60s&readTimeout=60s", + connectionString = fmt.Sprintf("%s:%s@tcp([%s]:%d)/%s?collation=utf8mb4_unicode_ci&interpolateParams=true&timeout=10s&parseTime=true&clientFoundRows=true&tls=%s&writeTimeout=60s&readTimeout=60s", config.Username, password, config.Host, config.Port, config.Name, sslMode) } else { connectionString = config.ConnectionString diff --git a/internal/dataprovider/pgsql.go b/internal/dataprovider/pgsql.go index c4122dce..a4bfdf4c 100644 --- a/internal/dataprovider/pgsql.go +++ b/internal/dataprovider/pgsql.go @@ -236,7 +236,7 @@ func initializePGSQLProvider() error { var err error dbHandle, err := sql.Open("pgx", getPGSQLConnectionString(false)) if err == nil { - providerLog(logger.LevelDebug, "postgres database handle created, connection string: %#v, pool size: %v", + providerLog(logger.LevelDebug, "postgres database handle created, connection string: %q, pool size: %d", getPGSQLConnectionString(true), config.PoolSize) dbHandle.SetMaxOpenConns(config.PoolSize) if config.PoolSize > 0 { @@ -248,7 +248,7 @@ func initializePGSQLProvider() error { dbHandle.SetConnMaxIdleTime(120 * time.Second) provider = &PGSQLProvider{dbHandle: dbHandle} } else { - providerLog(logger.LevelError, "error creating postgres database handler, connection string: %#v, error: %v", + providerLog(logger.LevelError, "error creating postgres database handler, connection string: %q, error: %v", getPGSQLConnectionString(true), err) } return err @@ -261,13 +261,13 @@ func getPGSQLConnectionString(redactedPwd bool) string { if redactedPwd && password != "" { password = "[redacted]" } - connectionString = fmt.Sprintf("host='%v' port=%v dbname='%v' user='%v' password='%v' sslmode=%v connect_timeout=10", + connectionString = fmt.Sprintf("host='%s' port=%d dbname='%s' user='%s' password='%s' sslmode=%s connect_timeout=10", config.Host, config.Port, config.Name, config.Username, password, getSSLMode()) if config.RootCert != "" { - connectionString += fmt.Sprintf(" sslrootcert='%v'", config.RootCert) + connectionString += fmt.Sprintf(" sslrootcert='%s'", config.RootCert) } if config.ClientCert != "" && config.ClientKey != "" { - connectionString += fmt.Sprintf(" sslcert='%v' sslkey='%v'", config.ClientCert, config.ClientKey) + connectionString += fmt.Sprintf(" sslcert='%s' sslkey='%s'", config.ClientCert, config.ClientKey) } if config.DisableSNI { connectionString += " sslsni=0" diff --git a/internal/dataprovider/sqlcommon.go b/internal/dataprovider/sqlcommon.go index 549abb80..91012c0a 100644 --- a/internal/dataprovider/sqlcommon.go +++ b/internal/dataprovider/sqlcommon.go @@ -118,11 +118,11 @@ func sqlCommonAddShare(share *Share, dbHandle *sql.DB) error { if err != nil { return err } - allowFrom := "" + var allowFrom []byte if len(share.AllowFrom) > 0 { res, err := json.Marshal(share.AllowFrom) if err == nil { - allowFrom = string(res) + allowFrom = res } } @@ -145,7 +145,7 @@ func sqlCommonAddShare(share *Share, dbHandle *sql.DB) error { lastUseAt = share.LastUseAt } _, err = dbHandle.ExecContext(ctx, q, share.ShareID, share.Name, share.Description, share.Scope, - string(paths), createdAt, updatedAt, lastUseAt, share.ExpiresAt, share.Password, + paths, createdAt, updatedAt, lastUseAt, share.ExpiresAt, share.Password, share.MaxTokens, usedTokens, allowFrom, user.ID) return err } @@ -161,11 +161,11 @@ func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error { return err } - allowFrom := "" + var allowFrom []byte if len(share.AllowFrom) > 0 { res, err := json.Marshal(share.AllowFrom) if err == nil { - allowFrom = string(res) + allowFrom = res } } @@ -184,6 +184,7 @@ func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error { q = getUpdateShareQuery() } + var res sql.Result if share.IsRestore { if share.CreatedAt == 0 { share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) @@ -191,15 +192,18 @@ func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error { if share.UpdatedAt == 0 { share.UpdatedAt = share.CreatedAt } - _, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, string(paths), + res, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, paths, share.CreatedAt, share.UpdatedAt, share.LastUseAt, share.ExpiresAt, share.Password, share.MaxTokens, share.UsedTokens, allowFrom, user.ID, share.ShareID) } else { - _, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, string(paths), + res, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, paths, util.GetTimeAsMsSinceEpoch(time.Now()), share.ExpiresAt, share.Password, share.MaxTokens, allowFrom, user.ID, share.ShareID) } - return err + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) } func sqlCommonDeleteShare(share Share, dbHandle *sql.DB) error { @@ -311,9 +315,12 @@ func sqlCommonUpdateAPIKey(apiKey *APIKey, dbHandle *sql.DB) error { defer cancel() q := getUpdateAPIKeyQuery() - _, err = dbHandle.ExecContext(ctx, q, apiKey.Name, apiKey.Scope, apiKey.ExpiresAt, userID, adminID, + res, err := dbHandle.ExecContext(ctx, q, apiKey.Name, apiKey.Scope, apiKey.ExpiresAt, userID, adminID, apiKey.Description, util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.KeyID) - return err + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) } func sqlCommonDeleteAPIKey(apiKey APIKey, dbHandle *sql.DB) error { @@ -436,8 +443,8 @@ func sqlCommonAddAdmin(admin *Admin, dbHandle *sql.DB) error { return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { q := getAddAdminQuery(admin.Role) - _, err = tx.ExecContext(ctx, q, admin.Username, admin.Password, admin.Status, admin.Email, string(perms), - string(filters), admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()), + _, err = tx.ExecContext(ctx, q, admin.Username, admin.Password, admin.Status, admin.Email, perms, + filters, admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()), admin.Role) if err != nil { return err @@ -467,7 +474,7 @@ func sqlCommonUpdateAdmin(admin *Admin, dbHandle *sql.DB) error { return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { q := getUpdateAdminQuery(admin.Role) - _, err = tx.ExecContext(ctx, q, admin.Password, admin.Status, admin.Email, string(perms), string(filters), + _, err = tx.ExecContext(ctx, q, admin.Password, admin.Status, admin.Email, perms, filters, admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()), admin.Role, admin.Username) if err != nil { return err @@ -753,9 +760,12 @@ func sqlCommonUpdateIPListEntry(entry *IPListEntry, dbHandle *sql.DB) error { defer cancel() q := getUpdateIPListEntryQuery() - _, err := dbHandle.ExecContext(ctx, q, entry.Mode, entry.Protocols, entry.Description, + res, err := dbHandle.ExecContext(ctx, q, entry.Mode, entry.Protocols, entry.Description, util.GetTimeAsMsSinceEpoch(time.Now()), entry.Type, entry.IPOrNet) - return err + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) } func sqlCommonDeleteIPListEntry(entry IPListEntry, softDelete bool, dbHandle *sql.DB) error { @@ -876,8 +886,11 @@ func sqlCommonUpdateRole(role *Role, dbHandle *sql.DB) error { defer cancel() q := getUpdateRoleQuery() - _, err := dbHandle.ExecContext(ctx, q, role.Description, util.GetTimeAsMsSinceEpoch(time.Now()), role.Name) - return err + res, err := dbHandle.ExecContext(ctx, q, role.Description, util.GetTimeAsMsSinceEpoch(time.Now()), role.Name) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) } func sqlCommonDeleteRole(role Role, dbHandle *sql.DB) error { @@ -1069,7 +1082,7 @@ func sqlCommonAddGroup(group *Group, dbHandle *sql.DB) error { return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { q := getAddGroupQuery() _, err := tx.ExecContext(ctx, q, group.Name, group.Description, util.GetTimeAsMsSinceEpoch(time.Now()), - util.GetTimeAsMsSinceEpoch(time.Now()), string(settings)) + util.GetTimeAsMsSinceEpoch(time.Now()), settings) if err != nil { return err } @@ -1170,7 +1183,7 @@ func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, isSSHCert bo func sqlCommonCheckAvailability(dbHandle *sql.DB) (err error) { defer func() { if r := recover(); r != nil { - providerLog(logger.LevelError, "panic in check provider availability, stack trace: %v", string(debug.Stack())) + providerLog(logger.LevelError, "panic in check provider availability, stack trace: %s", string(debug.Stack())) err = errors.New("unable to check provider status") } }() @@ -1189,10 +1202,10 @@ func sqlCommonUpdateTransferQuota(username string, uploadSize, downloadSize int6 q := getUpdateTransferQuotaQuery(reset) _, err := dbHandle.ExecContext(ctx, q, uploadSize, downloadSize, util.GetTimeAsMsSinceEpoch(time.Now()), username) if err == nil { - providerLog(logger.LevelDebug, "transfer quota updated for user %#v, ul increment: %v dl increment: %v is reset? %v", + providerLog(logger.LevelDebug, "transfer quota updated for user %q, ul increment: %d dl increment: %d is reset? %t", username, uploadSize, downloadSize, reset) } else { - providerLog(logger.LevelError, "error updating quota for user %#v: %v", username, err) + providerLog(logger.LevelError, "error updating quota for user %q: %v", username, err) } return err } @@ -1204,10 +1217,10 @@ func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bo q := getUpdateQuotaQuery(reset) _, err := dbHandle.ExecContext(ctx, q, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), username) if err == nil { - providerLog(logger.LevelDebug, "quota updated for user %#v, files increment: %v size increment: %v is reset? %v", + providerLog(logger.LevelDebug, "quota updated for user %q, files increment: %d size increment: %d is reset? %t", username, filesAdd, sizeAdd, reset) } else { - providerLog(logger.LevelError, "error updating quota for user %#v: %v", username, err) + providerLog(logger.LevelError, "error updating quota for user %q: %v", username, err) } return err } @@ -1234,9 +1247,9 @@ func sqlCommonUpdateShareLastUse(shareID string, numTokens int, dbHandle *sql.DB q := getUpdateShareLastUseQuery() _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), numTokens, shareID) if err == nil { - providerLog(logger.LevelDebug, "last use updated for shared object %#v", shareID) + providerLog(logger.LevelDebug, "last use updated for shared object %q", shareID) } else { - providerLog(logger.LevelWarn, "error updating last use for shared object %#v: %v", shareID, err) + providerLog(logger.LevelWarn, "error updating last use for shared object %q: %v", shareID, err) } return err } @@ -1248,9 +1261,9 @@ func sqlCommonUpdateAPIKeyLastUse(keyID string, dbHandle *sql.DB) error { q := getUpdateAPIKeyLastUseQuery() _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), keyID) if err == nil { - providerLog(logger.LevelDebug, "last use updated for key %#v", keyID) + providerLog(logger.LevelDebug, "last use updated for key %q", keyID) } else { - providerLog(logger.LevelWarn, "error updating last use for key %#v: %v", keyID, err) + providerLog(logger.LevelWarn, "error updating last use for key %q: %v", keyID, err) } return err } @@ -1262,9 +1275,9 @@ func sqlCommonUpdateAdminLastLogin(username string, dbHandle *sql.DB) error { q := getUpdateAdminLastLoginQuery() _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username) if err == nil { - providerLog(logger.LevelDebug, "last login updated for admin %#v", username) + providerLog(logger.LevelDebug, "last login updated for admin %q", username) } else { - providerLog(logger.LevelWarn, "error updating last login for admin %#v: %v", username, err) + providerLog(logger.LevelWarn, "error updating last login for admin %q: %v", username, err) } return err } @@ -1276,9 +1289,9 @@ func sqlCommonSetUpdatedAt(username string, dbHandle *sql.DB) { q := getSetUpdateAtQuery() _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username) if err == nil { - providerLog(logger.LevelDebug, "updated_at set for user %#v", username) + providerLog(logger.LevelDebug, "updated_at set for user %q", username) } else { - providerLog(logger.LevelWarn, "error setting updated_at for user %#v: %v", username, err) + providerLog(logger.LevelWarn, "error setting updated_at for user %q: %v", username, err) } } @@ -1313,9 +1326,9 @@ func sqlCommonUpdateLastLogin(username string, dbHandle *sql.DB) error { q := getUpdateLastLoginQuery() _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username) if err == nil { - providerLog(logger.LevelDebug, "last login updated for user %#v", username) + providerLog(logger.LevelDebug, "last login updated for user %q", username) } else { - providerLog(logger.LevelWarn, "error updating last login for user %#v: %v", username, err) + providerLog(logger.LevelWarn, "error updating last login for user %q: %v", username, err) } return err } @@ -1353,9 +1366,9 @@ func sqlCommonAddUser(user *User, dbHandle *sql.DB) error { } } q := getAddUserQuery(user.Role) - _, err := tx.ExecContext(ctx, q, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, - user.MaxSessions, user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth, - user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo, + _, err := tx.ExecContext(ctx, q, user.Username, user.Password, publicKeys, user.HomeDir, user.UID, user.GID, + user.MaxSessions, user.QuotaSize, user.QuotaFiles, permissions, user.UploadBandwidth, + user.DownloadBandwidth, user.Status, user.ExpirationDate, filters, fsConfig, user.AdditionalInfo, user.Description, user.Email, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()), user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer, user.Role, user.LastPasswordChange) if err != nil { @@ -1373,8 +1386,11 @@ func sqlCommonUpdateUserPassword(username, password string, dbHandle *sql.DB) er defer cancel() q := getUpdateUserPasswordQuery() - _, err := dbHandle.ExecContext(ctx, q, password, username) - return err + res, err := dbHandle.ExecContext(ctx, q, password, username) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) } func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error { @@ -1404,14 +1420,17 @@ func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error { return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { q := getUpdateUserQuery(user.Role) - _, err := tx.ExecContext(ctx, q, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, - user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, - user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo, user.Description, user.Email, + res, err := tx.ExecContext(ctx, q, user.Password, publicKeys, user.HomeDir, user.UID, user.GID, user.MaxSessions, + user.QuotaSize, user.QuotaFiles, permissions, user.UploadBandwidth, user.DownloadBandwidth, user.Status, + user.ExpirationDate, filters, fsConfig, user.AdditionalInfo, user.Description, user.Email, util.GetTimeAsMsSinceEpoch(time.Now()), user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer, - user.Role, user.LastPasswordChange, user.ID) + user.Role, user.LastPasswordChange, user.Username) if err != nil { return err } + if err := sqlCommonRequireRowAffected(res); err != nil { + return err + } if err := generateUserVirtualFoldersMapping(ctx, user, tx); err != nil { return err } @@ -1440,7 +1459,7 @@ func sqlCommonDeleteUser(user User, softDelete bool, dbHandle *sql.DB) error { return sqlCommonRequireRowAffected(res) }) } - res, err := dbHandle.ExecContext(ctx, q, user.ID) + res, err := dbHandle.ExecContext(ctx, q, user.Username) if err != nil { return err } @@ -2292,7 +2311,7 @@ func sqlCommonAddOrUpdateFolder(ctx context.Context, baseFolder *vfs.BaseVirtual } q := getUpsertFolderQuery() _, err = dbHandle.ExecContext(ctx, q, baseFolder.MappedPath, usedQuotaSize, usedQuotaFiles, - lastQuotaUpdate, baseFolder.Name, baseFolder.Description, string(fsConfig)) + lastQuotaUpdate, baseFolder.Name, baseFolder.Description, fsConfig) return err } @@ -2310,7 +2329,7 @@ func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) erro q := getAddFolderQuery() _, err = dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.UsedQuotaSize, folder.UsedQuotaFiles, - folder.LastQuotaUpdate, folder.Name, folder.Description, string(fsConfig)) + folder.LastQuotaUpdate, folder.Name, folder.Description, fsConfig) return err } @@ -2327,8 +2346,11 @@ func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) e defer cancel() q := getUpdateFolderQuery() - _, err = dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.Description, string(fsConfig), folder.Name) - return err + res, err := dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.Description, fsConfig, folder.Name) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) } func sqlCommonDeleteFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) error { @@ -2336,7 +2358,7 @@ func sqlCommonDeleteFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) er defer cancel() q := getDeleteFolderQuery() - res, err := dbHandle.ExecContext(ctx, q, folder.ID) + res, err := dbHandle.ExecContext(ctx, q, folder.Name) if err != nil { return err } @@ -2485,7 +2507,7 @@ func sqlCommonAddAdminGroupMapping(ctx context.Context, username, groupName stri return err } q := getAddAdminGroupMappingQuery() - _, err = dbHandle.ExecContext(ctx, q, username, groupName, string(options)) + _, err = dbHandle.ExecContext(ctx, q, username, groupName, options) return err } @@ -3082,7 +3104,7 @@ func sqlCommonUpdateFolderQuota(name string, filesAdd int, sizeAdd int64, reset q := getUpdateFolderQuotaQuery(reset) _, err := dbHandle.ExecContext(ctx, q, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), name) if err == nil { - providerLog(logger.LevelDebug, "quota updated for folder %#v, files increment: %v size increment: %v is reset? %v", + providerLog(logger.LevelDebug, "quota updated for folder %q, files increment: %d size increment: %d is reset? %t", name, filesAdd, sizeAdd, reset) } else { providerLog(logger.LevelWarn, "error updating quota for folder %#v: %v", name, err) @@ -3337,7 +3359,7 @@ func generateEventRuleActionsMapping(ctx context.Context, rule *EventRule, dbHan return err } q = getAddRuleActionMappingQuery() - _, err = dbHandle.ExecContext(ctx, q, rule.Name, action.Name, action.Order, string(options)) + _, err = dbHandle.ExecContext(ctx, q, rule.Name, action.Name, action.Order, options) if err != nil { return err } @@ -3444,7 +3466,7 @@ func sqlCommonAddEventAction(action *BaseEventAction, dbHandle *sql.DB) error { if err != nil { return err } - _, err = dbHandle.ExecContext(ctx, q, action.Name, action.Description, action.Type, string(options)) + _, err = dbHandle.ExecContext(ctx, q, action.Name, action.Description, action.Type, options) return err } @@ -3461,10 +3483,13 @@ func sqlCommonUpdateEventAction(action *BaseEventAction, dbHandle *sql.DB) error return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { q := getUpdateEventActionQuery() - _, err = tx.ExecContext(ctx, q, action.Description, action.Type, string(options), action.Name) + res, err := tx.ExecContext(ctx, q, action.Description, action.Type, options, action.Name) if err != nil { return err } + if err := sqlCommonRequireRowAffected(res); err != nil { + return err + } q = getUpdateRulesTimestampQuery() _, err = tx.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), action.ID) return err @@ -3609,7 +3634,7 @@ func sqlCommonAddEventRule(rule *EventRule, dbHandle *sql.DB) error { } q := getAddEventRuleQuery() _, err := tx.ExecContext(ctx, q, rule.Name, rule.Description, util.GetTimeAsMsSinceEpoch(time.Now()), - util.GetTimeAsMsSinceEpoch(time.Now()), rule.Trigger, string(conditions), rule.Status) + util.GetTimeAsMsSinceEpoch(time.Now()), rule.Trigger, conditions, rule.Status) if err != nil { return err } @@ -3631,7 +3656,7 @@ func sqlCommonUpdateEventRule(rule *EventRule, dbHandle *sql.DB) error { return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { q := getUpdateEventRuleQuery() _, err := tx.ExecContext(ctx, q, rule.Description, util.GetTimeAsMsSinceEpoch(time.Now()), - rule.Trigger, string(conditions), rule.Status, rule.Name) + rule.Trigger, conditions, rule.Status, rule.Name) if err != nil { return err } @@ -3744,7 +3769,7 @@ func sqlCommonAddNode(dbHandle *sql.DB) error { defer cancel() q := getAddNodeQuery() - _, err = dbHandle.ExecContext(ctx, q, currentNode.Name, string(data), util.GetTimeAsMsSinceEpoch(time.Now()), + _, err = dbHandle.ExecContext(ctx, q, currentNode.Name, data, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now())) if err != nil { return fmt.Errorf("unable to register cluster node: %w", err) @@ -3857,9 +3882,6 @@ func sqlCommonSetConfigs(configs *Configs, dbHandle *sql.DB) error { if err != nil { return err } - if config.Driver == MySQLDataProviderName { - return nil - } return sqlCommonRequireRowAffected(res) } @@ -3884,8 +3906,6 @@ func sqlCommonGetDatabaseVersion(dbHandle sqlQuerier, showInitWarn bool) (schema } func sqlCommonRequireRowAffected(res sql.Result) error { - // MariaDB/MySQL returns 0 rows affected for updates that don't change anything - // so we don't check rows affected for updates affected, err := res.RowsAffected() if err == nil && affected == 0 { return util.NewRecordNotFoundError(sql.ErrNoRows.Error()) diff --git a/internal/dataprovider/sqlite.go b/internal/dataprovider/sqlite.go index 092af313..2bcef645 100644 --- a/internal/dataprovider/sqlite.go +++ b/internal/dataprovider/sqlite.go @@ -221,17 +221,17 @@ func initializeSQLiteProvider(basePath string) error { if !filepath.IsAbs(dbPath) { dbPath = filepath.Join(basePath, dbPath) } - connectionString = fmt.Sprintf("file:%v?cache=shared&_foreign_keys=1", dbPath) + connectionString = fmt.Sprintf("file:%s?cache=shared&_foreign_keys=1", dbPath) } else { connectionString = config.ConnectionString } dbHandle, err := sql.Open("sqlite3", connectionString) if err == nil { - providerLog(logger.LevelDebug, "sqlite database handle created, connection string: %#v", connectionString) + providerLog(logger.LevelDebug, "sqlite database handle created, connection string: %q", connectionString) dbHandle.SetMaxOpenConns(1) provider = &SQLiteProvider{dbHandle: dbHandle} } else { - providerLog(logger.LevelError, "error creating sqlite database handler, connection string: %#v, error: %v", + providerLog(logger.LevelError, "error creating sqlite database handler, connection string: %q, error: %v", connectionString, err) } return err diff --git a/internal/dataprovider/sqlqueries.go b/internal/dataprovider/sqlqueries.go index 18c4ce52..0078cacb 100644 --- a/internal/dataprovider/sqlqueries.go +++ b/internal/dataprovider/sqlqueries.go @@ -706,7 +706,7 @@ func getUpdateUserQuery(role string) string { return fmt.Sprintf(`UPDATE %s SET password=%s,public_keys=%s,home_dir=%s,uid=%s,gid=%s,max_sessions=%s,quota_size=%s, quota_files=%s,permissions=%s,upload_bandwidth=%s,download_bandwidth=%s,status=%s,expiration_date=%s,filters=%s,filesystem=%s, additional_info=%s,description=%s,email=%s,updated_at=%s,upload_data_transfer=%s,download_data_transfer=%s, - total_data_transfer=%s,role_id=COALESCE((SELECT id from %s WHERE name=%s),%s),last_password_change=%s WHERE id = %s`, + total_data_transfer=%s,role_id=COALESCE((SELECT id from %s WHERE name=%s),%s),last_password_change=%s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13], sqlPlaceholders[14], @@ -725,7 +725,7 @@ func getDeleteUserQuery(softDelete bool) string { return fmt.Sprintf(`UPDATE %s SET updated_at=%s,deleted_at=%s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2]) } - return fmt.Sprintf(`DELETE FROM %s WHERE id = %s`, sqlTableUsers, sqlPlaceholders[0]) + return fmt.Sprintf(`DELETE FROM %s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0]) } func getRemoveSoftDeletedUserQuery() string { @@ -748,7 +748,7 @@ func getUpdateFolderQuery() string { } func getDeleteFolderQuery() string { - return fmt.Sprintf(`DELETE FROM %s WHERE id = %s`, sqlTableFolders, sqlPlaceholders[0]) + return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, sqlTableFolders, sqlPlaceholders[0]) } func getUpsertFolderQuery() string {