mysql.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. package dataprovider
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "runtime"
  6. "time"
  7. "github.com/drakkan/sftpgo/logger"
  8. )
  9. // MySQLProvider auth provider for sqlite database
  10. type MySQLProvider struct {
  11. }
  12. func initializeMySQLProvider() error {
  13. var err error
  14. var connectionString string
  15. if len(config.ConnectionString) == 0 {
  16. connectionString = fmt.Sprintf("%v:%v@tcp([%v]:%v)/%v?charset=utf8&interpolateParams=true&timeout=10s&tls=%v",
  17. config.Username, config.Password, config.Host, config.Port, config.Name, getSSLMode())
  18. } else {
  19. connectionString = config.ConnectionString
  20. }
  21. dbHandle, err = sql.Open("mysql", connectionString)
  22. if err == nil {
  23. numCPU := runtime.NumCPU()
  24. logger.Debug(logSender, "mysql database handle created, connection string: \"%v\", pool size: %v", connectionString, numCPU)
  25. dbHandle.SetMaxIdleConns(numCPU)
  26. dbHandle.SetMaxOpenConns(numCPU)
  27. dbHandle.SetConnMaxLifetime(1800 * time.Second)
  28. } else {
  29. logger.Warn(logSender, "error creating mysql database handler, connection string: \"%v\", error: %v", connectionString, err)
  30. }
  31. return err
  32. }
  33. func (p MySQLProvider) validateUserAndPass(username string, password string) (User, error) {
  34. return sqlCommonValidateUserAndPass(username, password)
  35. }
  36. func (p MySQLProvider) validateUserAndPubKey(username string, publicKey string) (User, error) {
  37. return sqlCommonValidateUserAndPubKey(username, publicKey)
  38. }
  39. func (p MySQLProvider) getUserByID(ID int64) (User, error) {
  40. return sqlCommonGetUserByID(ID)
  41. }
  42. func (p MySQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
  43. tx, err := dbHandle.Begin()
  44. if err != nil {
  45. logger.Warn(logSender, "error starting transaction to update quota for user %v: %v", username, err)
  46. return err
  47. }
  48. err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p)
  49. if err == nil {
  50. err = tx.Commit()
  51. } else {
  52. err = tx.Rollback()
  53. }
  54. if err != nil {
  55. logger.Warn(logSender, "error closing transaction to update quota for user %v: %v", username, err)
  56. }
  57. return err
  58. }
  59. func (p MySQLProvider) getUsedQuota(username string) (int, int64, error) {
  60. return sqlCommonGetUsedQuota(username)
  61. }
  62. func (p MySQLProvider) userExists(username string) (User, error) {
  63. return sqlCommonCheckUserExists(username)
  64. }
  65. func (p MySQLProvider) addUser(user User) error {
  66. return sqlCommonAddUser(user)
  67. }
  68. func (p MySQLProvider) updateUser(user User) error {
  69. return sqlCommonUpdateUser(user)
  70. }
  71. func (p MySQLProvider) deleteUser(user User) error {
  72. return sqlCommonDeleteUser(user)
  73. }
  74. func (p MySQLProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
  75. return sqlCommonGetUsers(limit, offset, order, username)
  76. }