database.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. package database
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "os"
  7. "time"
  8. "entgo.io/ent/dialect"
  9. entsql "entgo.io/ent/dialect/sql"
  10. "github.com/crowdsecurity/crowdsec/pkg/csconfig"
  11. "github.com/crowdsecurity/crowdsec/pkg/database/ent"
  12. "github.com/crowdsecurity/crowdsec/pkg/types"
  13. "github.com/go-co-op/gocron"
  14. _ "github.com/go-sql-driver/mysql"
  15. _ "github.com/jackc/pgx/v4/stdlib"
  16. _ "github.com/lib/pq"
  17. _ "github.com/mattn/go-sqlite3"
  18. "github.com/pkg/errors"
  19. log "github.com/sirupsen/logrus"
  20. )
  21. type Client struct {
  22. Ent *ent.Client
  23. CTX context.Context
  24. Log *log.Logger
  25. CanFlush bool
  26. }
  27. func getEntDriver(dbtype string, dbdialect string, dsn string, config *csconfig.DatabaseCfg) (*entsql.Driver, error) {
  28. db, err := sql.Open(dbtype, dsn)
  29. if err != nil {
  30. return nil, err
  31. }
  32. if config.MaxOpenConns == nil {
  33. log.Warningf("MaxOpenConns is 0, defaulting to %d", csconfig.DEFAULT_MAX_OPEN_CONNS)
  34. config.MaxOpenConns = types.IntPtr(csconfig.DEFAULT_MAX_OPEN_CONNS)
  35. }
  36. db.SetMaxOpenConns(*config.MaxOpenConns)
  37. drv := entsql.OpenDB(dbdialect, db)
  38. return drv, nil
  39. }
  40. func NewClient(config *csconfig.DatabaseCfg) (*Client, error) {
  41. var client *ent.Client
  42. var err error
  43. if config == nil {
  44. return &Client{}, fmt.Errorf("DB config is empty")
  45. }
  46. /*The logger that will be used by db operations*/
  47. clog := log.New()
  48. if err := types.ConfigureLogger(clog); err != nil {
  49. return nil, errors.Wrap(err, "while configuring db logger")
  50. }
  51. if config.LogLevel != nil {
  52. clog.SetLevel(*config.LogLevel)
  53. }
  54. entLogger := clog.WithField("context", "ent")
  55. entOpt := ent.Log(entLogger.Debug)
  56. switch config.Type {
  57. case "sqlite":
  58. /*if it's the first startup, we want to touch and chmod file*/
  59. if _, err := os.Stat(config.DbPath); os.IsNotExist(err) {
  60. f, err := os.OpenFile(config.DbPath, os.O_CREATE|os.O_RDWR, 0600)
  61. if err != nil {
  62. return &Client{}, errors.Wrapf(err, "failed to create SQLite database file %q", config.DbPath)
  63. }
  64. if err := f.Close(); err != nil {
  65. return &Client{}, errors.Wrapf(err, "failed to create SQLite database file %q", config.DbPath)
  66. }
  67. } else { /*ensure file perms*/
  68. if err := os.Chmod(config.DbPath, 0660); err != nil {
  69. return &Client{}, fmt.Errorf("unable to set perms on %s: %v", config.DbPath, err)
  70. }
  71. }
  72. drv, err := getEntDriver("sqlite3", dialect.SQLite, fmt.Sprintf("file:%s?_busy_timeout=100000&_fk=1", config.DbPath), config)
  73. if err != nil {
  74. return &Client{}, errors.Wrapf(err, "failed opening connection to sqlite: %v", config.DbPath)
  75. }
  76. client = ent.NewClient(ent.Driver(drv), entOpt)
  77. case "mysql":
  78. drv, err := getEntDriver("mysql", dialect.MySQL, fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", config.User, config.Password, config.Host, config.Port, config.DbName), config)
  79. if err != nil {
  80. return &Client{}, fmt.Errorf("failed opening connection to mysql: %v", err)
  81. }
  82. client = ent.NewClient(ent.Driver(drv), entOpt)
  83. case "postgres", "postgresql":
  84. drv, err := getEntDriver("postgres", dialect.Postgres, fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s sslmode=%s", config.Host, config.Port, config.User, config.DbName, config.Password, config.Sslmode), config)
  85. if err != nil {
  86. return &Client{}, fmt.Errorf("failed opening connection to postgresql: %v", err)
  87. }
  88. client = ent.NewClient(ent.Driver(drv), entOpt)
  89. case "pgx":
  90. drv, err := getEntDriver("pgx", dialect.Postgres, fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=%s", config.User, config.Password, config.Host, config.Port, config.DbName, config.Sslmode), config)
  91. if err != nil {
  92. return &Client{}, fmt.Errorf("failed opening connection to pgx: %v", err)
  93. }
  94. client = ent.NewClient(ent.Driver(drv), entOpt)
  95. default:
  96. return &Client{}, fmt.Errorf("unknown database type '%s'", config.Type)
  97. }
  98. if config.LogLevel != nil && *config.LogLevel >= log.DebugLevel {
  99. clog.Debugf("Enabling request debug")
  100. client = client.Debug()
  101. }
  102. if err = client.Schema.Create(context.Background()); err != nil {
  103. return nil, fmt.Errorf("failed creating schema resources: %v", err)
  104. }
  105. return &Client{Ent: client, CTX: context.Background(), Log: clog, CanFlush: true}, nil
  106. }
  107. func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Scheduler, error) {
  108. maxItems := 0
  109. maxAge := ""
  110. if config.MaxItems != nil && *config.MaxItems <= 0 {
  111. return nil, fmt.Errorf("max_items can't be zero or negative number")
  112. }
  113. if config.MaxItems != nil {
  114. maxItems = *config.MaxItems
  115. }
  116. if config.MaxAge != nil && *config.MaxAge != "" {
  117. maxAge = *config.MaxAge
  118. }
  119. // Init & Start cronjob every minute
  120. scheduler := gocron.NewScheduler(time.UTC)
  121. job, _ := scheduler.Every(1).Minute().Do(c.FlushAlerts, maxAge, maxItems)
  122. job.SingletonMode()
  123. scheduler.StartAsync()
  124. return scheduler, nil
  125. }