|
@@ -23,11 +23,13 @@ import (
|
|
|
"database/sql"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
+ "net"
|
|
|
+ "strconv"
|
|
|
"strings"
|
|
|
"time"
|
|
|
|
|
|
- // we import pgx here to be able to disable PostgreSQL support using a build tag
|
|
|
- _ "github.com/jackc/pgx/v5/stdlib"
|
|
|
+ "github.com/jackc/pgx/v5"
|
|
|
+ "github.com/jackc/pgx/v5/stdlib"
|
|
|
|
|
|
"github.com/drakkan/sftpgo/v2/internal/logger"
|
|
|
"github.com/drakkan/sftpgo/v2/internal/version"
|
|
@@ -233,25 +235,61 @@ func init() {
|
|
|
}
|
|
|
|
|
|
func initializePGSQLProvider() error {
|
|
|
- var err error
|
|
|
- dbHandle, err := sql.Open("pgx", getPGSQLConnectionString(false))
|
|
|
- if err == nil {
|
|
|
- providerLog(logger.LevelDebug, "postgres database handle created, connection string: %q, pool size: %d",
|
|
|
- getPGSQLConnectionString(true), config.PoolSize)
|
|
|
- dbHandle.SetMaxOpenConns(config.PoolSize)
|
|
|
- if config.PoolSize > 0 {
|
|
|
- dbHandle.SetMaxIdleConns(config.PoolSize)
|
|
|
- } else {
|
|
|
- dbHandle.SetMaxIdleConns(2)
|
|
|
+ var dbHandle *sql.DB
|
|
|
+ if config.TargetSessionAttrs == "any" {
|
|
|
+ pgxConfig, err := pgx.ParseConfig(getPGSQLConnectionString(false))
|
|
|
+ if err != nil {
|
|
|
+ providerLog(logger.LevelError, "error parsing postgres configuration, connection string: %q, error: %v",
|
|
|
+ getPGSQLConnectionString(true), err)
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ dbHandle = stdlib.OpenDB(*pgxConfig, stdlib.OptionBeforeConnect(stdlib.RandomizeHostOrderFunc))
|
|
|
+ } else {
|
|
|
+ var err error
|
|
|
+ dbHandle, err = sql.Open("pgx", getPGSQLConnectionString(false))
|
|
|
+ if err != nil {
|
|
|
+ providerLog(logger.LevelError, "error creating postgres database handler, connection string: %q, error: %v",
|
|
|
+ getPGSQLConnectionString(true), err)
|
|
|
+ return err
|
|
|
}
|
|
|
- dbHandle.SetConnMaxLifetime(240 * time.Second)
|
|
|
- dbHandle.SetConnMaxIdleTime(120 * time.Second)
|
|
|
- provider = &PGSQLProvider{dbHandle: dbHandle}
|
|
|
+ }
|
|
|
+ providerLog(logger.LevelDebug, "postgres database handle created, connection string: %q, pool size: %d",
|
|
|
+ getPGSQLConnectionString(true), config.PoolSize)
|
|
|
+ dbHandle.SetMaxOpenConns(config.PoolSize)
|
|
|
+ if config.PoolSize > 0 {
|
|
|
+ dbHandle.SetMaxIdleConns(config.PoolSize)
|
|
|
} else {
|
|
|
- providerLog(logger.LevelError, "error creating postgres database handler, connection string: %q, error: %v",
|
|
|
- getPGSQLConnectionString(true), err)
|
|
|
+ dbHandle.SetMaxIdleConns(2)
|
|
|
+ }
|
|
|
+ dbHandle.SetConnMaxLifetime(240 * time.Second)
|
|
|
+ dbHandle.SetConnMaxIdleTime(120 * time.Second)
|
|
|
+ provider = &PGSQLProvider{dbHandle: dbHandle}
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func getPGSQLHostsAndPorts(configHost string, configPort int) (string, string) {
|
|
|
+ var hosts, ports []string
|
|
|
+ defaultPort := strconv.Itoa(configPort)
|
|
|
+ if defaultPort == "0" {
|
|
|
+ defaultPort = "5432"
|
|
|
}
|
|
|
- return err
|
|
|
+
|
|
|
+ for _, hostport := range strings.Split(configHost, ",") {
|
|
|
+ hostport = strings.TrimSpace(hostport)
|
|
|
+ if hostport == "" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ host, port, err := net.SplitHostPort(hostport)
|
|
|
+ if err == nil {
|
|
|
+ hosts = append(hosts, host)
|
|
|
+ ports = append(ports, port)
|
|
|
+ } else {
|
|
|
+ hosts = append(hosts, hostport)
|
|
|
+ ports = append(ports, defaultPort)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return strings.Join(hosts, ","), strings.Join(ports, ",")
|
|
|
}
|
|
|
|
|
|
func getPGSQLConnectionString(redactedPwd bool) string {
|
|
@@ -261,8 +299,9 @@ func getPGSQLConnectionString(redactedPwd bool) string {
|
|
|
if redactedPwd && password != "" {
|
|
|
password = "[redacted]"
|
|
|
}
|
|
|
- connectionString = fmt.Sprintf("host='%s' port=%d dbname='%s' user='%s' password='%s' sslmode=%s connect_timeout=10",
|
|
|
- config.Host, config.Port, config.Name, config.Username, password, getSSLMode())
|
|
|
+ host, port := getPGSQLHostsAndPorts(config.Host, config.Port)
|
|
|
+ connectionString = fmt.Sprintf("host='%s' port='%s' dbname='%s' user='%s' password='%s' sslmode=%s connect_timeout=10",
|
|
|
+ host, port, config.Name, config.Username, password, getSSLMode())
|
|
|
if config.RootCert != "" {
|
|
|
connectionString += fmt.Sprintf(" sslrootcert='%s'", config.RootCert)
|
|
|
}
|