mirror of
https://github.com/drakkan/sftpgo.git
synced 2024-11-26 01:20:29 +00:00
8d4964c16d
Added a compatibility layer that will convert newline delimited keys to array when the user is fetched from the database. This code will be removed in future versions please update your public keys, you only need to resave the users using the REST API.
300 lines
8.4 KiB
Go
300 lines
8.4 KiB
Go
package dataprovider
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
|
|
"github.com/alexedwards/argon2id"
|
|
"golang.org/x/crypto/bcrypt"
|
|
|
|
"github.com/drakkan/sftpgo/logger"
|
|
"github.com/drakkan/sftpgo/utils"
|
|
)
|
|
|
|
func getUserByUsername(username string) (User, error) {
|
|
var user User
|
|
q := getUserByUsernameQuery()
|
|
stmt, err := dbHandle.Prepare(q)
|
|
if err != nil {
|
|
logger.Debug(logSender, "error preparing database query %v: %v", q, err)
|
|
return user, err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
row := stmt.QueryRow(username)
|
|
return getUserFromDbRow(row, nil)
|
|
}
|
|
|
|
func sqlCommonValidateUserAndPass(username string, password string) (User, error) {
|
|
var user User
|
|
if len(password) == 0 {
|
|
return user, errors.New("Credentials cannot be null or empty")
|
|
}
|
|
user, err := getUserByUsername(username)
|
|
if err != nil {
|
|
logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err)
|
|
} else {
|
|
// even if the password is empty inside the database an empty user password
|
|
// will be refused anyway so it cannot match, additional check to be paranoid
|
|
if len(user.Password) == 0 {
|
|
return user, errors.New("Credentials cannot be null or empty")
|
|
}
|
|
var match bool
|
|
if strings.HasPrefix(user.Password, argonPwdPrefix) {
|
|
match, err = argon2id.ComparePasswordAndHash(password, user.Password)
|
|
if err != nil {
|
|
logger.Warn(logSender, "error comparing password with argon hash: %v", err)
|
|
return user, err
|
|
}
|
|
} else if strings.HasPrefix(user.Password, bcryptPwdPrefix) {
|
|
if err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
|
|
logger.Warn(logSender, "error comparing password with bcrypt hash: %v", err)
|
|
return user, err
|
|
}
|
|
match = true
|
|
} else {
|
|
// clear text password match
|
|
match = (user.Password == password)
|
|
}
|
|
if !match {
|
|
err = errors.New("Invalid credentials")
|
|
}
|
|
}
|
|
return user, err
|
|
}
|
|
|
|
func sqlCommonValidateUserAndPubKey(username string, pubKey string) (User, error) {
|
|
var user User
|
|
if len(pubKey) == 0 {
|
|
return user, errors.New("Credentials cannot be null or empty")
|
|
}
|
|
user, err := getUserByUsername(username)
|
|
if err != nil {
|
|
logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err)
|
|
return user, err
|
|
}
|
|
if len(user.PublicKey) == 0 {
|
|
return user, errors.New("Invalid credentials")
|
|
}
|
|
|
|
for i, k := range user.PublicKey {
|
|
storedPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(k))
|
|
if err != nil {
|
|
logger.Warn(logSender, "error parsing stored public key %d for user %v: %v", i, username, err)
|
|
return user, err
|
|
}
|
|
if string(storedPubKey.Marshal()) == pubKey {
|
|
return user, nil
|
|
}
|
|
}
|
|
return user, errors.New("Invalid credentials")
|
|
}
|
|
|
|
func sqlCommonGetUserByID(ID int64) (User, error) {
|
|
var user User
|
|
q := getUserByIDQuery()
|
|
stmt, err := dbHandle.Prepare(q)
|
|
if err != nil {
|
|
logger.Debug(logSender, "error preparing database query %v: %v", q, err)
|
|
return user, err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
row := stmt.QueryRow(ID)
|
|
return getUserFromDbRow(row, nil)
|
|
}
|
|
|
|
func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, p Provider) error {
|
|
q := getUpdateQuotaQuery(reset)
|
|
stmt, err := dbHandle.Prepare(q)
|
|
if err != nil {
|
|
logger.Debug(logSender, "error preparing database query %v: %v", q, err)
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
_, err = stmt.Exec(sizeAdd, filesAdd, utils.GetTimeAsMsSinceEpoch(time.Now()), username)
|
|
if err == nil {
|
|
logger.Debug(logSender, "quota updated for user %v, files increment: %v size increment: %v is reset? %v",
|
|
username, filesAdd, sizeAdd, reset)
|
|
} else {
|
|
logger.Warn(logSender, "error updating quota for username %v: %v", username, err)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func sqlCommonGetUsedQuota(username string) (int, int64, error) {
|
|
q := getQuotaQuery()
|
|
stmt, err := dbHandle.Prepare(q)
|
|
if err != nil {
|
|
logger.Warn(logSender, "error preparing database query %v: %v", q, err)
|
|
return 0, 0, err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
var usedFiles int
|
|
var usedSize int64
|
|
err = stmt.QueryRow(username).Scan(&usedSize, &usedFiles)
|
|
if err != nil {
|
|
logger.Warn(logSender, "error getting user quota: %v, error: %v", username, err)
|
|
return 0, 0, err
|
|
}
|
|
return usedFiles, usedSize, err
|
|
}
|
|
|
|
func sqlCommonCheckUserExists(username string) (User, error) {
|
|
var user User
|
|
q := getUserByUsernameQuery()
|
|
stmt, err := dbHandle.Prepare(q)
|
|
if err != nil {
|
|
logger.Warn(logSender, "error preparing database query %v: %v", q, err)
|
|
return user, err
|
|
}
|
|
defer stmt.Close()
|
|
row := stmt.QueryRow(username)
|
|
return getUserFromDbRow(row, nil)
|
|
}
|
|
|
|
func sqlCommonAddUser(user User) error {
|
|
err := validateUser(&user)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
q := getAddUserQuery()
|
|
stmt, err := dbHandle.Prepare(q)
|
|
if err != nil {
|
|
logger.Warn(logSender, "error preparing database query %v: %v", q, err)
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
permissions, err := user.GetPermissionsAsJSON()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
publicKeys, err := user.GetPublicKeysAsJSON()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = stmt.Exec(user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
|
|
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth)
|
|
return err
|
|
}
|
|
|
|
func sqlCommonUpdateUser(user User) error {
|
|
err := validateUser(&user)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
q := getUpdateUserQuery()
|
|
stmt, err := dbHandle.Prepare(q)
|
|
if err != nil {
|
|
logger.Warn(logSender, "error preparing database query %v: %v", q, err)
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
permissions, err := user.GetPermissionsAsJSON()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
publicKeys, err := user.GetPublicKeysAsJSON()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = stmt.Exec(user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
|
|
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.ID)
|
|
return err
|
|
}
|
|
|
|
func sqlCommonDeleteUser(user User) error {
|
|
q := getDeleteUserQuery()
|
|
stmt, err := dbHandle.Prepare(q)
|
|
if err != nil {
|
|
logger.Warn(logSender, "error preparing database query %v: %v", q, err)
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
_, err = stmt.Exec(user.ID)
|
|
return err
|
|
}
|
|
|
|
func sqlCommonGetUsers(limit int, offset int, order string, username string) ([]User, error) {
|
|
users := []User{}
|
|
q := getUsersQuery(order, username)
|
|
stmt, err := dbHandle.Prepare(q)
|
|
if err != nil {
|
|
logger.Warn(logSender, "error preparing database query %v: %v", q, err)
|
|
return nil, err
|
|
}
|
|
defer stmt.Close()
|
|
var rows *sql.Rows
|
|
if len(username) > 0 {
|
|
rows, err = stmt.Query(username, limit, offset)
|
|
} else {
|
|
rows, err = stmt.Query(limit, offset)
|
|
}
|
|
if err == nil {
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
u, err := getUserFromDbRow(nil, rows)
|
|
// hide password and public key
|
|
u.Password = ""
|
|
u.PublicKey = []string{}
|
|
if err == nil {
|
|
users = append(users, u)
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
return users, err
|
|
}
|
|
|
|
func getUserFromDbRow(row *sql.Row, rows *sql.Rows) (User, error) {
|
|
var user User
|
|
var permissions sql.NullString
|
|
var password sql.NullString
|
|
var publicKey sql.NullString
|
|
var err error
|
|
if row != nil {
|
|
err = row.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions,
|
|
&user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate,
|
|
&user.UploadBandwidth, &user.DownloadBandwidth)
|
|
|
|
} else {
|
|
err = rows.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions,
|
|
&user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate,
|
|
&user.UploadBandwidth, &user.DownloadBandwidth)
|
|
}
|
|
if err != nil {
|
|
return user, err
|
|
}
|
|
if password.Valid {
|
|
user.Password = password.String
|
|
}
|
|
if publicKey.Valid {
|
|
var list []string
|
|
err = json.Unmarshal([]byte(publicKey.String), &list)
|
|
if err == nil {
|
|
user.PublicKey = list
|
|
} else {
|
|
// compatibility layer: initially we store public keys as string newline delimited
|
|
// we need to remove this code in future
|
|
user.PublicKey = strings.Split(publicKey.String, "\n")
|
|
logger.Warn(logSender, "public keys loaded using compatibility mode, this will not work in future versions! "+
|
|
"Number of public keys loaded: %v, username: %v", len(user.PublicKey), user.Username)
|
|
}
|
|
}
|
|
if permissions.Valid {
|
|
var list []string
|
|
err = json.Unmarshal([]byte(permissions.String), &list)
|
|
if err == nil {
|
|
user.Permissions = list
|
|
}
|
|
}
|
|
return user, err
|
|
}
|