|
@@ -7,7 +7,6 @@ import (
|
|
|
"os"
|
|
|
"time"
|
|
|
|
|
|
- "entgo.io/ent/dialect"
|
|
|
entsql "entgo.io/ent/dialect/sql"
|
|
|
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
|
|
|
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
|
|
@@ -61,8 +60,11 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) {
|
|
|
entLogger := clog.WithField("context", "ent")
|
|
|
|
|
|
entOpt := ent.Log(entLogger.Debug)
|
|
|
- switch config.Type {
|
|
|
- case "sqlite":
|
|
|
+ typ, dia, err := config.ConnectionDialect()
|
|
|
+ if err != nil {
|
|
|
+ return &Client{}, err //unsupported database caught here
|
|
|
+ }
|
|
|
+ if config.Type == "sqlite" {
|
|
|
/*if it's the first startup, we want to touch and chmod file*/
|
|
|
if _, err := os.Stat(config.DbPath); os.IsNotExist(err) {
|
|
|
f, err := os.OpenFile(config.DbPath, os.O_CREATE|os.O_RDWR, 0600)
|
|
@@ -77,45 +79,12 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) {
|
|
|
if err := setFilePerm(config.DbPath, 0640); err != nil {
|
|
|
return &Client{}, fmt.Errorf("unable to set perms on %s: %v", config.DbPath, err)
|
|
|
}
|
|
|
- var sqliteConnectionStringParameters string
|
|
|
- if config.UseWal != nil && *config.UseWal {
|
|
|
- sqliteConnectionStringParameters = "_busy_timeout=100000&_fk=1&_journal_mode=WAL"
|
|
|
- } else {
|
|
|
- sqliteConnectionStringParameters = "_busy_timeout=100000&_fk=1"
|
|
|
- }
|
|
|
- drv, err := getEntDriver("sqlite3", dialect.SQLite, fmt.Sprintf("file:%s?%s", config.DbPath, sqliteConnectionStringParameters), config)
|
|
|
- if err != nil {
|
|
|
- return &Client{}, errors.Wrapf(err, "failed opening connection to sqlite: %v", config.DbPath)
|
|
|
- }
|
|
|
- client = ent.NewClient(ent.Driver(drv), entOpt)
|
|
|
- case "mysql":
|
|
|
- connString := ""
|
|
|
- if config.Host == "" && config.Port == 0 && config.DbPath != "" {
|
|
|
- connString = fmt.Sprintf("%s:%s@unix(%s)/%s?parseTime=True", config.User, config.Password, config.DbPath, config.DbName)
|
|
|
- } else {
|
|
|
- connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", config.User, config.Password, config.Host, config.Port, config.DbName)
|
|
|
- }
|
|
|
- drv, err := getEntDriver("mysql", dialect.MySQL, connString, config)
|
|
|
- if err != nil {
|
|
|
- return &Client{}, fmt.Errorf("failed opening connection to mysql: %v", err)
|
|
|
- }
|
|
|
- client = ent.NewClient(ent.Driver(drv), entOpt)
|
|
|
- case "postgres", "postgresql":
|
|
|
- drv, err := getEntDriver("postgres", dialect.Postgres, fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s sslmode=%s", config.Host, config.Port, config.User, config.DbName, config.Password, config.Sslmode), config)
|
|
|
- if err != nil {
|
|
|
- return &Client{}, fmt.Errorf("failed opening connection to postgresql: %v", err)
|
|
|
- }
|
|
|
- client = ent.NewClient(ent.Driver(drv), entOpt)
|
|
|
- case "pgx":
|
|
|
- drv, err := getEntDriver("pgx", dialect.Postgres, fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=%s", config.User, config.Password, config.Host, config.Port, config.DbName, config.Sslmode), config)
|
|
|
- if err != nil {
|
|
|
- return &Client{}, fmt.Errorf("failed opening connection to pgx: %v", err)
|
|
|
- }
|
|
|
- client = ent.NewClient(ent.Driver(drv), entOpt)
|
|
|
- default:
|
|
|
- return &Client{}, fmt.Errorf("unknown database type '%s'", config.Type)
|
|
|
}
|
|
|
-
|
|
|
+ drv, err := getEntDriver(typ, dia, config.ConnectionString(), config)
|
|
|
+ if err != nil {
|
|
|
+ return &Client{}, fmt.Errorf("failed opening connection to %s: %v", config.Type, err)
|
|
|
+ }
|
|
|
+ client = ent.NewClient(ent.Driver(drv), entOpt)
|
|
|
if config.LogLevel != nil && *config.LogLevel >= log.DebugLevel {
|
|
|
clog.Debugf("Enabling request debug")
|
|
|
client = client.Debug()
|