浏览代码

data provider: add config options for certs validation/authentication

Fixes #682

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
Nicola Murino 3 年之前
父节点
当前提交
fb2d59ec92
共有 6 个文件被更改,包括 101 次插入17 次删除
  1. 6 0
      config/config.go
  2. 30 8
      dataprovider/dataprovider.go
  3. 52 8
      dataprovider/mysql.go
  4. 6 0
      dataprovider/pgsql.go
  5. 4 1
      docs/full-configuration.md
  6. 3 0
      sftpgo.json

+ 6 - 0
config/config.go

@@ -229,6 +229,9 @@ func Init() {
 			ConnectionString: "",
 			SQLTablesPrefix:  "",
 			SSLMode:          0,
+			RootCert:         "",
+			ClientCert:       "",
+			ClientKey:        "",
 			TrackQuota:       1,
 			PoolSize:         0,
 			UsersBaseDir:     "",
@@ -1276,6 +1279,9 @@ func setViperDefaults() {
 	viper.SetDefault("data_provider.username", globalConf.ProviderConf.Username)
 	viper.SetDefault("data_provider.password", globalConf.ProviderConf.Password)
 	viper.SetDefault("data_provider.sslmode", globalConf.ProviderConf.SSLMode)
+	viper.SetDefault("data_provider.root_cert", globalConf.ProviderConf.RootCert)
+	viper.SetDefault("data_provider.client_cert", globalConf.ProviderConf.ClientCert)
+	viper.SetDefault("data_provider.client_key", globalConf.ProviderConf.ClientKey)
 	viper.SetDefault("data_provider.connection_string", globalConf.ProviderConf.ConnectionString)
 	viper.SetDefault("data_provider.sql_tables_prefix", globalConf.ProviderConf.SQLTablesPrefix)
 	viper.SetDefault("data_provider.track_quota", globalConf.ProviderConf.TrackQuota)

+ 30 - 8
dataprovider/dataprovider.go

@@ -252,6 +252,12 @@ type Config struct {
 	// 2 set ssl mode to verify-ca for driver postgresql and skip-verify for driver mysql.
 	// 3 set ssl mode to verify-full for driver postgresql and preferred for driver mysql.
 	SSLMode int `json:"sslmode" mapstructure:"sslmode"`
+	// Path to the root certificate authority used to verify that the server certificate was signed by a trusted CA
+	RootCert string `json:"root_cert" mapstructure:"root_cert"`
+	// Path to the client certificate for two-way TLS authentication
+	ClientCert string `json:"client_cert" mapstructure:"client_cert"`
+	// Path to the client key for two-way TLS authentication
+	ClientKey string `json:"client_key" mapstructure:"client_key"`
 	// Custom database connection string.
 	// If not empty this connection string will be used instead of build one using the previous parameters
 	ConnectionString string `json:"connection_string" mapstructure:"connection_string"`
@@ -392,6 +398,17 @@ func (c *Config) IsDefenderSupported() bool {
 	}
 }
 
+func (c *Config) requireCustomTLSForMySQL() bool {
+	if config.RootCert != "" && util.IsFileInputValid(config.RootCert) {
+		return config.SSLMode != 0
+	}
+	if config.ClientCert != "" && config.ClientKey != "" && util.IsFileInputValid(config.ClientCert) &&
+		util.IsFileInputValid(config.ClientKey) {
+		return config.SSLMode != 0
+	}
+	return false
+}
+
 // ActiveTransfer defines an active protocol transfer
 type ActiveTransfer struct {
 	ID            int64
@@ -2420,23 +2437,28 @@ func addFolderCredentialsToUser(user *User) error {
 
 func getSSLMode() string {
 	if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName {
-		if config.SSLMode == 0 {
+		switch config.SSLMode {
+		case 0:
 			return "disable"
-		} else if config.SSLMode == 1 {
+		case 1:
 			return "require"
-		} else if config.SSLMode == 2 {
+		case 2:
 			return "verify-ca"
-		} else if config.SSLMode == 3 {
+		case 3:
 			return "verify-full"
 		}
 	} else if config.Driver == MySQLDataProviderName {
-		if config.SSLMode == 0 {
+		if config.requireCustomTLSForMySQL() {
+			return "custom"
+		}
+		switch config.SSLMode {
+		case 0:
 			return "false"
-		} else if config.SSLMode == 1 {
+		case 1:
 			return "true"
-		} else if config.SSLMode == 2 {
+		case 2:
 			return "skip-verify"
-		} else if config.SSLMode == 3 {
+		case 3:
 			return "preferred"
 		}
 	}

+ 52 - 8
dataprovider/mysql.go

@@ -5,15 +5,16 @@ package dataprovider
 
 import (
 	"context"
+	"crypto/tls"
 	"crypto/x509"
 	"database/sql"
 	"errors"
 	"fmt"
+	"os"
 	"strings"
 	"time"
 
-	// we import go-sql-driver/mysql here to be able to disable MySQL support using a build tag
-	_ "github.com/go-sql-driver/mysql"
+	"github.com/go-sql-driver/mysql"
 
 	"github.com/drakkan/sftpgo/v2/logger"
 	"github.com/drakkan/sftpgo/v2/version"
@@ -114,10 +115,18 @@ func init() {
 func initializeMySQLProvider() error {
 	var err error
 
-	dbHandle, err := sql.Open("mysql", getMySQLConnectionString(false))
+	connString, err := getMySQLConnectionString(false)
+	if err != nil {
+		return err
+	}
+	redactedConnString, err := getMySQLConnectionString(true)
+	if err != nil {
+		return err
+	}
+	dbHandle, err := sql.Open("mysql", connString)
 	if err == nil {
 		providerLog(logger.LevelDebug, "mysql database handle created, connection string: %#v, pool size: %v",
-			getMySQLConnectionString(true), config.PoolSize)
+			redactedConnString, config.PoolSize)
 		dbHandle.SetMaxOpenConns(config.PoolSize)
 		if config.PoolSize > 0 {
 			dbHandle.SetMaxIdleConns(config.PoolSize)
@@ -128,23 +137,58 @@ func initializeMySQLProvider() error {
 		provider = &MySQLProvider{dbHandle: dbHandle}
 	} else {
 		providerLog(logger.LevelError, "error creating mysql database handler, connection string: %#v, error: %v",
-			getMySQLConnectionString(true), err)
+			redactedConnString, err)
 	}
 	return err
 }
-func getMySQLConnectionString(redactedPwd bool) string {
+func getMySQLConnectionString(redactedPwd bool) (string, error) {
 	var connectionString string
 	if config.ConnectionString == "" {
 		password := config.Password
 		if redactedPwd {
 			password = "[redacted]"
 		}
+		sslMode := getSSLMode()
+		if sslMode == "custom" && !redactedPwd {
+			tlsConfig := &tls.Config{}
+			if config.RootCert != "" {
+				rootCAs, err := x509.SystemCertPool()
+				if err != nil {
+					rootCAs = x509.NewCertPool()
+				}
+				rootCrt, err := os.ReadFile(config.RootCert)
+				if err != nil {
+					return "", fmt.Errorf("unable to load root certificate %#v: %v", config.RootCert, err)
+				}
+				if !rootCAs.AppendCertsFromPEM(rootCrt) {
+					return "", fmt.Errorf("unable to parse root certificate %#v", config.RootCert)
+				}
+				tlsConfig.RootCAs = rootCAs
+			}
+			if config.ClientCert != "" && config.ClientKey != "" {
+				clientCert := make([]tls.Certificate, 0, 1)
+				tlsCert, err := tls.LoadX509KeyPair(config.ClientCert, config.ClientKey)
+				if err != nil {
+					return "", fmt.Errorf("unable to load key pair %#v, %#v: %v", config.ClientCert, config.ClientKey, err)
+				}
+				clientCert = append(clientCert, tlsCert)
+				tlsConfig.Certificates = clientCert
+			}
+			if config.SSLMode == 2 {
+				tlsConfig.InsecureSkipVerify = true
+			}
+			providerLog(logger.LevelInfo, "registering custom TLS config, root cert %#v, client cert %#v, client key %#v",
+				config.RootCert, config.ClientCert, config.ClientKey)
+			if err := mysql.RegisterTLSConfig("custom", tlsConfig); err != nil {
+				return "", fmt.Errorf("unable to register tls config: %v", err)
+			}
+		}
 		connectionString = fmt.Sprintf("%v:%v@tcp([%v]:%v)/%v?charset=utf8mb4&interpolateParams=true&timeout=10s&parseTime=true&tls=%v&writeTimeout=10s&readTimeout=10s",
-			config.Username, password, config.Host, config.Port, config.Name, getSSLMode())
+			config.Username, password, config.Host, config.Port, config.Name, sslMode)
 	} else {
 		connectionString = config.ConnectionString
 	}
-	return connectionString
+	return connectionString, nil
 }
 
 func (p *MySQLProvider) checkAvailability() error {

+ 6 - 0
dataprovider/pgsql.go

@@ -155,6 +155,12 @@ func getPGSQLConnectionString(redactedPwd bool) string {
 		}
 		connectionString = fmt.Sprintf("host='%v' port=%v dbname='%v' user='%v' password='%v' sslmode=%v connect_timeout=10",
 			config.Host, config.Port, config.Name, config.Username, password, getSSLMode())
+		if config.RootCert != "" {
+			connectionString += fmt.Sprintf(" sslrootcert='%v'", config.RootCert)
+		}
+		if config.ClientCert != "" && config.ClientKey != "" {
+			connectionString += fmt.Sprintf(" sslcert='%v' sslkey='%v'", config.ClientCert, config.ClientKey)
+		}
 	} else {
 		connectionString = config.ConnectionString
 	}

+ 4 - 1
docs/full-configuration.md

@@ -174,7 +174,10 @@ The configuration file contains the following sections:
   - `port`, integer. Database port. Leave empty for drivers `sqlite`, `bolt` and `memory`
   - `username`, string. Database user. Leave empty for drivers `sqlite`, `bolt` and `memory`
   - `password`, string. Database password. Leave empty for drivers `sqlite`, `bolt` and `memory`
-  - `sslmode`, integer. Used for drivers `mysql` and `postgresql`. 0 disable SSL/TLS connections, 1 require ssl, 2 set ssl mode to `verify-ca` for driver `postgresql` and `skip-verify` for driver `mysql`, 3 set ssl mode to `verify-full` for driver `postgresql` and `preferred` for driver `mysql`
+  - `sslmode`, integer. Used for drivers `mysql` and `postgresql`. 0 disable TLS connections, 1 require TLS, 2 set TLS mode to `verify-ca` for driver `postgresql` and `skip-verify` for driver `mysql`, 3 set TLS mode to `verify-full` for driver `postgresql` and `preferred` for driver `mysql`
+  - `root_cert`, string. Path to the root certificate authority used to verify that the server certificate was signed by a trusted CA
+  - `client_cert`, string. Path to the client certificate for two-way TLS authentication
+  - `client_key`,string. Path to the client key for two-way TLS authentication
   - `connection_string`, string. Provide a custom database connection string. If not empty, this connection string will be used instead of building one using the previous parameters. Leave empty for drivers `bolt` and `memory`
   - `sql_tables_prefix`, string. Prefix for SQL tables
   - `track_quota`, integer. Set the preferred mode to track users quota between the following choices:

+ 3 - 0
sftpgo.json

@@ -155,6 +155,9 @@
     "username": "",
     "password": "",
     "sslmode": 0,
+    "root_cert": "",
+    "client_cert": "",
+    "client_key": "",
     "connection_string": "",
     "sql_tables_prefix": "",
     "track_quota": 2,