sql providers: remove unnecessary []byte to string conversion
always check affected rows for updates Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
parent
a3fff56da5
commit
2c1319985d
7 changed files with 98 additions and 77 deletions
2
go.mod
2
go.mod
|
@ -121,7 +121,7 @@ require (
|
|||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.3 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/kr/fs v0.1.0 // indirect
|
||||
github.com/lestrrat-go/blackmagic v1.0.1 // indirect
|
||||
github.com/lestrrat-go/httpcc v1.0.1 // indirect
|
||||
|
|
3
go.sum
3
go.sum
|
@ -1401,8 +1401,9 @@ github.com/klauspost/compress v1.15.1/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47e
|
|||
github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7yIO+lw=
|
||||
github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4=
|
||||
github.com/klauspost/cpuid/v2 v2.0.4/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.3 h1:sxCkb+qR91z4vsqw4vGGZlDgPz3G7gjaLyK3V8y70BU=
|
||||
github.com/klauspost/cpuid/v2 v2.2.3/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
|
||||
github.com/kolo/xmlrpc v0.0.0-20201022064351-38db28db192b/go.mod h1:pcaDhQK0/NJZEvtCO0qQPPropqV0sJOJ6YW7X+9kRwM=
|
||||
github.com/kolo/xmlrpc v0.0.0-20220921171641-a4b6fa1dd06b/go.mod h1:pcaDhQK0/NJZEvtCO0qQPPropqV0sJOJ6YW7X+9kRwM=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
|
|
|
@ -230,7 +230,7 @@ func initializeMySQLProvider() error {
|
|||
}
|
||||
dbHandle, err := sql.Open("mysql", connString)
|
||||
if err == nil {
|
||||
providerLog(logger.LevelDebug, "mysql database handle created, connection string: %#v, pool size: %v",
|
||||
providerLog(logger.LevelDebug, "mysql database handle created, connection string: %q, pool size: %v",
|
||||
redactedConnString, config.PoolSize)
|
||||
dbHandle.SetMaxOpenConns(config.PoolSize)
|
||||
if config.PoolSize > 0 {
|
||||
|
@ -242,7 +242,7 @@ func initializeMySQLProvider() error {
|
|||
dbHandle.SetConnMaxIdleTime(120 * time.Second)
|
||||
provider = &MySQLProvider{dbHandle: dbHandle}
|
||||
} else {
|
||||
providerLog(logger.LevelError, "error creating mysql database handler, connection string: %#v, error: %v",
|
||||
providerLog(logger.LevelError, "error creating mysql database handler, connection string: %q, error: %v",
|
||||
redactedConnString, err)
|
||||
}
|
||||
return err
|
||||
|
@ -260,7 +260,7 @@ func getMySQLConnectionString(redactedPwd bool) (string, error) {
|
|||
return "", err
|
||||
}
|
||||
}
|
||||
connectionString = fmt.Sprintf("%v:%v@tcp([%v]:%v)/%v?charset=utf8mb4&interpolateParams=true&timeout=10s&parseTime=true&tls=%v&writeTimeout=60s&readTimeout=60s",
|
||||
connectionString = fmt.Sprintf("%s:%s@tcp([%s]:%d)/%s?collation=utf8mb4_unicode_ci&interpolateParams=true&timeout=10s&parseTime=true&clientFoundRows=true&tls=%s&writeTimeout=60s&readTimeout=60s",
|
||||
config.Username, password, config.Host, config.Port, config.Name, sslMode)
|
||||
} else {
|
||||
connectionString = config.ConnectionString
|
||||
|
|
|
@ -236,7 +236,7 @@ func initializePGSQLProvider() error {
|
|||
var err error
|
||||
dbHandle, err := sql.Open("pgx", getPGSQLConnectionString(false))
|
||||
if err == nil {
|
||||
providerLog(logger.LevelDebug, "postgres database handle created, connection string: %#v, pool size: %v",
|
||||
providerLog(logger.LevelDebug, "postgres database handle created, connection string: %q, pool size: %d",
|
||||
getPGSQLConnectionString(true), config.PoolSize)
|
||||
dbHandle.SetMaxOpenConns(config.PoolSize)
|
||||
if config.PoolSize > 0 {
|
||||
|
@ -248,7 +248,7 @@ func initializePGSQLProvider() error {
|
|||
dbHandle.SetConnMaxIdleTime(120 * time.Second)
|
||||
provider = &PGSQLProvider{dbHandle: dbHandle}
|
||||
} else {
|
||||
providerLog(logger.LevelError, "error creating postgres database handler, connection string: %#v, error: %v",
|
||||
providerLog(logger.LevelError, "error creating postgres database handler, connection string: %q, error: %v",
|
||||
getPGSQLConnectionString(true), err)
|
||||
}
|
||||
return err
|
||||
|
@ -261,13 +261,13 @@ func getPGSQLConnectionString(redactedPwd bool) string {
|
|||
if redactedPwd && password != "" {
|
||||
password = "[redacted]"
|
||||
}
|
||||
connectionString = fmt.Sprintf("host='%v' port=%v dbname='%v' user='%v' password='%v' sslmode=%v connect_timeout=10",
|
||||
connectionString = fmt.Sprintf("host='%s' port=%d dbname='%s' user='%s' password='%s' sslmode=%s connect_timeout=10",
|
||||
config.Host, config.Port, config.Name, config.Username, password, getSSLMode())
|
||||
if config.RootCert != "" {
|
||||
connectionString += fmt.Sprintf(" sslrootcert='%v'", config.RootCert)
|
||||
connectionString += fmt.Sprintf(" sslrootcert='%s'", config.RootCert)
|
||||
}
|
||||
if config.ClientCert != "" && config.ClientKey != "" {
|
||||
connectionString += fmt.Sprintf(" sslcert='%v' sslkey='%v'", config.ClientCert, config.ClientKey)
|
||||
connectionString += fmt.Sprintf(" sslcert='%s' sslkey='%s'", config.ClientCert, config.ClientKey)
|
||||
}
|
||||
if config.DisableSNI {
|
||||
connectionString += " sslsni=0"
|
||||
|
|
|
@ -118,11 +118,11 @@ func sqlCommonAddShare(share *Share, dbHandle *sql.DB) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
allowFrom := ""
|
||||
var allowFrom []byte
|
||||
if len(share.AllowFrom) > 0 {
|
||||
res, err := json.Marshal(share.AllowFrom)
|
||||
if err == nil {
|
||||
allowFrom = string(res)
|
||||
allowFrom = res
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -145,7 +145,7 @@ func sqlCommonAddShare(share *Share, dbHandle *sql.DB) error {
|
|||
lastUseAt = share.LastUseAt
|
||||
}
|
||||
_, err = dbHandle.ExecContext(ctx, q, share.ShareID, share.Name, share.Description, share.Scope,
|
||||
string(paths), createdAt, updatedAt, lastUseAt, share.ExpiresAt, share.Password,
|
||||
paths, createdAt, updatedAt, lastUseAt, share.ExpiresAt, share.Password,
|
||||
share.MaxTokens, usedTokens, allowFrom, user.ID)
|
||||
return err
|
||||
}
|
||||
|
@ -161,11 +161,11 @@ func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error {
|
|||
return err
|
||||
}
|
||||
|
||||
allowFrom := ""
|
||||
var allowFrom []byte
|
||||
if len(share.AllowFrom) > 0 {
|
||||
res, err := json.Marshal(share.AllowFrom)
|
||||
if err == nil {
|
||||
allowFrom = string(res)
|
||||
allowFrom = res
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -184,6 +184,7 @@ func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error {
|
|||
q = getUpdateShareQuery()
|
||||
}
|
||||
|
||||
var res sql.Result
|
||||
if share.IsRestore {
|
||||
if share.CreatedAt == 0 {
|
||||
share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
|
||||
|
@ -191,15 +192,18 @@ func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error {
|
|||
if share.UpdatedAt == 0 {
|
||||
share.UpdatedAt = share.CreatedAt
|
||||
}
|
||||
_, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, string(paths),
|
||||
res, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, paths,
|
||||
share.CreatedAt, share.UpdatedAt, share.LastUseAt, share.ExpiresAt, share.Password, share.MaxTokens,
|
||||
share.UsedTokens, allowFrom, user.ID, share.ShareID)
|
||||
} else {
|
||||
_, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, string(paths),
|
||||
res, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, paths,
|
||||
util.GetTimeAsMsSinceEpoch(time.Now()), share.ExpiresAt, share.Password, share.MaxTokens,
|
||||
allowFrom, user.ID, share.ShareID)
|
||||
}
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlCommonRequireRowAffected(res)
|
||||
}
|
||||
|
||||
func sqlCommonDeleteShare(share Share, dbHandle *sql.DB) error {
|
||||
|
@ -311,9 +315,12 @@ func sqlCommonUpdateAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
|
|||
defer cancel()
|
||||
|
||||
q := getUpdateAPIKeyQuery()
|
||||
_, err = dbHandle.ExecContext(ctx, q, apiKey.Name, apiKey.Scope, apiKey.ExpiresAt, userID, adminID,
|
||||
res, err := dbHandle.ExecContext(ctx, q, apiKey.Name, apiKey.Scope, apiKey.ExpiresAt, userID, adminID,
|
||||
apiKey.Description, util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.KeyID)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlCommonRequireRowAffected(res)
|
||||
}
|
||||
|
||||
func sqlCommonDeleteAPIKey(apiKey APIKey, dbHandle *sql.DB) error {
|
||||
|
@ -436,8 +443,8 @@ func sqlCommonAddAdmin(admin *Admin, dbHandle *sql.DB) error {
|
|||
|
||||
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
|
||||
q := getAddAdminQuery(admin.Role)
|
||||
_, err = tx.ExecContext(ctx, q, admin.Username, admin.Password, admin.Status, admin.Email, string(perms),
|
||||
string(filters), admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
_, err = tx.ExecContext(ctx, q, admin.Username, admin.Password, admin.Status, admin.Email, perms,
|
||||
filters, admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
util.GetTimeAsMsSinceEpoch(time.Now()), admin.Role)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -467,7 +474,7 @@ func sqlCommonUpdateAdmin(admin *Admin, dbHandle *sql.DB) error {
|
|||
|
||||
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
|
||||
q := getUpdateAdminQuery(admin.Role)
|
||||
_, err = tx.ExecContext(ctx, q, admin.Password, admin.Status, admin.Email, string(perms), string(filters),
|
||||
_, err = tx.ExecContext(ctx, q, admin.Password, admin.Status, admin.Email, perms, filters,
|
||||
admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()), admin.Role, admin.Username)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -753,9 +760,12 @@ func sqlCommonUpdateIPListEntry(entry *IPListEntry, dbHandle *sql.DB) error {
|
|||
defer cancel()
|
||||
|
||||
q := getUpdateIPListEntryQuery()
|
||||
_, err := dbHandle.ExecContext(ctx, q, entry.Mode, entry.Protocols, entry.Description,
|
||||
res, err := dbHandle.ExecContext(ctx, q, entry.Mode, entry.Protocols, entry.Description,
|
||||
util.GetTimeAsMsSinceEpoch(time.Now()), entry.Type, entry.IPOrNet)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlCommonRequireRowAffected(res)
|
||||
}
|
||||
|
||||
func sqlCommonDeleteIPListEntry(entry IPListEntry, softDelete bool, dbHandle *sql.DB) error {
|
||||
|
@ -876,8 +886,11 @@ func sqlCommonUpdateRole(role *Role, dbHandle *sql.DB) error {
|
|||
defer cancel()
|
||||
|
||||
q := getUpdateRoleQuery()
|
||||
_, err := dbHandle.ExecContext(ctx, q, role.Description, util.GetTimeAsMsSinceEpoch(time.Now()), role.Name)
|
||||
return err
|
||||
res, err := dbHandle.ExecContext(ctx, q, role.Description, util.GetTimeAsMsSinceEpoch(time.Now()), role.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlCommonRequireRowAffected(res)
|
||||
}
|
||||
|
||||
func sqlCommonDeleteRole(role Role, dbHandle *sql.DB) error {
|
||||
|
@ -1069,7 +1082,7 @@ func sqlCommonAddGroup(group *Group, dbHandle *sql.DB) error {
|
|||
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
|
||||
q := getAddGroupQuery()
|
||||
_, err := tx.ExecContext(ctx, q, group.Name, group.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
util.GetTimeAsMsSinceEpoch(time.Now()), string(settings))
|
||||
util.GetTimeAsMsSinceEpoch(time.Now()), settings)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1170,7 +1183,7 @@ func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, isSSHCert bo
|
|||
func sqlCommonCheckAvailability(dbHandle *sql.DB) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
providerLog(logger.LevelError, "panic in check provider availability, stack trace: %v", string(debug.Stack()))
|
||||
providerLog(logger.LevelError, "panic in check provider availability, stack trace: %s", string(debug.Stack()))
|
||||
err = errors.New("unable to check provider status")
|
||||
}
|
||||
}()
|
||||
|
@ -1189,10 +1202,10 @@ func sqlCommonUpdateTransferQuota(username string, uploadSize, downloadSize int6
|
|||
q := getUpdateTransferQuotaQuery(reset)
|
||||
_, err := dbHandle.ExecContext(ctx, q, uploadSize, downloadSize, util.GetTimeAsMsSinceEpoch(time.Now()), username)
|
||||
if err == nil {
|
||||
providerLog(logger.LevelDebug, "transfer quota updated for user %#v, ul increment: %v dl increment: %v is reset? %v",
|
||||
providerLog(logger.LevelDebug, "transfer quota updated for user %q, ul increment: %d dl increment: %d is reset? %t",
|
||||
username, uploadSize, downloadSize, reset)
|
||||
} else {
|
||||
providerLog(logger.LevelError, "error updating quota for user %#v: %v", username, err)
|
||||
providerLog(logger.LevelError, "error updating quota for user %q: %v", username, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -1204,10 +1217,10 @@ func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bo
|
|||
q := getUpdateQuotaQuery(reset)
|
||||
_, err := dbHandle.ExecContext(ctx, q, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), username)
|
||||
if err == nil {
|
||||
providerLog(logger.LevelDebug, "quota updated for user %#v, files increment: %v size increment: %v is reset? %v",
|
||||
providerLog(logger.LevelDebug, "quota updated for user %q, files increment: %d size increment: %d is reset? %t",
|
||||
username, filesAdd, sizeAdd, reset)
|
||||
} else {
|
||||
providerLog(logger.LevelError, "error updating quota for user %#v: %v", username, err)
|
||||
providerLog(logger.LevelError, "error updating quota for user %q: %v", username, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -1234,9 +1247,9 @@ func sqlCommonUpdateShareLastUse(shareID string, numTokens int, dbHandle *sql.DB
|
|||
q := getUpdateShareLastUseQuery()
|
||||
_, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), numTokens, shareID)
|
||||
if err == nil {
|
||||
providerLog(logger.LevelDebug, "last use updated for shared object %#v", shareID)
|
||||
providerLog(logger.LevelDebug, "last use updated for shared object %q", shareID)
|
||||
} else {
|
||||
providerLog(logger.LevelWarn, "error updating last use for shared object %#v: %v", shareID, err)
|
||||
providerLog(logger.LevelWarn, "error updating last use for shared object %q: %v", shareID, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -1248,9 +1261,9 @@ func sqlCommonUpdateAPIKeyLastUse(keyID string, dbHandle *sql.DB) error {
|
|||
q := getUpdateAPIKeyLastUseQuery()
|
||||
_, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), keyID)
|
||||
if err == nil {
|
||||
providerLog(logger.LevelDebug, "last use updated for key %#v", keyID)
|
||||
providerLog(logger.LevelDebug, "last use updated for key %q", keyID)
|
||||
} else {
|
||||
providerLog(logger.LevelWarn, "error updating last use for key %#v: %v", keyID, err)
|
||||
providerLog(logger.LevelWarn, "error updating last use for key %q: %v", keyID, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -1262,9 +1275,9 @@ func sqlCommonUpdateAdminLastLogin(username string, dbHandle *sql.DB) error {
|
|||
q := getUpdateAdminLastLoginQuery()
|
||||
_, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
|
||||
if err == nil {
|
||||
providerLog(logger.LevelDebug, "last login updated for admin %#v", username)
|
||||
providerLog(logger.LevelDebug, "last login updated for admin %q", username)
|
||||
} else {
|
||||
providerLog(logger.LevelWarn, "error updating last login for admin %#v: %v", username, err)
|
||||
providerLog(logger.LevelWarn, "error updating last login for admin %q: %v", username, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -1276,9 +1289,9 @@ func sqlCommonSetUpdatedAt(username string, dbHandle *sql.DB) {
|
|||
q := getSetUpdateAtQuery()
|
||||
_, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
|
||||
if err == nil {
|
||||
providerLog(logger.LevelDebug, "updated_at set for user %#v", username)
|
||||
providerLog(logger.LevelDebug, "updated_at set for user %q", username)
|
||||
} else {
|
||||
providerLog(logger.LevelWarn, "error setting updated_at for user %#v: %v", username, err)
|
||||
providerLog(logger.LevelWarn, "error setting updated_at for user %q: %v", username, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1313,9 +1326,9 @@ func sqlCommonUpdateLastLogin(username string, dbHandle *sql.DB) error {
|
|||
q := getUpdateLastLoginQuery()
|
||||
_, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username)
|
||||
if err == nil {
|
||||
providerLog(logger.LevelDebug, "last login updated for user %#v", username)
|
||||
providerLog(logger.LevelDebug, "last login updated for user %q", username)
|
||||
} else {
|
||||
providerLog(logger.LevelWarn, "error updating last login for user %#v: %v", username, err)
|
||||
providerLog(logger.LevelWarn, "error updating last login for user %q: %v", username, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -1353,9 +1366,9 @@ func sqlCommonAddUser(user *User, dbHandle *sql.DB) error {
|
|||
}
|
||||
}
|
||||
q := getAddUserQuery(user.Role)
|
||||
_, err := tx.ExecContext(ctx, q, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID,
|
||||
user.MaxSessions, user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth,
|
||||
user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo,
|
||||
_, err := tx.ExecContext(ctx, q, user.Username, user.Password, publicKeys, user.HomeDir, user.UID, user.GID,
|
||||
user.MaxSessions, user.QuotaSize, user.QuotaFiles, permissions, user.UploadBandwidth,
|
||||
user.DownloadBandwidth, user.Status, user.ExpirationDate, filters, fsConfig, user.AdditionalInfo,
|
||||
user.Description, user.Email, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer, user.Role, user.LastPasswordChange)
|
||||
if err != nil {
|
||||
|
@ -1373,8 +1386,11 @@ func sqlCommonUpdateUserPassword(username, password string, dbHandle *sql.DB) er
|
|||
defer cancel()
|
||||
|
||||
q := getUpdateUserPasswordQuery()
|
||||
_, err := dbHandle.ExecContext(ctx, q, password, username)
|
||||
return err
|
||||
res, err := dbHandle.ExecContext(ctx, q, password, username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlCommonRequireRowAffected(res)
|
||||
}
|
||||
|
||||
func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
|
||||
|
@ -1404,14 +1420,17 @@ func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
|
|||
|
||||
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
|
||||
q := getUpdateUserQuery(user.Role)
|
||||
_, err := tx.ExecContext(ctx, q, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions,
|
||||
user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status,
|
||||
user.ExpirationDate, string(filters), string(fsConfig), user.AdditionalInfo, user.Description, user.Email,
|
||||
res, err := tx.ExecContext(ctx, q, user.Password, publicKeys, user.HomeDir, user.UID, user.GID, user.MaxSessions,
|
||||
user.QuotaSize, user.QuotaFiles, permissions, user.UploadBandwidth, user.DownloadBandwidth, user.Status,
|
||||
user.ExpirationDate, filters, fsConfig, user.AdditionalInfo, user.Description, user.Email,
|
||||
util.GetTimeAsMsSinceEpoch(time.Now()), user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer,
|
||||
user.Role, user.LastPasswordChange, user.ID)
|
||||
user.Role, user.LastPasswordChange, user.Username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := sqlCommonRequireRowAffected(res); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := generateUserVirtualFoldersMapping(ctx, user, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1440,7 +1459,7 @@ func sqlCommonDeleteUser(user User, softDelete bool, dbHandle *sql.DB) error {
|
|||
return sqlCommonRequireRowAffected(res)
|
||||
})
|
||||
}
|
||||
res, err := dbHandle.ExecContext(ctx, q, user.ID)
|
||||
res, err := dbHandle.ExecContext(ctx, q, user.Username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -2292,7 +2311,7 @@ func sqlCommonAddOrUpdateFolder(ctx context.Context, baseFolder *vfs.BaseVirtual
|
|||
}
|
||||
q := getUpsertFolderQuery()
|
||||
_, err = dbHandle.ExecContext(ctx, q, baseFolder.MappedPath, usedQuotaSize, usedQuotaFiles,
|
||||
lastQuotaUpdate, baseFolder.Name, baseFolder.Description, string(fsConfig))
|
||||
lastQuotaUpdate, baseFolder.Name, baseFolder.Description, fsConfig)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -2310,7 +2329,7 @@ func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) erro
|
|||
|
||||
q := getAddFolderQuery()
|
||||
_, err = dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.UsedQuotaSize, folder.UsedQuotaFiles,
|
||||
folder.LastQuotaUpdate, folder.Name, folder.Description, string(fsConfig))
|
||||
folder.LastQuotaUpdate, folder.Name, folder.Description, fsConfig)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -2327,8 +2346,11 @@ func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) e
|
|||
defer cancel()
|
||||
|
||||
q := getUpdateFolderQuery()
|
||||
_, err = dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.Description, string(fsConfig), folder.Name)
|
||||
return err
|
||||
res, err := dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.Description, fsConfig, folder.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlCommonRequireRowAffected(res)
|
||||
}
|
||||
|
||||
func sqlCommonDeleteFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
|
||||
|
@ -2336,7 +2358,7 @@ func sqlCommonDeleteFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) er
|
|||
defer cancel()
|
||||
|
||||
q := getDeleteFolderQuery()
|
||||
res, err := dbHandle.ExecContext(ctx, q, folder.ID)
|
||||
res, err := dbHandle.ExecContext(ctx, q, folder.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -2485,7 +2507,7 @@ func sqlCommonAddAdminGroupMapping(ctx context.Context, username, groupName stri
|
|||
return err
|
||||
}
|
||||
q := getAddAdminGroupMappingQuery()
|
||||
_, err = dbHandle.ExecContext(ctx, q, username, groupName, string(options))
|
||||
_, err = dbHandle.ExecContext(ctx, q, username, groupName, options)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -3082,7 +3104,7 @@ func sqlCommonUpdateFolderQuota(name string, filesAdd int, sizeAdd int64, reset
|
|||
q := getUpdateFolderQuotaQuery(reset)
|
||||
_, err := dbHandle.ExecContext(ctx, q, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), name)
|
||||
if err == nil {
|
||||
providerLog(logger.LevelDebug, "quota updated for folder %#v, files increment: %v size increment: %v is reset? %v",
|
||||
providerLog(logger.LevelDebug, "quota updated for folder %q, files increment: %d size increment: %d is reset? %t",
|
||||
name, filesAdd, sizeAdd, reset)
|
||||
} else {
|
||||
providerLog(logger.LevelWarn, "error updating quota for folder %#v: %v", name, err)
|
||||
|
@ -3337,7 +3359,7 @@ func generateEventRuleActionsMapping(ctx context.Context, rule *EventRule, dbHan
|
|||
return err
|
||||
}
|
||||
q = getAddRuleActionMappingQuery()
|
||||
_, err = dbHandle.ExecContext(ctx, q, rule.Name, action.Name, action.Order, string(options))
|
||||
_, err = dbHandle.ExecContext(ctx, q, rule.Name, action.Name, action.Order, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -3444,7 +3466,7 @@ func sqlCommonAddEventAction(action *BaseEventAction, dbHandle *sql.DB) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = dbHandle.ExecContext(ctx, q, action.Name, action.Description, action.Type, string(options))
|
||||
_, err = dbHandle.ExecContext(ctx, q, action.Name, action.Description, action.Type, options)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -3461,10 +3483,13 @@ func sqlCommonUpdateEventAction(action *BaseEventAction, dbHandle *sql.DB) error
|
|||
|
||||
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
|
||||
q := getUpdateEventActionQuery()
|
||||
_, err = tx.ExecContext(ctx, q, action.Description, action.Type, string(options), action.Name)
|
||||
res, err := tx.ExecContext(ctx, q, action.Description, action.Type, options, action.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := sqlCommonRequireRowAffected(res); err != nil {
|
||||
return err
|
||||
}
|
||||
q = getUpdateRulesTimestampQuery()
|
||||
_, err = tx.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), action.ID)
|
||||
return err
|
||||
|
@ -3609,7 +3634,7 @@ func sqlCommonAddEventRule(rule *EventRule, dbHandle *sql.DB) error {
|
|||
}
|
||||
q := getAddEventRuleQuery()
|
||||
_, err := tx.ExecContext(ctx, q, rule.Name, rule.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
util.GetTimeAsMsSinceEpoch(time.Now()), rule.Trigger, string(conditions), rule.Status)
|
||||
util.GetTimeAsMsSinceEpoch(time.Now()), rule.Trigger, conditions, rule.Status)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -3631,7 +3656,7 @@ func sqlCommonUpdateEventRule(rule *EventRule, dbHandle *sql.DB) error {
|
|||
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
|
||||
q := getUpdateEventRuleQuery()
|
||||
_, err := tx.ExecContext(ctx, q, rule.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
rule.Trigger, string(conditions), rule.Status, rule.Name)
|
||||
rule.Trigger, conditions, rule.Status, rule.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -3744,7 +3769,7 @@ func sqlCommonAddNode(dbHandle *sql.DB) error {
|
|||
defer cancel()
|
||||
|
||||
q := getAddNodeQuery()
|
||||
_, err = dbHandle.ExecContext(ctx, q, currentNode.Name, string(data), util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
_, err = dbHandle.ExecContext(ctx, q, currentNode.Name, data, util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
util.GetTimeAsMsSinceEpoch(time.Now()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to register cluster node: %w", err)
|
||||
|
@ -3857,9 +3882,6 @@ func sqlCommonSetConfigs(configs *Configs, dbHandle *sql.DB) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if config.Driver == MySQLDataProviderName {
|
||||
return nil
|
||||
}
|
||||
return sqlCommonRequireRowAffected(res)
|
||||
}
|
||||
|
||||
|
@ -3884,8 +3906,6 @@ func sqlCommonGetDatabaseVersion(dbHandle sqlQuerier, showInitWarn bool) (schema
|
|||
}
|
||||
|
||||
func sqlCommonRequireRowAffected(res sql.Result) error {
|
||||
// MariaDB/MySQL returns 0 rows affected for updates that don't change anything
|
||||
// so we don't check rows affected for updates
|
||||
affected, err := res.RowsAffected()
|
||||
if err == nil && affected == 0 {
|
||||
return util.NewRecordNotFoundError(sql.ErrNoRows.Error())
|
||||
|
|
|
@ -221,17 +221,17 @@ func initializeSQLiteProvider(basePath string) error {
|
|||
if !filepath.IsAbs(dbPath) {
|
||||
dbPath = filepath.Join(basePath, dbPath)
|
||||
}
|
||||
connectionString = fmt.Sprintf("file:%v?cache=shared&_foreign_keys=1", dbPath)
|
||||
connectionString = fmt.Sprintf("file:%s?cache=shared&_foreign_keys=1", dbPath)
|
||||
} else {
|
||||
connectionString = config.ConnectionString
|
||||
}
|
||||
dbHandle, err := sql.Open("sqlite3", connectionString)
|
||||
if err == nil {
|
||||
providerLog(logger.LevelDebug, "sqlite database handle created, connection string: %#v", connectionString)
|
||||
providerLog(logger.LevelDebug, "sqlite database handle created, connection string: %q", connectionString)
|
||||
dbHandle.SetMaxOpenConns(1)
|
||||
provider = &SQLiteProvider{dbHandle: dbHandle}
|
||||
} else {
|
||||
providerLog(logger.LevelError, "error creating sqlite database handler, connection string: %#v, error: %v",
|
||||
providerLog(logger.LevelError, "error creating sqlite database handler, connection string: %q, error: %v",
|
||||
connectionString, err)
|
||||
}
|
||||
return err
|
||||
|
|
|
@ -706,7 +706,7 @@ func getUpdateUserQuery(role string) string {
|
|||
return fmt.Sprintf(`UPDATE %s SET password=%s,public_keys=%s,home_dir=%s,uid=%s,gid=%s,max_sessions=%s,quota_size=%s,
|
||||
quota_files=%s,permissions=%s,upload_bandwidth=%s,download_bandwidth=%s,status=%s,expiration_date=%s,filters=%s,filesystem=%s,
|
||||
additional_info=%s,description=%s,email=%s,updated_at=%s,upload_data_transfer=%s,download_data_transfer=%s,
|
||||
total_data_transfer=%s,role_id=COALESCE((SELECT id from %s WHERE name=%s),%s),last_password_change=%s WHERE id = %s`,
|
||||
total_data_transfer=%s,role_id=COALESCE((SELECT id from %s WHERE name=%s),%s),last_password_change=%s WHERE username = %s`,
|
||||
sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4],
|
||||
sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9],
|
||||
sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13], sqlPlaceholders[14],
|
||||
|
@ -725,7 +725,7 @@ func getDeleteUserQuery(softDelete bool) string {
|
|||
return fmt.Sprintf(`UPDATE %s SET updated_at=%s,deleted_at=%s WHERE username = %s`,
|
||||
sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2])
|
||||
}
|
||||
return fmt.Sprintf(`DELETE FROM %s WHERE id = %s`, sqlTableUsers, sqlPlaceholders[0])
|
||||
return fmt.Sprintf(`DELETE FROM %s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0])
|
||||
}
|
||||
|
||||
func getRemoveSoftDeletedUserQuery() string {
|
||||
|
@ -748,7 +748,7 @@ func getUpdateFolderQuery() string {
|
|||
}
|
||||
|
||||
func getDeleteFolderQuery() string {
|
||||
return fmt.Sprintf(`DELETE FROM %s WHERE id = %s`, sqlTableFolders, sqlPlaceholders[0])
|
||||
return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, sqlTableFolders, sqlPlaceholders[0])
|
||||
}
|
||||
|
||||
func getUpsertFolderQuery() string {
|
||||
|
|
Loading…
Reference in a new issue