From fb2d59ec9250e86b1e8272e5c99e90e90ddb2e86 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Sun, 30 Jan 2022 18:04:03 +0100 Subject: [PATCH] data provider: add config options for certs validation/authentication Fixes #682 Signed-off-by: Nicola Murino --- config/config.go | 6 ++++ dataprovider/dataprovider.go | 38 ++++++++++++++++++----- dataprovider/mysql.go | 60 +++++++++++++++++++++++++++++++----- dataprovider/pgsql.go | 6 ++++ docs/full-configuration.md | 5 ++- sftpgo.json | 3 ++ 6 files changed, 101 insertions(+), 17 deletions(-) diff --git a/config/config.go b/config/config.go index 1b83953e..11d66e1e 100644 --- a/config/config.go +++ b/config/config.go @@ -229,6 +229,9 @@ func Init() { ConnectionString: "", SQLTablesPrefix: "", SSLMode: 0, + RootCert: "", + ClientCert: "", + ClientKey: "", TrackQuota: 1, PoolSize: 0, UsersBaseDir: "", @@ -1276,6 +1279,9 @@ func setViperDefaults() { viper.SetDefault("data_provider.username", globalConf.ProviderConf.Username) viper.SetDefault("data_provider.password", globalConf.ProviderConf.Password) viper.SetDefault("data_provider.sslmode", globalConf.ProviderConf.SSLMode) + viper.SetDefault("data_provider.root_cert", globalConf.ProviderConf.RootCert) + viper.SetDefault("data_provider.client_cert", globalConf.ProviderConf.ClientCert) + viper.SetDefault("data_provider.client_key", globalConf.ProviderConf.ClientKey) viper.SetDefault("data_provider.connection_string", globalConf.ProviderConf.ConnectionString) viper.SetDefault("data_provider.sql_tables_prefix", globalConf.ProviderConf.SQLTablesPrefix) viper.SetDefault("data_provider.track_quota", globalConf.ProviderConf.TrackQuota) diff --git a/dataprovider/dataprovider.go b/dataprovider/dataprovider.go index 68a7170a..dd6a9e91 100644 --- a/dataprovider/dataprovider.go +++ b/dataprovider/dataprovider.go @@ -252,6 +252,12 @@ type Config struct { // 2 set ssl mode to verify-ca for driver postgresql and skip-verify for driver mysql. // 3 set ssl mode to verify-full for driver postgresql and preferred for driver mysql. SSLMode int `json:"sslmode" mapstructure:"sslmode"` + // Path to the root certificate authority used to verify that the server certificate was signed by a trusted CA + RootCert string `json:"root_cert" mapstructure:"root_cert"` + // Path to the client certificate for two-way TLS authentication + ClientCert string `json:"client_cert" mapstructure:"client_cert"` + // Path to the client key for two-way TLS authentication + ClientKey string `json:"client_key" mapstructure:"client_key"` // Custom database connection string. // If not empty this connection string will be used instead of build one using the previous parameters ConnectionString string `json:"connection_string" mapstructure:"connection_string"` @@ -392,6 +398,17 @@ func (c *Config) IsDefenderSupported() bool { } } +func (c *Config) requireCustomTLSForMySQL() bool { + if config.RootCert != "" && util.IsFileInputValid(config.RootCert) { + return config.SSLMode != 0 + } + if config.ClientCert != "" && config.ClientKey != "" && util.IsFileInputValid(config.ClientCert) && + util.IsFileInputValid(config.ClientKey) { + return config.SSLMode != 0 + } + return false +} + // ActiveTransfer defines an active protocol transfer type ActiveTransfer struct { ID int64 @@ -2420,23 +2437,28 @@ func addFolderCredentialsToUser(user *User) error { func getSSLMode() string { if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName { - if config.SSLMode == 0 { + switch config.SSLMode { + case 0: return "disable" - } else if config.SSLMode == 1 { + case 1: return "require" - } else if config.SSLMode == 2 { + case 2: return "verify-ca" - } else if config.SSLMode == 3 { + case 3: return "verify-full" } } else if config.Driver == MySQLDataProviderName { - if config.SSLMode == 0 { + if config.requireCustomTLSForMySQL() { + return "custom" + } + switch config.SSLMode { + case 0: return "false" - } else if config.SSLMode == 1 { + case 1: return "true" - } else if config.SSLMode == 2 { + case 2: return "skip-verify" - } else if config.SSLMode == 3 { + case 3: return "preferred" } } diff --git a/dataprovider/mysql.go b/dataprovider/mysql.go index bcf54253..e9dd7444 100644 --- a/dataprovider/mysql.go +++ b/dataprovider/mysql.go @@ -5,15 +5,16 @@ package dataprovider import ( "context" + "crypto/tls" "crypto/x509" "database/sql" "errors" "fmt" + "os" "strings" "time" - // we import go-sql-driver/mysql here to be able to disable MySQL support using a build tag - _ "github.com/go-sql-driver/mysql" + "github.com/go-sql-driver/mysql" "github.com/drakkan/sftpgo/v2/logger" "github.com/drakkan/sftpgo/v2/version" @@ -114,10 +115,18 @@ func init() { func initializeMySQLProvider() error { var err error - dbHandle, err := sql.Open("mysql", getMySQLConnectionString(false)) + connString, err := getMySQLConnectionString(false) + if err != nil { + return err + } + redactedConnString, err := getMySQLConnectionString(true) + if err != nil { + return err + } + dbHandle, err := sql.Open("mysql", connString) if err == nil { providerLog(logger.LevelDebug, "mysql database handle created, connection string: %#v, pool size: %v", - getMySQLConnectionString(true), config.PoolSize) + redactedConnString, config.PoolSize) dbHandle.SetMaxOpenConns(config.PoolSize) if config.PoolSize > 0 { dbHandle.SetMaxIdleConns(config.PoolSize) @@ -128,23 +137,58 @@ func initializeMySQLProvider() error { provider = &MySQLProvider{dbHandle: dbHandle} } else { providerLog(logger.LevelError, "error creating mysql database handler, connection string: %#v, error: %v", - getMySQLConnectionString(true), err) + redactedConnString, err) } return err } -func getMySQLConnectionString(redactedPwd bool) string { +func getMySQLConnectionString(redactedPwd bool) (string, error) { var connectionString string if config.ConnectionString == "" { password := config.Password if redactedPwd { password = "[redacted]" } + sslMode := getSSLMode() + if sslMode == "custom" && !redactedPwd { + tlsConfig := &tls.Config{} + if config.RootCert != "" { + rootCAs, err := x509.SystemCertPool() + if err != nil { + rootCAs = x509.NewCertPool() + } + rootCrt, err := os.ReadFile(config.RootCert) + if err != nil { + return "", fmt.Errorf("unable to load root certificate %#v: %v", config.RootCert, err) + } + if !rootCAs.AppendCertsFromPEM(rootCrt) { + return "", fmt.Errorf("unable to parse root certificate %#v", config.RootCert) + } + tlsConfig.RootCAs = rootCAs + } + if config.ClientCert != "" && config.ClientKey != "" { + clientCert := make([]tls.Certificate, 0, 1) + tlsCert, err := tls.LoadX509KeyPair(config.ClientCert, config.ClientKey) + if err != nil { + return "", fmt.Errorf("unable to load key pair %#v, %#v: %v", config.ClientCert, config.ClientKey, err) + } + clientCert = append(clientCert, tlsCert) + tlsConfig.Certificates = clientCert + } + if config.SSLMode == 2 { + tlsConfig.InsecureSkipVerify = true + } + providerLog(logger.LevelInfo, "registering custom TLS config, root cert %#v, client cert %#v, client key %#v", + config.RootCert, config.ClientCert, config.ClientKey) + if err := mysql.RegisterTLSConfig("custom", tlsConfig); err != nil { + return "", fmt.Errorf("unable to register tls config: %v", err) + } + } connectionString = fmt.Sprintf("%v:%v@tcp([%v]:%v)/%v?charset=utf8mb4&interpolateParams=true&timeout=10s&parseTime=true&tls=%v&writeTimeout=10s&readTimeout=10s", - config.Username, password, config.Host, config.Port, config.Name, getSSLMode()) + config.Username, password, config.Host, config.Port, config.Name, sslMode) } else { connectionString = config.ConnectionString } - return connectionString + return connectionString, nil } func (p *MySQLProvider) checkAvailability() error { diff --git a/dataprovider/pgsql.go b/dataprovider/pgsql.go index faf8484e..7c5697f6 100644 --- a/dataprovider/pgsql.go +++ b/dataprovider/pgsql.go @@ -155,6 +155,12 @@ func getPGSQLConnectionString(redactedPwd bool) string { } connectionString = fmt.Sprintf("host='%v' port=%v dbname='%v' user='%v' password='%v' sslmode=%v connect_timeout=10", config.Host, config.Port, config.Name, config.Username, password, getSSLMode()) + if config.RootCert != "" { + connectionString += fmt.Sprintf(" sslrootcert='%v'", config.RootCert) + } + if config.ClientCert != "" && config.ClientKey != "" { + connectionString += fmt.Sprintf(" sslcert='%v' sslkey='%v'", config.ClientCert, config.ClientKey) + } } else { connectionString = config.ConnectionString } diff --git a/docs/full-configuration.md b/docs/full-configuration.md index 1d91d2bc..4dde9c28 100644 --- a/docs/full-configuration.md +++ b/docs/full-configuration.md @@ -174,7 +174,10 @@ The configuration file contains the following sections: - `port`, integer. Database port. Leave empty for drivers `sqlite`, `bolt` and `memory` - `username`, string. Database user. Leave empty for drivers `sqlite`, `bolt` and `memory` - `password`, string. Database password. Leave empty for drivers `sqlite`, `bolt` and `memory` - - `sslmode`, integer. Used for drivers `mysql` and `postgresql`. 0 disable SSL/TLS connections, 1 require ssl, 2 set ssl mode to `verify-ca` for driver `postgresql` and `skip-verify` for driver `mysql`, 3 set ssl mode to `verify-full` for driver `postgresql` and `preferred` for driver `mysql` + - `sslmode`, integer. Used for drivers `mysql` and `postgresql`. 0 disable TLS connections, 1 require TLS, 2 set TLS mode to `verify-ca` for driver `postgresql` and `skip-verify` for driver `mysql`, 3 set TLS mode to `verify-full` for driver `postgresql` and `preferred` for driver `mysql` + - `root_cert`, string. Path to the root certificate authority used to verify that the server certificate was signed by a trusted CA + - `client_cert`, string. Path to the client certificate for two-way TLS authentication + - `client_key`,string. Path to the client key for two-way TLS authentication - `connection_string`, string. Provide a custom database connection string. If not empty, this connection string will be used instead of building one using the previous parameters. Leave empty for drivers `bolt` and `memory` - `sql_tables_prefix`, string. Prefix for SQL tables - `track_quota`, integer. Set the preferred mode to track users quota between the following choices: diff --git a/sftpgo.json b/sftpgo.json index f5b86344..15d15472 100644 --- a/sftpgo.json +++ b/sftpgo.json @@ -155,6 +155,9 @@ "username": "", "password": "", "sslmode": 0, + "root_cert": "", + "client_cert": "", + "client_key": "", "connection_string": "", "sql_tables_prefix": "", "track_quota": 2,