Bläddra i källkod

Add postgres socket support, clean some code (#1926)

Laurence Jones 2 år sedan
förälder
incheckning
fe23da6e0c
2 ändrade filer med 56 tillägg och 41 borttagningar
  1. 46 0
      pkg/csconfig/database.go
  2. 10 41
      pkg/database/database.go

+ 46 - 0
pkg/csconfig/database.go

@@ -4,6 +4,7 @@ import (
 	"fmt"
 	"time"
 
+	"entgo.io/ent/dialect"
 	"github.com/crowdsecurity/crowdsec/pkg/types"
 	log "github.com/sirupsen/logrus"
 )
@@ -67,3 +68,48 @@ func (c *Config) LoadDBConfig() error {
 
 	return nil
 }
+
+func (d *DatabaseCfg) ConnectionString() string {
+	connString := ""
+	switch d.Type {
+	case "sqlite":
+		var sqliteConnectionStringParameters string
+		if d.UseWal != nil && *d.UseWal {
+			sqliteConnectionStringParameters = "_busy_timeout=100000&_fk=1&_journal_mode=WAL"
+		} else {
+			sqliteConnectionStringParameters = "_busy_timeout=100000&_fk=1"
+		}
+		connString = fmt.Sprintf("file:%s?%s", d.DbPath, sqliteConnectionStringParameters)
+	case "mysql":
+		if d.isSocketConfig() {
+			connString = fmt.Sprintf("%s:%s@unix(%s)/%s?parseTime=True", d.User, d.Password, d.DbPath, d.DbName)
+		} else {
+			connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", d.User, d.Password, d.Host, d.Port, d.DbName)
+		}
+	case "postgres", "postgresql", "pgx":
+		if d.isSocketConfig() {
+			connString = fmt.Sprintf("host=%s user=%s dbname=%s password=%s", d.DbPath, d.User, d.DbName, d.Password)
+		} else {
+			connString = fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s sslmode=%s", d.Host, d.Port, d.User, d.DbName, d.Password, d.Sslmode)
+		}
+	}
+	return connString
+}
+
+func (d *DatabaseCfg) ConnectionDialect() (string, string, error) {
+	switch d.Type {
+	case "sqlite":
+		return "sqlite3", dialect.SQLite, nil
+	case "mysql":
+		return "mysql", dialect.MySQL, nil
+	case "postgres", "postgresql":
+		return "postgres", dialect.Postgres, nil
+	case "pgx":
+		return "pgx", dialect.Postgres, nil
+	}
+	return "", "", fmt.Errorf("unknown database type '%s'", d.Type)
+}
+
+func (d *DatabaseCfg) isSocketConfig() bool {
+	return d.Host == "" && d.Port == 0 && d.DbPath != ""
+}

+ 10 - 41
pkg/database/database.go

@@ -7,7 +7,6 @@ import (
 	"os"
 	"time"
 
-	"entgo.io/ent/dialect"
 	entsql "entgo.io/ent/dialect/sql"
 	"github.com/crowdsecurity/crowdsec/pkg/csconfig"
 	"github.com/crowdsecurity/crowdsec/pkg/database/ent"
@@ -61,8 +60,11 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) {
 	entLogger := clog.WithField("context", "ent")
 
 	entOpt := ent.Log(entLogger.Debug)
-	switch config.Type {
-	case "sqlite":
+	typ, dia, err := config.ConnectionDialect()
+	if err != nil {
+		return &Client{}, err //unsupported database caught here
+	}
+	if config.Type == "sqlite" {
 		/*if it's the first startup, we want to touch and chmod file*/
 		if _, err := os.Stat(config.DbPath); os.IsNotExist(err) {
 			f, err := os.OpenFile(config.DbPath, os.O_CREATE|os.O_RDWR, 0600)
@@ -77,45 +79,12 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) {
 		if err := setFilePerm(config.DbPath, 0640); err != nil {
 			return &Client{}, fmt.Errorf("unable to set perms on %s: %v", config.DbPath, err)
 		}
-		var sqliteConnectionStringParameters string
-		if config.UseWal != nil && *config.UseWal {
-			sqliteConnectionStringParameters = "_busy_timeout=100000&_fk=1&_journal_mode=WAL"
-		} else {
-			sqliteConnectionStringParameters = "_busy_timeout=100000&_fk=1"
-		}
-		drv, err := getEntDriver("sqlite3", dialect.SQLite, fmt.Sprintf("file:%s?%s", config.DbPath, sqliteConnectionStringParameters), config)
-		if err != nil {
-			return &Client{}, errors.Wrapf(err, "failed opening connection to sqlite: %v", config.DbPath)
-		}
-		client = ent.NewClient(ent.Driver(drv), entOpt)
-	case "mysql":
-		connString := ""
-		if config.Host == "" && config.Port == 0 && config.DbPath != "" {
-			connString = fmt.Sprintf("%s:%s@unix(%s)/%s?parseTime=True", config.User, config.Password, config.DbPath, config.DbName)
-		} else {
-			connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", config.User, config.Password, config.Host, config.Port, config.DbName)
-		}
-		drv, err := getEntDriver("mysql", dialect.MySQL, connString, config)
-		if err != nil {
-			return &Client{}, fmt.Errorf("failed opening connection to mysql: %v", err)
-		}
-		client = ent.NewClient(ent.Driver(drv), entOpt)
-	case "postgres", "postgresql":
-		drv, err := getEntDriver("postgres", dialect.Postgres, fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s sslmode=%s", config.Host, config.Port, config.User, config.DbName, config.Password, config.Sslmode), config)
-		if err != nil {
-			return &Client{}, fmt.Errorf("failed opening connection to postgresql: %v", err)
-		}
-		client = ent.NewClient(ent.Driver(drv), entOpt)
-	case "pgx":
-		drv, err := getEntDriver("pgx", dialect.Postgres, fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=%s", config.User, config.Password, config.Host, config.Port, config.DbName, config.Sslmode), config)
-		if err != nil {
-			return &Client{}, fmt.Errorf("failed opening connection to pgx: %v", err)
-		}
-		client = ent.NewClient(ent.Driver(drv), entOpt)
-	default:
-		return &Client{}, fmt.Errorf("unknown database type '%s'", config.Type)
 	}
-
+	drv, err := getEntDriver(typ, dia, config.ConnectionString(), config)
+	if err != nil {
+		return &Client{}, fmt.Errorf("failed opening connection to %s: %v", config.Type, err)
+	}
+	client = ent.NewClient(ent.Driver(drv), entOpt)
 	if config.LogLevel != nil && *config.LogLevel >= log.DebugLevel {
 		clog.Debugf("Enabling request debug")
 		client = client.Debug()