sql providers: remove unnecessary []byte to string conversion

always check affected rows for updates

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino 2023-02-20 18:14:02 +01:00
parent a3fff56da5
commit 2c1319985d
No known key found for this signature in database
GPG key ID: 935D2952DEC4EECF
7 changed files with 98 additions and 77 deletions

2
go.mod
View file

@ -121,7 +121,7 @@ require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.3 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/kr/fs v0.1.0 // indirect
github.com/lestrrat-go/blackmagic v1.0.1 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect

3
go.sum
View file

@ -1401,8 +1401,9 @@ github.com/klauspost/compress v1.15.1/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47e
github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7yIO+lw=
github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4=
github.com/klauspost/cpuid/v2 v2.0.4/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.3 h1:sxCkb+qR91z4vsqw4vGGZlDgPz3G7gjaLyK3V8y70BU=
github.com/klauspost/cpuid/v2 v2.2.3/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/kolo/xmlrpc v0.0.0-20201022064351-38db28db192b/go.mod h1:pcaDhQK0/NJZEvtCO0qQPPropqV0sJOJ6YW7X+9kRwM=
github.com/kolo/xmlrpc v0.0.0-20220921171641-a4b6fa1dd06b/go.mod h1:pcaDhQK0/NJZEvtCO0qQPPropqV0sJOJ6YW7X+9kRwM=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=

View file

@ -230,7 +230,7 @@ func initializeMySQLProvider() error {
}
dbHandle, err := sql.Open("mysql", connString)
if err == nil {
providerLog(logger.LevelDebug, "mysql database handle created, connection string: %#v, pool size: %v",
providerLog(logger.LevelDebug, "mysql database handle created, connection string: %q, pool size: %v",
redactedConnString, config.PoolSize)
dbHandle.SetMaxOpenConns(config.PoolSize)
if config.PoolSize > 0 {
@ -242,7 +242,7 @@ func initializeMySQLProvider() error {
dbHandle.SetConnMaxIdleTime(120 * time.Second)
provider = &MySQLProvider{dbHandle: dbHandle}
} else {
providerLog(logger.LevelError, "error creating mysql database handler, connection string: %#v, error: %v",
providerLog(logger.LevelError, "error creating mysql database handler, connection string: %q, error: %v",
redactedConnString, err)
}
return err
@ -260,7 +260,7 @@ func getMySQLConnectionString(redactedPwd bool) (string, error) {
return "", err
}
}
connectionString = fmt.Sprintf("%v:%v@tcp([%v]:%v)/%v?charset=utf8mb4&interpolateParams=true&timeout=10s&parseTime=true&tls=%v&writeTimeout=60s&readTimeout=60s",
connectionString = fmt.Sprintf("%s:%s@tcp([%s]:%d)/%s?collation=utf8mb4_unicode_ci&interpolateParams=true&timeout=10s&parseTime=true&clientFoundRows=true&tls=%s&writeTimeout=60s&readTimeout=60s",
config.Username, password, config.Host, config.Port, config.Name, sslMode)
} else {
connectionString = config.ConnectionString

View file

@ -236,7 +236,7 @@ func initializePGSQLProvider() error {
var err error
dbHandle, err := sql.Open("pgx", getPGSQLConnectionString(false))
if err == nil {
providerLog(logger.LevelDebug, "postgres database handle created, connection string: %#v, pool size: %v",
providerLog(logger.LevelDebug, "postgres database handle created, connection string: %q, pool size: %d",
getPGSQLConnectionString(true), config.PoolSize)
dbHandle.SetMaxOpenConns(config.PoolSize)
if config.PoolSize > 0 {
@ -248,7 +248,7 @@ func initializePGSQLProvider() error {
dbHandle.SetConnMaxIdleTime(120 * time.Second)
provider = &PGSQLProvider{dbHandle: dbHandle}
} else {
providerLog(logger.LevelError, "error creating postgres database handler, connection string: %#v, error: %v",
providerLog(logger.LevelError, "error creating postgres database handler, connection string: %q, error: %v",
getPGSQLConnectionString(true), err)
}
return err
@ -261,13 +261,13 @@ func getPGSQLConnectionString(redactedPwd bool) string {
if redactedPwd && password != "" {
password = "[redacted]"
}
connectionString = fmt.Sprintf("host='%v' port=%v dbname='%v' user='%v' password='%v' sslmode=%v connect_timeout=10",
connectionString = fmt.Sprintf("host='%s' port=%d dbname='%s' user='%s' password='%s' sslmode=%s connect_timeout=10",
config.Host, config.Port, config.Name, config.Username, password, getSSLMode())
if config.RootCert != "" {
connectionString += fmt.Sprintf(" sslrootcert='%v'", config.RootCert)
connectionString += fmt.Sprintf(" sslrootcert='%s'", config.RootCert)
}
if config.ClientCert != "" && config.ClientKey != "" {
connectionString += fmt.Sprintf(" sslcert='%v' sslkey='%v'", config.ClientCert, config.ClientKey)
connectionString += fmt.Sprintf(" sslcert='%s' sslkey='%s'", config.ClientCert, config.ClientKey)
}
if config.DisableSNI {
connectionString += " sslsni=0"

View file

@ -118,11 +118,11 @@ func sqlCommonAddShare(share *Share, dbHandle *sql.DB) error {
if err != nil {
return err
}
allowFrom := ""
var allowFrom []byte
if len(share.AllowFrom) > 0 {
res, err := json.Marshal(share.AllowFrom)
if err == nil {
allowFrom = string(res)
allowFrom = res
}
}
@ -145,7 +145,7 @@ func sqlCommonAddShare(share *Share, dbHandle *sql.DB) error {
lastUseAt = share.LastUseAt
}
_, err = dbHandle.ExecContext(ctx, q, share.ShareID, share.Name, share.Description, share.Scope,
string(paths), createdAt, updatedAt, lastUseAt, share.ExpiresAt, share.Password,
paths, createdAt, updatedAt, lastUseAt, share.ExpiresAt, share.Password,
share.MaxTokens, usedTokens, allowFrom, user.ID)
return err
}
@ -161,11 +161,11 @@ func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error {
return err
}
allowFrom := ""
var allowFrom []byte
if len(share.AllowFrom) > 0 {
res, err := json.Marshal(share.AllowFrom)
if err == nil {
allowFrom = string(res)
allowFrom = res
}
}
@ -184,6 +184,7 @@ func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error {
q = getUpdateShareQuery()
}
var res sql.Result
if share.IsRestore {
if share.CreatedAt == 0 {
share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
@ -191,15 +192,18 @@ func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error {
if share.UpdatedAt == 0 {
share.UpdatedAt = share.CreatedAt
}
_, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, string(paths),
res, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, paths,
share.CreatedAt, share.UpdatedAt, share.LastUseAt, share.ExpiresAt, share.Password, share.MaxTokens,
share.UsedTokens, allowFrom, user.ID, share.ShareID)
} else {
_, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, string(paths),
res, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, paths,
util.GetTimeAsMsSinceEpoch(time.Now()), share.ExpiresAt, share.Password, share.MaxTokens,
allowFrom, user.ID, share.ShareID)
}
return err
if err != nil {
return err
}
return sqlCommonRequireRowAffected(res)
}
func sqlCommonDeleteShare(share Share, dbHandle *sql.DB) error {
@ -311,9 +315,12 @@ func sqlCommonUpdateAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
defer cancel()
q := getUpdateAPIKeyQuery()
_, err = dbHandle.ExecContext(ctx, q, apiKey.Name, apiKey.Scope, apiKey.ExpiresAt, userID, adminID,
res, err := dbHandle.ExecContext(ctx, q, apiKey.Name, apiKey.Scope, apiKey.ExpiresAt, userID, adminID,
apiKey.Description, util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.KeyID)
return err
if err != nil {
return err
}
return sqlCommonRequireRowAffected(res)
}
func sqlCommonDeleteAPIKey(apiKey APIKey, dbHandle *sql.DB) error {
@ -436,8 +443,8 @@ func sqlCommonAddAdmin(admin *Admin, dbHandle *sql.DB) error {
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
q := getAddAdminQuery(admin.Role)
_, err = tx.ExecContext(ctx, q, admin.Username, admin.Password, admin.Status, admin.Email, string(perms),
string(filters), admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
_, err = tx.ExecContext(ctx, q, admin.Username, admin.Password, admin.Status, admin.Email, perms,
filters, admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
util.GetTimeAsMsSinceEpoch(time.Now()), admin.Role)
if err != nil {
return err
@ -467,7 +474,7 @@ func sqlCommonUpdateAdmin(admin *Admin, dbHandle *sql.DB) error {
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
q := getUpdateAdminQuery(admin.Role)
_, err = tx.ExecContext(ctx, q, admin.Password, admin.Status, admin.Email, string(perms), string(filters),
_, err = tx.ExecContext(ctx, q, admin.Password, admin.Status, admin.Email, perms, filters,
admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()), admin.Role, admin.Username)
if err != nil {
return err
@ -753,9 +760,12 @@ func sqlCommonUpdateIPListEntry(entry *IPListEntry, dbHandle *sql.DB) error {
defer cancel()
q := getUpdateIPListEntryQuery()
_, err := dbHandle.ExecContext(ctx, q, entry.Mode, entry.Protocols, entry.Description,
res, err := dbHandle.ExecContext(ctx, q, entry.Mode, entry.Protocols, entry.Description,
util.GetTimeAsMsSinceEpoch(time.Now()), entry.Type, entry.IPOrNet)
return err
if err != nil {
return err
}
return sqlCommonRequireRowAffected(res)
}
func sqlCommonDeleteIPListEntry(entry IPListEntry, softDelete bool, dbHandle *sql.DB) error {
@ -876,8 +886,11 @@ func sqlCommonUpdateRole(role *Role, dbHandle *sql.DB) error {
defer cancel()
q := getUpdateRoleQuery()
_, err := dbHandle.ExecContext(ctx, q, role.Description, util.GetTimeAsMsSinceEpoch(time.Now()), role.Name)
return err
res, err := dbHandle.ExecContext(ctx, q, role.Description, util.GetTimeAsMsSinceEpoch(time.Now()), role.Name)
if err != nil {
return err
}
return sqlCommonRequireRowAffected(res)
}
func sqlCommonDeleteRole(role Role, dbHandle *sql.DB) error {
@ -1069,7 +1082,7 @@ func sqlCommonAddGroup(group *Group, dbHandle *sql.DB) error {
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
q := getAddGroupQuery()
_, err := tx.ExecContext(ctx, q, group.Name, group.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
util.GetTimeAsMsSinceEpoch(time.Now()), string(settings))
util.GetTimeAsMsSinceEpoch(time.Now()), settings)
if err != nil {
return err
}
@ -1170,7 +1183,7 @@ func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, isSSHCert bo
func sqlCommonCheckAvailability(dbHandle *sql.DB) (err error) {
defer func() {
if r := recover(); r != nil {
providerLog(logger.LevelError, "panic in check provider availability, stack trace: %v", string(debug.Stack()))
providerLog(logger.LevelError, "panic in check provider availability, stack trace: %s", string(debug.Stack()))
err = errors.New("unable to check provider status")
}
}()
@ -1189,10 +1202,10 @@ func sqlCommonUpdateTransferQuota(username string, uploadSize, downloadSize int6
q := getUpdateTransferQuotaQuery(reset)
_, err := dbHandle.ExecContext(ctx, q, uploadSize, downloadSize, util.GetTimeAsMsSinceEpoch(time.Now()), username)
if err == nil {
providerLog(logger.LevelDebug, "transfer quota updated for user %#v, ul increment: %v dl increment: %v is reset? %v",
providerLog(logger.LevelDebug, "transfer quota updated for user %q, ul increment: %d dl increment: %d is reset? %t",
username, uploadSize, downloadSize, reset)
} else {
providerLog(logger.LevelError, "error updating quota for user %#v: %v", username, err)
providerLog(logger.LevelError, "error updating quota for user %q: %v", username, err)
}
return err
}
@ -1204,10 +1217,10 @@ func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bo
q := getUpdateQuotaQuery(reset)
_, err := dbHandle.ExecContext(ctx, q, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), username)
if err == nil {
providerLog(logger.LevelDebug, "quota updated for user %#v, files increment: %v size increment: %v is reset? %v",
providerLog(logger.LevelDebug, "quota updated for user %q, files increment: %d size increment: %d is reset? %t",
username, filesAdd, sizeAdd, reset)
} else {
providerLog(logger.LevelError, "error updating quota for user %#v: %v", username, err)
providerLog(logger.LevelError, "error updating quota for user %q: %v", username, err)
}
return err
}
@ -1234,9 +1247,9 @@ func sqlCommonUpdateShareLastUse(shareID string, numTokens int, dbHandle *sql.DB
q := getUpdateShareLastUseQuery()
_, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), numTokens, shareID)
if err == nil {
providerLog(logger.LevelDebug, "last use updated for shared object %#v", shareID)
providerLog(logger.LevelDebug, "last use updated for shared object %q", shareID)
} else {
providerLog(logger.LevelWarn, "error updating last use for shared object %#v: %v", shareID, err)
providerLog(logger.LevelWarn, "error updating last use for shared object %q: %v", shareID, err)
}
return err
}
@ -1248,9 +1261,9 @@ func sqlCommonUpdateAPIKeyLastUse(keyID string, dbHandle *sql.DB) error {
q := getUpdateAPIKeyLastUseQuery()
_, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), keyID)
if err == nil {
providerLog(logger.LevelDebug, "last use updated for key %#v", keyID)
providerLog(logger.LevelDebug, "last use updated for key %q", keyID)
} else {
providerLog(logger.LevelWarn, "error updating last use for key %#v: %v", keyID, err)
providerLog(logger.LevelWarn, "error updating last use for key %q: %v", keyID, err)
}
return err
}
@ -1262,9 +1275,9 @@ func sqlCommonUpdateAdminLastLogin(username string, dbHandle *sql.DB) error {
q := getUpdateAdminLastLoginQuery()
_, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
if err == nil {
providerLog(logger.LevelDebug, "last login updated for admin %#v", username)
providerLog(logger.LevelDebug, "last login updated for admin %q", username)
} else {
providerLog(logger.LevelWarn, "error updating last login for admin %#v: %v", username, err)
providerLog(logger.LevelWarn, "error updating last login for admin %q: %v", username, err)
}
return err
}
@ -1276,9 +1289,9 @@ func sqlCommonSetUpdatedAt(username string, dbHandle *sql.DB) {
q := getSetUpdateAtQuery()
_, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
if err == nil {
providerLog(logger.LevelDebug, "updated_at set for user %#v", username)
providerLog(logger.LevelDebug, "updated_at set for user %q", username)
} else {
providerLog(logger.LevelWarn, "error setting updated_at for user %#v: %v", username, err)
providerLog(logger.LevelWarn, "error setting updated_at for user %q: %v", username, err)
}
}
@ -1313,9 +1326,9 @@ func sqlCommonUpdateLastLogin(username string, dbHandle *sql.DB) error {
q := getUpdateLastLoginQuery()
_, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
if err == nil {
providerLog(logger.LevelDebug, "last login updated for user %#v", username)
providerLog(logger.LevelDebug, "last login updated for user %q", username)
} else {
providerLog(logger.LevelWarn, "error updating last login for user %#v: %v", username, err)
providerLog(logger.LevelWarn, "error updating last login for user %q: %v", username, err)
}
return err
}
@ -1353,9 +1366,9 @@ func sqlCommonAddUser(user *User, dbHandle *sql.DB) error {
}
}
q := getAddUserQuery(user.Role)
_, err := tx.ExecContext(ctx, q, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID,
user.MaxSessions, user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth,
user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo,
_, err := tx.ExecContext(ctx, q, user.Username, user.Password, publicKeys, user.HomeDir, user.UID, user.GID,
user.MaxSessions, user.QuotaSize, user.QuotaFiles, permissions, user.UploadBandwidth,
user.DownloadBandwidth, user.Status, user.ExpirationDate, filters, fsConfig, user.AdditionalInfo,
user.Description, user.Email, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()),
user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer, user.Role, user.LastPasswordChange)
if err != nil {
@ -1373,8 +1386,11 @@ func sqlCommonUpdateUserPassword(username, password string, dbHandle *sql.DB) er
defer cancel()
q := getUpdateUserPasswordQuery()
_, err := dbHandle.ExecContext(ctx, q, password, username)
return err
res, err := dbHandle.ExecContext(ctx, q, password, username)
if err != nil {
return err
}
return sqlCommonRequireRowAffected(res)
}
func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
@ -1404,14 +1420,17 @@ func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
q := getUpdateUserQuery(user.Role)
_, err := tx.ExecContext(ctx, q, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions,
user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status,
user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo, user.Description, user.Email,
res, err := tx.ExecContext(ctx, q, user.Password, publicKeys, user.HomeDir, user.UID, user.GID, user.MaxSessions,
user.QuotaSize, user.QuotaFiles, permissions, user.UploadBandwidth, user.DownloadBandwidth, user.Status,
user.ExpirationDate, filters, fsConfig, user.AdditionalInfo, user.Description, user.Email,
util.GetTimeAsMsSinceEpoch(time.Now()), user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer,
user.Role, user.LastPasswordChange, user.ID)
user.Role, user.LastPasswordChange, user.Username)
if err != nil {
return err
}
if err := sqlCommonRequireRowAffected(res); err != nil {
return err
}
if err := generateUserVirtualFoldersMapping(ctx, user, tx); err != nil {
return err
}
@ -1440,7 +1459,7 @@ func sqlCommonDeleteUser(user User, softDelete bool, dbHandle *sql.DB) error {
return sqlCommonRequireRowAffected(res)
})
}
res, err := dbHandle.ExecContext(ctx, q, user.ID)
res, err := dbHandle.ExecContext(ctx, q, user.Username)
if err != nil {
return err
}
@ -2292,7 +2311,7 @@ func sqlCommonAddOrUpdateFolder(ctx context.Context, baseFolder *vfs.BaseVirtual
}
q := getUpsertFolderQuery()
_, err = dbHandle.ExecContext(ctx, q, baseFolder.MappedPath, usedQuotaSize, usedQuotaFiles,
lastQuotaUpdate, baseFolder.Name, baseFolder.Description, string(fsConfig))
lastQuotaUpdate, baseFolder.Name, baseFolder.Description, fsConfig)
return err
}
@ -2310,7 +2329,7 @@ func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) erro
q := getAddFolderQuery()
_, err = dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.UsedQuotaSize, folder.UsedQuotaFiles,
folder.LastQuotaUpdate, folder.Name, folder.Description, string(fsConfig))
folder.LastQuotaUpdate, folder.Name, folder.Description, fsConfig)
return err
}
@ -2327,8 +2346,11 @@ func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) e
defer cancel()
q := getUpdateFolderQuery()
_, err = dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.Description, string(fsConfig), folder.Name)
return err
res, err := dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.Description, fsConfig, folder.Name)
if err != nil {
return err
}
return sqlCommonRequireRowAffected(res)
}
func sqlCommonDeleteFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
@ -2336,7 +2358,7 @@ func sqlCommonDeleteFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) er
defer cancel()
q := getDeleteFolderQuery()
res, err := dbHandle.ExecContext(ctx, q, folder.ID)
res, err := dbHandle.ExecContext(ctx, q, folder.Name)
if err != nil {
return err
}
@ -2485,7 +2507,7 @@ func sqlCommonAddAdminGroupMapping(ctx context.Context, username, groupName stri
return err
}
q := getAddAdminGroupMappingQuery()
_, err = dbHandle.ExecContext(ctx, q, username, groupName, string(options))
_, err = dbHandle.ExecContext(ctx, q, username, groupName, options)
return err
}
@ -3082,7 +3104,7 @@ func sqlCommonUpdateFolderQuota(name string, filesAdd int, sizeAdd int64, reset
q := getUpdateFolderQuotaQuery(reset)
_, err := dbHandle.ExecContext(ctx, q, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), name)
if err == nil {
providerLog(logger.LevelDebug, "quota updated for folder %#v, files increment: %v size increment: %v is reset? %v",
providerLog(logger.LevelDebug, "quota updated for folder %q, files increment: %d size increment: %d is reset? %t",
name, filesAdd, sizeAdd, reset)
} else {
providerLog(logger.LevelWarn, "error updating quota for folder %#v: %v", name, err)
@ -3337,7 +3359,7 @@ func generateEventRuleActionsMapping(ctx context.Context, rule *EventRule, dbHan
return err
}
q = getAddRuleActionMappingQuery()
_, err = dbHandle.ExecContext(ctx, q, rule.Name, action.Name, action.Order, string(options))
_, err = dbHandle.ExecContext(ctx, q, rule.Name, action.Name, action.Order, options)
if err != nil {
return err
}
@ -3444,7 +3466,7 @@ func sqlCommonAddEventAction(action *BaseEventAction, dbHandle *sql.DB) error {
if err != nil {
return err
}
_, err = dbHandle.ExecContext(ctx, q, action.Name, action.Description, action.Type, string(options))
_, err = dbHandle.ExecContext(ctx, q, action.Name, action.Description, action.Type, options)
return err
}
@ -3461,10 +3483,13 @@ func sqlCommonUpdateEventAction(action *BaseEventAction, dbHandle *sql.DB) error
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
q := getUpdateEventActionQuery()
_, err = tx.ExecContext(ctx, q, action.Description, action.Type, string(options), action.Name)
res, err := tx.ExecContext(ctx, q, action.Description, action.Type, options, action.Name)
if err != nil {
return err
}
if err := sqlCommonRequireRowAffected(res); err != nil {
return err
}
q = getUpdateRulesTimestampQuery()
_, err = tx.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), action.ID)
return err
@ -3609,7 +3634,7 @@ func sqlCommonAddEventRule(rule *EventRule, dbHandle *sql.DB) error {
}
q := getAddEventRuleQuery()
_, err := tx.ExecContext(ctx, q, rule.Name, rule.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
util.GetTimeAsMsSinceEpoch(time.Now()), rule.Trigger, string(conditions), rule.Status)
util.GetTimeAsMsSinceEpoch(time.Now()), rule.Trigger, conditions, rule.Status)
if err != nil {
return err
}
@ -3631,7 +3656,7 @@ func sqlCommonUpdateEventRule(rule *EventRule, dbHandle *sql.DB) error {
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
q := getUpdateEventRuleQuery()
_, err := tx.ExecContext(ctx, q, rule.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
rule.Trigger, string(conditions), rule.Status, rule.Name)
rule.Trigger, conditions, rule.Status, rule.Name)
if err != nil {
return err
}
@ -3744,7 +3769,7 @@ func sqlCommonAddNode(dbHandle *sql.DB) error {
defer cancel()
q := getAddNodeQuery()
_, err = dbHandle.ExecContext(ctx, q, currentNode.Name, string(data), util.GetTimeAsMsSinceEpoch(time.Now()),
_, err = dbHandle.ExecContext(ctx, q, currentNode.Name, data, util.GetTimeAsMsSinceEpoch(time.Now()),
util.GetTimeAsMsSinceEpoch(time.Now()))
if err != nil {
return fmt.Errorf("unable to register cluster node: %w", err)
@ -3857,9 +3882,6 @@ func sqlCommonSetConfigs(configs *Configs, dbHandle *sql.DB) error {
if err != nil {
return err
}
if config.Driver == MySQLDataProviderName {
return nil
}
return sqlCommonRequireRowAffected(res)
}
@ -3884,8 +3906,6 @@ func sqlCommonGetDatabaseVersion(dbHandle sqlQuerier, showInitWarn bool) (schema
}
func sqlCommonRequireRowAffected(res sql.Result) error {
// MariaDB/MySQL returns 0 rows affected for updates that don't change anything
// so we don't check rows affected for updates
affected, err := res.RowsAffected()
if err == nil && affected == 0 {
return util.NewRecordNotFoundError(sql.ErrNoRows.Error())

View file

@ -221,17 +221,17 @@ func initializeSQLiteProvider(basePath string) error {
if !filepath.IsAbs(dbPath) {
dbPath = filepath.Join(basePath, dbPath)
}
connectionString = fmt.Sprintf("file:%v?cache=shared&_foreign_keys=1", dbPath)
connectionString = fmt.Sprintf("file:%s?cache=shared&_foreign_keys=1", dbPath)
} else {
connectionString = config.ConnectionString
}
dbHandle, err := sql.Open("sqlite3", connectionString)
if err == nil {
providerLog(logger.LevelDebug, "sqlite database handle created, connection string: %#v", connectionString)
providerLog(logger.LevelDebug, "sqlite database handle created, connection string: %q", connectionString)
dbHandle.SetMaxOpenConns(1)
provider = &SQLiteProvider{dbHandle: dbHandle}
} else {
providerLog(logger.LevelError, "error creating sqlite database handler, connection string: %#v, error: %v",
providerLog(logger.LevelError, "error creating sqlite database handler, connection string: %q, error: %v",
connectionString, err)
}
return err

View file

@ -706,7 +706,7 @@ func getUpdateUserQuery(role string) string {
return fmt.Sprintf(`UPDATE %s SET password=%s,public_keys=%s,home_dir=%s,uid=%s,gid=%s,max_sessions=%s,quota_size=%s,
quota_files=%s,permissions=%s,upload_bandwidth=%s,download_bandwidth=%s,status=%s,expiration_date=%s,filters=%s,filesystem=%s,
additional_info=%s,description=%s,email=%s,updated_at=%s,upload_data_transfer=%s,download_data_transfer=%s,
total_data_transfer=%s,role_id=COALESCE((SELECT id from %s WHERE name=%s),%s),last_password_change=%s WHERE id = %s`,
total_data_transfer=%s,role_id=COALESCE((SELECT id from %s WHERE name=%s),%s),last_password_change=%s WHERE username = %s`,
sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4],
sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9],
sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13], sqlPlaceholders[14],
@ -725,7 +725,7 @@ func getDeleteUserQuery(softDelete bool) string {
return fmt.Sprintf(`UPDATE %s SET updated_at=%s,deleted_at=%s WHERE username = %s`,
sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2])
}
return fmt.Sprintf(`DELETE FROM %s WHERE id = %s`, sqlTableUsers, sqlPlaceholders[0])
return fmt.Sprintf(`DELETE FROM %s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0])
}
func getRemoveSoftDeletedUserQuery() string {
@ -748,7 +748,7 @@ func getUpdateFolderQuery() string {
}
func getDeleteFolderQuery() string {
return fmt.Sprintf(`DELETE FROM %s WHERE id = %s`, sqlTableFolders, sqlPlaceholders[0])
return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, sqlTableFolders, sqlPlaceholders[0])
}
func getUpsertFolderQuery() string {