|
@@ -7,7 +7,6 @@ import (
|
|
|
"os"
|
|
|
"time"
|
|
|
|
|
|
- "entgo.io/ent/dialect"
|
|
|
entsql "entgo.io/ent/dialect/sql"
|
|
|
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
|
|
|
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
|
|
@@ -28,6 +27,20 @@ type Client struct {
|
|
|
CanFlush bool
|
|
|
}
|
|
|
|
|
|
+func getEntDriver(dbtype string, dsn string, config *csconfig.DatabaseCfg) (*entsql.Driver, error) {
|
|
|
+ db, err := sql.Open(dbtype, dsn)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ if config.MaxOpenConns == nil {
|
|
|
+ log.Warningf("MaxOpenConns is 0, defaulting to %d", csconfig.DEFAULT_MAX_OPEN_CONNS)
|
|
|
+ config.MaxOpenConns = types.IntPtr(csconfig.DEFAULT_MAX_OPEN_CONNS)
|
|
|
+ }
|
|
|
+ db.SetMaxOpenConns(*config.MaxOpenConns)
|
|
|
+ drv := entsql.OpenDB(dbtype, db)
|
|
|
+ return drv, nil
|
|
|
+}
|
|
|
+
|
|
|
func NewClient(config *csconfig.DatabaseCfg) (*Client, error) {
|
|
|
var client *ent.Client
|
|
|
var err error
|
|
@@ -62,27 +75,28 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) {
|
|
|
return &Client{}, fmt.Errorf("unable to set perms on %s: %v", config.DbPath, err)
|
|
|
}
|
|
|
}
|
|
|
- client, err = ent.Open("sqlite3", fmt.Sprintf("file:%s?_busy_timeout=100000&_fk=1", config.DbPath), entOpt)
|
|
|
+ drv, err := getEntDriver("sqlite3", fmt.Sprintf("file:%s?_busy_timeout=100000&_fk=1", config.DbPath), config)
|
|
|
if err != nil {
|
|
|
- return &Client{}, fmt.Errorf("failed opening connection to sqlite: %v", err)
|
|
|
+ return &Client{}, errors.Wrapf(err, "failed opening connection to sqlite: %v", config.DbPath)
|
|
|
}
|
|
|
+ client = ent.NewClient(ent.Driver(drv), entOpt)
|
|
|
case "mysql":
|
|
|
- client, err = ent.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", config.User, config.Password, config.Host, config.Port, config.DbName), entOpt)
|
|
|
+ drv, err := getEntDriver("mysql", fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", config.User, config.Password, config.Host, config.Port, config.DbName), config)
|
|
|
if err != nil {
|
|
|
return &Client{}, fmt.Errorf("failed opening connection to mysql: %v", err)
|
|
|
}
|
|
|
+ client = ent.NewClient(ent.Driver(drv), entOpt)
|
|
|
case "postgres", "postgresql":
|
|
|
- client, err = ent.Open("postgres", fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s sslmode=%s", config.Host, config.Port, config.User, config.DbName, config.Password, config.Sslmode), entOpt)
|
|
|
+ drv, err := getEntDriver("postgres", fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s sslmode=%s", config.Host, config.Port, config.User, config.DbName, config.Password, config.Sslmode), config)
|
|
|
if err != nil {
|
|
|
- return &Client{}, fmt.Errorf("failed opening connection to postgres: %v", err)
|
|
|
+ return &Client{}, fmt.Errorf("failed opening connection to postgresql: %v", err)
|
|
|
}
|
|
|
+ client = ent.NewClient(ent.Driver(drv), entOpt)
|
|
|
case "pgx":
|
|
|
- db, err := sql.Open("pgx", fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=%s", config.User, config.Password, config.Host, config.Port, config.DbName, config.Sslmode))
|
|
|
+ drv, err := getEntDriver("pgx", fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=%s", config.User, config.Password, config.Host, config.Port, config.DbName, config.Sslmode), config)
|
|
|
if err != nil {
|
|
|
return &Client{}, fmt.Errorf("failed opening connection to pgx: %v", err)
|
|
|
}
|
|
|
- // Create an ent.Driver from `db`.
|
|
|
- drv := entsql.OpenDB(dialect.Postgres, db)
|
|
|
client = ent.NewClient(ent.Driver(drv), entOpt)
|
|
|
default:
|
|
|
return &Client{}, fmt.Errorf("unknown database type")
|