|
@@ -5,15 +5,16 @@ package dataprovider
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
+ "crypto/tls"
|
|
|
"crypto/x509"
|
|
|
"database/sql"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
+ "os"
|
|
|
"strings"
|
|
|
"time"
|
|
|
|
|
|
- // we import go-sql-driver/mysql here to be able to disable MySQL support using a build tag
|
|
|
- _ "github.com/go-sql-driver/mysql"
|
|
|
+ "github.com/go-sql-driver/mysql"
|
|
|
|
|
|
"github.com/drakkan/sftpgo/v2/logger"
|
|
|
"github.com/drakkan/sftpgo/v2/version"
|
|
@@ -114,10 +115,18 @@ func init() {
|
|
|
func initializeMySQLProvider() error {
|
|
|
var err error
|
|
|
|
|
|
- dbHandle, err := sql.Open("mysql", getMySQLConnectionString(false))
|
|
|
+ connString, err := getMySQLConnectionString(false)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ redactedConnString, err := getMySQLConnectionString(true)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ dbHandle, err := sql.Open("mysql", connString)
|
|
|
if err == nil {
|
|
|
providerLog(logger.LevelDebug, "mysql database handle created, connection string: %#v, pool size: %v",
|
|
|
- getMySQLConnectionString(true), config.PoolSize)
|
|
|
+ redactedConnString, config.PoolSize)
|
|
|
dbHandle.SetMaxOpenConns(config.PoolSize)
|
|
|
if config.PoolSize > 0 {
|
|
|
dbHandle.SetMaxIdleConns(config.PoolSize)
|
|
@@ -128,23 +137,58 @@ func initializeMySQLProvider() error {
|
|
|
provider = &MySQLProvider{dbHandle: dbHandle}
|
|
|
} else {
|
|
|
providerLog(logger.LevelError, "error creating mysql database handler, connection string: %#v, error: %v",
|
|
|
- getMySQLConnectionString(true), err)
|
|
|
+ redactedConnString, err)
|
|
|
}
|
|
|
return err
|
|
|
}
|
|
|
-func getMySQLConnectionString(redactedPwd bool) string {
|
|
|
+func getMySQLConnectionString(redactedPwd bool) (string, error) {
|
|
|
var connectionString string
|
|
|
if config.ConnectionString == "" {
|
|
|
password := config.Password
|
|
|
if redactedPwd {
|
|
|
password = "[redacted]"
|
|
|
}
|
|
|
+ sslMode := getSSLMode()
|
|
|
+ if sslMode == "custom" && !redactedPwd {
|
|
|
+ tlsConfig := &tls.Config{}
|
|
|
+ if config.RootCert != "" {
|
|
|
+ rootCAs, err := x509.SystemCertPool()
|
|
|
+ if err != nil {
|
|
|
+ rootCAs = x509.NewCertPool()
|
|
|
+ }
|
|
|
+ rootCrt, err := os.ReadFile(config.RootCert)
|
|
|
+ if err != nil {
|
|
|
+ return "", fmt.Errorf("unable to load root certificate %#v: %v", config.RootCert, err)
|
|
|
+ }
|
|
|
+ if !rootCAs.AppendCertsFromPEM(rootCrt) {
|
|
|
+ return "", fmt.Errorf("unable to parse root certificate %#v", config.RootCert)
|
|
|
+ }
|
|
|
+ tlsConfig.RootCAs = rootCAs
|
|
|
+ }
|
|
|
+ if config.ClientCert != "" && config.ClientKey != "" {
|
|
|
+ clientCert := make([]tls.Certificate, 0, 1)
|
|
|
+ tlsCert, err := tls.LoadX509KeyPair(config.ClientCert, config.ClientKey)
|
|
|
+ if err != nil {
|
|
|
+ return "", fmt.Errorf("unable to load key pair %#v, %#v: %v", config.ClientCert, config.ClientKey, err)
|
|
|
+ }
|
|
|
+ clientCert = append(clientCert, tlsCert)
|
|
|
+ tlsConfig.Certificates = clientCert
|
|
|
+ }
|
|
|
+ if config.SSLMode == 2 {
|
|
|
+ tlsConfig.InsecureSkipVerify = true
|
|
|
+ }
|
|
|
+ providerLog(logger.LevelInfo, "registering custom TLS config, root cert %#v, client cert %#v, client key %#v",
|
|
|
+ config.RootCert, config.ClientCert, config.ClientKey)
|
|
|
+ if err := mysql.RegisterTLSConfig("custom", tlsConfig); err != nil {
|
|
|
+ return "", fmt.Errorf("unable to register tls config: %v", err)
|
|
|
+ }
|
|
|
+ }
|
|
|
connectionString = fmt.Sprintf("%v:%v@tcp([%v]:%v)/%v?charset=utf8mb4&interpolateParams=true&timeout=10s&parseTime=true&tls=%v&writeTimeout=10s&readTimeout=10s",
|
|
|
- config.Username, password, config.Host, config.Port, config.Name, getSSLMode())
|
|
|
+ config.Username, password, config.Host, config.Port, config.Name, sslMode)
|
|
|
} else {
|
|
|
connectionString = config.ConnectionString
|
|
|
}
|
|
|
- return connectionString
|
|
|
+ return connectionString, nil
|
|
|
}
|
|
|
|
|
|
func (p *MySQLProvider) checkAvailability() error {
|