Merge branch 'master' into http_plugin_unix_socket

This commit is contained in:
Laurence Jones 2024-02-15 13:24:37 +00:00 committed by GitHub
commit 90a621d233
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
146 changed files with 9598 additions and 9505 deletions

View file

@ -8,9 +8,6 @@ on:
GIST_BADGES_ID:
required: true
env:
PREFIX_TEST_NAMES_WITH_FILE: true
jobs:
build:
strategy:
@ -36,7 +33,7 @@ jobs:
- name: "Set up Go"
uses: actions/setup-go@v5
with:
go-version: "1.21.6"
go-version: "1.21.7"
- name: "Install bats dependencies"
env:
@ -50,7 +47,7 @@ jobs:
- name: "Run hub tests"
run: |
./test/bin/generate-hub-tests
./test/run-tests test/dyn-bats/${{ matrix.test-file }}
./test/run-tests ./test/dyn-bats/${{ matrix.test-file }} --formatter $(pwd)/test/lib/color-formatter
- name: "Collect hub coverage"
run: ./test/bin/collect-hub-coverage >> $GITHUB_ENV

View file

@ -7,9 +7,6 @@ on:
required: true
type: string
env:
PREFIX_TEST_NAMES_WITH_FILE: true
jobs:
build:
name: "Functional tests"
@ -39,7 +36,7 @@ jobs:
- name: "Set up Go"
uses: actions/setup-go@v5
with:
go-version: "1.21.6"
go-version: "1.21.7"
- name: "Install bats dependencies"
env:
@ -58,7 +55,7 @@ jobs:
MYSQL_USER: root
- name: "Run tests"
run: make bats-test
run: ./test/run-tests ./test/bats --formatter $(pwd)/test/lib/color-formatter
env:
DB_BACKEND: mysql
MYSQL_HOST: 127.0.0.1

View file

@ -3,9 +3,6 @@ name: (sub) Bats / Postgres
on:
workflow_call:
env:
PREFIX_TEST_NAMES_WITH_FILE: true
jobs:
build:
name: "Functional tests"
@ -48,7 +45,7 @@ jobs:
- name: "Set up Go"
uses: actions/setup-go@v5
with:
go-version: "1.21.6"
go-version: "1.21.7"
- name: "Install bats dependencies"
env:
@ -67,7 +64,7 @@ jobs:
PGUSER: postgres
- name: "Run tests (DB_BACKEND: pgx)"
run: make bats-test
run: ./test/run-tests ./test/bats --formatter $(pwd)/test/lib/color-formatter
env:
DB_BACKEND: pgx
PGHOST: 127.0.0.1

View file

@ -4,7 +4,6 @@ on:
workflow_call:
env:
PREFIX_TEST_NAMES_WITH_FILE: true
TEST_COVERAGE: true
jobs:
@ -29,7 +28,7 @@ jobs:
- name: "Set up Go"
uses: actions/setup-go@v5
with:
go-version: "1.21.6"
go-version: "1.21.7"
- name: "Install bats dependencies"
env:
@ -42,7 +41,7 @@ jobs:
make clean bats-build bats-fixture BUILD_STATIC=1
- name: "Run tests"
run: make bats-test
run: ./test/run-tests ./test/bats --formatter $(pwd)/test/lib/color-formatter
- name: "Collect coverage data"
run: |

View file

@ -35,7 +35,7 @@ jobs:
- name: "Set up Go"
uses: actions/setup-go@v5
with:
go-version: "1.21.6"
go-version: "1.21.7"
- name: Build
run: make windows_installer BUILD_RE2_WASM=1

View file

@ -49,9 +49,15 @@ jobs:
# required to pick up tags for BUILD_VERSION
fetch-depth: 0
- name: "Set up Go"
uses: actions/setup-go@v5
with:
go-version: "1.21.7"
cache-dependency-path: "**/go.sum"
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v2
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
@ -71,14 +77,8 @@ jobs:
# and modify them (or add more) to build your code if your project
# uses a compiled language
- name: "Set up Go"
uses: actions/setup-go@v5
with:
go-version: "1.21.6"
cache-dependency-path: "**/go.sum"
- run: |
make clean build BUILD_RE2_WASM=1
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v2
uses: github/codeql-action/analyze@v3

View file

@ -50,7 +50,7 @@ jobs:
cache-to: type=gha,mode=min
- name: "Setup Python"
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: "3.x"
@ -61,7 +61,7 @@ jobs:
- name: "Cache virtualenvs"
id: cache-pipenv
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/.local/share/virtualenvs
key: ${{ runner.os }}-pipenv-${{ hashFiles('**/Pipfile.lock') }}

View file

@ -34,7 +34,7 @@ jobs:
- name: "Set up Go"
uses: actions/setup-go@v5
with:
go-version: "1.21.6"
go-version: "1.21.7"
- name: Build
run: |

View file

@ -126,7 +126,7 @@ jobs:
- name: "Set up Go"
uses: actions/setup-go@v5
with:
go-version: "1.21.6"
go-version: "1.21.7"
- name: Create localstack streams
run: |

View file

@ -25,7 +25,7 @@ jobs:
- name: "Set up Go"
uses: actions/setup-go@v5
with:
go-version: "1.21.6"
go-version: "1.21.7"
- name: Build the binaries
run: |

View file

@ -11,7 +11,7 @@ run:
linters-settings:
cyclop:
# lower this after refactoring
max-complexity: 66
max-complexity: 53
gci:
sections:
@ -26,7 +26,7 @@ linters-settings:
gocyclo:
# lower this after refactoring
min-complexity: 64
min-complexity: 49
funlen:
# Checks the number of lines in a function.
@ -46,14 +46,14 @@ linters-settings:
maintidx:
# raise this after refactoring
under: 9
under: 11
misspell:
locale: US
nestif:
# lower this after refactoring
min-complexity: 27
min-complexity: 28
nlreturn:
block-size: 4
@ -263,6 +263,10 @@ issues:
- perfsprint
text: "fmt.Sprintf can be replaced .*"
- linters:
- perfsprint
text: "fmt.Errorf can be replaced with errors.New"
#
# Will fix, easy but some neurons required
#
@ -310,10 +314,6 @@ issues:
# Will fix, might be trickier
#
- linters:
- staticcheck
text: "x509.ParseCRL has been deprecated since Go 1.19: Use ParseRevocationList instead"
# https://github.com/pkg/errors/issues/245
- linters:
- depguard

View file

@ -1,5 +1,5 @@
# vim: set ft=dockerfile:
FROM golang:1.21.6-alpine3.18 AS build
FROM golang:1.21.7-alpine3.18 AS build
ARG BUILD_VERSION

View file

@ -1,5 +1,5 @@
# vim: set ft=dockerfile:
FROM golang:1.21.6-bookworm AS build
FROM golang:1.21.7-bookworm AS build
ARG BUILD_VERSION

View file

@ -27,7 +27,7 @@ stages:
- task: GoTool@0
displayName: "Install Go"
inputs:
version: '1.21.6'
version: '1.21.7'
- pwsh: |
choco install -y make

View file

@ -29,39 +29,46 @@ import (
func DecisionsFromAlert(alert *models.Alert) string {
ret := ""
var decMap = make(map[string]int)
decMap := make(map[string]int)
for _, decision := range alert.Decisions {
k := *decision.Type
if *decision.Simulated {
k = fmt.Sprintf("(simul)%s", k)
}
v := decMap[k]
decMap[k] = v + 1
}
for k, v := range decMap {
if len(ret) > 0 {
ret += " "
}
ret += fmt.Sprintf("%s:%d", k, v)
}
return ret
}
func alertsToTable(alerts *models.GetAlertsResponse, printMachine bool) error {
switch csConfig.Cscli.Output {
func (cli *cliAlerts) alertsToTable(alerts *models.GetAlertsResponse, printMachine bool) error {
switch cli.cfg().Cscli.Output {
case "raw":
csvwriter := csv.NewWriter(os.Stdout)
header := []string{"id", "scope", "value", "reason", "country", "as", "decisions", "created_at"}
if printMachine {
header = append(header, "machine")
}
err := csvwriter.Write(header)
if err != nil {
if err := csvwriter.Write(header); err != nil {
return err
}
for _, alertItem := range *alerts {
row := []string{
fmt.Sprintf("%d", alertItem.ID),
strconv.FormatInt(alertItem.ID, 10),
*alertItem.Source.Scope,
*alertItem.Source.Value,
*alertItem.Scenario,
@ -73,11 +80,12 @@ func alertsToTable(alerts *models.GetAlertsResponse, printMachine bool) error {
if printMachine {
row = append(row, alertItem.MachineID)
}
err := csvwriter.Write(row)
if err != nil {
if err := csvwriter.Write(row); err != nil {
return err
}
}
csvwriter.Flush()
case "json":
if *alerts == nil {
@ -86,6 +94,7 @@ func alertsToTable(alerts *models.GetAlertsResponse, printMachine bool) error {
fmt.Println("[]")
return nil
}
x, _ := json.MarshalIndent(alerts, "", " ")
fmt.Print(string(x))
case "human":
@ -93,8 +102,10 @@ func alertsToTable(alerts *models.GetAlertsResponse, printMachine bool) error {
fmt.Println("No active alerts")
return nil
}
alertsTable(color.Output, alerts, printMachine)
}
return nil
}
@ -116,13 +127,13 @@ var alertTemplate = `
`
func displayOneAlert(alert *models.Alert, withDetail bool) error {
func (cli *cliAlerts) displayOneAlert(alert *models.Alert, withDetail bool) error {
tmpl, err := template.New("alert").Parse(alertTemplate)
if err != nil {
return err
}
err = tmpl.Execute(os.Stdout, alert)
if err != nil {
if err = tmpl.Execute(os.Stdout, alert); err != nil {
return err
}
@ -133,14 +144,17 @@ func displayOneAlert(alert *models.Alert, withDetail bool) error {
sort.Slice(alert.Meta, func(i, j int) bool {
return alert.Meta[i].Key < alert.Meta[j].Key
})
table := newTable(color.Output)
table.SetRowLines(false)
table.SetHeaders("Key", "Value")
for _, meta := range alert.Meta {
var valSlice []string
if err := json.Unmarshal([]byte(meta.Value), &valSlice); err != nil {
return fmt.Errorf("unknown context value type '%s' : %s", meta.Value, err)
return fmt.Errorf("unknown context value type '%s': %w", meta.Value, err)
}
for _, value := range valSlice {
table.AddRow(
meta.Key,
@ -148,11 +162,13 @@ func displayOneAlert(alert *models.Alert, withDetail bool) error {
)
}
}
table.Render()
}
if withDetail {
fmt.Printf("\n - Events :\n")
for _, event := range alert.Events {
alertEventTable(color.Output, event)
}
@ -163,10 +179,13 @@ func displayOneAlert(alert *models.Alert, withDetail bool) error {
type cliAlerts struct{
client *apiclient.ApiClient
cfg configGetter
}
func NewCLIAlerts() *cliAlerts {
return &cliAlerts{}
func NewCLIAlerts(getconfig configGetter) *cliAlerts {
return &cliAlerts{
cfg: getconfig,
}
}
func (cli *cliAlerts) NewCommand() *cobra.Command {
@ -176,18 +195,18 @@ func (cli *cliAlerts) NewCommand() *cobra.Command {
Args: cobra.MinimumNArgs(1),
DisableAutoGenTag: true,
Aliases: []string{"alert"},
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
var err error
if err := csConfig.LoadAPIClient(); err != nil {
PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
cfg := cli.cfg()
if err := cfg.LoadAPIClient(); err != nil {
return fmt.Errorf("loading api client: %w", err)
}
apiURL, err := url.Parse(csConfig.API.Client.Credentials.URL)
apiURL, err := url.Parse(cfg.API.Client.Credentials.URL)
if err != nil {
return fmt.Errorf("parsing api url %s: %w", apiURL, err)
}
cli.client, err = apiclient.NewClient(&apiclient.Config{
MachineID: csConfig.API.Client.Credentials.Login,
Password: strfmt.Password(csConfig.API.Client.Credentials.Password),
MachineID: cfg.API.Client.Credentials.Login,
Password: strfmt.Password(cfg.API.Client.Credentials.Password),
UserAgent: fmt.Sprintf("crowdsec/%s", version.String()),
URL: apiURL,
VersionPrefix: "v1",
@ -196,6 +215,7 @@ func (cli *cliAlerts) NewCommand() *cobra.Command {
if err != nil {
return fmt.Errorf("new api client: %w", err)
}
return nil
},
}
@ -221,8 +241,10 @@ func (cli *cliAlerts) NewListCmd() *cobra.Command {
IncludeCAPI: new(bool),
OriginEquals: new(string),
}
limit := new(int)
contained := new(bool)
var printMachine bool
cmd := &cobra.Command{
@ -234,9 +256,7 @@ cscli alerts list --range 1.2.3.0/24
cscli alerts list -s crowdsecurity/ssh-bf
cscli alerts list --type ban`,
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, args []string) error {
var err error
RunE: func(cmd *cobra.Command, _ []string) error {
if err := manageCliDecisionAlerts(alertListFilter.IPEquals, alertListFilter.RangeEquals,
alertListFilter.ScopeEquals, alertListFilter.ValueEquals); err != nil {
printHelp(cmd)
@ -304,40 +324,43 @@ cscli alerts list --type ban`,
alerts, _, err := cli.client.Alerts.List(context.Background(), alertListFilter)
if err != nil {
return fmt.Errorf("unable to list alerts: %v", err)
return fmt.Errorf("unable to list alerts: %w", err)
}
err = alertsToTable(alerts, printMachine)
if err != nil {
return fmt.Errorf("unable to list alerts: %v", err)
if err = cli.alertsToTable(alerts, printMachine); err != nil {
return fmt.Errorf("unable to list alerts: %w", err)
}
return nil
},
}
cmd.Flags().SortFlags = false
cmd.Flags().BoolVarP(alertListFilter.IncludeCAPI, "all", "a", false, "Include decisions from Central API")
cmd.Flags().StringVar(alertListFilter.Until, "until", "", "restrict to alerts older than until (ie. 4h, 30d)")
cmd.Flags().StringVar(alertListFilter.Since, "since", "", "restrict to alerts newer than since (ie. 4h, 30d)")
cmd.Flags().StringVarP(alertListFilter.IPEquals, "ip", "i", "", "restrict to alerts from this source ip (shorthand for --scope ip --value <IP>)")
cmd.Flags().StringVarP(alertListFilter.ScenarioEquals, "scenario", "s", "", "the scenario (ie. crowdsecurity/ssh-bf)")
cmd.Flags().StringVarP(alertListFilter.RangeEquals, "range", "r", "", "restrict to alerts from this range (shorthand for --scope range --value <RANGE/X>)")
cmd.Flags().StringVar(alertListFilter.TypeEquals, "type", "", "restrict to alerts with given decision type (ie. ban, captcha)")
cmd.Flags().StringVar(alertListFilter.ScopeEquals, "scope", "", "restrict to alerts of this scope (ie. ip,range)")
cmd.Flags().StringVarP(alertListFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope")
cmd.Flags().StringVar(alertListFilter.OriginEquals, "origin", "", fmt.Sprintf("the value to match for the specified origin (%s ...)", strings.Join(types.GetOrigins(), ",")))
cmd.Flags().BoolVar(contained, "contained", false, "query decisions contained by range")
cmd.Flags().BoolVarP(&printMachine, "machine", "m", false, "print machines that sent alerts")
cmd.Flags().IntVarP(limit, "limit", "l", 50, "limit size of alerts list table (0 to view all alerts)")
flags := cmd.Flags()
flags.SortFlags = false
flags.BoolVarP(alertListFilter.IncludeCAPI, "all", "a", false, "Include decisions from Central API")
flags.StringVar(alertListFilter.Until, "until", "", "restrict to alerts older than until (ie. 4h, 30d)")
flags.StringVar(alertListFilter.Since, "since", "", "restrict to alerts newer than since (ie. 4h, 30d)")
flags.StringVarP(alertListFilter.IPEquals, "ip", "i", "", "restrict to alerts from this source ip (shorthand for --scope ip --value <IP>)")
flags.StringVarP(alertListFilter.ScenarioEquals, "scenario", "s", "", "the scenario (ie. crowdsecurity/ssh-bf)")
flags.StringVarP(alertListFilter.RangeEquals, "range", "r", "", "restrict to alerts from this range (shorthand for --scope range --value <RANGE/X>)")
flags.StringVar(alertListFilter.TypeEquals, "type", "", "restrict to alerts with given decision type (ie. ban, captcha)")
flags.StringVar(alertListFilter.ScopeEquals, "scope", "", "restrict to alerts of this scope (ie. ip,range)")
flags.StringVarP(alertListFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope")
flags.StringVar(alertListFilter.OriginEquals, "origin", "", fmt.Sprintf("the value to match for the specified origin (%s ...)", strings.Join(types.GetOrigins(), ",")))
flags.BoolVar(contained, "contained", false, "query decisions contained by range")
flags.BoolVarP(&printMachine, "machine", "m", false, "print machines that sent alerts")
flags.IntVarP(limit, "limit", "l", 50, "limit size of alerts list table (0 to view all alerts)")
return cmd
}
func (cli *cliAlerts) NewDeleteCmd() *cobra.Command {
var ActiveDecision *bool
var AlertDeleteAll bool
var delAlertByID string
contained := new(bool)
var (
ActiveDecision *bool
AlertDeleteAll bool
delAlertByID string
)
var alertDeleteFilter = apiclient.AlertsDeleteOpts{
ScopeEquals: new(string),
ValueEquals: new(string),
@ -345,6 +368,9 @@ func (cli *cliAlerts) NewDeleteCmd() *cobra.Command {
IPEquals: new(string),
RangeEquals: new(string),
}
contained := new(bool)
cmd := &cobra.Command{
Use: "delete [filters] [--all]",
Short: `Delete alerts
@ -355,7 +381,7 @@ cscli alerts delete -s crowdsecurity/ssh-bf"`,
DisableAutoGenTag: true,
Aliases: []string{"remove"},
Args: cobra.ExactArgs(0),
PreRunE: func(cmd *cobra.Command, args []string) error {
PreRunE: func(cmd *cobra.Command, _ []string) error {
if AlertDeleteAll {
return nil
}
@ -368,11 +394,11 @@ cscli alerts delete -s crowdsecurity/ssh-bf"`,
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
RunE: func(cmd *cobra.Command, _ []string) error {
var err error
if !AlertDeleteAll {
if err := manageCliDecisionAlerts(alertDeleteFilter.IPEquals, alertDeleteFilter.RangeEquals,
if err = manageCliDecisionAlerts(alertDeleteFilter.IPEquals, alertDeleteFilter.RangeEquals,
alertDeleteFilter.ScopeEquals, alertDeleteFilter.ValueEquals); err != nil {
printHelp(cmd)
return err
@ -410,12 +436,12 @@ cscli alerts delete -s crowdsecurity/ssh-bf"`,
if delAlertByID == "" {
alerts, _, err = cli.client.Alerts.Delete(context.Background(), alertDeleteFilter)
if err != nil {
return fmt.Errorf("unable to delete alerts : %v", err)
return fmt.Errorf("unable to delete alerts: %w", err)
}
} else {
alerts, _, err = cli.client.Alerts.DeleteOne(context.Background(), delAlertByID)
if err != nil {
return fmt.Errorf("unable to delete alert: %v", err)
return fmt.Errorf("unable to delete alert: %w", err)
}
}
log.Infof("%s alert(s) deleted", alerts.NbDeleted)
@ -423,26 +449,31 @@ cscli alerts delete -s crowdsecurity/ssh-bf"`,
return nil
},
}
cmd.Flags().SortFlags = false
cmd.Flags().StringVar(alertDeleteFilter.ScopeEquals, "scope", "", "the scope (ie. ip,range)")
cmd.Flags().StringVarP(alertDeleteFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope")
cmd.Flags().StringVarP(alertDeleteFilter.ScenarioEquals, "scenario", "s", "", "the scenario (ie. crowdsecurity/ssh-bf)")
cmd.Flags().StringVarP(alertDeleteFilter.IPEquals, "ip", "i", "", "Source ip (shorthand for --scope ip --value <IP>)")
cmd.Flags().StringVarP(alertDeleteFilter.RangeEquals, "range", "r", "", "Range source ip (shorthand for --scope range --value <RANGE>)")
cmd.Flags().StringVar(&delAlertByID, "id", "", "alert ID")
cmd.Flags().BoolVarP(&AlertDeleteAll, "all", "a", false, "delete all alerts")
cmd.Flags().BoolVar(contained, "contained", false, "query decisions contained by range")
flags := cmd.Flags()
flags.SortFlags = false
flags.StringVar(alertDeleteFilter.ScopeEquals, "scope", "", "the scope (ie. ip,range)")
flags.StringVarP(alertDeleteFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope")
flags.StringVarP(alertDeleteFilter.ScenarioEquals, "scenario", "s", "", "the scenario (ie. crowdsecurity/ssh-bf)")
flags.StringVarP(alertDeleteFilter.IPEquals, "ip", "i", "", "Source ip (shorthand for --scope ip --value <IP>)")
flags.StringVarP(alertDeleteFilter.RangeEquals, "range", "r", "", "Range source ip (shorthand for --scope range --value <RANGE>)")
flags.StringVar(&delAlertByID, "id", "", "alert ID")
flags.BoolVarP(&AlertDeleteAll, "all", "a", false, "delete all alerts")
flags.BoolVar(contained, "contained", false, "query decisions contained by range")
return cmd
}
func (cli *cliAlerts) NewInspectCmd() *cobra.Command {
var details bool
cmd := &cobra.Command{
Use: `inspect "alert_id"`,
Short: `Show info about an alert`,
Example: `cscli alerts inspect 123`,
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, args []string) error {
cfg := cli.cfg()
if len(args) == 0 {
printHelp(cmd)
return fmt.Errorf("missing alert_id")
@ -454,31 +485,32 @@ func (cli *cliAlerts) NewInspectCmd() *cobra.Command {
}
alert, _, err := cli.client.Alerts.GetByID(context.Background(), id)
if err != nil {
return fmt.Errorf("can't find alert with id %s: %s", alertID, err)
return fmt.Errorf("can't find alert with id %s: %w", alertID, err)
}
switch csConfig.Cscli.Output {
switch cfg.Cscli.Output {
case "human":
if err := displayOneAlert(alert, details); err != nil {
if err := cli.displayOneAlert(alert, details); err != nil {
continue
}
case "json":
data, err := json.MarshalIndent(alert, "", " ")
if err != nil {
return fmt.Errorf("unable to marshal alert with id %s: %s", alertID, err)
return fmt.Errorf("unable to marshal alert with id %s: %w", alertID, err)
}
fmt.Printf("%s\n", string(data))
case "raw":
data, err := yaml.Marshal(alert)
if err != nil {
return fmt.Errorf("unable to marshal alert with id %s: %s", alertID, err)
return fmt.Errorf("unable to marshal alert with id %s: %w", alertID, err)
}
fmt.Printf("%s\n", string(data))
fmt.Println(string(data))
}
}
return nil
},
}
cmd.Flags().SortFlags = false
cmd.Flags().BoolVarP(&details, "details", "d", false, "show alerts with events")
@ -486,27 +518,30 @@ func (cli *cliAlerts) NewInspectCmd() *cobra.Command {
}
func (cli *cliAlerts) NewFlushCmd() *cobra.Command {
var maxItems int
var maxAge string
var (
maxItems int
maxAge string
)
cmd := &cobra.Command{
Use: `flush`,
Short: `Flush alerts
/!\ This command can be used only on the same machine than the local API`,
Example: `cscli alerts flush --max-items 1000 --max-age 7d`,
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, args []string) error {
var err error
if err := require.LAPI(csConfig); err != nil {
RunE: func(_ *cobra.Command, _ []string) error {
cfg := cli.cfg()
if err := require.LAPI(cfg); err != nil {
return err
}
db, err := database.NewClient(csConfig.DbConfig)
db, err := database.NewClient(cfg.DbConfig)
if err != nil {
return fmt.Errorf("unable to create new database client: %s", err)
return fmt.Errorf("unable to create new database client: %w", err)
}
log.Info("Flushing alerts. !! This may take a long time !!")
err = db.FlushAlerts(maxAge, maxItems)
if err != nil {
return fmt.Errorf("unable to flush alerts: %s", err)
return fmt.Errorf("unable to flush alerts: %w", err)
}
log.Info("Alerts flushed")

View file

@ -36,13 +36,13 @@ func askYesNo(message string, defaultAnswer bool) (bool, error) {
}
type cliBouncers struct {
db *database.Client
db *database.Client
cfg configGetter
}
func NewCLIBouncers(getconfig configGetter) *cliBouncers {
func NewCLIBouncers(cfg configGetter) *cliBouncers {
return &cliBouncers{
cfg: getconfig,
cfg: cfg,
}
}
@ -197,13 +197,13 @@ cscli bouncers add MyBouncerName --key <random-key>`,
return cmd
}
func (cli *cliBouncers) deleteValid(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
func (cli *cliBouncers) deleteValid(_ *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
bouncers, err := cli.db.ListBouncers()
if err != nil {
cobra.CompError("unable to list bouncers " + err.Error())
}
ret :=[]string{}
ret := []string{}
for _, bouncer := range bouncers {
if strings.Contains(bouncer.Name, toComplete) && !slices.Contains(args, bouncer.Name) {

View file

@ -146,7 +146,12 @@ func restoreConfigFromDirectory(dirPath string, oldBackup bool) error {
// Now we have config.yaml, we should regenerate config struct to have rights paths etc
ConfigFilePath = fmt.Sprintf("%s/config.yaml", csConfig.ConfigPaths.ConfigDir)
initConfig()
log.Debug("Reloading configuration")
csConfig, _, err = loadConfigFor("config")
if err != nil {
return fmt.Errorf("failed to reload configuration: %s", err)
}
backupCAPICreds := fmt.Sprintf("%s/online_api_credentials.yaml", dirPath)
if _, err = os.Stat(backupCAPICreds); err == nil {
@ -227,7 +232,7 @@ func restoreConfigFromDirectory(dirPath string, oldBackup bool) error {
}
}
// if there is files in the acquis backup dir, restore them
// if there are files in the acquis backup dir, restore them
acquisBackupDir := filepath.Join(dirPath, "acquis", "*.yaml")
if acquisFiles, err := filepath.Glob(acquisBackupDir); err == nil {
for _, acquisFile := range acquisFiles {

View file

@ -25,32 +25,53 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/types"
)
func NewConsoleCmd() *cobra.Command {
var cmdConsole = &cobra.Command{
type cliConsole struct {
cfg configGetter
}
func NewCLIConsole(cfg configGetter) *cliConsole {
return &cliConsole{
cfg: cfg,
}
}
func (cli *cliConsole) NewCommand() *cobra.Command {
var cmd = &cobra.Command{
Use: "console [action]",
Short: "Manage interaction with Crowdsec console (https://app.crowdsec.net)",
Args: cobra.MinimumNArgs(1),
DisableAutoGenTag: true,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
if err := require.LAPI(csConfig); err != nil {
PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
cfg := cli.cfg()
if err := require.LAPI(cfg); err != nil {
return err
}
if err := require.CAPI(csConfig); err != nil {
if err := require.CAPI(cfg); err != nil {
return err
}
if err := require.CAPIRegistered(csConfig); err != nil {
if err := require.CAPIRegistered(cfg); err != nil {
return err
}
return nil
},
}
cmd.AddCommand(cli.newEnrollCmd())
cmd.AddCommand(cli.newEnableCmd())
cmd.AddCommand(cli.newDisableCmd())
cmd.AddCommand(cli.newStatusCmd())
return cmd
}
func (cli *cliConsole) newEnrollCmd() *cobra.Command {
name := ""
overwrite := false
tags := []string{}
opts := []string{}
cmdEnroll := &cobra.Command{
cmd := &cobra.Command{
Use: "enroll [enroll-key]",
Short: "Enroll this instance to https://app.crowdsec.net [requires local API]",
Long: `
@ -66,96 +87,107 @@ After running this command your will need to validate the enrollment in the weba
valid options are : %s,all (see 'cscli console status' for details)`, strings.Join(csconfig.CONSOLE_CONFIGS, ",")),
Args: cobra.ExactArgs(1),
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, args []string) error {
password := strfmt.Password(csConfig.API.Server.OnlineClient.Credentials.Password)
apiURL, err := url.Parse(csConfig.API.Server.OnlineClient.Credentials.URL)
RunE: func(_ *cobra.Command, args []string) error {
cfg := cli.cfg()
password := strfmt.Password(cfg.API.Server.OnlineClient.Credentials.Password)
apiURL, err := url.Parse(cfg.API.Server.OnlineClient.Credentials.URL)
if err != nil {
return fmt.Errorf("could not parse CAPI URL: %s", err)
return fmt.Errorf("could not parse CAPI URL: %w", err)
}
hub, err := require.Hub(csConfig, nil, nil)
hub, err := require.Hub(cfg, nil, nil)
if err != nil {
return err
}
scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS)
if err != nil {
return fmt.Errorf("failed to get installed scenarios: %s", err)
return fmt.Errorf("failed to get installed scenarios: %w", err)
}
if len(scenarios) == 0 {
scenarios = make([]string, 0)
}
enable_opts := []string{csconfig.SEND_MANUAL_SCENARIOS, csconfig.SEND_TAINTED_SCENARIOS}
enableOpts := []string{csconfig.SEND_MANUAL_SCENARIOS, csconfig.SEND_TAINTED_SCENARIOS}
if len(opts) != 0 {
for _, opt := range opts {
valid := false
if opt == "all" {
enable_opts = csconfig.CONSOLE_CONFIGS
enableOpts = csconfig.CONSOLE_CONFIGS
break
}
for _, available_opt := range csconfig.CONSOLE_CONFIGS {
if opt == available_opt {
for _, availableOpt := range csconfig.CONSOLE_CONFIGS {
if opt == availableOpt {
valid = true
enable := true
for _, enabled_opt := range enable_opts {
if opt == enabled_opt {
for _, enabledOpt := range enableOpts {
if opt == enabledOpt {
enable = false
continue
}
}
if enable {
enable_opts = append(enable_opts, opt)
enableOpts = append(enableOpts, opt)
}
break
}
}
if !valid {
return fmt.Errorf("option %s doesn't exist", opt)
}
}
}
c, _ := apiclient.NewClient(&apiclient.Config{
MachineID: csConfig.API.Server.OnlineClient.Credentials.Login,
MachineID: cli.cfg().API.Server.OnlineClient.Credentials.Login,
Password: password,
Scenarios: scenarios,
UserAgent: fmt.Sprintf("crowdsec/%s", version.String()),
URL: apiURL,
VersionPrefix: "v3",
})
resp, err := c.Auth.EnrollWatcher(context.Background(), args[0], name, tags, overwrite)
if err != nil {
return fmt.Errorf("could not enroll instance: %s", err)
return fmt.Errorf("could not enroll instance: %w", err)
}
if resp.Response.StatusCode == 200 && !overwrite {
log.Warning("Instance already enrolled. You can use '--overwrite' to force enroll")
return nil
}
if err := SetConsoleOpts(enable_opts, true); err != nil {
if err := cli.setConsoleOpts(enableOpts, true); err != nil {
return err
}
for _, opt := range enable_opts {
for _, opt := range enableOpts {
log.Infof("Enabled %s : %s", opt, csconfig.CONSOLE_CONFIGS_HELP[opt])
}
log.Info("Watcher successfully enrolled. Visit https://app.crowdsec.net to accept it.")
log.Info("Please restart crowdsec after accepting the enrollment.")
return nil
},
}
cmdEnroll.Flags().StringVarP(&name, "name", "n", "", "Name to display in the console")
cmdEnroll.Flags().BoolVarP(&overwrite, "overwrite", "", false, "Force enroll the instance")
cmdEnroll.Flags().StringSliceVarP(&tags, "tags", "t", tags, "Tags to display in the console")
cmdEnroll.Flags().StringSliceVarP(&opts, "enable", "e", opts, "Enable console options")
cmdConsole.AddCommand(cmdEnroll)
var enableAll, disableAll bool
flags := cmd.Flags()
flags.StringVarP(&name, "name", "n", "", "Name to display in the console")
flags.BoolVarP(&overwrite, "overwrite", "", false, "Force enroll the instance")
flags.StringSliceVarP(&tags, "tags", "t", tags, "Tags to display in the console")
flags.StringSliceVarP(&opts, "enable", "e", opts, "Enable console options")
cmdEnable := &cobra.Command{
return cmd
}
func (cli *cliConsole) newEnableCmd() *cobra.Command {
var enableAll bool
cmd := &cobra.Command{
Use: "enable [option]",
Short: "Enable a console option",
Example: "sudo cscli console enable tainted",
@ -163,9 +195,9 @@ After running this command your will need to validate the enrollment in the weba
Enable given information push to the central API. Allows to empower the console`,
ValidArgs: csconfig.CONSOLE_CONFIGS,
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, args []string) error {
RunE: func(_ *cobra.Command, args []string) error {
if enableAll {
if err := SetConsoleOpts(csconfig.CONSOLE_CONFIGS, true); err != nil {
if err := cli.setConsoleOpts(csconfig.CONSOLE_CONFIGS, true); err != nil {
return err
}
log.Infof("All features have been enabled successfully")
@ -173,19 +205,26 @@ Enable given information push to the central API. Allows to empower the console`
if len(args) == 0 {
return fmt.Errorf("you must specify at least one feature to enable")
}
if err := SetConsoleOpts(args, true); err != nil {
if err := cli.setConsoleOpts(args, true); err != nil {
return err
}
log.Infof("%v have been enabled", args)
}
log.Infof(ReloadMessage())
return nil
},
}
cmdEnable.Flags().BoolVarP(&enableAll, "all", "a", false, "Enable all console options")
cmdConsole.AddCommand(cmdEnable)
cmd.Flags().BoolVarP(&enableAll, "all", "a", false, "Enable all console options")
cmdDisable := &cobra.Command{
return cmd
}
func (cli *cliConsole) newDisableCmd() *cobra.Command {
var disableAll bool
cmd := &cobra.Command{
Use: "disable [option]",
Short: "Disable a console option",
Example: "sudo cscli console disable tainted",
@ -193,47 +232,52 @@ Enable given information push to the central API. Allows to empower the console`
Disable given information push to the central API.`,
ValidArgs: csconfig.CONSOLE_CONFIGS,
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, args []string) error {
RunE: func(_ *cobra.Command, args []string) error {
if disableAll {
if err := SetConsoleOpts(csconfig.CONSOLE_CONFIGS, false); err != nil {
if err := cli.setConsoleOpts(csconfig.CONSOLE_CONFIGS, false); err != nil {
return err
}
log.Infof("All features have been disabled")
} else {
if err := SetConsoleOpts(args, false); err != nil {
if err := cli.setConsoleOpts(args, false); err != nil {
return err
}
log.Infof("%v have been disabled", args)
}
log.Infof(ReloadMessage())
return nil
},
}
cmdDisable.Flags().BoolVarP(&disableAll, "all", "a", false, "Disable all console options")
cmdConsole.AddCommand(cmdDisable)
cmd.Flags().BoolVarP(&disableAll, "all", "a", false, "Disable all console options")
cmdConsoleStatus := &cobra.Command{
return cmd
}
func (cli *cliConsole) newStatusCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "status",
Short: "Shows status of the console options",
Example: `sudo cscli console status`,
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, args []string) error {
switch csConfig.Cscli.Output {
RunE: func(_ *cobra.Command, _ []string) error {
cfg := cli.cfg()
consoleCfg := cfg.API.Server.ConsoleConfig
switch cfg.Cscli.Output {
case "human":
cmdConsoleStatusTable(color.Output, *csConfig)
cmdConsoleStatusTable(color.Output, *consoleCfg)
case "json":
c := csConfig.API.Server.ConsoleConfig
out := map[string](*bool){
csconfig.SEND_MANUAL_SCENARIOS: c.ShareManualDecisions,
csconfig.SEND_CUSTOM_SCENARIOS: c.ShareCustomScenarios,
csconfig.SEND_TAINTED_SCENARIOS: c.ShareTaintedScenarios,
csconfig.SEND_CONTEXT: c.ShareContext,
csconfig.CONSOLE_MANAGEMENT: c.ConsoleManagement,
csconfig.SEND_MANUAL_SCENARIOS: consoleCfg.ShareManualDecisions,
csconfig.SEND_CUSTOM_SCENARIOS: consoleCfg.ShareCustomScenarios,
csconfig.SEND_TAINTED_SCENARIOS: consoleCfg.ShareTaintedScenarios,
csconfig.SEND_CONTEXT: consoleCfg.ShareContext,
csconfig.CONSOLE_MANAGEMENT: consoleCfg.ConsoleManagement,
}
data, err := json.MarshalIndent(out, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal configuration: %s", err)
return fmt.Errorf("failed to marshal configuration: %w", err)
}
fmt.Println(string(data))
case "raw":
@ -244,11 +288,11 @@ Disable given information push to the central API.`,
}
rows := [][]string{
{csconfig.SEND_MANUAL_SCENARIOS, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ShareManualDecisions)},
{csconfig.SEND_CUSTOM_SCENARIOS, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ShareCustomScenarios)},
{csconfig.SEND_TAINTED_SCENARIOS, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios)},
{csconfig.SEND_CONTEXT, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ShareContext)},
{csconfig.CONSOLE_MANAGEMENT, fmt.Sprintf("%t", *csConfig.API.Server.ConsoleConfig.ConsoleManagement)},
{csconfig.SEND_MANUAL_SCENARIOS, fmt.Sprintf("%t", *consoleCfg.ShareManualDecisions)},
{csconfig.SEND_CUSTOM_SCENARIOS, fmt.Sprintf("%t", *consoleCfg.ShareCustomScenarios)},
{csconfig.SEND_TAINTED_SCENARIOS, fmt.Sprintf("%t", *consoleCfg.ShareTaintedScenarios)},
{csconfig.SEND_CONTEXT, fmt.Sprintf("%t", *consoleCfg.ShareContext)},
{csconfig.CONSOLE_MANAGEMENT, fmt.Sprintf("%t", *consoleCfg.ConsoleManagement)},
}
for _, row := range rows {
err = csvwriter.Write(row)
@ -258,132 +302,137 @@ Disable given information push to the central API.`,
}
csvwriter.Flush()
}
return nil
},
}
cmdConsole.AddCommand(cmdConsoleStatus)
return cmdConsole
return cmd
}
func dumpConsoleConfig(c *csconfig.LocalApiServerCfg) error {
out, err := yaml.Marshal(c.ConsoleConfig)
func (cli *cliConsole) dumpConfig() error {
serverCfg := cli.cfg().API.Server
out, err := yaml.Marshal(serverCfg.ConsoleConfig)
if err != nil {
return fmt.Errorf("while marshaling ConsoleConfig (for %s): %w", c.ConsoleConfigPath, err)
return fmt.Errorf("while marshaling ConsoleConfig (for %s): %w", serverCfg.ConsoleConfigPath, err)
}
if c.ConsoleConfigPath == "" {
c.ConsoleConfigPath = csconfig.DefaultConsoleConfigFilePath
log.Debugf("Empty console_path, defaulting to %s", c.ConsoleConfigPath)
if serverCfg.ConsoleConfigPath == "" {
serverCfg.ConsoleConfigPath = csconfig.DefaultConsoleConfigFilePath
log.Debugf("Empty console_path, defaulting to %s", serverCfg.ConsoleConfigPath)
}
if err := os.WriteFile(c.ConsoleConfigPath, out, 0o600); err != nil {
return fmt.Errorf("while dumping console config to %s: %w", c.ConsoleConfigPath, err)
if err := os.WriteFile(serverCfg.ConsoleConfigPath, out, 0o600); err != nil {
return fmt.Errorf("while dumping console config to %s: %w", serverCfg.ConsoleConfigPath, err)
}
return nil
}
func SetConsoleOpts(args []string, wanted bool) error {
func (cli *cliConsole) setConsoleOpts(args []string, wanted bool) error {
cfg := cli.cfg()
consoleCfg := cfg.API.Server.ConsoleConfig
for _, arg := range args {
switch arg {
case csconfig.CONSOLE_MANAGEMENT:
/*for each flag check if it's already set before setting it*/
if csConfig.API.Server.ConsoleConfig.ConsoleManagement != nil {
if *csConfig.API.Server.ConsoleConfig.ConsoleManagement == wanted {
if consoleCfg.ConsoleManagement != nil {
if *consoleCfg.ConsoleManagement == wanted {
log.Debugf("%s already set to %t", csconfig.CONSOLE_MANAGEMENT, wanted)
} else {
log.Infof("%s set to %t", csconfig.CONSOLE_MANAGEMENT, wanted)
*csConfig.API.Server.ConsoleConfig.ConsoleManagement = wanted
*consoleCfg.ConsoleManagement = wanted
}
} else {
log.Infof("%s set to %t", csconfig.CONSOLE_MANAGEMENT, wanted)
csConfig.API.Server.ConsoleConfig.ConsoleManagement = ptr.Of(wanted)
consoleCfg.ConsoleManagement = ptr.Of(wanted)
}
if csConfig.API.Server.OnlineClient.Credentials != nil {
if cfg.API.Server.OnlineClient.Credentials != nil {
changed := false
if wanted && csConfig.API.Server.OnlineClient.Credentials.PapiURL == "" {
if wanted && cfg.API.Server.OnlineClient.Credentials.PapiURL == "" {
changed = true
csConfig.API.Server.OnlineClient.Credentials.PapiURL = types.PAPIBaseURL
} else if !wanted && csConfig.API.Server.OnlineClient.Credentials.PapiURL != "" {
cfg.API.Server.OnlineClient.Credentials.PapiURL = types.PAPIBaseURL
} else if !wanted && cfg.API.Server.OnlineClient.Credentials.PapiURL != "" {
changed = true
csConfig.API.Server.OnlineClient.Credentials.PapiURL = ""
cfg.API.Server.OnlineClient.Credentials.PapiURL = ""
}
if changed {
fileContent, err := yaml.Marshal(csConfig.API.Server.OnlineClient.Credentials)
fileContent, err := yaml.Marshal(cfg.API.Server.OnlineClient.Credentials)
if err != nil {
return fmt.Errorf("cannot marshal credentials: %s", err)
return fmt.Errorf("cannot marshal credentials: %w", err)
}
log.Infof("Updating credentials file: %s", csConfig.API.Server.OnlineClient.CredentialsFilePath)
log.Infof("Updating credentials file: %s", cfg.API.Server.OnlineClient.CredentialsFilePath)
err = os.WriteFile(csConfig.API.Server.OnlineClient.CredentialsFilePath, fileContent, 0o600)
err = os.WriteFile(cfg.API.Server.OnlineClient.CredentialsFilePath, fileContent, 0o600)
if err != nil {
return fmt.Errorf("cannot write credentials file: %s", err)
return fmt.Errorf("cannot write credentials file: %w", err)
}
}
}
case csconfig.SEND_CUSTOM_SCENARIOS:
/*for each flag check if it's already set before setting it*/
if csConfig.API.Server.ConsoleConfig.ShareCustomScenarios != nil {
if *csConfig.API.Server.ConsoleConfig.ShareCustomScenarios == wanted {
if consoleCfg.ShareCustomScenarios != nil {
if *consoleCfg.ShareCustomScenarios == wanted {
log.Debugf("%s already set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted)
} else {
log.Infof("%s set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted)
*csConfig.API.Server.ConsoleConfig.ShareCustomScenarios = wanted
*consoleCfg.ShareCustomScenarios = wanted
}
} else {
log.Infof("%s set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted)
csConfig.API.Server.ConsoleConfig.ShareCustomScenarios = ptr.Of(wanted)
consoleCfg.ShareCustomScenarios = ptr.Of(wanted)
}
case csconfig.SEND_TAINTED_SCENARIOS:
/*for each flag check if it's already set before setting it*/
if csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios != nil {
if *csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios == wanted {
if consoleCfg.ShareTaintedScenarios != nil {
if *consoleCfg.ShareTaintedScenarios == wanted {
log.Debugf("%s already set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted)
} else {
log.Infof("%s set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted)
*csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios = wanted
*consoleCfg.ShareTaintedScenarios = wanted
}
} else {
log.Infof("%s set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted)
csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios = ptr.Of(wanted)
consoleCfg.ShareTaintedScenarios = ptr.Of(wanted)
}
case csconfig.SEND_MANUAL_SCENARIOS:
/*for each flag check if it's already set before setting it*/
if csConfig.API.Server.ConsoleConfig.ShareManualDecisions != nil {
if *csConfig.API.Server.ConsoleConfig.ShareManualDecisions == wanted {
if consoleCfg.ShareManualDecisions != nil {
if *consoleCfg.ShareManualDecisions == wanted {
log.Debugf("%s already set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted)
} else {
log.Infof("%s set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted)
*csConfig.API.Server.ConsoleConfig.ShareManualDecisions = wanted
*consoleCfg.ShareManualDecisions = wanted
}
} else {
log.Infof("%s set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted)
csConfig.API.Server.ConsoleConfig.ShareManualDecisions = ptr.Of(wanted)
consoleCfg.ShareManualDecisions = ptr.Of(wanted)
}
case csconfig.SEND_CONTEXT:
/*for each flag check if it's already set before setting it*/
if csConfig.API.Server.ConsoleConfig.ShareContext != nil {
if *csConfig.API.Server.ConsoleConfig.ShareContext == wanted {
if consoleCfg.ShareContext != nil {
if *consoleCfg.ShareContext == wanted {
log.Debugf("%s already set to %t", csconfig.SEND_CONTEXT, wanted)
} else {
log.Infof("%s set to %t", csconfig.SEND_CONTEXT, wanted)
*csConfig.API.Server.ConsoleConfig.ShareContext = wanted
*consoleCfg.ShareContext = wanted
}
} else {
log.Infof("%s set to %t", csconfig.SEND_CONTEXT, wanted)
csConfig.API.Server.ConsoleConfig.ShareContext = ptr.Of(wanted)
consoleCfg.ShareContext = ptr.Of(wanted)
}
default:
return fmt.Errorf("unknown flag %s", arg)
}
}
if err := dumpConsoleConfig(csConfig.API.Server); err != nil {
return fmt.Errorf("failed writing console config: %s", err)
if err := cli.dumpConfig(); err != nil {
return fmt.Errorf("failed writing console config: %w", err)
}
return nil

View file

@ -9,7 +9,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
)
func cmdConsoleStatusTable(out io.Writer, csConfig csconfig.Config) {
func cmdConsoleStatusTable(out io.Writer, consoleCfg csconfig.ConsoleConfig) {
t := newTable(out)
t.SetRowLines(false)
@ -18,28 +18,30 @@ func cmdConsoleStatusTable(out io.Writer, csConfig csconfig.Config) {
for _, option := range csconfig.CONSOLE_CONFIGS {
activated := string(emoji.CrossMark)
switch option {
case csconfig.SEND_CUSTOM_SCENARIOS:
if *csConfig.API.Server.ConsoleConfig.ShareCustomScenarios {
if *consoleCfg.ShareCustomScenarios {
activated = string(emoji.CheckMarkButton)
}
case csconfig.SEND_MANUAL_SCENARIOS:
if *csConfig.API.Server.ConsoleConfig.ShareManualDecisions {
if *consoleCfg.ShareManualDecisions {
activated = string(emoji.CheckMarkButton)
}
case csconfig.SEND_TAINTED_SCENARIOS:
if *csConfig.API.Server.ConsoleConfig.ShareTaintedScenarios {
if *consoleCfg.ShareTaintedScenarios {
activated = string(emoji.CheckMarkButton)
}
case csconfig.SEND_CONTEXT:
if *csConfig.API.Server.ConsoleConfig.ShareContext {
if *consoleCfg.ShareContext {
activated = string(emoji.CheckMarkButton)
}
case csconfig.CONSOLE_MANAGEMENT:
if *csConfig.API.Server.ConsoleConfig.ConsoleManagement {
if *consoleCfg.ConsoleManagement {
activated = string(emoji.CheckMarkButton)
}
}
t.AddRow(option, activated, csconfig.CONSOLE_CONFIGS_HELP[option])
}

View file

@ -19,15 +19,14 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/crowdsecurity/crowdsec/pkg/metabase"
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
"github.com/crowdsecurity/crowdsec/pkg/metabase"
)
var (
metabaseUser = "crowdsec@crowdsec.net"
metabasePassword string
metabaseDbPath string
metabaseDBPath string
metabaseConfigPath string
metabaseConfigFolder = "metabase/"
metabaseConfigFile = "metabase.yaml"
@ -43,13 +42,13 @@ var (
// information needed to set up a random password on user's behalf
)
type cliDashboard struct{
type cliDashboard struct {
cfg configGetter
}
func NewCLIDashboard(getconfig configGetter) *cliDashboard {
func NewCLIDashboard(cfg configGetter) *cliDashboard {
return &cliDashboard{
cfg: getconfig,
cfg: cfg,
}
}
@ -99,6 +98,7 @@ cscli dashboard remove
metabaseContainerID = oldContainerID
}
}
return nil
},
}
@ -127,8 +127,8 @@ cscli dashboard setup --listen 0.0.0.0
cscli dashboard setup -l 0.0.0.0 -p 443 --password <password>
`,
RunE: func(_ *cobra.Command, _ []string) error {
if metabaseDbPath == "" {
metabaseDbPath = cli.cfg().ConfigPaths.DataDir
if metabaseDBPath == "" {
metabaseDBPath = cli.cfg().ConfigPaths.DataDir
}
if metabasePassword == "" {
@ -152,7 +152,7 @@ cscli dashboard setup -l 0.0.0.0 -p 443 --password <password>
if err = cli.chownDatabase(dockerGroup.Gid); err != nil {
return err
}
mb, err := metabase.SetupMetabase(cli.cfg().API.Server.DbConfig, metabaseListenAddress, metabaseListenPort, metabaseUser, metabasePassword, metabaseDbPath, dockerGroup.Gid, metabaseContainerID, metabaseImage)
mb, err := metabase.SetupMetabase(cli.cfg().API.Server.DbConfig, metabaseListenAddress, metabaseListenPort, metabaseUser, metabasePassword, metabaseDBPath, dockerGroup.Gid, metabaseContainerID, metabaseImage)
if err != nil {
return err
}
@ -165,18 +165,19 @@ cscli dashboard setup -l 0.0.0.0 -p 443 --password <password>
fmt.Printf("\tURL : '%s'\n", mb.Config.ListenURL)
fmt.Printf("\tusername : '%s'\n", mb.Config.Username)
fmt.Printf("\tpassword : '%s'\n", mb.Config.Password)
return nil
},
}
flags := cmd.Flags()
flags.BoolVarP(&force, "force", "f", false, "Force setup : override existing files")
flags.StringVarP(&metabaseDbPath, "dir", "d", "", "Shared directory with metabase container")
flags.StringVarP(&metabaseDBPath, "dir", "d", "", "Shared directory with metabase container")
flags.StringVarP(&metabaseListenAddress, "listen", "l", metabaseListenAddress, "Listen address of container")
flags.StringVar(&metabaseImage, "metabase-image", metabaseImage, "Metabase image to use")
flags.StringVarP(&metabaseListenPort, "port", "p", metabaseListenPort, "Listen port of container")
flags.BoolVarP(&forceYes, "yes", "y", false, "force yes")
//flags.StringVarP(&metabaseUser, "user", "u", "crowdsec@crowdsec.net", "metabase user")
// flags.StringVarP(&metabaseUser, "user", "u", "crowdsec@crowdsec.net", "metabase user")
flags.StringVar(&metabasePassword, "password", "", "metabase password")
return cmd
@ -203,6 +204,7 @@ func (cli *cliDashboard) newStartCmd() *cobra.Command {
}
log.Infof("Started metabase")
log.Infof("url : http://%s:%s", mb.Config.ListenAddr, mb.Config.ListenPort)
return nil
},
}
@ -241,6 +243,7 @@ func (cli *cliDashboard) newShowPasswordCmd() *cobra.Command {
return err
}
log.Printf("'%s'", m.Config.Password)
return nil
},
}
@ -313,6 +316,7 @@ cscli dashboard remove --force
}
}
}
return nil
},
}
@ -443,6 +447,7 @@ func checkGroups(forceYes *bool) (*user.Group, error) {
func (cli *cliDashboard) chownDatabase(gid string) error {
cfg := cli.cfg()
intID, err := strconv.Atoi(gid)
if err != nil {
return fmt.Errorf("unable to convert group ID to int: %s", err)
}

View file

@ -13,9 +13,9 @@ type cliDashboard struct{
cfg configGetter
}
func NewCLIDashboard(getconfig configGetter) *cliDashboard {
func NewCLIDashboard(cfg configGetter) *cliDashboard {
return &cliDashboard{
cfg: getconfig,
cfg: cfg,
}
}

View file

@ -25,7 +25,7 @@ import (
var Client *apiclient.ApiClient
func DecisionsToTable(alerts *models.GetAlertsResponse, printMachine bool) error {
func (cli *cliDecisions) decisionsToTable(alerts *models.GetAlertsResponse, printMachine bool) error {
/*here we cheat a bit : to make it more readable for the user, we dedup some entries*/
spamLimit := make(map[string]bool)
skipped := 0
@ -49,7 +49,8 @@ func DecisionsToTable(alerts *models.GetAlertsResponse, printMachine bool) error
alertItem.Decisions = newDecisions
}
if csConfig.Cscli.Output == "raw" {
switch cli.cfg().Cscli.Output {
case "raw":
csvwriter := csv.NewWriter(os.Stdout)
header := []string{"id", "source", "ip", "reason", "action", "country", "as", "events_count", "expiration", "simulated", "alert_id"}
@ -89,21 +90,24 @@ func DecisionsToTable(alerts *models.GetAlertsResponse, printMachine bool) error
}
csvwriter.Flush()
} else if csConfig.Cscli.Output == "json" {
case "json":
if *alerts == nil {
// avoid returning "null" in `json"
// could be cleaner if we used slice of alerts directly
fmt.Println("[]")
return nil
}
x, _ := json.MarshalIndent(alerts, "", " ")
fmt.Printf("%s", string(x))
} else if csConfig.Cscli.Output == "human" {
case "human":
if len(*alerts) == 0 {
fmt.Println("No active decisions")
return nil
}
decisionsTable(color.Output, alerts, printMachine)
cli.decisionsTable(color.Output, alerts, printMachine)
if skipped > 0 {
fmt.Printf("%d duplicated entries skipped\n", skipped)
}
@ -112,14 +116,17 @@ func DecisionsToTable(alerts *models.GetAlertsResponse, printMachine bool) error
return nil
}
type cliDecisions struct {}
func NewCLIDecisions() *cliDecisions {
return &cliDecisions{}
type cliDecisions struct {
cfg configGetter
}
func (cli cliDecisions) NewCommand() *cobra.Command {
func NewCLIDecisions(cfg configGetter) *cliDecisions {
return &cliDecisions{
cfg: cfg,
}
}
func (cli *cliDecisions) NewCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "decisions [action]",
Short: "Manage decisions",
@ -130,16 +137,17 @@ func (cli cliDecisions) NewCommand() *cobra.Command {
Args: cobra.MinimumNArgs(1),
DisableAutoGenTag: true,
PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
if err := csConfig.LoadAPIClient(); err != nil {
cfg := cli.cfg()
if err := cfg.LoadAPIClient(); err != nil {
return fmt.Errorf("loading api client: %w", err)
}
password := strfmt.Password(csConfig.API.Client.Credentials.Password)
apiurl, err := url.Parse(csConfig.API.Client.Credentials.URL)
password := strfmt.Password(cfg.API.Client.Credentials.Password)
apiurl, err := url.Parse(cfg.API.Client.Credentials.URL)
if err != nil {
return fmt.Errorf("parsing api url %s: %w", csConfig.API.Client.Credentials.URL, err)
return fmt.Errorf("parsing api url %s: %w", cfg.API.Client.Credentials.URL, err)
}
Client, err = apiclient.NewClient(&apiclient.Config{
MachineID: csConfig.API.Client.Credentials.Login,
MachineID: cfg.API.Client.Credentials.Login,
Password: password,
UserAgent: fmt.Sprintf("crowdsec/%s", version.String()),
URL: apiurl,
@ -148,19 +156,20 @@ func (cli cliDecisions) NewCommand() *cobra.Command {
if err != nil {
return fmt.Errorf("creating api client: %w", err)
}
return nil
},
}
cmd.AddCommand(cli.NewListCmd())
cmd.AddCommand(cli.NewAddCmd())
cmd.AddCommand(cli.NewDeleteCmd())
cmd.AddCommand(cli.NewImportCmd())
cmd.AddCommand(cli.newListCmd())
cmd.AddCommand(cli.newAddCmd())
cmd.AddCommand(cli.newDeleteCmd())
cmd.AddCommand(cli.newImportCmd())
return cmd
}
func (cli cliDecisions) NewListCmd() *cobra.Command {
func (cli *cliDecisions) newListCmd() *cobra.Command {
var filter = apiclient.AlertsListOpts{
ValueEquals: new(string),
ScopeEquals: new(string),
@ -262,7 +271,7 @@ cscli decisions list -t ban
return fmt.Errorf("unable to retrieve decisions: %w", err)
}
err = DecisionsToTable(alerts, printMachine)
err = cli.decisionsToTable(alerts, printMachine)
if err != nil {
return fmt.Errorf("unable to print decisions: %w", err)
}
@ -289,7 +298,7 @@ cscli decisions list -t ban
return cmd
}
func (cli cliDecisions) NewAddCmd() *cobra.Command {
func (cli *cliDecisions) newAddCmd() *cobra.Command {
var (
addIP string
addRange string
@ -325,7 +334,7 @@ cscli decisions add --scope username --value foobar
createdAt := time.Now().UTC().Format(time.RFC3339)
/*take care of shorthand options*/
if err := manageCliDecisionAlerts(&addIP, &addRange, &addScope, &addValue); err != nil {
if err = manageCliDecisionAlerts(&addIP, &addRange, &addScope, &addValue); err != nil {
return err
}
@ -341,7 +350,7 @@ cscli decisions add --scope username --value foobar
}
if addReason == "" {
addReason = fmt.Sprintf("manual '%s' from '%s'", addType, csConfig.API.Client.Credentials.Login)
addReason = fmt.Sprintf("manual '%s' from '%s'", addType, cli.cfg().API.Client.Credentials.Login)
}
decision := models.Decision{
Duration: &addDuration,
@ -384,6 +393,7 @@ cscli decisions add --scope username --value foobar
}
log.Info("Decision successfully added")
return nil
},
}
@ -400,7 +410,7 @@ cscli decisions add --scope username --value foobar
return cmd
}
func (cli cliDecisions) NewDeleteCmd() *cobra.Command {
func (cli *cliDecisions) newDeleteCmd() *cobra.Command {
var delFilter = apiclient.DecisionsDeleteOpts{
ScopeEquals: new(string),
ValueEquals: new(string),
@ -490,6 +500,7 @@ cscli decisions delete --type captcha
}
}
log.Infof("%s decision(s) deleted", decisions.NbDeleted)
return nil
},
}

View file

@ -67,7 +67,7 @@ func parseDecisionList(content []byte, format string) ([]decisionRaw, error) {
}
func (cli cliDecisions) runImport(cmd *cobra.Command, args []string) error {
func (cli *cliDecisions) runImport(cmd *cobra.Command, args []string) error {
flags := cmd.Flags()
input, err := flags.GetString("input")
@ -236,7 +236,7 @@ func (cli cliDecisions) runImport(cmd *cobra.Command, args []string) error {
}
func (cli cliDecisions) NewImportCmd() *cobra.Command {
func (cli *cliDecisions) newImportCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "import [options]",
Short: "Import decisions from a file or pipe",

View file

@ -8,13 +8,15 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/models"
)
func decisionsTable(out io.Writer, alerts *models.GetAlertsResponse, printMachine bool) {
func (cli *cliDecisions) decisionsTable(out io.Writer, alerts *models.GetAlertsResponse, printMachine bool) {
t := newTable(out)
t.SetRowLines(false)
header := []string{"ID", "Source", "Scope:Value", "Reason", "Action", "Country", "AS", "Events", "expiration", "Alert ID"}
if printMachine {
header = append(header, "Machine")
}
t.SetHeaders(header...)
for _, alertItem := range *alerts {
@ -22,6 +24,7 @@ func decisionsTable(out io.Writer, alerts *models.GetAlertsResponse, printMachin
if *alertItem.Simulated {
*decisionItem.Type = fmt.Sprintf("(simul)%s", *decisionItem.Type)
}
row := []string{
strconv.Itoa(int(decisionItem.ID)),
*decisionItem.Origin,
@ -42,5 +45,6 @@ func decisionsTable(out io.Writer, alerts *models.GetAlertsResponse, printMachin
t.AddRow(row...)
}
}
t.Render()
}

View file

@ -16,33 +16,53 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/hubtest"
)
func GetLineCountForFile(filepath string) (int, error) {
func getLineCountForFile(filepath string) (int, error) {
f, err := os.Open(filepath)
if err != nil {
return 0, err
}
defer f.Close()
lc := 0
fs := bufio.NewReader(f)
for {
input, err := fs.ReadBytes('\n')
if len(input) > 1 {
lc++
}
if err != nil && err == io.EOF {
break
}
}
return lc, nil
}
type cliExplain struct{}
func NewCLIExplain() *cliExplain {
return &cliExplain{}
type cliExplain struct {
cfg configGetter
flags struct {
logFile string
dsn string
logLine string
logType string
details bool
skipOk bool
onlySuccessfulParsers bool
noClean bool
crowdsec string
labels string
}
}
func (cli cliExplain) NewCommand() *cobra.Command {
func NewCLIExplain(cfg configGetter) *cliExplain {
return &cliExplain{
cfg: cfg,
}
}
func (cli *cliExplain) NewCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "explain",
Short: "Explain log pipeline",
@ -57,118 +77,50 @@ tail -n 5 myfile.log | cscli explain --type nginx -f -
`,
Args: cobra.ExactArgs(0),
DisableAutoGenTag: true,
RunE: cli.run,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
flags := cmd.Flags()
logFile, err := flags.GetString("file")
if err != nil {
return err
}
dsn, err := flags.GetString("dsn")
if err != nil {
return err
}
logLine, err := flags.GetString("log")
if err != nil {
return err
}
logType, err := flags.GetString("type")
if err != nil {
return err
}
if logLine == "" && logFile == "" && dsn == "" {
printHelp(cmd)
fmt.Println()
return fmt.Errorf("please provide --log, --file or --dsn flag")
}
if logType == "" {
printHelp(cmd)
fmt.Println()
return fmt.Errorf("please provide --type flag")
}
RunE: func(_ *cobra.Command, _ []string) error {
return cli.run()
},
PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
fileInfo, _ := os.Stdin.Stat()
if logFile == "-" && ((fileInfo.Mode() & os.ModeCharDevice) == os.ModeCharDevice) {
if cli.flags.logFile == "-" && ((fileInfo.Mode() & os.ModeCharDevice) == os.ModeCharDevice) {
return fmt.Errorf("the option -f - is intended to work with pipes")
}
return nil
},
}
flags := cmd.Flags()
flags.StringP("file", "f", "", "Log file to test")
flags.StringP("dsn", "d", "", "DSN to test")
flags.StringP("log", "l", "", "Log line to test")
flags.StringP("type", "t", "", "Type of the acquisition to test")
flags.String("labels", "", "Additional labels to add to the acquisition format (key:value,key2:value2)")
flags.BoolP("verbose", "v", false, "Display individual changes")
flags.Bool("failures", false, "Only show failed lines")
flags.Bool("only-successful-parsers", false, "Only show successful parsers")
flags.String("crowdsec", "crowdsec", "Path to crowdsec")
flags.Bool("no-clean", false, "Don't clean runtime environment after tests")
flags.StringVarP(&cli.flags.logFile, "file", "f", "", "Log file to test")
flags.StringVarP(&cli.flags.dsn, "dsn", "d", "", "DSN to test")
flags.StringVarP(&cli.flags.logLine, "log", "l", "", "Log line to test")
flags.StringVarP(&cli.flags.logType, "type", "t", "", "Type of the acquisition to test")
flags.StringVar(&cli.flags.labels, "labels", "", "Additional labels to add to the acquisition format (key:value,key2:value2)")
flags.BoolVarP(&cli.flags.details, "verbose", "v", false, "Display individual changes")
flags.BoolVar(&cli.flags.skipOk, "failures", false, "Only show failed lines")
flags.BoolVar(&cli.flags.onlySuccessfulParsers, "only-successful-parsers", false, "Only show successful parsers")
flags.StringVar(&cli.flags.crowdsec, "crowdsec", "crowdsec", "Path to crowdsec")
flags.BoolVar(&cli.flags.noClean, "no-clean", false, "Don't clean runtime environment after tests")
cmd.MarkFlagRequired("type")
cmd.MarkFlagsOneRequired("log", "file", "dsn")
return cmd
}
func (cli cliExplain) run(cmd *cobra.Command, args []string) error {
flags := cmd.Flags()
func (cli *cliExplain) run() error {
logFile := cli.flags.logFile
logLine := cli.flags.logLine
logType := cli.flags.logType
dsn := cli.flags.dsn
labels := cli.flags.labels
crowdsec := cli.flags.crowdsec
logFile, err := flags.GetString("file")
if err != nil {
return err
}
dsn, err := flags.GetString("dsn")
if err != nil {
return err
}
logLine, err := flags.GetString("log")
if err != nil {
return err
}
logType, err := flags.GetString("type")
if err != nil {
return err
}
opts := dumps.DumpOpts{}
opts.Details, err = flags.GetBool("verbose")
if err != nil {
return err
}
no_clean, err := flags.GetBool("no-clean")
if err != nil {
return err
}
opts.SkipOk, err = flags.GetBool("failures")
if err != nil {
return err
}
opts.ShowNotOkParsers, err = flags.GetBool("only-successful-parsers")
opts.ShowNotOkParsers = !opts.ShowNotOkParsers
if err != nil {
return err
}
crowdsec, err := flags.GetString("crowdsec")
if err != nil {
return err
}
labels, err := flags.GetString("labels")
if err != nil {
return err
opts := dumps.DumpOpts{
Details: cli.flags.details,
SkipOk: cli.flags.skipOk,
ShowNotOkParsers: !cli.flags.onlySuccessfulParsers,
}
var f *os.File
@ -176,21 +128,25 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error {
// using empty string fallback to /tmp
dir, err := os.MkdirTemp("", "cscli_explain")
if err != nil {
return fmt.Errorf("couldn't create a temporary directory to store cscli explain result: %s", err)
return fmt.Errorf("couldn't create a temporary directory to store cscli explain result: %w", err)
}
defer func() {
if no_clean {
if cli.flags.noClean {
return
}
if _, err := os.Stat(dir); !os.IsNotExist(err) {
if err := os.RemoveAll(dir); err != nil {
log.Errorf("unable to delete temporary directory '%s': %s", dir, err)
}
}
}()
// we create a temporary log file if a log line/stdin has been provided
if logLine != "" || logFile == "-" {
tmpFile := filepath.Join(dir, "cscli_test_tmp.log")
f, err = os.Create(tmpFile)
if err != nil {
return err
@ -220,6 +176,7 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error {
log.Warnf("Failed to write %d lines to %s", errCount, tmpFile)
}
}
f.Close()
// this is the file that was going to be read by crowdsec anyway
logFile = tmpFile
@ -230,15 +187,20 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error {
if err != nil {
return fmt.Errorf("unable to get absolute path of '%s', exiting", logFile)
}
dsn = fmt.Sprintf("file://%s", absolutePath)
lineCount, err := GetLineCountForFile(absolutePath)
lineCount, err := getLineCountForFile(absolutePath)
if err != nil {
return err
}
log.Debugf("file %s has %d lines", absolutePath, lineCount)
if lineCount == 0 {
return fmt.Errorf("the log file is empty: %s", absolutePath)
}
if lineCount > 100 {
log.Warnf("%s contains %d lines. This may take a lot of resources.", absolutePath, lineCount)
}
@ -249,15 +211,19 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error {
}
cmdArgs := []string{"-c", ConfigFilePath, "-type", logType, "-dsn", dsn, "-dump-data", dir, "-no-api"}
if labels != "" {
log.Debugf("adding labels %s", labels)
cmdArgs = append(cmdArgs, "-label", labels)
}
crowdsecCmd := exec.Command(crowdsec, cmdArgs...)
output, err := crowdsecCmd.CombinedOutput()
if err != nil {
fmt.Println(string(output))
return fmt.Errorf("fail to run crowdsec for test: %v", err)
return fmt.Errorf("fail to run crowdsec for test: %w", err)
}
parserDumpFile := filepath.Join(dir, hubtest.ParserResultFileName)
@ -265,12 +231,12 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error {
parserDump, err := dumps.LoadParserDump(parserDumpFile)
if err != nil {
return fmt.Errorf("unable to load parser dump result: %s", err)
return fmt.Errorf("unable to load parser dump result: %w", err)
}
bucketStateDump, err := dumps.LoadBucketPourDump(bucketStateDumpFile)
if err != nil {
return fmt.Errorf("unable to load bucket dump result: %s", err)
return fmt.Errorf("unable to load bucket dump result: %w", err)
}
dumps.DumpTree(*parserDump, *bucketStateDump, opts)

View file

@ -18,6 +18,7 @@ func (p *MachinePassword) Set(v string) error {
if len(v) > 72 {
return errors.New("password too long (max 72 characters)")
}
*p = MachinePassword(v)
return nil

View file

@ -13,13 +13,13 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/cwhub"
)
type cliHub struct{
type cliHub struct {
cfg configGetter
}
func NewCLIHub(getconfig configGetter) *cliHub {
func NewCLIHub(cfg configGetter) *cliHub {
return &cliHub{
cfg: getconfig,
cfg: cfg,
}
}

View file

@ -6,6 +6,7 @@ import (
"fmt"
"net/url"
"os"
"slices"
"sort"
"strings"
@ -13,7 +14,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"gopkg.in/yaml.v2"
"slices"
"github.com/crowdsecurity/go-cs-lib/version"
@ -29,15 +29,27 @@ import (
const LAPIURLPrefix = "v1"
func runLapiStatus(cmd *cobra.Command, args []string) error {
password := strfmt.Password(csConfig.API.Client.Credentials.Password)
apiurl, err := url.Parse(csConfig.API.Client.Credentials.URL)
login := csConfig.API.Client.Credentials.Login
type cliLapi struct {
cfg configGetter
}
func NewCLILapi(cfg configGetter) *cliLapi {
return &cliLapi{
cfg: cfg,
}
}
func (cli *cliLapi) status() error {
cfg := cli.cfg()
password := strfmt.Password(cfg.API.Client.Credentials.Password)
login := cfg.API.Client.Credentials.Login
apiurl, err := url.Parse(cfg.API.Client.Credentials.URL)
if err != nil {
return fmt.Errorf("parsing api url: %w", err)
}
hub, err := require.Hub(csConfig, nil, nil)
hub, err := require.Hub(cfg, nil, nil)
if err != nil {
return err
}
@ -54,13 +66,14 @@ func runLapiStatus(cmd *cobra.Command, args []string) error {
if err != nil {
return fmt.Errorf("init default client: %w", err)
}
t := models.WatcherAuthRequest{
MachineID: &login,
Password: &password,
Scenarios: scenarios,
}
log.Infof("Loaded credentials from %s", csConfig.API.Client.CredentialsFilePath)
log.Infof("Loaded credentials from %s", cfg.API.Client.CredentialsFilePath)
log.Infof("Trying to authenticate with username %s on %s", login, apiurl)
_, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t)
@ -69,26 +82,15 @@ func runLapiStatus(cmd *cobra.Command, args []string) error {
}
log.Infof("You can successfully interact with Local API (LAPI)")
return nil
}
func runLapiRegister(cmd *cobra.Command, args []string) error {
flags := cmd.Flags()
func (cli *cliLapi) register(apiURL string, outputFile string, machine string) error {
var err error
apiURL, err := flags.GetString("url")
if err != nil {
return err
}
outputFile, err := flags.GetString("file")
if err != nil {
return err
}
lapiUser, err := flags.GetString("machine")
if err != nil {
return err
}
lapiUser := machine
cfg := cli.cfg()
if lapiUser == "" {
lapiUser, err = generateID("")
@ -96,12 +98,15 @@ func runLapiRegister(cmd *cobra.Command, args []string) error {
return fmt.Errorf("unable to generate machine id: %w", err)
}
}
password := strfmt.Password(generatePassword(passwordLength))
if apiURL == "" {
if csConfig.API.Client == nil || csConfig.API.Client.Credentials == nil || csConfig.API.Client.Credentials.URL == "" {
if cfg.API.Client == nil || cfg.API.Client.Credentials == nil || cfg.API.Client.Credentials.URL == "" {
return fmt.Errorf("no Local API URL. Please provide it in your configuration or with the -u parameter")
}
apiURL = csConfig.API.Client.Credentials.URL
apiURL = cfg.API.Client.Credentials.URL
}
/*URL needs to end with /, but user doesn't care*/
if !strings.HasSuffix(apiURL, "/") {
@ -111,10 +116,12 @@ func runLapiRegister(cmd *cobra.Command, args []string) error {
if !strings.HasPrefix(apiURL, "http://") && !strings.HasPrefix(apiURL, "https://") {
apiURL = "http://" + apiURL
}
apiurl, err := url.Parse(apiURL)
if err != nil {
return fmt.Errorf("parsing api url: %w", err)
}
_, err = apiclient.RegisterClient(&apiclient.Config{
MachineID: lapiUser,
Password: password,
@ -130,138 +137,142 @@ func runLapiRegister(cmd *cobra.Command, args []string) error {
log.Printf("Successfully registered to Local API (LAPI)")
var dumpFile string
if outputFile != "" {
dumpFile = outputFile
} else if csConfig.API.Client.CredentialsFilePath != "" {
dumpFile = csConfig.API.Client.CredentialsFilePath
} else if cfg.API.Client.CredentialsFilePath != "" {
dumpFile = cfg.API.Client.CredentialsFilePath
} else {
dumpFile = ""
}
apiCfg := csconfig.ApiCredentialsCfg{
Login: lapiUser,
Password: password.String(),
URL: apiURL,
}
apiConfigDump, err := yaml.Marshal(apiCfg)
if err != nil {
return fmt.Errorf("unable to marshal api credentials: %w", err)
}
if dumpFile != "" {
err = os.WriteFile(dumpFile, apiConfigDump, 0o600)
if err != nil {
return fmt.Errorf("write api credentials to '%s' failed: %w", dumpFile, err)
}
log.Printf("Local API credentials written to '%s'", dumpFile)
} else {
fmt.Printf("%s\n", string(apiConfigDump))
}
log.Warning(ReloadMessage())
return nil
}
func NewLapiStatusCmd() *cobra.Command {
func (cli *cliLapi) newStatusCmd() *cobra.Command {
cmdLapiStatus := &cobra.Command{
Use: "status",
Short: "Check authentication to Local API (LAPI)",
Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true,
RunE: runLapiStatus,
RunE: func(cmd *cobra.Command, args []string) error {
return cli.status()
},
}
return cmdLapiStatus
}
func NewLapiRegisterCmd() *cobra.Command {
cmdLapiRegister := &cobra.Command{
func (cli *cliLapi) newRegisterCmd() *cobra.Command {
var (
apiURL string
outputFile string
machine string
)
cmd := &cobra.Command{
Use: "register",
Short: "Register a machine to Local API (LAPI)",
Long: `Register your machine to the Local API (LAPI).
Keep in mind the machine needs to be validated by an administrator on LAPI side to be effective.`,
Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true,
RunE: runLapiRegister,
RunE: func(_ *cobra.Command, _ []string) error {
return cli.register(apiURL, outputFile, machine)
},
}
flags := cmdLapiRegister.Flags()
flags.StringP("url", "u", "", "URL of the API (ie. http://127.0.0.1)")
flags.StringP("file", "f", "", "output file destination")
flags.String("machine", "", "Name of the machine to register with")
flags := cmd.Flags()
flags.StringVarP(&apiURL, "url", "u", "", "URL of the API (ie. http://127.0.0.1)")
flags.StringVarP(&outputFile, "file", "f", "", "output file destination")
flags.StringVar(&machine, "machine", "", "Name of the machine to register with")
return cmdLapiRegister
return cmd
}
func NewLapiCmd() *cobra.Command {
cmdLapi := &cobra.Command{
func (cli *cliLapi) NewCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "lapi [action]",
Short: "Manage interaction with Local API (LAPI)",
Args: cobra.MinimumNArgs(1),
DisableAutoGenTag: true,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
if err := csConfig.LoadAPIClient(); err != nil {
PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
if err := cli.cfg().LoadAPIClient(); err != nil {
return fmt.Errorf("loading api client: %w", err)
}
return nil
},
}
cmdLapi.AddCommand(NewLapiRegisterCmd())
cmdLapi.AddCommand(NewLapiStatusCmd())
cmdLapi.AddCommand(NewLapiContextCmd())
cmd.AddCommand(cli.newRegisterCmd())
cmd.AddCommand(cli.newStatusCmd())
cmd.AddCommand(cli.newContextCmd())
return cmdLapi
return cmd
}
func AddContext(key string, values []string) error {
func (cli *cliLapi) addContext(key string, values []string) error {
cfg := cli.cfg()
if err := alertcontext.ValidateContextExpr(key, values); err != nil {
return fmt.Errorf("invalid context configuration :%s", err)
return fmt.Errorf("invalid context configuration: %w", err)
}
if _, ok := csConfig.Crowdsec.ContextToSend[key]; !ok {
csConfig.Crowdsec.ContextToSend[key] = make([]string, 0)
if _, ok := cfg.Crowdsec.ContextToSend[key]; !ok {
cfg.Crowdsec.ContextToSend[key] = make([]string, 0)
log.Infof("key '%s' added", key)
}
data := csConfig.Crowdsec.ContextToSend[key]
data := cfg.Crowdsec.ContextToSend[key]
for _, val := range values {
if !slices.Contains(data, val) {
log.Infof("value '%s' added to key '%s'", val, key)
data = append(data, val)
}
csConfig.Crowdsec.ContextToSend[key] = data
cfg.Crowdsec.ContextToSend[key] = data
}
if err := csConfig.Crowdsec.DumpContextConfigFile(); err != nil {
if err := cfg.Crowdsec.DumpContextConfigFile(); err != nil {
return err
}
return nil
}
func NewLapiContextCmd() *cobra.Command {
cmdContext := &cobra.Command{
Use: "context [command]",
Short: "Manage context to send with alerts",
DisableAutoGenTag: true,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
if err := csConfig.LoadCrowdsec(); err != nil {
fileNotFoundMessage := fmt.Sprintf("failed to open context file: open %s: no such file or directory", csConfig.Crowdsec.ConsoleContextPath)
if err.Error() != fileNotFoundMessage {
return fmt.Errorf("unable to load CrowdSec agent configuration: %w", err)
}
}
if csConfig.DisableAgent {
return errors.New("agent is disabled and lapi context can only be used on the agent")
}
func (cli *cliLapi) newContextAddCmd() *cobra.Command {
var (
keyToAdd string
valuesToAdd []string
)
return nil
},
Run: func(cmd *cobra.Command, args []string) {
printHelp(cmd)
},
}
var keyToAdd string
var valuesToAdd []string
cmdContextAdd := &cobra.Command{
cmd := &cobra.Command{
Use: "add",
Short: "Add context to send with alerts. You must specify the output key with the expr value you want",
Example: `cscli lapi context add --key source_ip --value evt.Meta.source_ip
@ -269,18 +280,18 @@ cscli lapi context add --key file_source --value evt.Line.Src
cscli lapi context add --value evt.Meta.source_ip --value evt.Meta.target_user
`,
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, args []string) error {
hub, err := require.Hub(csConfig, nil, nil)
RunE: func(_ *cobra.Command, _ []string) error {
hub, err := require.Hub(cli.cfg(), nil, nil)
if err != nil {
return err
}
if err = alertcontext.LoadConsoleContext(csConfig, hub); err != nil {
if err = alertcontext.LoadConsoleContext(cli.cfg(), hub); err != nil {
return fmt.Errorf("while loading context: %w", err)
}
if keyToAdd != "" {
if err := AddContext(keyToAdd, valuesToAdd); err != nil {
if err := cli.addContext(keyToAdd, valuesToAdd); err != nil {
return err
}
return nil
@ -290,7 +301,7 @@ cscli lapi context add --value evt.Meta.source_ip --value evt.Meta.target_user
keySlice := strings.Split(v, ".")
key := keySlice[len(keySlice)-1]
value := []string{v}
if err := AddContext(key, value); err != nil {
if err := cli.addContext(key, value); err != nil {
return err
}
}
@ -298,31 +309,37 @@ cscli lapi context add --value evt.Meta.source_ip --value evt.Meta.target_user
return nil
},
}
cmdContextAdd.Flags().StringVarP(&keyToAdd, "key", "k", "", "The key of the different values to send")
cmdContextAdd.Flags().StringSliceVar(&valuesToAdd, "value", []string{}, "The expr fields to associate with the key")
cmdContextAdd.MarkFlagRequired("value")
cmdContext.AddCommand(cmdContextAdd)
cmdContextStatus := &cobra.Command{
flags := cmd.Flags()
flags.StringVarP(&keyToAdd, "key", "k", "", "The key of the different values to send")
flags.StringSliceVar(&valuesToAdd, "value", []string{}, "The expr fields to associate with the key")
cmd.MarkFlagRequired("value")
return cmd
}
func (cli *cliLapi) newContextStatusCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "status",
Short: "List context to send with alerts",
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, args []string) error {
hub, err := require.Hub(csConfig, nil, nil)
RunE: func(_ *cobra.Command, _ []string) error {
cfg := cli.cfg()
hub, err := require.Hub(cfg, nil, nil)
if err != nil {
return err
}
if err = alertcontext.LoadConsoleContext(csConfig, hub); err != nil {
if err = alertcontext.LoadConsoleContext(cfg, hub); err != nil {
return fmt.Errorf("while loading context: %w", err)
}
if len(csConfig.Crowdsec.ContextToSend) == 0 {
if len(cfg.Crowdsec.ContextToSend) == 0 {
fmt.Println("No context found on this agent. You can use 'cscli lapi context add' to add context to your alerts.")
return nil
}
dump, err := yaml.Marshal(csConfig.Crowdsec.ContextToSend)
dump, err := yaml.Marshal(cfg.Crowdsec.ContextToSend)
if err != nil {
return fmt.Errorf("unable to show context status: %w", err)
}
@ -332,10 +349,14 @@ cscli lapi context add --value evt.Meta.source_ip --value evt.Meta.target_user
return nil
},
}
cmdContext.AddCommand(cmdContextStatus)
return cmd
}
func (cli *cliLapi) newContextDetectCmd() *cobra.Command {
var detectAll bool
cmdContextDetect := &cobra.Command{
cmd := &cobra.Command{
Use: "detect",
Short: "Detect available fields from the installed parsers",
Example: `cscli lapi context detect --all
@ -343,6 +364,7 @@ cscli lapi context detect crowdsecurity/sshd-logs
`,
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, args []string) error {
cfg := cli.cfg()
if !detectAll && len(args) == 0 {
log.Infof("Please provide parsers to detect or --all flag.")
printHelp(cmd)
@ -355,13 +377,13 @@ cscli lapi context detect crowdsecurity/sshd-logs
return fmt.Errorf("failed to init expr helpers: %w", err)
}
hub, err := require.Hub(csConfig, nil, nil)
hub, err := require.Hub(cfg, nil, nil)
if err != nil {
return err
}
csParsers := parser.NewParsers(hub)
if csParsers, err = parser.LoadParsers(csConfig, csParsers); err != nil {
if csParsers, err = parser.LoadParsers(cfg, csParsers); err != nil {
return fmt.Errorf("unable to load parsers: %w", err)
}
@ -418,47 +440,85 @@ cscli lapi context detect crowdsecurity/sshd-logs
return nil
},
}
cmdContextDetect.Flags().BoolVarP(&detectAll, "all", "a", false, "Detect evt field for all installed parser")
cmdContext.AddCommand(cmdContextDetect)
cmd.Flags().BoolVarP(&detectAll, "all", "a", false, "Detect evt field for all installed parser")
cmdContextDelete := &cobra.Command{
return cmd
}
func (cli *cliLapi) newContextDeleteCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "delete",
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
filePath := csConfig.Crowdsec.ConsoleContextPath
filePath := cli.cfg().Crowdsec.ConsoleContextPath
if filePath == "" {
filePath = "the context file"
}
fmt.Printf("Command \"delete\" is deprecated, please manually edit %s.", filePath)
fmt.Printf("Command 'delete' is deprecated, please manually edit %s.", filePath)
return nil
},
}
cmdContext.AddCommand(cmdContextDelete)
return cmdContext
return cmd
}
func detectStaticField(GrokStatics []parser.ExtraField) []string {
func (cli *cliLapi) newContextCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "context [command]",
Short: "Manage context to send with alerts",
DisableAutoGenTag: true,
PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
cfg := cli.cfg()
if err := cfg.LoadCrowdsec(); err != nil {
fileNotFoundMessage := fmt.Sprintf("failed to open context file: open %s: no such file or directory", cfg.Crowdsec.ConsoleContextPath)
if err.Error() != fileNotFoundMessage {
return fmt.Errorf("unable to load CrowdSec agent configuration: %w", err)
}
}
if cfg.DisableAgent {
return errors.New("agent is disabled and lapi context can only be used on the agent")
}
return nil
},
Run: func(cmd *cobra.Command, _ []string) {
printHelp(cmd)
},
}
cmd.AddCommand(cli.newContextAddCmd())
cmd.AddCommand(cli.newContextStatusCmd())
cmd.AddCommand(cli.newContextDetectCmd())
cmd.AddCommand(cli.newContextDeleteCmd())
return cmd
}
func detectStaticField(grokStatics []parser.ExtraField) []string {
ret := make([]string, 0)
for _, static := range GrokStatics {
for _, static := range grokStatics {
if static.Parsed != "" {
fieldName := fmt.Sprintf("evt.Parsed.%s", static.Parsed)
if !slices.Contains(ret, fieldName) {
ret = append(ret, fieldName)
}
}
if static.Meta != "" {
fieldName := fmt.Sprintf("evt.Meta.%s", static.Meta)
if !slices.Contains(ret, fieldName) {
ret = append(ret, fieldName)
}
}
if static.TargetByName != "" {
fieldName := static.TargetByName
if !strings.HasPrefix(fieldName, "evt.") {
fieldName = "evt." + fieldName
}
if !slices.Contains(ret, fieldName) {
ret = append(ret, fieldName)
}
@ -526,6 +586,7 @@ func detectSubNode(node parser.Node, parserCTX parser.UnixParserCtx) []string {
}
}
}
if subnode.Grok.RegexpName != "" {
grokCompiled, err := parserCTX.Grok.Get(subnode.Grok.RegexpName)
if err == nil {

View file

@ -7,6 +7,7 @@ import (
"fmt"
"math/big"
"os"
"slices"
"strings"
"time"
@ -17,7 +18,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"gopkg.in/yaml.v3"
"slices"
"github.com/crowdsecurity/machineid"
@ -45,6 +45,7 @@ func generatePassword(length int) string {
if err != nil {
log.Fatalf("failed getting data from prng for password generation : %s", err)
}
buf[i] = charset[rInt.Int64()]
}
@ -59,12 +60,14 @@ func generateIDPrefix() (string, error) {
if err == nil {
return prefix, nil
}
log.Debugf("failed to get machine-id with usual files: %s", err)
bID, err := uuid.NewRandom()
if err == nil {
return bID.String(), nil
}
return "", fmt.Errorf("generating machine id: %w", err)
}
@ -75,11 +78,14 @@ func generateID(prefix string) (string, error) {
if prefix == "" {
prefix, err = generateIDPrefix()
}
if err != nil {
return "", err
}
prefix = strings.ReplaceAll(prefix, "-", "")[:32]
suffix := generatePassword(16)
return prefix + suffix, nil
}
@ -100,14 +106,14 @@ func getLastHeartbeat(m *ent.Machine) (string, bool) {
return hb, true
}
type cliMachines struct{
db *database.Client
type cliMachines struct {
db *database.Client
cfg configGetter
}
func NewCLIMachines(getconfig configGetter) *cliMachines {
func NewCLIMachines(cfg configGetter) *cliMachines {
return &cliMachines{
cfg: getconfig,
cfg: cfg,
}
}
@ -130,6 +136,7 @@ Note: This command requires database direct access, so is intended to be run on
if err != nil {
return fmt.Errorf("unable to create new database client: %s", err)
}
return nil
},
}
@ -243,7 +250,7 @@ cscli machines add -f- --auto > /tmp/mycreds.yaml`,
func (cli *cliMachines) add(args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error {
var (
err error
err error
machineID string
)
@ -289,6 +296,7 @@ func (cli *cliMachines) add(args []string, machinePassword string, dumpFile stri
if !autoAdd {
return fmt.Errorf("please specify a password with --password or use --auto")
}
machinePassword = generatePassword(passwordLength)
} else if machinePassword == "" && interactive {
qs := &survey.Password{
@ -328,10 +336,10 @@ func (cli *cliMachines) add(args []string, machinePassword string, dumpFile stri
}
if dumpFile != "" && dumpFile != "-" {
err = os.WriteFile(dumpFile, apiConfigDump, 0o600)
if err != nil {
if err = os.WriteFile(dumpFile, apiConfigDump, 0o600); err != nil {
return fmt.Errorf("write api credentials in '%s' failed: %s", dumpFile, err)
}
fmt.Fprintf(os.Stderr, "API credentials written to '%s'.\n", dumpFile)
} else {
fmt.Print(string(apiConfigDump))
@ -340,7 +348,7 @@ func (cli *cliMachines) add(args []string, machinePassword string, dumpFile stri
return nil
}
func (cli *cliMachines) deleteValid(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
func (cli *cliMachines) deleteValid(_ *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
machines, err := cli.db.ListMachines()
if err != nil {
cobra.CompError("unable to list machines " + err.Error())
@ -359,11 +367,11 @@ func (cli *cliMachines) deleteValid(cmd *cobra.Command, args []string, toComplet
func (cli *cliMachines) delete(machines []string) error {
for _, machineID := range machines {
err := cli.db.DeleteWatcher(machineID)
if err != nil {
if err := cli.db.DeleteWatcher(machineID); err != nil {
log.Errorf("unable to delete machine '%s': %s", machineID, err)
return nil
}
log.Infof("machine '%s' deleted successfully", machineID)
}
@ -440,9 +448,9 @@ func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force b
func (cli *cliMachines) newPruneCmd() *cobra.Command {
var (
duration time.Duration
notValidOnly bool
force bool
duration time.Duration
notValidOnly bool
force bool
)
const defaultDuration = 10 * time.Minute
@ -473,6 +481,7 @@ func (cli *cliMachines) validate(machineID string) error {
if err := cli.db.ValidateMachine(machineID); err != nil {
return fmt.Errorf("unable to validate machine '%s': %s", machineID, err)
}
log.Infof("machine '%s' validated successfully", machineID)
return nil

View file

@ -15,45 +15,88 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/fflag"
)
var trace_lvl, dbg_lvl, nfo_lvl, wrn_lvl, err_lvl bool
var ConfigFilePath string
var csConfig *csconfig.Config
var dbClient *database.Client
var outputFormat string
var OutputColor string
type configGetter func() *csconfig.Config
var mergedConfig string
// flagBranch overrides the value in csConfig.Cscli.HubBranch
var flagBranch = ""
type cliRoot struct {
logTrace bool
logDebug bool
logInfo bool
logWarn bool
logErr bool
outputColor string
outputFormat string
// flagBranch overrides the value in csConfig.Cscli.HubBranch
flagBranch string
}
type configGetter func() *csconfig.Config
func newCliRoot() *cliRoot {
return &cliRoot{}
}
func initConfig() {
var err error
// cfg() is a helper function to get the configuration loaded from config.yaml,
// we pass it to subcommands because the file is not read until the Execute() call
func (cli *cliRoot) cfg() *csconfig.Config {
return csConfig
}
if trace_lvl {
log.SetLevel(log.TraceLevel)
} else if dbg_lvl {
log.SetLevel(log.DebugLevel)
} else if nfo_lvl {
log.SetLevel(log.InfoLevel)
} else if wrn_lvl {
log.SetLevel(log.WarnLevel)
} else if err_lvl {
log.SetLevel(log.ErrorLevel)
// wantedLogLevel returns the log level requested in the command line flags.
func (cli *cliRoot) wantedLogLevel() log.Level {
switch {
case cli.logTrace:
return log.TraceLevel
case cli.logDebug:
return log.DebugLevel
case cli.logInfo:
return log.InfoLevel
case cli.logWarn:
return log.WarnLevel
case cli.logErr:
return log.ErrorLevel
default:
return log.InfoLevel
}
}
// loadConfigFor loads the configuration file for the given sub-command.
// If the sub-command does not need it, it returns a default configuration.
func loadConfigFor(command string) (*csconfig.Config, string, error) {
noNeedConfig := []string{
"doc",
"help",
"completion",
"version",
"hubtest",
}
if !slices.Contains(NoNeedConfig, os.Args[1]) {
if !slices.Contains(noNeedConfig, command) {
log.Debugf("Using %s as configuration file", ConfigFilePath)
csConfig, mergedConfig, err = csconfig.NewConfig(ConfigFilePath, false, false, true)
config, merged, err := csconfig.NewConfig(ConfigFilePath, false, false, true)
if err != nil {
log.Fatal(err)
return nil, "", err
}
} else {
csConfig = csconfig.NewDefaultConfig()
return config, merged, nil
}
return csconfig.NewDefaultConfig(), "", nil
}
// initialize is called before the subcommand is executed.
func (cli *cliRoot) initialize() {
var err error
log.SetLevel(cli.wantedLogLevel())
csConfig, mergedConfig, err = loadConfigFor(os.Args[1])
if err != nil {
log.Fatal(err)
}
// recap of the enabled feature flags, because logging
@ -62,12 +105,12 @@ func initConfig() {
log.Debugf("Enabled feature flags: %s", fflist)
}
if flagBranch != "" {
csConfig.Cscli.HubBranch = flagBranch
if cli.flagBranch != "" {
csConfig.Cscli.HubBranch = cli.flagBranch
}
if outputFormat != "" {
csConfig.Cscli.Output = outputFormat
if cli.outputFormat != "" {
csConfig.Cscli.Output = cli.outputFormat
}
if csConfig.Cscli.Output == "" {
@ -85,11 +128,11 @@ func initConfig() {
log.SetLevel(log.ErrorLevel)
}
if OutputColor != "" {
csConfig.Cscli.Color = OutputColor
if cli.outputColor != "" {
csConfig.Cscli.Color = cli.outputColor
if OutputColor != "yes" && OutputColor != "no" && OutputColor != "auto" {
log.Fatalf("output color %s unknown", OutputColor)
if cli.outputColor != "yes" && cli.outputColor != "no" && cli.outputColor != "auto" {
log.Fatalf("output color %s unknown", cli.outputColor)
}
}
}
@ -102,15 +145,25 @@ var validArgs = []string{
"postoverflows", "scenarios", "simulation", "support", "version",
}
var NoNeedConfig = []string{
"doc",
"help",
"completion",
"version",
"hubtest",
func (cli *cliRoot) colorize(cmd *cobra.Command) {
cc.Init(&cc.Config{
RootCmd: cmd,
Headings: cc.Yellow,
Commands: cc.Green + cc.Bold,
CmdShortDescr: cc.Cyan,
Example: cc.Italic,
ExecName: cc.Bold,
Aliases: cc.Bold + cc.Italic,
FlagsDataType: cc.White,
Flags: cc.Green,
FlagsDescr: cc.Cyan,
NoExtraNewlines: true,
NoBottomNewline: true,
})
cmd.SetOut(color.Output)
}
func main() {
func (cli *cliRoot) NewCommand() *cobra.Command {
// set the formatter asap and worry about level later
logFormatter := &log.TextFormatter{TimestampFormat: time.RFC3339, FullTimestamp: true}
log.SetFormatter(logFormatter)
@ -135,31 +188,25 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall
/*TBD examples*/
}
cc.Init(&cc.Config{
RootCmd: cmd,
Headings: cc.Yellow,
Commands: cc.Green + cc.Bold,
CmdShortDescr: cc.Cyan,
Example: cc.Italic,
ExecName: cc.Bold,
Aliases: cc.Bold + cc.Italic,
FlagsDataType: cc.White,
Flags: cc.Green,
FlagsDescr: cc.Cyan,
})
cmd.SetOut(color.Output)
cli.colorize(cmd)
cmd.PersistentFlags().StringVarP(&ConfigFilePath, "config", "c", csconfig.DefaultConfigPath("config.yaml"), "path to crowdsec config file")
cmd.PersistentFlags().StringVarP(&outputFormat, "output", "o", "", "Output format: human, json, raw")
cmd.PersistentFlags().StringVarP(&OutputColor, "color", "", "auto", "Output color: yes, no, auto")
cmd.PersistentFlags().BoolVar(&dbg_lvl, "debug", false, "Set logging to debug")
cmd.PersistentFlags().BoolVar(&nfo_lvl, "info", false, "Set logging to info")
cmd.PersistentFlags().BoolVar(&wrn_lvl, "warning", false, "Set logging to warning")
cmd.PersistentFlags().BoolVar(&err_lvl, "error", false, "Set logging to error")
cmd.PersistentFlags().BoolVar(&trace_lvl, "trace", false, "Set logging to trace")
/*don't sort flags so we can enforce order*/
cmd.Flags().SortFlags = false
cmd.PersistentFlags().StringVar(&flagBranch, "branch", "", "Override hub branch on github")
if err := cmd.PersistentFlags().MarkHidden("branch"); err != nil {
pflags := cmd.PersistentFlags()
pflags.SortFlags = false
pflags.StringVarP(&ConfigFilePath, "config", "c", csconfig.DefaultConfigPath("config.yaml"), "path to crowdsec config file")
pflags.StringVarP(&cli.outputFormat, "output", "o", "", "Output format: human, json, raw")
pflags.StringVarP(&cli.outputColor, "color", "", "auto", "Output color: yes, no, auto")
pflags.BoolVar(&cli.logDebug, "debug", false, "Set logging to debug")
pflags.BoolVar(&cli.logInfo, "info", false, "Set logging to info")
pflags.BoolVar(&cli.logWarn, "warning", false, "Set logging to warning")
pflags.BoolVar(&cli.logErr, "error", false, "Set logging to error")
pflags.BoolVar(&cli.logTrace, "trace", false, "Set logging to trace")
pflags.StringVar(&cli.flagBranch, "branch", "", "Override hub branch on github")
if err := pflags.MarkHidden("branch"); err != nil {
log.Fatalf("failed to hide flag: %s", err)
}
@ -179,38 +226,29 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall
}
if len(os.Args) > 1 {
cobra.OnInitialize(initConfig)
}
/*don't sort flags so we can enforce order*/
cmd.Flags().SortFlags = false
cmd.PersistentFlags().SortFlags = false
// we use a getter because the config is not initialized until the Execute() call
getconfig := func() *csconfig.Config {
return csConfig
cobra.OnInitialize(cli.initialize)
}
cmd.AddCommand(NewCLIDoc().NewCommand(cmd))
cmd.AddCommand(NewCLIVersion().NewCommand())
cmd.AddCommand(NewConfigCmd())
cmd.AddCommand(NewCLIHub(getconfig).NewCommand())
cmd.AddCommand(NewMetricsCmd())
cmd.AddCommand(NewCLIDashboard(getconfig).NewCommand())
cmd.AddCommand(NewCLIDecisions().NewCommand())
cmd.AddCommand(NewCLIAlerts().NewCommand())
cmd.AddCommand(NewCLISimulation(getconfig).NewCommand())
cmd.AddCommand(NewCLIBouncers(getconfig).NewCommand())
cmd.AddCommand(NewCLIMachines(getconfig).NewCommand())
cmd.AddCommand(NewCLIHub(cli.cfg).NewCommand())
cmd.AddCommand(NewCLIMetrics(cli.cfg).NewCommand())
cmd.AddCommand(NewCLIDashboard(cli.cfg).NewCommand())
cmd.AddCommand(NewCLIDecisions(cli.cfg).NewCommand())
cmd.AddCommand(NewCLIAlerts(cli.cfg).NewCommand())
cmd.AddCommand(NewCLISimulation(cli.cfg).NewCommand())
cmd.AddCommand(NewCLIBouncers(cli.cfg).NewCommand())
cmd.AddCommand(NewCLIMachines(cli.cfg).NewCommand())
cmd.AddCommand(NewCLICapi().NewCommand())
cmd.AddCommand(NewLapiCmd())
cmd.AddCommand(NewCLILapi(cli.cfg).NewCommand())
cmd.AddCommand(NewCompletionCmd())
cmd.AddCommand(NewConsoleCmd())
cmd.AddCommand(NewCLIExplain().NewCommand())
cmd.AddCommand(NewCLIConsole(cli.cfg).NewCommand())
cmd.AddCommand(NewCLIExplain(cli.cfg).NewCommand())
cmd.AddCommand(NewCLIHubTest().NewCommand())
cmd.AddCommand(NewCLINotifications().NewCommand())
cmd.AddCommand(NewCLINotifications(cli.cfg).NewCommand())
cmd.AddCommand(NewCLISupport().NewCommand())
cmd.AddCommand(NewCLIPapi(getconfig).NewCommand())
cmd.AddCommand(NewCLIPapi(cli.cfg).NewCommand())
cmd.AddCommand(NewCLICollection().NewCommand())
cmd.AddCommand(NewCLIParser().NewCommand())
cmd.AddCommand(NewCLIScenario().NewCommand())
@ -223,6 +261,11 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall
cmd.AddCommand(NewSetupCmd())
}
return cmd
}
func main() {
cmd := newCliRoot().NewCommand()
if err := cmd.Execute(); err != nil {
log.Fatal(err)
}

View file

@ -2,6 +2,7 @@ package main
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@ -16,11 +17,64 @@ import (
"github.com/spf13/cobra"
"gopkg.in/yaml.v3"
"github.com/crowdsecurity/go-cs-lib/maptools"
"github.com/crowdsecurity/go-cs-lib/trace"
)
// FormatPrometheusMetrics is a complete rip from prom2json
func FormatPrometheusMetrics(out io.Writer, url string, formatType string) error {
type (
statAcquis map[string]map[string]int
statParser map[string]map[string]int
statBucket map[string]map[string]int
statWhitelist map[string]map[string]map[string]int
statLapi map[string]map[string]int
statLapiMachine map[string]map[string]map[string]int
statLapiBouncer map[string]map[string]map[string]int
statLapiDecision map[string]struct {
NonEmpty int
Empty int
}
statDecision map[string]map[string]map[string]int
statAppsecEngine map[string]map[string]int
statAppsecRule map[string]map[string]map[string]int
statAlert map[string]int
statStash map[string]struct {
Type string
Count int
}
)
var (
ErrMissingConfig = errors.New("prometheus section missing, can't show metrics")
ErrMetricsDisabled = errors.New("prometheus is not enabled, can't show metrics")
)
type metricSection interface {
Table(out io.Writer, noUnit bool, showEmpty bool)
Description() (string, string)
}
type metricStore map[string]metricSection
func NewMetricStore() metricStore {
return metricStore{
"acquisition": statAcquis{},
"buckets": statBucket{},
"parsers": statParser{},
"lapi": statLapi{},
"lapi-machine": statLapiMachine{},
"lapi-bouncer": statLapiBouncer{},
"lapi-decisions": statLapiDecision{},
"decisions": statDecision{},
"alerts": statAlert{},
"stash": statStash{},
"appsec-engine": statAppsecEngine{},
"appsec-rule": statAppsecRule{},
"whitelists": statWhitelist{},
}
}
func (ms metricStore) Fetch(url string) error {
mfChan := make(chan *dto.MetricFamily, 1024)
errChan := make(chan error, 1)
@ -33,9 +87,10 @@ func FormatPrometheusMetrics(out io.Writer, url string, formatType string) error
transport.ResponseHeaderTimeout = time.Minute
go func() {
defer trace.CatchPanic("crowdsec/ShowPrometheus")
err := prom2json.FetchMetricFamilies(url, mfChan, transport)
if err != nil {
errChan <- fmt.Errorf("failed to fetch prometheus metrics: %w", err)
errChan <- fmt.Errorf("failed to fetch metrics: %w", err)
return
}
errChan <- nil
@ -50,42 +105,42 @@ func FormatPrometheusMetrics(out io.Writer, url string, formatType string) error
return err
}
log.Debugf("Finished reading prometheus output, %d entries", len(result))
log.Debugf("Finished reading metrics output, %d entries", len(result))
/*walk*/
lapi_decisions_stats := map[string]struct {
NonEmpty int
Empty int
}{}
acquis_stats := map[string]map[string]int{}
parsers_stats := map[string]map[string]int{}
buckets_stats := map[string]map[string]int{}
lapi_stats := map[string]map[string]int{}
lapi_machine_stats := map[string]map[string]map[string]int{}
lapi_bouncer_stats := map[string]map[string]map[string]int{}
decisions_stats := map[string]map[string]map[string]int{}
appsec_engine_stats := map[string]map[string]int{}
appsec_rule_stats := map[string]map[string]map[string]int{}
alerts_stats := map[string]int{}
stash_stats := map[string]struct {
Type string
Count int
}{}
mAcquis := ms["acquisition"].(statAcquis)
mParser := ms["parsers"].(statParser)
mBucket := ms["buckets"].(statBucket)
mLapi := ms["lapi"].(statLapi)
mLapiMachine := ms["lapi-machine"].(statLapiMachine)
mLapiBouncer := ms["lapi-bouncer"].(statLapiBouncer)
mLapiDecision := ms["lapi-decisions"].(statLapiDecision)
mDecision := ms["decisions"].(statDecision)
mAppsecEngine := ms["appsec-engine"].(statAppsecEngine)
mAppsecRule := ms["appsec-rule"].(statAppsecRule)
mAlert := ms["alerts"].(statAlert)
mStash := ms["stash"].(statStash)
mWhitelist := ms["whitelists"].(statWhitelist)
for idx, fam := range result {
if !strings.HasPrefix(fam.Name, "cs_") {
continue
}
log.Tracef("round %d", idx)
for _, m := range fam.Metrics {
metric, ok := m.(prom2json.Metric)
if !ok {
log.Debugf("failed to convert metric to prom2json.Metric")
continue
}
name, ok := metric.Labels["name"]
if !ok {
log.Debugf("no name in Metric %v", metric.Labels)
}
source, ok := metric.Labels["source"]
if !ok {
log.Debugf("no source in Metric %v for %s", metric.Labels, fam.Name)
@ -106,148 +161,89 @@ func FormatPrometheusMetrics(out io.Writer, url string, formatType string) error
origin := metric.Labels["origin"]
action := metric.Labels["action"]
appsecEngine := metric.Labels["appsec_engine"]
appsecRule := metric.Labels["rule_name"]
mtype := metric.Labels["type"]
fval, err := strconv.ParseFloat(value, 32)
if err != nil {
log.Errorf("Unexpected int value %s : %s", value, err)
}
ival := int(fval)
switch fam.Name {
/*buckets*/
//
// buckets
//
case "cs_bucket_created_total":
if _, ok := buckets_stats[name]; !ok {
buckets_stats[name] = make(map[string]int)
}
buckets_stats[name]["instantiation"] += ival
mBucket.Process(name, "instantiation", ival)
case "cs_buckets":
if _, ok := buckets_stats[name]; !ok {
buckets_stats[name] = make(map[string]int)
}
buckets_stats[name]["curr_count"] += ival
mBucket.Process(name, "curr_count", ival)
case "cs_bucket_overflowed_total":
if _, ok := buckets_stats[name]; !ok {
buckets_stats[name] = make(map[string]int)
}
buckets_stats[name]["overflow"] += ival
mBucket.Process(name, "overflow", ival)
case "cs_bucket_poured_total":
if _, ok := buckets_stats[name]; !ok {
buckets_stats[name] = make(map[string]int)
}
if _, ok := acquis_stats[source]; !ok {
acquis_stats[source] = make(map[string]int)
}
buckets_stats[name]["pour"] += ival
acquis_stats[source]["pour"] += ival
mBucket.Process(name, "pour", ival)
mAcquis.Process(source, "pour", ival)
case "cs_bucket_underflowed_total":
if _, ok := buckets_stats[name]; !ok {
buckets_stats[name] = make(map[string]int)
}
buckets_stats[name]["underflow"] += ival
/*acquis*/
mBucket.Process(name, "underflow", ival)
//
// parsers
//
case "cs_parser_hits_total":
if _, ok := acquis_stats[source]; !ok {
acquis_stats[source] = make(map[string]int)
}
acquis_stats[source]["reads"] += ival
mAcquis.Process(source, "reads", ival)
case "cs_parser_hits_ok_total":
if _, ok := acquis_stats[source]; !ok {
acquis_stats[source] = make(map[string]int)
}
acquis_stats[source]["parsed"] += ival
mAcquis.Process(source, "parsed", ival)
case "cs_parser_hits_ko_total":
if _, ok := acquis_stats[source]; !ok {
acquis_stats[source] = make(map[string]int)
}
acquis_stats[source]["unparsed"] += ival
mAcquis.Process(source, "unparsed", ival)
case "cs_node_hits_total":
if _, ok := parsers_stats[name]; !ok {
parsers_stats[name] = make(map[string]int)
}
parsers_stats[name]["hits"] += ival
mParser.Process(name, "hits", ival)
case "cs_node_hits_ok_total":
if _, ok := parsers_stats[name]; !ok {
parsers_stats[name] = make(map[string]int)
}
parsers_stats[name]["parsed"] += ival
mParser.Process(name, "parsed", ival)
case "cs_node_hits_ko_total":
if _, ok := parsers_stats[name]; !ok {
parsers_stats[name] = make(map[string]int)
}
parsers_stats[name]["unparsed"] += ival
mParser.Process(name, "unparsed", ival)
//
// whitelists
//
case "cs_node_wl_hits_total":
mWhitelist.Process(name, reason, "hits", ival)
case "cs_node_wl_hits_ok_total":
mWhitelist.Process(name, reason, "whitelisted", ival)
// track as well whitelisted lines at acquis level
mAcquis.Process(source, "whitelisted", ival)
//
// lapi
//
case "cs_lapi_route_requests_total":
if _, ok := lapi_stats[route]; !ok {
lapi_stats[route] = make(map[string]int)
}
lapi_stats[route][method] += ival
mLapi.Process(route, method, ival)
case "cs_lapi_machine_requests_total":
if _, ok := lapi_machine_stats[machine]; !ok {
lapi_machine_stats[machine] = make(map[string]map[string]int)
}
if _, ok := lapi_machine_stats[machine][route]; !ok {
lapi_machine_stats[machine][route] = make(map[string]int)
}
lapi_machine_stats[machine][route][method] += ival
mLapiMachine.Process(machine, route, method, ival)
case "cs_lapi_bouncer_requests_total":
if _, ok := lapi_bouncer_stats[bouncer]; !ok {
lapi_bouncer_stats[bouncer] = make(map[string]map[string]int)
}
if _, ok := lapi_bouncer_stats[bouncer][route]; !ok {
lapi_bouncer_stats[bouncer][route] = make(map[string]int)
}
lapi_bouncer_stats[bouncer][route][method] += ival
mLapiBouncer.Process(bouncer, route, method, ival)
case "cs_lapi_decisions_ko_total", "cs_lapi_decisions_ok_total":
if _, ok := lapi_decisions_stats[bouncer]; !ok {
lapi_decisions_stats[bouncer] = struct {
NonEmpty int
Empty int
}{}
}
x := lapi_decisions_stats[bouncer]
if fam.Name == "cs_lapi_decisions_ko_total" {
x.Empty += ival
} else if fam.Name == "cs_lapi_decisions_ok_total" {
x.NonEmpty += ival
}
lapi_decisions_stats[bouncer] = x
mLapiDecision.Process(bouncer, fam.Name, ival)
//
// decisions
//
case "cs_active_decisions":
if _, ok := decisions_stats[reason]; !ok {
decisions_stats[reason] = make(map[string]map[string]int)
}
if _, ok := decisions_stats[reason][origin]; !ok {
decisions_stats[reason][origin] = make(map[string]int)
}
decisions_stats[reason][origin][action] += ival
mDecision.Process(reason, origin, action, ival)
case "cs_alerts":
/*if _, ok := alerts_stats[scenario]; !ok {
alerts_stats[scenario] = make(map[string]int)
}*/
alerts_stats[reason] += ival
mAlert.Process(reason, ival)
//
// stash
//
case "cs_cache_size":
stash_stats[name] = struct {
Type string
Count int
}{Type: mtype, Count: ival}
mStash.Process(name, mtype, ival)
//
// appsec
//
case "cs_appsec_reqs_total":
if _, ok := appsec_engine_stats[metric.Labels["appsec_engine"]]; !ok {
appsec_engine_stats[metric.Labels["appsec_engine"]] = make(map[string]int, 0)
}
appsec_engine_stats[metric.Labels["appsec_engine"]]["processed"] = ival
mAppsecEngine.Process(appsecEngine, "processed", ival)
case "cs_appsec_block_total":
if _, ok := appsec_engine_stats[metric.Labels["appsec_engine"]]; !ok {
appsec_engine_stats[metric.Labels["appsec_engine"]] = make(map[string]int, 0)
}
appsec_engine_stats[metric.Labels["appsec_engine"]]["blocked"] = ival
mAppsecEngine.Process(appsecEngine, "blocked", ival)
case "cs_appsec_rule_hits":
appsecEngine := metric.Labels["appsec_engine"]
ruleID := metric.Labels["rule_name"]
if _, ok := appsec_rule_stats[appsecEngine]; !ok {
appsec_rule_stats[appsecEngine] = make(map[string]map[string]int, 0)
}
if _, ok := appsec_rule_stats[appsecEngine][ruleID]; !ok {
appsec_rule_stats[appsecEngine][ruleID] = make(map[string]int, 0)
}
appsec_rule_stats[appsecEngine][ruleID]["triggered"] = ival
mAppsecRule.Process(appsecEngine, appsecRule, "triggered", ival)
default:
log.Debugf("unknown: %+v", fam.Name)
continue
@ -255,46 +251,52 @@ func FormatPrometheusMetrics(out io.Writer, url string, formatType string) error
}
}
if formatType == "human" {
acquisStatsTable(out, acquis_stats)
bucketStatsTable(out, buckets_stats)
parserStatsTable(out, parsers_stats)
lapiStatsTable(out, lapi_stats)
lapiMachineStatsTable(out, lapi_machine_stats)
lapiBouncerStatsTable(out, lapi_bouncer_stats)
lapiDecisionStatsTable(out, lapi_decisions_stats)
decisionStatsTable(out, decisions_stats)
alertStatsTable(out, alerts_stats)
stashStatsTable(out, stash_stats)
appsecMetricsToTable(out, appsec_engine_stats)
appsecRulesToTable(out, appsec_rule_stats)
return nil
return nil
}
type cliMetrics struct {
cfg configGetter
}
func NewCLIMetrics(cfg configGetter) *cliMetrics {
return &cliMetrics{
cfg: cfg,
}
}
func (ms metricStore) Format(out io.Writer, sections []string, formatType string, noUnit bool) error {
// copy only the sections we want
want := map[string]metricSection{}
// if explicitly asking for sections, we want to show empty tables
showEmpty := len(sections) > 0
// if no sections are specified, we want all of them
if len(sections) == 0 {
for section := range ms {
sections = append(sections, section)
}
}
stats := make(map[string]any)
stats["acquisition"] = acquis_stats
stats["buckets"] = buckets_stats
stats["parsers"] = parsers_stats
stats["lapi"] = lapi_stats
stats["lapi_machine"] = lapi_machine_stats
stats["lapi_bouncer"] = lapi_bouncer_stats
stats["lapi_decisions"] = lapi_decisions_stats
stats["decisions"] = decisions_stats
stats["alerts"] = alerts_stats
stats["stash"] = stash_stats
for _, section := range sections {
want[section] = ms[section]
}
switch formatType {
case "human":
for section := range want {
want[section].Table(out, noUnit, showEmpty)
}
case "json":
x, err := json.MarshalIndent(stats, "", " ")
x, err := json.MarshalIndent(want, "", " ")
if err != nil {
return fmt.Errorf("failed to unmarshal metrics : %v", err)
return fmt.Errorf("failed to marshal metrics: %w", err)
}
out.Write(x)
case "raw":
x, err := yaml.Marshal(stats)
x, err := yaml.Marshal(want)
if err != nil {
return fmt.Errorf("failed to unmarshal metrics : %v", err)
return fmt.Errorf("failed to marshal metrics: %w", err)
}
out.Write(x)
default:
@ -304,52 +306,190 @@ func FormatPrometheusMetrics(out io.Writer, url string, formatType string) error
return nil
}
var noUnit bool
func runMetrics(cmd *cobra.Command, args []string) error {
flags := cmd.Flags()
url, err := flags.GetString("url")
if err != nil {
return err
}
func (cli *cliMetrics) show(sections []string, url string, noUnit bool) error {
cfg := cli.cfg()
if url != "" {
csConfig.Cscli.PrometheusUrl = url
cfg.Cscli.PrometheusUrl = url
}
noUnit, err = flags.GetBool("no-unit")
if err != nil {
if cfg.Prometheus == nil {
return ErrMissingConfig
}
if !cfg.Prometheus.Enabled {
return ErrMetricsDisabled
}
ms := NewMetricStore()
if err := ms.Fetch(cfg.Cscli.PrometheusUrl); err != nil {
return err
}
if csConfig.Prometheus == nil {
return fmt.Errorf("prometheus section missing, can't show metrics")
// any section that we don't have in the store is an error
for _, section := range sections {
if _, ok := ms[section]; !ok {
return fmt.Errorf("unknown metrics type: %s", section)
}
}
if !csConfig.Prometheus.Enabled {
return fmt.Errorf("prometheus is not enabled, can't show metrics")
}
if err = FormatPrometheusMetrics(color.Output, csConfig.Cscli.PrometheusUrl, csConfig.Cscli.Output); err != nil {
if err := ms.Format(color.Output, sections, cfg.Cscli.Output, noUnit); err != nil {
return err
}
return nil
}
func NewMetricsCmd() *cobra.Command {
cmdMetrics := &cobra.Command{
Use: "metrics",
Short: "Display crowdsec prometheus metrics.",
Long: `Fetch metrics from the prometheus server and display them in a human-friendly way`,
func (cli *cliMetrics) NewCommand() *cobra.Command {
var (
url string
noUnit bool
)
cmd := &cobra.Command{
Use: "metrics",
Short: "Display crowdsec prometheus metrics.",
Long: `Fetch metrics from a Local API server and display them`,
Example: `# Show all Metrics, skip empty tables (same as "cecli metrics show")
cscli metrics
# Show only some metrics, connect to a different url
cscli metrics --url http://lapi.local:6060/metrics show acquisition parsers
# List available metric types
cscli metrics list`,
Args: cobra.ExactArgs(0),
DisableAutoGenTag: true,
RunE: runMetrics,
RunE: func(cmd *cobra.Command, args []string) error {
return cli.show(nil, url, noUnit)
},
}
flags := cmdMetrics.PersistentFlags()
flags.StringP("url", "u", "", "Prometheus url (http://<ip>:<port>/metrics)")
flags.Bool("no-unit", false, "Show the real number instead of formatted with units")
flags := cmd.Flags()
flags.StringVarP(&url, "url", "u", "", "Prometheus url (http://<ip>:<port>/metrics)")
flags.BoolVar(&noUnit, "no-unit", false, "Show the real number instead of formatted with units")
return cmdMetrics
cmd.AddCommand(cli.newShowCmd())
cmd.AddCommand(cli.newListCmd())
return cmd
}
// expandAlias returns a list of sections. The input can be a list of sections or alias.
func (cli *cliMetrics) expandSectionGroups(args []string) []string {
ret := []string{}
for _, section := range args {
switch section {
case "engine":
ret = append(ret, "acquisition", "parsers", "buckets", "stash", "whitelists")
case "lapi":
ret = append(ret, "alerts", "decisions", "lapi", "lapi-bouncer", "lapi-decisions", "lapi-machine")
case "appsec":
ret = append(ret, "appsec-engine", "appsec-rule")
default:
ret = append(ret, section)
}
}
return ret
}
func (cli *cliMetrics) newShowCmd() *cobra.Command {
var (
url string
noUnit bool
)
cmd := &cobra.Command{
Use: "show [type]...",
Short: "Display all or part of the available metrics.",
Long: `Fetch metrics from a Local API server and display them, optionally filtering on specific types.`,
Example: `# Show all Metrics, skip empty tables
cscli metrics show
# Use an alias: "engine", "lapi" or "appsec" to show a group of metrics
cscli metrics show engine
# Show some specific metrics, show empty tables, connect to a different url
cscli metrics show acquisition parsers buckets stash --url http://lapi.local:6060/metrics
# Show metrics in json format
cscli metrics show acquisition parsers buckets stash -o json`,
// Positional args are optional
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, args []string) error {
args = cli.expandSectionGroups(args)
return cli.show(args, url, noUnit)
},
}
flags := cmd.Flags()
flags.StringVarP(&url, "url", "u", "", "Metrics url (http://<ip>:<port>/metrics)")
flags.BoolVar(&noUnit, "no-unit", false, "Show the real number instead of formatted with units")
return cmd
}
func (cli *cliMetrics) list() error {
type metricType struct {
Type string `json:"type" yaml:"type"`
Title string `json:"title" yaml:"title"`
Description string `json:"description" yaml:"description"`
}
var allMetrics []metricType
ms := NewMetricStore()
for _, section := range maptools.SortedKeys(ms) {
title, description := ms[section].Description()
allMetrics = append(allMetrics, metricType{
Type: section,
Title: title,
Description: description,
})
}
switch cli.cfg().Cscli.Output {
case "human":
t := newTable(color.Output)
t.SetRowLines(true)
t.SetHeaders("Type", "Title", "Description")
for _, metric := range allMetrics {
t.AddRow(metric.Type, metric.Title, metric.Description)
}
t.Render()
case "json":
x, err := json.MarshalIndent(allMetrics, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal metric types: %w", err)
}
fmt.Println(string(x))
case "raw":
x, err := yaml.Marshal(allMetrics)
if err != nil {
return fmt.Errorf("failed to marshal metric types: %w", err)
}
fmt.Println(string(x))
}
return nil
}
func (cli *cliMetrics) newListCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "list",
Short: "List available types of metrics.",
Long: `List available types of metrics.`,
Args: cobra.ExactArgs(0),
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, _ []string) error {
return cli.list()
},
}
return cmd
}

View file

@ -4,22 +4,29 @@ import (
"fmt"
"io"
"sort"
"strconv"
"github.com/aquasecurity/table"
log "github.com/sirupsen/logrus"
"github.com/crowdsecurity/go-cs-lib/maptools"
)
// ErrNilTable means a nil pointer was passed instead of a table instance. This is a programming error.
var ErrNilTable = fmt.Errorf("nil table")
func lapiMetricsToTable(t *table.Table, stats map[string]map[string]map[string]int) int {
// stats: machine -> route -> method -> count
// sort keys to keep consistent order when printing
machineKeys := []string{}
for k := range stats {
machineKeys = append(machineKeys, k)
}
sort.Strings(machineKeys)
numRows := 0
for _, machine := range machineKeys {
// oneRow: route -> method -> count
machineRow := stats[machine]
@ -31,41 +38,77 @@ func lapiMetricsToTable(t *table.Table, stats map[string]map[string]map[string]i
methodName,
}
if count != 0 {
row = append(row, fmt.Sprintf("%d", count))
row = append(row, strconv.Itoa(count))
} else {
row = append(row, "-")
}
t.AddRow(row...)
numRows++
}
}
}
return numRows
}
func metricsToTable(t *table.Table, stats map[string]map[string]int, keys []string) (int, error) {
func wlMetricsToTable(t *table.Table, stats map[string]map[string]map[string]int, noUnit bool) (int, error) {
if t == nil {
return 0, fmt.Errorf("nil table")
return 0, ErrNilTable
}
// sort keys to keep consistent order when printing
sortedKeys := []string{}
for k := range stats {
sortedKeys = append(sortedKeys, k)
}
sort.Strings(sortedKeys)
numRows := 0
for _, alabel := range sortedKeys {
for _, name := range maptools.SortedKeys(stats) {
for _, reason := range maptools.SortedKeys(stats[name]) {
row := []string{
name,
reason,
"-",
"-",
}
for _, action := range maptools.SortedKeys(stats[name][reason]) {
value := stats[name][reason][action]
switch action {
case "whitelisted":
row[3] = strconv.Itoa(value)
case "hits":
row[2] = strconv.Itoa(value)
default:
log.Debugf("unexpected counter '%s' for whitelists = %d", action, value)
}
}
t.AddRow(row...)
numRows++
}
}
return numRows, nil
}
func metricsToTable(t *table.Table, stats map[string]map[string]int, keys []string, noUnit bool) (int, error) {
if t == nil {
return 0, ErrNilTable
}
numRows := 0
for _, alabel := range maptools.SortedKeys(stats) {
astats, ok := stats[alabel]
if !ok {
continue
}
row := []string{
alabel,
}
for _, sl := range keys {
if v, ok := astats[sl]; ok && v != 0 {
numberToShow := fmt.Sprintf("%d", v)
numberToShow := strconv.Itoa(v)
if !noUnit {
numberToShow = formatNumber(v)
}
@ -75,13 +118,29 @@ func metricsToTable(t *table.Table, stats map[string]map[string]int, keys []stri
row = append(row, "-")
}
}
t.AddRow(row...)
numRows++
}
return numRows, nil
}
func bucketStatsTable(out io.Writer, stats map[string]map[string]int) {
func (s statBucket) Description() (string, string) {
return "Bucket Metrics",
`Measure events in different scenarios. Current count is the number of buckets during metrics collection. ` +
`Overflows are past event-producing buckets, while Expired are the ones that didnt receive enough events to Overflow.`
}
func (s statBucket) Process(bucket, metric string, val int) {
if _, ok := s[bucket]; !ok {
s[bucket] = make(map[string]int)
}
s[bucket][metric] += val
}
func (s statBucket) Table(out io.Writer, noUnit bool, showEmpty bool) {
t := newTable(out)
t.SetRowLines(false)
t.SetHeaders("Bucket", "Current Count", "Overflows", "Instantiated", "Poured", "Expired")
@ -89,62 +148,161 @@ func bucketStatsTable(out io.Writer, stats map[string]map[string]int) {
keys := []string{"curr_count", "overflow", "instantiation", "pour", "underflow"}
if numRows, err := metricsToTable(t, stats, keys); err != nil {
if numRows, err := metricsToTable(t, s, keys, noUnit); err != nil {
log.Warningf("while collecting bucket stats: %s", err)
} else if numRows > 0 {
renderTableTitle(out, "\nBucket Metrics:")
} else if numRows > 0 || showEmpty {
title, _ := s.Description()
renderTableTitle(out, "\n"+title+":")
t.Render()
}
}
func acquisStatsTable(out io.Writer, stats map[string]map[string]int) {
func (s statAcquis) Description() (string, string) {
return "Acquisition Metrics",
`Measures the lines read, parsed, and unparsed per datasource. ` +
`Zero read lines indicate a misconfigured or inactive datasource. ` +
`Zero parsed lines mean the parser(s) failed. ` +
`Non-zero parsed lines are fine as crowdsec selects relevant lines.`
}
func (s statAcquis) Process(source, metric string, val int) {
if _, ok := s[source]; !ok {
s[source] = make(map[string]int)
}
s[source][metric] += val
}
func (s statAcquis) Table(out io.Writer, noUnit bool, showEmpty bool) {
t := newTable(out)
t.SetRowLines(false)
t.SetHeaders("Source", "Lines read", "Lines parsed", "Lines unparsed", "Lines poured to bucket")
t.SetHeaders("Source", "Lines read", "Lines parsed", "Lines unparsed", "Lines poured to bucket", "Lines whitelisted")
t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft)
keys := []string{"reads", "parsed", "unparsed", "pour"}
keys := []string{"reads", "parsed", "unparsed", "pour", "whitelisted"}
if numRows, err := metricsToTable(t, stats, keys); err != nil {
if numRows, err := metricsToTable(t, s, keys, noUnit); err != nil {
log.Warningf("while collecting acquis stats: %s", err)
} else if numRows > 0 {
renderTableTitle(out, "\nAcquisition Metrics:")
} else if numRows > 0 || showEmpty {
title, _ := s.Description()
renderTableTitle(out, "\n"+title+":")
t.Render()
}
}
func appsecMetricsToTable(out io.Writer, metrics map[string]map[string]int) {
func (s statAppsecEngine) Description() (string, string) {
return "Appsec Metrics",
`Measures the number of parsed and blocked requests by the AppSec Component.`
}
func (s statAppsecEngine) Process(appsecEngine, metric string, val int) {
if _, ok := s[appsecEngine]; !ok {
s[appsecEngine] = make(map[string]int)
}
s[appsecEngine][metric] += val
}
func (s statAppsecEngine) Table(out io.Writer, noUnit bool, showEmpty bool) {
t := newTable(out)
t.SetRowLines(false)
t.SetHeaders("Appsec Engine", "Processed", "Blocked")
t.SetAlignment(table.AlignLeft, table.AlignLeft)
keys := []string{"processed", "blocked"}
if numRows, err := metricsToTable(t, metrics, keys); err != nil {
if numRows, err := metricsToTable(t, s, keys, noUnit); err != nil {
log.Warningf("while collecting appsec stats: %s", err)
} else if numRows > 0 {
renderTableTitle(out, "\nAppsec Metrics:")
} else if numRows > 0 || showEmpty {
title, _ := s.Description()
renderTableTitle(out, "\n"+title+":")
t.Render()
}
}
func appsecRulesToTable(out io.Writer, metrics map[string]map[string]map[string]int) {
for appsecEngine, appsecEngineRulesStats := range metrics {
func (s statAppsecRule) Description() (string, string) {
return "Appsec Rule Metrics",
`Provides “per AppSec Component” information about the number of matches for loaded AppSec Rules.`
}
func (s statAppsecRule) Process(appsecEngine, appsecRule string, metric string, val int) {
if _, ok := s[appsecEngine]; !ok {
s[appsecEngine] = make(map[string]map[string]int)
}
if _, ok := s[appsecEngine][appsecRule]; !ok {
s[appsecEngine][appsecRule] = make(map[string]int)
}
s[appsecEngine][appsecRule][metric] += val
}
func (s statAppsecRule) Table(out io.Writer, noUnit bool, showEmpty bool) {
for appsecEngine, appsecEngineRulesStats := range s {
t := newTable(out)
t.SetRowLines(false)
t.SetHeaders("Rule ID", "Triggered")
t.SetAlignment(table.AlignLeft, table.AlignLeft)
keys := []string{"triggered"}
if numRows, err := metricsToTable(t, appsecEngineRulesStats, keys); err != nil {
if numRows, err := metricsToTable(t, appsecEngineRulesStats, keys, noUnit); err != nil {
log.Warningf("while collecting appsec rules stats: %s", err)
} else if numRows > 0 {
} else if numRows > 0 || showEmpty {
renderTableTitle(out, fmt.Sprintf("\nAppsec '%s' Rules Metrics:", appsecEngine))
t.Render()
}
}
}
func parserStatsTable(out io.Writer, stats map[string]map[string]int) {
func (s statWhitelist) Description() (string, string) {
return "Whitelist Metrics",
`Tracks the number of events processed and possibly whitelisted by each parser whitelist.`
}
func (s statWhitelist) Process(whitelist, reason, metric string, val int) {
if _, ok := s[whitelist]; !ok {
s[whitelist] = make(map[string]map[string]int)
}
if _, ok := s[whitelist][reason]; !ok {
s[whitelist][reason] = make(map[string]int)
}
s[whitelist][reason][metric] += val
}
func (s statWhitelist) Table(out io.Writer, noUnit bool, showEmpty bool) {
t := newTable(out)
t.SetRowLines(false)
t.SetHeaders("Whitelist", "Reason", "Hits", "Whitelisted")
t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft)
if numRows, err := wlMetricsToTable(t, s, noUnit); err != nil {
log.Warningf("while collecting parsers stats: %s", err)
} else if numRows > 0 || showEmpty {
title, _ := s.Description()
renderTableTitle(out, "\n"+title+":")
t.Render()
}
}
func (s statParser) Description() (string, string) {
return "Parser Metrics",
`Tracks the number of events processed by each parser and indicates success of failure. ` +
`Zero parsed lines means the parer(s) failed. ` +
`Non-zero unparsed lines are fine as crowdsec select relevant lines.`
}
func (s statParser) Process(parser, metric string, val int) {
if _, ok := s[parser]; !ok {
s[parser] = make(map[string]int)
}
s[parser][metric] += val
}
func (s statParser) Table(out io.Writer, noUnit bool, showEmpty bool) {
t := newTable(out)
t.SetRowLines(false)
t.SetHeaders("Parsers", "Hits", "Parsed", "Unparsed")
@ -152,187 +310,296 @@ func parserStatsTable(out io.Writer, stats map[string]map[string]int) {
keys := []string{"hits", "parsed", "unparsed"}
if numRows, err := metricsToTable(t, stats, keys); err != nil {
if numRows, err := metricsToTable(t, s, keys, noUnit); err != nil {
log.Warningf("while collecting parsers stats: %s", err)
} else if numRows > 0 {
renderTableTitle(out, "\nParser Metrics:")
} else if numRows > 0 || showEmpty {
title, _ := s.Description()
renderTableTitle(out, "\n"+title+":")
t.Render()
}
}
func stashStatsTable(out io.Writer, stats map[string]struct {
Type string
Count int
}) {
func (s statStash) Description() (string, string) {
return "Parser Stash Metrics",
`Tracks the status of stashes that might be created by various parsers and scenarios.`
}
func (s statStash) Process(name, mtype string, val int) {
s[name] = struct {
Type string
Count int
}{
Type: mtype,
Count: val,
}
}
func (s statStash) Table(out io.Writer, noUnit bool, showEmpty bool) {
t := newTable(out)
t.SetRowLines(false)
t.SetHeaders("Name", "Type", "Items")
t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft)
// unfortunately, we can't reuse metricsToTable as the structure is too different :/
sortedKeys := []string{}
for k := range stats {
sortedKeys = append(sortedKeys, k)
}
sort.Strings(sortedKeys)
numRows := 0
for _, alabel := range sortedKeys {
astats := stats[alabel]
for _, alabel := range maptools.SortedKeys(s) {
astats := s[alabel]
row := []string{
alabel,
astats.Type,
fmt.Sprintf("%d", astats.Count),
strconv.Itoa(astats.Count),
}
t.AddRow(row...)
numRows++
}
if numRows > 0 {
renderTableTitle(out, "\nParser Stash Metrics:")
if numRows > 0 || showEmpty {
title, _ := s.Description()
renderTableTitle(out, "\n"+title+":")
t.Render()
}
}
func lapiStatsTable(out io.Writer, stats map[string]map[string]int) {
func (s statLapi) Description() (string, string) {
return "Local API Metrics",
`Monitors the requests made to local API routes.`
}
func (s statLapi) Process(route, method string, val int) {
if _, ok := s[route]; !ok {
s[route] = make(map[string]int)
}
s[route][method] += val
}
func (s statLapi) Table(out io.Writer, noUnit bool, showEmpty bool) {
t := newTable(out)
t.SetRowLines(false)
t.SetHeaders("Route", "Method", "Hits")
t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft)
// unfortunately, we can't reuse metricsToTable as the structure is too different :/
sortedKeys := []string{}
for k := range stats {
sortedKeys = append(sortedKeys, k)
}
sort.Strings(sortedKeys)
numRows := 0
for _, alabel := range sortedKeys {
astats := stats[alabel]
for _, alabel := range maptools.SortedKeys(s) {
astats := s[alabel]
subKeys := []string{}
for skey := range astats {
subKeys = append(subKeys, skey)
}
sort.Strings(subKeys)
for _, sl := range subKeys {
row := []string{
alabel,
sl,
fmt.Sprintf("%d", astats[sl]),
strconv.Itoa(astats[sl]),
}
t.AddRow(row...)
numRows++
}
}
if numRows > 0 {
renderTableTitle(out, "\nLocal API Metrics:")
if numRows > 0 || showEmpty {
title, _ := s.Description()
renderTableTitle(out, "\n"+title+":")
t.Render()
}
}
func lapiMachineStatsTable(out io.Writer, stats map[string]map[string]map[string]int) {
func (s statLapiMachine) Description() (string, string) {
return "Local API Machines Metrics",
`Tracks the number of calls to the local API from each registered machine.`
}
func (s statLapiMachine) Process(machine, route, method string, val int) {
if _, ok := s[machine]; !ok {
s[machine] = make(map[string]map[string]int)
}
if _, ok := s[machine][route]; !ok {
s[machine][route] = make(map[string]int)
}
s[machine][route][method] += val
}
func (s statLapiMachine) Table(out io.Writer, noUnit bool, showEmpty bool) {
t := newTable(out)
t.SetRowLines(false)
t.SetHeaders("Machine", "Route", "Method", "Hits")
t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft)
numRows := lapiMetricsToTable(t, stats)
numRows := lapiMetricsToTable(t, s)
if numRows > 0 {
renderTableTitle(out, "\nLocal API Machines Metrics:")
if numRows > 0 || showEmpty {
title, _ := s.Description()
renderTableTitle(out, "\n"+title+":")
t.Render()
}
}
func lapiBouncerStatsTable(out io.Writer, stats map[string]map[string]map[string]int) {
func (s statLapiBouncer) Description() (string, string) {
return "Local API Bouncers Metrics",
`Tracks total hits to remediation component related API routes.`
}
func (s statLapiBouncer) Process(bouncer, route, method string, val int) {
if _, ok := s[bouncer]; !ok {
s[bouncer] = make(map[string]map[string]int)
}
if _, ok := s[bouncer][route]; !ok {
s[bouncer][route] = make(map[string]int)
}
s[bouncer][route][method] += val
}
func (s statLapiBouncer) Table(out io.Writer, noUnit bool, showEmpty bool) {
t := newTable(out)
t.SetRowLines(false)
t.SetHeaders("Bouncer", "Route", "Method", "Hits")
t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft)
numRows := lapiMetricsToTable(t, stats)
numRows := lapiMetricsToTable(t, s)
if numRows > 0 {
renderTableTitle(out, "\nLocal API Bouncers Metrics:")
if numRows > 0 || showEmpty {
title, _ := s.Description()
renderTableTitle(out, "\n"+title+":")
t.Render()
}
}
func lapiDecisionStatsTable(out io.Writer, stats map[string]struct {
NonEmpty int
Empty int
},
) {
func (s statLapiDecision) Description() (string, string) {
return "Local API Bouncers Decisions",
`Tracks the number of empty/non-empty answers from LAPI to bouncers that are working in "live" mode.`
}
func (s statLapiDecision) Process(bouncer, fam string, val int) {
if _, ok := s[bouncer]; !ok {
s[bouncer] = struct {
NonEmpty int
Empty int
}{}
}
x := s[bouncer]
switch fam {
case "cs_lapi_decisions_ko_total":
x.Empty += val
case "cs_lapi_decisions_ok_total":
x.NonEmpty += val
}
s[bouncer] = x
}
func (s statLapiDecision) Table(out io.Writer, noUnit bool, showEmpty bool) {
t := newTable(out)
t.SetRowLines(false)
t.SetHeaders("Bouncer", "Empty answers", "Non-empty answers")
t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft)
numRows := 0
for bouncer, hits := range stats {
for bouncer, hits := range s {
t.AddRow(
bouncer,
fmt.Sprintf("%d", hits.Empty),
fmt.Sprintf("%d", hits.NonEmpty),
strconv.Itoa(hits.Empty),
strconv.Itoa(hits.NonEmpty),
)
numRows++
}
if numRows > 0 {
renderTableTitle(out, "\nLocal API Bouncers Decisions:")
if numRows > 0 || showEmpty {
title, _ := s.Description()
renderTableTitle(out, "\n"+title+":")
t.Render()
}
}
func decisionStatsTable(out io.Writer, stats map[string]map[string]map[string]int) {
func (s statDecision) Description() (string, string) {
return "Local API Decisions",
`Provides information about all currently active decisions. ` +
`Includes both local (crowdsec) and global decisions (CAPI), and lists subscriptions (lists).`
}
func (s statDecision) Process(reason, origin, action string, val int) {
if _, ok := s[reason]; !ok {
s[reason] = make(map[string]map[string]int)
}
if _, ok := s[reason][origin]; !ok {
s[reason][origin] = make(map[string]int)
}
s[reason][origin][action] += val
}
func (s statDecision) Table(out io.Writer, noUnit bool, showEmpty bool) {
t := newTable(out)
t.SetRowLines(false)
t.SetHeaders("Reason", "Origin", "Action", "Count")
t.SetAlignment(table.AlignLeft, table.AlignLeft, table.AlignLeft, table.AlignLeft)
numRows := 0
for reason, origins := range stats {
for reason, origins := range s {
for origin, actions := range origins {
for action, hits := range actions {
t.AddRow(
reason,
origin,
action,
fmt.Sprintf("%d", hits),
strconv.Itoa(hits),
)
numRows++
}
}
}
if numRows > 0 {
renderTableTitle(out, "\nLocal API Decisions:")
if numRows > 0 || showEmpty {
title, _ := s.Description()
renderTableTitle(out, "\n"+title+":")
t.Render()
}
}
func alertStatsTable(out io.Writer, stats map[string]int) {
func (s statAlert) Description() (string, string) {
return "Local API Alerts",
`Tracks the total number of past and present alerts for the installed scenarios.`
}
func (s statAlert) Process(reason string, val int) {
s[reason] += val
}
func (s statAlert) Table(out io.Writer, noUnit bool, showEmpty bool) {
t := newTable(out)
t.SetRowLines(false)
t.SetHeaders("Reason", "Count")
t.SetAlignment(table.AlignLeft, table.AlignLeft)
numRows := 0
for scenario, hits := range stats {
for scenario, hits := range s {
t.AddRow(
scenario,
fmt.Sprintf("%d", hits),
strconv.Itoa(hits),
)
numRows++
}
if numRows > 0 {
renderTableTitle(out, "\nLocal API Alerts:")
if numRows > 0 || showEmpty {
title, _ := s.Description()
renderTableTitle(out, "\n"+title+":")
t.Render()
}
}

View file

@ -23,14 +23,13 @@ import (
"github.com/crowdsecurity/go-cs-lib/ptr"
"github.com/crowdsecurity/go-cs-lib/version"
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
"github.com/crowdsecurity/crowdsec/pkg/apiclient"
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/csplugin"
"github.com/crowdsecurity/crowdsec/pkg/csprofiles"
"github.com/crowdsecurity/crowdsec/pkg/types"
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
"github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/types"
)
type NotificationsCfg struct {
@ -39,13 +38,17 @@ type NotificationsCfg struct {
ids []uint
}
type cliNotifications struct{}
func NewCLINotifications() *cliNotifications {
return &cliNotifications{}
type cliNotifications struct {
cfg configGetter
}
func (cli cliNotifications) NewCommand() *cobra.Command {
func NewCLINotifications(cfg configGetter) *cliNotifications {
return &cliNotifications{
cfg: cfg,
}
}
func (cli *cliNotifications) NewCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "notifications [action]",
Short: "Helper for notification plugin configuration",
@ -53,14 +56,15 @@ func (cli cliNotifications) NewCommand() *cobra.Command {
Args: cobra.MinimumNArgs(1),
Aliases: []string{"notifications", "notification"},
DisableAutoGenTag: true,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
if err := require.LAPI(csConfig); err != nil {
PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
cfg := cli.cfg()
if err := require.LAPI(cfg); err != nil {
return err
}
if err := csConfig.LoadAPIClient(); err != nil {
if err := cfg.LoadAPIClient(); err != nil {
return fmt.Errorf("loading api client: %w", err)
}
if err := require.Notifications(csConfig); err != nil {
if err := require.Notifications(cfg); err != nil {
return err
}
@ -76,67 +80,79 @@ func (cli cliNotifications) NewCommand() *cobra.Command {
return cmd
}
func getPluginConfigs() (map[string]csplugin.PluginConfig, error) {
func (cli *cliNotifications) getPluginConfigs() (map[string]csplugin.PluginConfig, error) {
cfg := cli.cfg()
pcfgs := map[string]csplugin.PluginConfig{}
wf := func(path string, info fs.FileInfo, err error) error {
if info == nil {
return fmt.Errorf("error while traversing directory %s: %w", path, err)
}
name := filepath.Join(csConfig.ConfigPaths.NotificationDir, info.Name()) //Avoid calling info.Name() twice
name := filepath.Join(cfg.ConfigPaths.NotificationDir, info.Name()) //Avoid calling info.Name() twice
if (strings.HasSuffix(name, "yaml") || strings.HasSuffix(name, "yml")) && !(info.IsDir()) {
ts, err := csplugin.ParsePluginConfigFile(name)
if err != nil {
return fmt.Errorf("loading notifification plugin configuration with %s: %w", name, err)
}
for _, t := range ts {
csplugin.SetRequiredFields(&t)
pcfgs[t.Name] = t
}
}
return nil
}
if err := filepath.Walk(csConfig.ConfigPaths.NotificationDir, wf); err != nil {
if err := filepath.Walk(cfg.ConfigPaths.NotificationDir, wf); err != nil {
return nil, fmt.Errorf("while loading notifification plugin configuration: %w", err)
}
return pcfgs, nil
}
func getProfilesConfigs() (map[string]NotificationsCfg, error) {
func (cli *cliNotifications) getProfilesConfigs() (map[string]NotificationsCfg, error) {
cfg := cli.cfg()
// A bit of a tricky stuf now: reconcile profiles and notification plugins
pcfgs, err := getPluginConfigs()
pcfgs, err := cli.getPluginConfigs()
if err != nil {
return nil, err
}
ncfgs := map[string]NotificationsCfg{}
for _, pc := range pcfgs {
ncfgs[pc.Name] = NotificationsCfg{
Config: pc,
}
}
profiles, err := csprofiles.NewProfile(csConfig.API.Server.Profiles)
profiles, err := csprofiles.NewProfile(cfg.API.Server.Profiles)
if err != nil {
return nil, fmt.Errorf("while extracting profiles from configuration: %w", err)
}
for profileID, profile := range profiles {
for _, notif := range profile.Cfg.Notifications {
pc, ok := pcfgs[notif]
if !ok {
return nil, fmt.Errorf("notification plugin '%s' does not exist", notif)
}
tmp, ok := ncfgs[pc.Name]
if !ok {
return nil, fmt.Errorf("notification plugin '%s' does not exist", pc.Name)
}
tmp.Profiles = append(tmp.Profiles, profile.Cfg)
tmp.ids = append(tmp.ids, uint(profileID))
ncfgs[pc.Name] = tmp
}
}
return ncfgs, nil
}
func (cli cliNotifications) NewListCmd() *cobra.Command {
func (cli *cliNotifications) NewListCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "list",
Short: "list active notifications plugins",
@ -144,21 +160,22 @@ func (cli cliNotifications) NewListCmd() *cobra.Command {
Example: `cscli notifications list`,
Args: cobra.ExactArgs(0),
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, arg []string) error {
ncfgs, err := getProfilesConfigs()
RunE: func(_ *cobra.Command, _ []string) error {
cfg := cli.cfg()
ncfgs, err := cli.getProfilesConfigs()
if err != nil {
return fmt.Errorf("can't build profiles configuration: %w", err)
}
if csConfig.Cscli.Output == "human" {
if cfg.Cscli.Output == "human" {
notificationListTable(color.Output, ncfgs)
} else if csConfig.Cscli.Output == "json" {
} else if cfg.Cscli.Output == "json" {
x, err := json.MarshalIndent(ncfgs, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal notification configuration: %w", err)
}
fmt.Printf("%s", string(x))
} else if csConfig.Cscli.Output == "raw" {
} else if cfg.Cscli.Output == "raw" {
csvwriter := csv.NewWriter(os.Stdout)
err := csvwriter.Write([]string{"Name", "Type", "Profile name"})
if err != nil {
@ -176,6 +193,7 @@ func (cli cliNotifications) NewListCmd() *cobra.Command {
}
csvwriter.Flush()
}
return nil
},
}
@ -183,7 +201,7 @@ func (cli cliNotifications) NewListCmd() *cobra.Command {
return cmd
}
func (cli cliNotifications) NewInspectCmd() *cobra.Command {
func (cli *cliNotifications) NewInspectCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "inspect",
Short: "Inspect active notifications plugin configuration",
@ -191,36 +209,32 @@ func (cli cliNotifications) NewInspectCmd() *cobra.Command {
Example: `cscli notifications inspect <plugin_name>`,
Args: cobra.ExactArgs(1),
DisableAutoGenTag: true,
PreRunE: func(cmd *cobra.Command, args []string) error {
if args[0] == "" {
return fmt.Errorf("please provide a plugin name to inspect")
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
ncfgs, err := getProfilesConfigs()
RunE: func(_ *cobra.Command, args []string) error {
cfg := cli.cfg()
ncfgs, err := cli.getProfilesConfigs()
if err != nil {
return fmt.Errorf("can't build profiles configuration: %w", err)
}
cfg, ok := ncfgs[args[0]]
ncfg, ok := ncfgs[args[0]]
if !ok {
return fmt.Errorf("plugin '%s' does not exist or is not active", args[0])
}
if csConfig.Cscli.Output == "human" || csConfig.Cscli.Output == "raw" {
fmt.Printf(" - %15s: %15s\n", "Type", cfg.Config.Type)
fmt.Printf(" - %15s: %15s\n", "Name", cfg.Config.Name)
fmt.Printf(" - %15s: %15s\n", "Timeout", cfg.Config.TimeOut)
fmt.Printf(" - %15s: %15s\n", "Format", cfg.Config.Format)
for k, v := range cfg.Config.Config {
if cfg.Cscli.Output == "human" || cfg.Cscli.Output == "raw" {
fmt.Printf(" - %15s: %15s\n", "Type", ncfg.Config.Type)
fmt.Printf(" - %15s: %15s\n", "Name", ncfg.Config.Name)
fmt.Printf(" - %15s: %15s\n", "Timeout", ncfg.Config.TimeOut)
fmt.Printf(" - %15s: %15s\n", "Format", ncfg.Config.Format)
for k, v := range ncfg.Config.Config {
fmt.Printf(" - %15s: %15v\n", k, v)
}
} else if csConfig.Cscli.Output == "json" {
} else if cfg.Cscli.Output == "json" {
x, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal notification configuration: %w", err)
}
fmt.Printf("%s", string(x))
}
return nil
},
}
@ -228,12 +242,13 @@ func (cli cliNotifications) NewInspectCmd() *cobra.Command {
return cmd
}
func (cli cliNotifications) NewTestCmd() *cobra.Command {
func (cli *cliNotifications) NewTestCmd() *cobra.Command {
var (
pluginBroker csplugin.PluginBroker
pluginTomb tomb.Tomb
alertOverride string
)
cmd := &cobra.Command{
Use: "test [plugin name]",
Short: "send a generic test alert to notification plugin",
@ -241,25 +256,26 @@ func (cli cliNotifications) NewTestCmd() *cobra.Command {
Example: `cscli notifications test [plugin_name]`,
Args: cobra.ExactArgs(1),
DisableAutoGenTag: true,
PreRunE: func(cmd *cobra.Command, args []string) error {
pconfigs, err := getPluginConfigs()
PreRunE: func(_ *cobra.Command, args []string) error {
cfg := cli.cfg()
pconfigs, err := cli.getPluginConfigs()
if err != nil {
return fmt.Errorf("can't build profiles configuration: %w", err)
}
cfg, ok := pconfigs[args[0]]
pcfg, ok := pconfigs[args[0]]
if !ok {
return fmt.Errorf("plugin name: '%s' does not exist", args[0])
}
//Create a single profile with plugin name as notification name
return pluginBroker.Init(csConfig.PluginConfig, []*csconfig.ProfileCfg{
return pluginBroker.Init(cfg.PluginConfig, []*csconfig.ProfileCfg{
{
Notifications: []string{
cfg.Name,
pcfg.Name,
},
},
}, csConfig.ConfigPaths)
}, cfg.ConfigPaths)
},
RunE: func(cmd *cobra.Command, args []string) error {
RunE: func(_ *cobra.Command, _ []string) error {
pluginTomb.Go(func() error {
pluginBroker.Run(&pluginTomb)
return nil
@ -298,13 +314,16 @@ func (cli cliNotifications) NewTestCmd() *cobra.Command {
if err := yaml.Unmarshal([]byte(alertOverride), alert); err != nil {
return fmt.Errorf("failed to unmarshal alert override: %w", err)
}
pluginBroker.PluginChannel <- csplugin.ProfileAlert{
ProfileID: uint(0),
Alert: alert,
}
//time.Sleep(2 * time.Second) // There's no mechanism to ensure notification has been sent
pluginTomb.Kill(fmt.Errorf("terminating"))
pluginTomb.Wait()
return nil
},
}
@ -313,9 +332,11 @@ func (cli cliNotifications) NewTestCmd() *cobra.Command {
return cmd
}
func (cli cliNotifications) NewReinjectCmd() *cobra.Command {
var alertOverride string
var alert *models.Alert
func (cli *cliNotifications) NewReinjectCmd() *cobra.Command {
var (
alertOverride string
alert *models.Alert
)
cmd := &cobra.Command{
Use: "reinject",
@ -328,25 +349,30 @@ cscli notifications reinject <alert_id> -a '{"remediation": true,"scenario":"not
`,
Args: cobra.ExactArgs(1),
DisableAutoGenTag: true,
PreRunE: func(cmd *cobra.Command, args []string) error {
PreRunE: func(_ *cobra.Command, args []string) error {
var err error
alert, err = FetchAlertFromArgString(args[0])
alert, err = cli.fetchAlertFromArgString(args[0])
if err != nil {
return err
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
RunE: func(_ *cobra.Command, _ []string) error {
var (
pluginBroker csplugin.PluginBroker
pluginTomb tomb.Tomb
)
cfg := cli.cfg()
if alertOverride != "" {
if err := json.Unmarshal([]byte(alertOverride), alert); err != nil {
return fmt.Errorf("can't unmarshal data in the alert flag: %w", err)
}
}
err := pluginBroker.Init(csConfig.PluginConfig, csConfig.API.Server.Profiles, csConfig.ConfigPaths)
err := pluginBroker.Init(cfg.PluginConfig, cfg.API.Server.Profiles, cfg.ConfigPaths)
if err != nil {
return fmt.Errorf("can't initialize plugins: %w", err)
}
@ -356,7 +382,7 @@ cscli notifications reinject <alert_id> -a '{"remediation": true,"scenario":"not
return nil
})
profiles, err := csprofiles.NewProfile(csConfig.API.Server.Profiles)
profiles, err := csprofiles.NewProfile(cfg.API.Server.Profiles)
if err != nil {
return fmt.Errorf("cannot extract profiles from configuration: %w", err)
}
@ -382,9 +408,9 @@ cscli notifications reinject <alert_id> -a '{"remediation": true,"scenario":"not
default:
time.Sleep(50 * time.Millisecond)
log.Info("sleeping\n")
}
}
if profile.Cfg.OnSuccess == "break" {
log.Infof("The profile %s contains a 'on_success: break' so bailing out", profile.Cfg.Name)
break
@ -393,6 +419,7 @@ cscli notifications reinject <alert_id> -a '{"remediation": true,"scenario":"not
//time.Sleep(2 * time.Second) // There's no mechanism to ensure notification has been sent
pluginTomb.Kill(fmt.Errorf("terminating"))
pluginTomb.Wait()
return nil
},
}
@ -401,18 +428,22 @@ cscli notifications reinject <alert_id> -a '{"remediation": true,"scenario":"not
return cmd
}
func FetchAlertFromArgString(toParse string) (*models.Alert, error) {
func (cli *cliNotifications) fetchAlertFromArgString(toParse string) (*models.Alert, error) {
cfg := cli.cfg()
id, err := strconv.Atoi(toParse)
if err != nil {
return nil, fmt.Errorf("bad alert id %s", toParse)
}
apiURL, err := url.Parse(csConfig.API.Client.Credentials.URL)
apiURL, err := url.Parse(cfg.API.Client.Credentials.URL)
if err != nil {
return nil, fmt.Errorf("error parsing the URL of the API: %w", err)
}
client, err := apiclient.NewClient(&apiclient.Config{
MachineID: csConfig.API.Client.Credentials.Login,
Password: strfmt.Password(csConfig.API.Client.Credentials.Password),
MachineID: cfg.API.Client.Credentials.Login,
Password: strfmt.Password(cfg.API.Client.Credentials.Password),
UserAgent: fmt.Sprintf("crowdsec/%s", version.String()),
URL: apiURL,
VersionPrefix: "v1",
@ -420,9 +451,11 @@ func FetchAlertFromArgString(toParse string) (*models.Alert, error) {
if err != nil {
return nil, fmt.Errorf("error creating the client for the API: %w", err)
}
alert, _, err := client.Alerts.GetByID(context.Background(), id)
if err != nil {
return nil, fmt.Errorf("can't find alert with id %d: %w", id, err)
}
return alert, nil
}

View file

@ -10,19 +10,18 @@ import (
"github.com/crowdsecurity/go-cs-lib/ptr"
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
"github.com/crowdsecurity/crowdsec/pkg/apiserver"
"github.com/crowdsecurity/crowdsec/pkg/database"
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
)
type cliPapi struct {
cfg configGetter
}
func NewCLIPapi(getconfig configGetter) *cliPapi {
func NewCLIPapi(cfg configGetter) *cliPapi {
return &cliPapi{
cfg: getconfig,
cfg: cfg,
}
}
@ -32,7 +31,7 @@ func (cli *cliPapi) NewCommand() *cobra.Command {
Short: "Manage interaction with Polling API (PAPI)",
Args: cobra.MinimumNArgs(1),
DisableAutoGenTag: true,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
cfg := cli.cfg()
if err := require.LAPI(cfg); err != nil {
return err
@ -43,6 +42,7 @@ func (cli *cliPapi) NewCommand() *cobra.Command {
if err := require.PAPI(cfg); err != nil {
return err
}
return nil
},
}
@ -59,7 +59,7 @@ func (cli *cliPapi) NewStatusCmd() *cobra.Command {
Short: "Get status of the Polling API",
Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, args []string) error {
RunE: func(_ *cobra.Command, _ []string) error {
var err error
cfg := cli.cfg()
dbClient, err = database.NewClient(cfg.DbConfig)
@ -111,7 +111,7 @@ func (cli *cliPapi) NewSyncCmd() *cobra.Command {
Short: "Sync with the Polling API, pulling all non-expired orders for the instance",
Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, args []string) error {
RunE: func(_ *cobra.Command, _ []string) error {
var err error
cfg := cli.cfg()
t := tomb.Tomb{}

View file

@ -3,23 +3,23 @@ package main
import (
"fmt"
"os"
"slices"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"gopkg.in/yaml.v2"
"slices"
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
"github.com/crowdsecurity/crowdsec/pkg/cwhub"
)
type cliSimulation struct{
type cliSimulation struct {
cfg configGetter
}
func NewCLISimulation(getconfig configGetter) *cliSimulation {
func NewCLISimulation(cfg configGetter) *cliSimulation {
return &cliSimulation{
cfg: getconfig,
cfg: cfg,
}
}
@ -38,6 +38,7 @@ cscli simulation disable crowdsecurity/ssh-bf`,
if cli.cfg().Cscli.SimulationConfig == nil {
return fmt.Errorf("no simulation configured")
}
return nil
},
PersistentPostRun: func(cmd *cobra.Command, _ []string) {
@ -211,14 +212,17 @@ func (cli *cliSimulation) enableGlobalSimulation() error {
func (cli *cliSimulation) dumpSimulationFile() error {
cfg := cli.cfg()
newConfigSim, err := yaml.Marshal(cfg.Cscli.SimulationConfig)
if err != nil {
return fmt.Errorf("unable to marshal simulation configuration: %s", err)
}
err = os.WriteFile(cfg.ConfigPaths.SimulationFilePath, newConfigSim, 0o644)
if err != nil {
return fmt.Errorf("write simulation config in '%s' failed: %s", cfg.ConfigPaths.SimulationFilePath, err)
}
log.Debugf("updated simulation file %s", cfg.ConfigPaths.SimulationFilePath)
return nil
@ -230,16 +234,19 @@ func (cli *cliSimulation) disableGlobalSimulation() error {
*cfg.Cscli.SimulationConfig.Simulation = false
cfg.Cscli.SimulationConfig.Exclusions = []string{}
newConfigSim, err := yaml.Marshal(cfg.Cscli.SimulationConfig)
if err != nil {
return fmt.Errorf("unable to marshal new simulation configuration: %s", err)
}
err = os.WriteFile(cfg.ConfigPaths.SimulationFilePath, newConfigSim, 0o644)
if err != nil {
return fmt.Errorf("unable to write new simulation config in '%s' : %s", cfg.ConfigPaths.SimulationFilePath, err)
return fmt.Errorf("unable to write new simulation config in '%s': %s", cfg.ConfigPaths.SimulationFilePath, err)
}
log.Printf("global simulation: disabled")
return nil
}
@ -249,10 +256,13 @@ func (cli *cliSimulation) status() {
log.Printf("global simulation: disabled (configuration file is missing)")
return
}
if *cfg.Cscli.SimulationConfig.Simulation {
log.Println("global simulation: enabled")
if len(cfg.Cscli.SimulationConfig.Exclusions) > 0 {
log.Println("Scenarios not in simulation mode :")
for _, scenario := range cfg.Cscli.SimulationConfig.Exclusions {
log.Printf(" - %s", scenario)
}

View file

@ -66,19 +66,25 @@ func collectMetrics() ([]byte, []byte, error) {
}
humanMetrics := bytes.NewBuffer(nil)
err := FormatPrometheusMetrics(humanMetrics, csConfig.Cscli.PrometheusUrl, "human")
if err != nil {
return nil, nil, fmt.Errorf("could not fetch promtheus metrics: %s", err)
ms := NewMetricStore()
if err := ms.Fetch(csConfig.Cscli.PrometheusUrl); err != nil {
return nil, nil, fmt.Errorf("could not fetch prometheus metrics: %s", err)
}
if err := ms.Format(humanMetrics, nil, "human", false); err != nil {
return nil, nil, err
}
req, err := http.NewRequest(http.MethodGet, csConfig.Cscli.PrometheusUrl, nil)
if err != nil {
return nil, nil, fmt.Errorf("could not create requests to prometheus endpoint: %s", err)
}
client := &http.Client{}
resp, err := client.Do(req)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("could not get metrics from prometheus endpoint: %s", err)
}
@ -100,17 +106,20 @@ func collectVersion() []byte {
func collectFeatures() []byte {
log.Info("Collecting feature flags")
enabledFeatures := fflag.Crowdsec.GetEnabledFeatures()
w := bytes.NewBuffer(nil)
for _, k := range enabledFeatures {
fmt.Fprintf(w, "%s\n", k)
}
return w.Bytes()
}
func collectOSInfo() ([]byte, error) {
log.Info("Collecting OS info")
info, err := osinfo.GetOSInfo()
if err != nil {
@ -133,6 +142,7 @@ func collectHubItems(hub *cwhub.Hub, itemType string) []byte {
var err error
out := bytes.NewBuffer(nil)
log.Infof("Collecting %s list", itemType)
items := make(map[string][]*cwhub.Item)
@ -144,26 +154,33 @@ func collectHubItems(hub *cwhub.Hub, itemType string) []byte {
if err := listItems(out, []string{itemType}, items, false); err != nil {
log.Warnf("could not collect %s list: %s", itemType, err)
}
return out.Bytes()
}
func collectBouncers(dbClient *database.Client) ([]byte, error) {
out := bytes.NewBuffer(nil)
bouncers, err := dbClient.ListBouncers()
if err != nil {
return nil, fmt.Errorf("unable to list bouncers: %s", err)
}
getBouncersTable(out, bouncers)
return out.Bytes(), nil
}
func collectAgents(dbClient *database.Client) ([]byte, error) {
out := bytes.NewBuffer(nil)
machines, err := dbClient.ListMachines()
if err != nil {
return nil, fmt.Errorf("unable to list machines: %s", err)
}
getAgentsTable(out, machines)
return out.Bytes(), nil
}
@ -171,12 +188,14 @@ func collectAPIStatus(login string, password string, endpoint string, prefix str
if csConfig.API.Client == nil || csConfig.API.Client.Credentials == nil {
return []byte("No agent credentials found, are we LAPI ?")
}
pwd := strfmt.Password(password)
apiurl, err := url.Parse(endpoint)
pwd := strfmt.Password(password)
apiurl, err := url.Parse(endpoint)
if err != nil {
return []byte(fmt.Sprintf("cannot parse API URL: %s", err))
}
scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS)
if err != nil {
return []byte(fmt.Sprintf("could not collect scenarios: %s", err))
@ -189,6 +208,7 @@ func collectAPIStatus(login string, password string, endpoint string, prefix str
if err != nil {
return []byte(fmt.Sprintf("could not init client: %s", err))
}
t := models.WatcherAuthRequest{
MachineID: &login,
Password: &pwd,
@ -205,6 +225,7 @@ func collectAPIStatus(login string, password string, endpoint string, prefix str
func collectCrowdsecConfig() []byte {
log.Info("Collecting crowdsec config")
config, err := os.ReadFile(*csConfig.FilePath)
if err != nil {
return []byte(fmt.Sprintf("could not read config file: %s", err))
@ -217,15 +238,18 @@ func collectCrowdsecConfig() []byte {
func collectCrowdsecProfile() []byte {
log.Info("Collecting crowdsec profile")
config, err := os.ReadFile(csConfig.API.Server.ProfilesPath)
if err != nil {
return []byte(fmt.Sprintf("could not read profile file: %s", err))
}
return config
}
func collectAcquisitionConfig() map[string][]byte {
log.Info("Collecting acquisition config")
ret := make(map[string][]byte)
for _, filename := range csConfig.Crowdsec.AcquisitionFiles {
@ -287,7 +311,7 @@ cscli support dump -f /tmp/crowdsec-support.zip
`,
Args: cobra.NoArgs,
DisableAutoGenTag: true,
Run: func(cmd *cobra.Command, args []string) {
Run: func(_ *cobra.Command, _ []string) {
var err error
var skipHub, skipDB, skipCAPI, skipLAPI, skipAgent bool
infos := map[string][]byte{
@ -307,13 +331,13 @@ cscli support dump -f /tmp/crowdsec-support.zip
infos[SUPPORT_AGENTS_PATH] = []byte(err.Error())
}
if err := csConfig.LoadAPIServer(true); err != nil {
if err = csConfig.LoadAPIServer(true); err != nil {
log.Warnf("could not load LAPI, skipping CAPI check")
skipLAPI = true
infos[SUPPORT_CAPI_STATUS_PATH] = []byte(err.Error())
}
if err := csConfig.LoadCrowdsec(); err != nil {
if err = csConfig.LoadCrowdsec(); err != nil {
log.Warnf("could not load agent config, skipping crowdsec config check")
skipAgent = true
}
@ -399,7 +423,6 @@ cscli support dump -f /tmp/crowdsec-support.zip
}
if !skipAgent {
acquis := collectAcquisitionConfig()
for filename, content := range acquis {

View file

@ -25,6 +25,7 @@ func manageCliDecisionAlerts(ip *string, ipRange *string, scope *string, value *
return fmt.Errorf("%s isn't a valid range", *ipRange)
}
}
if *ip != "" {
ipRepr := net.ParseIP(*ip)
if ipRepr == nil {
@ -32,7 +33,7 @@ func manageCliDecisionAlerts(ip *string, ipRange *string, scope *string, value *
}
}
//avoid confusion on scope (ip vs Ip and range vs Range)
// avoid confusion on scope (ip vs Ip and range vs Range)
switch strings.ToLower(*scope) {
case "ip":
*scope = types.Ip
@ -43,6 +44,7 @@ func manageCliDecisionAlerts(ip *string, ipRange *string, scope *string, value *
case "as":
*scope = types.AS
}
return nil
}

View file

@ -56,7 +56,8 @@ func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) {
return apiServer, nil
}
func serveAPIServer(apiServer *apiserver.APIServer, apiReady chan bool) {
func serveAPIServer(apiServer *apiserver.APIServer) {
apiReady := make(chan bool, 1)
apiTomb.Go(func() error {
defer trace.CatchPanic("crowdsec/serveAPIServer")
go func() {
@ -80,6 +81,7 @@ func serveAPIServer(apiServer *apiserver.APIServer, apiReady chan bool) {
}
return nil
})
<-apiReady
}
func hasPlugins(profiles []*csconfig.ProfileCfg) bool {

View file

@ -1,6 +1,7 @@
package main
import (
"context"
"fmt"
"os"
"path/filepath"
@ -13,8 +14,8 @@ import (
"github.com/crowdsecurity/go-cs-lib/trace"
"github.com/crowdsecurity/crowdsec/pkg/acquisition"
"github.com/crowdsecurity/crowdsec/pkg/appsec"
"github.com/crowdsecurity/crowdsec/pkg/alertcontext"
"github.com/crowdsecurity/crowdsec/pkg/appsec"
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/cwhub"
leaky "github.com/crowdsecurity/crowdsec/pkg/leakybucket"
@ -56,63 +57,86 @@ func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.H
//start go-routines for parsing, buckets pour and outputs.
parserWg := &sync.WaitGroup{}
parsersTomb.Go(func() error {
parserWg.Add(1)
for i := 0; i < cConfig.Crowdsec.ParserRoutinesCount; i++ {
parsersTomb.Go(func() error {
defer trace.CatchPanic("crowdsec/runParse")
if err := runParse(inputLineChan, inputEventChan, *parsers.Ctx, parsers.Nodes); err != nil { //this error will never happen as parser.Parse is not able to return errors
log.Fatalf("starting parse error : %s", err)
return err
}
return nil
})
}
parserWg.Done()
return nil
})
parserWg.Wait()
bucketWg := &sync.WaitGroup{}
bucketsTomb.Go(func() error {
bucketWg.Add(1)
/*restore previous state as well if present*/
if cConfig.Crowdsec.BucketStateFile != "" {
log.Warningf("Restoring buckets state from %s", cConfig.Crowdsec.BucketStateFile)
if err := leaky.LoadBucketsState(cConfig.Crowdsec.BucketStateFile, buckets, holders); err != nil {
return fmt.Errorf("unable to restore buckets : %s", err)
return fmt.Errorf("unable to restore buckets: %w", err)
}
}
for i := 0; i < cConfig.Crowdsec.BucketsRoutinesCount; i++ {
bucketsTomb.Go(func() error {
defer trace.CatchPanic("crowdsec/runPour")
if err := runPour(inputEventChan, holders, buckets, cConfig); err != nil {
log.Fatalf("starting pour error : %s", err)
return err
}
return nil
})
}
bucketWg.Done()
return nil
})
bucketWg.Wait()
apiClient, err := AuthenticatedLAPIClient(*cConfig.API.Client.Credentials, hub)
if err != nil {
return err
}
log.Debugf("Starting HeartBeat service")
apiClient.HeartBeat.StartHeartBeat(context.Background(), &outputsTomb)
outputWg := &sync.WaitGroup{}
outputsTomb.Go(func() error {
outputWg.Add(1)
for i := 0; i < cConfig.Crowdsec.OutputRoutinesCount; i++ {
outputsTomb.Go(func() error {
defer trace.CatchPanic("crowdsec/runOutput")
if err := runOutput(inputEventChan, outputEventChan, buckets, *parsers.Povfwctx, parsers.Povfwnodes, *cConfig.API.Client.Credentials, hub); err != nil {
if err := runOutput(inputEventChan, outputEventChan, buckets, *parsers.Povfwctx, parsers.Povfwnodes, apiClient); err != nil {
log.Fatalf("starting outputs error : %s", err)
return err
}
return nil
})
}
outputWg.Done()
return nil
})
outputWg.Wait()
@ -122,16 +146,16 @@ func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.H
if cConfig.Prometheus.Level == "aggregated" {
aggregated = true
}
if err := acquisition.GetMetrics(dataSources, aggregated); err != nil {
return fmt.Errorf("while fetching prometheus metrics for datasources: %w", err)
}
}
log.Info("Starting processing data")
if err := acquisition.StartAcquisition(dataSources, inputLineChan, &acquisTomb); err != nil {
log.Fatalf("starting acquisition error : %s", err)
return err
return fmt.Errorf("starting acquisition error: %w", err)
}
return nil
@ -140,11 +164,13 @@ func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.H
func serveCrowdsec(parsers *parser.Parsers, cConfig *csconfig.Config, hub *cwhub.Hub, agentReady chan bool) {
crowdsecTomb.Go(func() error {
defer trace.CatchPanic("crowdsec/serveCrowdsec")
go func() {
defer trace.CatchPanic("crowdsec/runCrowdsec")
// this logs every time, even at config reload
log.Debugf("running agent after %s ms", time.Since(crowdsecT0))
agentReady <- true
if err := runCrowdsec(cConfig, parsers, hub); err != nil {
log.Fatalf("unable to start crowdsec routines: %s", err)
}
@ -156,16 +182,20 @@ func serveCrowdsec(parsers *parser.Parsers, cConfig *csconfig.Config, hub *cwhub
*/
waitOnTomb()
log.Debugf("Shutting down crowdsec routines")
if err := ShutdownCrowdsecRoutines(); err != nil {
log.Fatalf("unable to shutdown crowdsec routines: %s", err)
}
log.Debugf("everything is dead, return crowdsecTomb")
if dumpStates {
dumpParserState()
dumpOverflowState()
dumpBucketsPour()
os.Exit(0)
}
return nil
})
}
@ -175,55 +205,65 @@ func dumpBucketsPour() {
if err != nil {
log.Fatalf("open: %s", err)
}
out, err := yaml.Marshal(leaky.BucketPourCache)
if err != nil {
log.Fatalf("marshal: %s", err)
}
b, err := fd.Write(out)
if err != nil {
log.Fatalf("write: %s", err)
}
log.Tracef("wrote %d bytes", b)
if err := fd.Close(); err != nil {
log.Fatalf(" close: %s", err)
}
}
func dumpParserState() {
fd, err := os.OpenFile(filepath.Join(parser.DumpFolder, "parser-dump.yaml"), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666)
if err != nil {
log.Fatalf("open: %s", err)
}
out, err := yaml.Marshal(parser.StageParseCache)
if err != nil {
log.Fatalf("marshal: %s", err)
}
b, err := fd.Write(out)
if err != nil {
log.Fatalf("write: %s", err)
}
log.Tracef("wrote %d bytes", b)
if err := fd.Close(); err != nil {
log.Fatalf(" close: %s", err)
}
}
func dumpOverflowState() {
fd, err := os.OpenFile(filepath.Join(parser.DumpFolder, "bucket-dump.yaml"), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666)
if err != nil {
log.Fatalf("open: %s", err)
}
out, err := yaml.Marshal(bucketOverflows)
if err != nil {
log.Fatalf("marshal: %s", err)
}
b, err := fd.Write(out)
if err != nil {
log.Fatalf("write: %s", err)
}
log.Tracef("wrote %d bytes", b)
if err := fd.Close(); err != nil {
log.Fatalf(" close: %s", err)
}

View file

@ -0,0 +1,92 @@
package main
import (
"context"
"fmt"
"net/url"
"time"
"github.com/go-openapi/strfmt"
"github.com/crowdsecurity/go-cs-lib/version"
"github.com/crowdsecurity/crowdsec/pkg/apiclient"
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/cwhub"
"github.com/crowdsecurity/crowdsec/pkg/models"
)
func AuthenticatedLAPIClient(credentials csconfig.ApiCredentialsCfg, hub *cwhub.Hub) (*apiclient.ApiClient, error) {
scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS)
if err != nil {
return nil, fmt.Errorf("loading list of installed hub scenarios: %w", err)
}
appsecRules, err := hub.GetInstalledItemNames(cwhub.APPSEC_RULES)
if err != nil {
return nil, fmt.Errorf("loading list of installed hub appsec rules: %w", err)
}
installedScenariosAndAppsecRules := make([]string, 0, len(scenarios)+len(appsecRules))
installedScenariosAndAppsecRules = append(installedScenariosAndAppsecRules, scenarios...)
installedScenariosAndAppsecRules = append(installedScenariosAndAppsecRules, appsecRules...)
apiURL, err := url.Parse(credentials.URL)
if err != nil {
return nil, fmt.Errorf("parsing api url ('%s'): %w", credentials.URL, err)
}
papiURL, err := url.Parse(credentials.PapiURL)
if err != nil {
return nil, fmt.Errorf("parsing polling api url ('%s'): %w", credentials.PapiURL, err)
}
password := strfmt.Password(credentials.Password)
client, err := apiclient.NewClient(&apiclient.Config{
MachineID: credentials.Login,
Password: password,
Scenarios: installedScenariosAndAppsecRules,
UserAgent: fmt.Sprintf("crowdsec/%s", version.String()),
URL: apiURL,
PapiURL: papiURL,
VersionPrefix: "v1",
UpdateScenario: func() ([]string, error) {
scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS)
if err != nil {
return nil, err
}
appsecRules, err := hub.GetInstalledItemNames(cwhub.APPSEC_RULES)
if err != nil {
return nil, err
}
ret := make([]string, 0, len(scenarios)+len(appsecRules))
ret = append(ret, scenarios...)
ret = append(ret, appsecRules...)
return ret, nil
},
})
if err != nil {
return nil, fmt.Errorf("new client api: %w", err)
}
authResp, _, err := client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
MachineID: &credentials.Login,
Password: &password,
Scenarios: installedScenariosAndAppsecRules,
})
if err != nil {
return nil, fmt.Errorf("authenticate watcher (%s): %w", credentials.Login, err)
}
var expiration time.Time
if err := expiration.UnmarshalText([]byte(authResp.Expire)); err != nil {
return nil, fmt.Errorf("unable to parse jwt expiration: %w", err)
}
client.GetClient().Transport.(*apiclient.JWTTransport).Token = authResp.Token
client.GetClient().Transport.(*apiclient.JWTTransport).Expiration = expiration
return client, nil
}

View file

@ -114,13 +114,17 @@ func computeDynamicMetrics(next http.Handler, dbClient *database.Client) http.Ha
}
decisionsFilters := make(map[string][]string, 0)
decisions, err := dbClient.QueryDecisionCountByScenario(decisionsFilters)
if err != nil {
log.Errorf("Error querying decisions for metrics: %v", err)
next.ServeHTTP(w, r)
return
}
globalActiveDecisions.Reset()
for _, d := range decisions {
globalActiveDecisions.With(prometheus.Labels{"reason": d.Scenario, "origin": d.Origin, "action": d.Type}).Set(float64(d.Count))
}
@ -136,6 +140,7 @@ func computeDynamicMetrics(next http.Handler, dbClient *database.Client) http.Ha
if err != nil {
log.Errorf("Error querying alerts for metrics: %v", err)
next.ServeHTTP(w, r)
return
}
@ -161,7 +166,7 @@ func registerPrometheus(config *csconfig.PrometheusCfg) {
leaky.BucketsUnderflow, leaky.BucketsCanceled, leaky.BucketsInstantiation, leaky.BucketsOverflow,
v1.LapiRouteHits,
leaky.BucketsCurrentCount,
cache.CacheMetrics, exprhelpers.RegexpCacheMetrics,
cache.CacheMetrics, exprhelpers.RegexpCacheMetrics, parser.NodesWlHitsOk, parser.NodesWlHits,
)
} else {
log.Infof("Loading prometheus collectors")
@ -170,14 +175,15 @@ func registerPrometheus(config *csconfig.PrometheusCfg) {
globalCsInfo, globalParsingHistogram, globalPourHistogram,
v1.LapiRouteHits, v1.LapiMachineHits, v1.LapiBouncerHits, v1.LapiNilDecisions, v1.LapiNonNilDecisions, v1.LapiResponseTime,
leaky.BucketsPour, leaky.BucketsUnderflow, leaky.BucketsCanceled, leaky.BucketsInstantiation, leaky.BucketsOverflow, leaky.BucketsCurrentCount,
globalActiveDecisions, globalAlerts,
globalActiveDecisions, globalAlerts, parser.NodesWlHitsOk, parser.NodesWlHits,
cache.CacheMetrics, exprhelpers.RegexpCacheMetrics,
)
}
}
func servePrometheus(config *csconfig.PrometheusCfg, dbClient *database.Client, apiReady chan bool, agentReady chan bool) {
func servePrometheus(config *csconfig.PrometheusCfg, dbClient *database.Client, agentReady chan bool) {
<-agentReady
if !config.Enabled {
return
}
@ -185,9 +191,8 @@ func servePrometheus(config *csconfig.PrometheusCfg, dbClient *database.Client,
defer trace.CatchPanic("crowdsec/servePrometheus")
http.Handle("/metrics", computeDynamicMetrics(promhttp.Handler(), dbClient))
<-apiReady
<-agentReady
log.Debugf("serving metrics after %s ms", time.Since(crowdsecT0))
if err := http.ListenAndServe(fmt.Sprintf("%s:%d", config.ListenAddr, config.ListenPort), nil); err != nil {
log.Warningf("prometheus: %s", err)
}

View file

@ -3,18 +3,12 @@ package main
import (
"context"
"fmt"
"net/url"
"sync"
"time"
"github.com/go-openapi/strfmt"
log "github.com/sirupsen/logrus"
"github.com/crowdsecurity/go-cs-lib/version"
"github.com/crowdsecurity/crowdsec/pkg/apiclient"
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/cwhub"
leaky "github.com/crowdsecurity/crowdsec/pkg/leakybucket"
"github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/parser"
@ -22,7 +16,6 @@ import (
)
func dedupAlerts(alerts []types.RuntimeAlert) ([]*models.Alert, error) {
var dedupCache []*models.Alert
for idx, alert := range alerts {
@ -32,16 +25,21 @@ func dedupAlerts(alerts []types.RuntimeAlert) ([]*models.Alert, error) {
dedupCache = append(dedupCache, alert.Alert)
continue
}
for k, src := range alert.Sources {
refsrc := *alert.Alert //copy
log.Tracef("source[%s]", k)
refsrc.Source = &src
dedupCache = append(dedupCache, &refsrc)
}
}
if len(dedupCache) != len(alerts) {
log.Tracef("went from %d to %d alerts", len(alerts), len(dedupCache))
}
return dedupCache, nil
}
@ -52,93 +50,25 @@ func PushAlerts(alerts []types.RuntimeAlert, client *apiclient.ApiClient) error
if err != nil {
return fmt.Errorf("failed to transform alerts for api: %w", err)
}
_, _, err = client.Alerts.Add(ctx, alertsToPush)
if err != nil {
return fmt.Errorf("failed sending alert to LAPI: %w", err)
}
return nil
}
var bucketOverflows []types.Event
func runOutput(input chan types.Event, overflow chan types.Event, buckets *leaky.Buckets,
postOverflowCTX parser.UnixParserCtx, postOverflowNodes []parser.Node,
apiConfig csconfig.ApiCredentialsCfg, hub *cwhub.Hub) error {
func runOutput(input chan types.Event, overflow chan types.Event, buckets *leaky.Buckets, postOverflowCTX parser.UnixParserCtx,
postOverflowNodes []parser.Node, client *apiclient.ApiClient) error {
var (
cache []types.RuntimeAlert
cacheMutex sync.Mutex
)
var err error
ticker := time.NewTicker(1 * time.Second)
var cache []types.RuntimeAlert
var cacheMutex sync.Mutex
scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS)
if err != nil {
return fmt.Errorf("loading list of installed hub scenarios: %w", err)
}
appsecRules, err := hub.GetInstalledItemNames(cwhub.APPSEC_RULES)
if err != nil {
return fmt.Errorf("loading list of installed hub appsec rules: %w", err)
}
installedScenariosAndAppsecRules := make([]string, 0, len(scenarios)+len(appsecRules))
installedScenariosAndAppsecRules = append(installedScenariosAndAppsecRules, scenarios...)
installedScenariosAndAppsecRules = append(installedScenariosAndAppsecRules, appsecRules...)
apiURL, err := url.Parse(apiConfig.URL)
if err != nil {
return fmt.Errorf("parsing api url ('%s'): %w", apiConfig.URL, err)
}
papiURL, err := url.Parse(apiConfig.PapiURL)
if err != nil {
return fmt.Errorf("parsing polling api url ('%s'): %w", apiConfig.PapiURL, err)
}
password := strfmt.Password(apiConfig.Password)
Client, err := apiclient.NewClient(&apiclient.Config{
MachineID: apiConfig.Login,
Password: password,
Scenarios: installedScenariosAndAppsecRules,
UserAgent: fmt.Sprintf("crowdsec/%s", version.String()),
URL: apiURL,
PapiURL: papiURL,
VersionPrefix: "v1",
UpdateScenario: func() ([]string, error) {
scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS)
if err != nil {
return nil, err
}
appsecRules, err := hub.GetInstalledItemNames(cwhub.APPSEC_RULES)
if err != nil {
return nil, err
}
ret := make([]string, 0, len(scenarios)+len(appsecRules))
ret = append(ret, scenarios...)
ret = append(ret, appsecRules...)
return ret, nil
},
})
if err != nil {
return fmt.Errorf("new client api: %w", err)
}
authResp, _, err := Client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
MachineID: &apiConfig.Login,
Password: &password,
Scenarios: installedScenariosAndAppsecRules,
})
if err != nil {
return fmt.Errorf("authenticate watcher (%s): %w", apiConfig.Login, err)
}
if err := Client.GetClient().Transport.(*apiclient.JWTTransport).Expiration.UnmarshalText([]byte(authResp.Expire)); err != nil {
return fmt.Errorf("unable to parse jwt expiration: %w", err)
}
Client.GetClient().Transport.(*apiclient.JWTTransport).Token = authResp.Token
//start the heartbeat service
log.Debugf("Starting HeartBeat service")
Client.HeartBeat.StartHeartBeat(context.Background(), &outputsTomb)
LOOP:
for {
select {
@ -149,7 +79,7 @@ LOOP:
newcache := make([]types.RuntimeAlert, 0)
cache = newcache
cacheMutex.Unlock()
if err := PushAlerts(cachecopy, Client); err != nil {
if err := PushAlerts(cachecopy, client); err != nil {
log.Errorf("while pushing to api : %s", err)
//just push back the events to the queue
cacheMutex.Lock()
@ -162,10 +92,11 @@ LOOP:
cacheMutex.Lock()
cachecopy := cache
cacheMutex.Unlock()
if err := PushAlerts(cachecopy, Client); err != nil {
if err := PushAlerts(cachecopy, client); err != nil {
log.Errorf("while pushing leftovers to api : %s", err)
}
}
break LOOP
case event := <-overflow:
/*if alert is empty and mapKey is present, the overflow is just to cleanup bucket*/
@ -176,7 +107,7 @@ LOOP:
/* process post overflow parser nodes */
event, err := parser.Parse(postOverflowCTX, event, postOverflowNodes)
if err != nil {
return fmt.Errorf("postoverflow failed : %s", err)
return fmt.Errorf("postoverflow failed: %w", err)
}
log.Printf("%s", *event.Overflow.Alert.Message)
//if the Alert is nil, it's to signal bucket is ready for GC, don't track this
@ -206,6 +137,6 @@ LOOP:
}
ticker.Stop()
return nil
return nil
}

View file

@ -33,7 +33,6 @@ func StartRunSvc() error {
log.Infof("Crowdsec %s", version.String())
apiReady := make(chan bool, 1)
agentReady := make(chan bool, 1)
// Enable profiling early
@ -46,14 +45,19 @@ func StartRunSvc() error {
dbClient, err = database.NewClient(cConfig.DbConfig)
if err != nil {
return fmt.Errorf("unable to create database client: %s", err)
return fmt.Errorf("unable to create database client: %w", err)
}
}
registerPrometheus(cConfig.Prometheus)
go servePrometheus(cConfig.Prometheus, dbClient, apiReady, agentReady)
go servePrometheus(cConfig.Prometheus, dbClient, agentReady)
} else {
// avoid leaking the channel
go func() {
<-agentReady
}()
}
return Serve(cConfig, apiReady, agentReady)
return Serve(cConfig, agentReady)
}

View file

@ -73,7 +73,6 @@ func WindowsRun() error {
log.Infof("Crowdsec %s", version.String())
apiReady := make(chan bool, 1)
agentReady := make(chan bool, 1)
// Enable profiling early
@ -85,11 +84,11 @@ func WindowsRun() error {
dbClient, err = database.NewClient(cConfig.DbConfig)
if err != nil {
return fmt.Errorf("unable to create database client: %s", err)
return fmt.Errorf("unable to create database client: %w", err)
}
}
registerPrometheus(cConfig.Prometheus)
go servePrometheus(cConfig.Prometheus, dbClient, apiReady, agentReady)
go servePrometheus(cConfig.Prometheus, dbClient, agentReady)
}
return Serve(cConfig, apiReady, agentReady)
return Serve(cConfig, agentReady)
}

View file

@ -42,7 +42,9 @@ func debugHandler(sig os.Signal, cConfig *csconfig.Config) error {
if err := leaky.ShutdownAllBuckets(buckets); err != nil {
log.Warningf("Failed to shut down routines : %s", err)
}
log.Printf("Shutdown is finished, buckets are in %s", tmpFile)
return nil
}
@ -66,15 +68,16 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) {
if !cConfig.DisableAPI {
if flags.DisableCAPI {
log.Warningf("Communication with CrowdSec Central API disabled from args")
cConfig.API.Server.OnlineClient = nil
}
apiServer, err := initAPIServer(cConfig)
if err != nil {
return nil, fmt.Errorf("unable to init api server: %w", err)
}
apiReady := make(chan bool, 1)
serveAPIServer(apiServer, apiReady)
serveAPIServer(apiServer)
}
if !cConfig.DisableAgent {
@ -110,6 +113,7 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) {
log.Warningf("Failed to delete temp file (%s) : %s", tmpFile, err)
}
}
return cConfig, nil
}
@ -117,10 +121,12 @@ func ShutdownCrowdsecRoutines() error {
var reterr error
log.Debugf("Shutting down crowdsec sub-routines")
if len(dataSources) > 0 {
acquisTomb.Kill(nil)
log.Debugf("waiting for acquisition to finish")
drainChan(inputLineChan)
if err := acquisTomb.Wait(); err != nil {
log.Warningf("Acquisition returned error : %s", err)
reterr = err
@ -130,6 +136,7 @@ func ShutdownCrowdsecRoutines() error {
log.Debugf("acquisition is finished, wait for parser/bucket/ouputs.")
parsersTomb.Kill(nil)
drainChan(inputEventChan)
if err := parsersTomb.Wait(); err != nil {
log.Warningf("Parsers returned error : %s", err)
reterr = err
@ -160,6 +167,7 @@ func ShutdownCrowdsecRoutines() error {
log.Warningf("Outputs returned error : %s", err)
reterr = err
}
log.Debugf("outputs are done")
case <-time.After(3 * time.Second):
// this can happen if outputs are stuck in a http retry loop
@ -181,6 +189,7 @@ func shutdownAPI() error {
}
log.Debugf("done")
return nil
}
@ -193,6 +202,7 @@ func shutdownCrowdsec() error {
}
log.Debugf("done")
return nil
}
@ -292,10 +302,11 @@ func HandleSignals(cConfig *csconfig.Config) error {
if err == nil {
log.Warning("Crowdsec service shutting down")
}
return err
}
func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) error {
func Serve(cConfig *csconfig.Config, agentReady chan bool) error {
acquisTomb = tomb.Tomb{}
parsersTomb = tomb.Tomb{}
bucketsTomb = tomb.Tomb{}
@ -325,6 +336,7 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e
if cConfig.API.CTI != nil && *cConfig.API.CTI.Enabled {
log.Infof("Crowdsec CTI helper enabled")
if err := exprhelpers.InitCrowdsecCTI(cConfig.API.CTI.Key, cConfig.API.CTI.CacheTimeout, cConfig.API.CTI.CacheSize, cConfig.API.CTI.LogLevel); err != nil {
return fmt.Errorf("failed to init crowdsec cti: %w", err)
}
@ -337,6 +349,7 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e
if flags.DisableCAPI {
log.Warningf("Communication with CrowdSec Central API disabled from args")
cConfig.API.Server.OnlineClient = nil
}
@ -346,10 +359,8 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e
}
if !flags.TestMode {
serveAPIServer(apiServer, apiReady)
serveAPIServer(apiServer)
}
} else {
apiReady <- true
}
if !cConfig.DisableAgent {
@ -366,6 +377,8 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e
// if it's just linting, we're done
if !flags.TestMode {
serveCrowdsec(csParsers, cConfig, hub, agentReady)
} else {
agentReady <- true
}
} else {
agentReady <- true
@ -395,6 +408,7 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e
for _, ch := range waitChans {
<-ch
switch ch {
case apiTomb.Dead():
log.Infof("api shutdown")
@ -402,5 +416,6 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e
log.Infof("crowdsec shutdown")
}
}
return nil
}

4
go.mod
View file

@ -77,7 +77,7 @@ require (
github.com/shirou/gopsutil/v3 v3.23.5
github.com/sirupsen/logrus v1.9.3
github.com/slack-go/slack v0.12.2
github.com/spf13/cobra v1.7.0
github.com/spf13/cobra v1.8.0
github.com/stretchr/testify v1.8.4
github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26
github.com/wasilibs/go-re2 v1.3.0
@ -108,7 +108,7 @@ require (
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/corazawaf/libinjection-go v0.1.2 // indirect
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.3 // indirect
github.com/creack/pty v1.1.18 // indirect
github.com/docker/distribution v2.8.2+incompatible // indirect
github.com/docker/go-units v0.5.0 // indirect

16
go.sum
View file

@ -91,21 +91,17 @@ github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/cpuguy83/go-md2man/v2 v2.0.3 h1:qMCsGGgs+MAzDFyp9LpAe1Lqy/fY/qCovCm0qnXZOBM=
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/crowdsecurity/coraza/v3 v3.0.0-20231213144607-41d5358da94f h1:FkOB9aDw0xzDd14pTarGRLsUNAymONq3dc7zhvsXElg=
github.com/crowdsecurity/coraza/v3 v3.0.0-20231213144607-41d5358da94f/go.mod h1:TrU7Li+z2RHNrPy0TKJ6R65V6Yzpan2sTIRryJJyJso=
github.com/crowdsecurity/coraza/v3 v3.0.0-20240108124027-a62b8d8e5607 h1:hyrYw3h8clMcRL2u5ooZ3tmwnmJftmhb9Ws1MKmavvI=
github.com/crowdsecurity/coraza/v3 v3.0.0-20240108124027-a62b8d8e5607/go.mod h1:br36fEqurGYZQGit+iDYsIzW0FF6VufMbDzyyLxEuPA=
github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:r97WNVC30Uen+7WnLs4xDScS/Ex988+id2k6mDf8psU=
github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:zpv7r+7KXwgVUZnUNjyP22zc/D7LKjyoY02weH2RBbk=
github.com/crowdsecurity/go-cs-lib v0.0.5 h1:eVLW+BRj3ZYn0xt5/xmgzfbbB8EBo32gM4+WpQQk2e8=
github.com/crowdsecurity/go-cs-lib v0.0.5/go.mod h1:8FMKNGsh3hMZi2SEv6P15PURhEJnZV431XjzzBSuf0k=
github.com/crowdsecurity/go-cs-lib v0.0.6 h1:Ef6MylXe0GaJE9vrfvxEdbHb31+JUP1os+murPz7Pos=
github.com/crowdsecurity/go-cs-lib v0.0.6/go.mod h1:8FMKNGsh3hMZi2SEv6P15PURhEJnZV431XjzzBSuf0k=
github.com/crowdsecurity/grokky v0.2.1 h1:t4VYnDlAd0RjDM2SlILalbwfCrQxtJSMGdQOR0zwkE4=
@ -546,6 +542,8 @@ github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/oklog/run v1.0.0 h1:Ru7dDtJNOyC66gQ5dQmaCa0qIsAUFY3sFpK1Xk8igrw=
github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 h1:rc3tiVYb5z54aKaDfakKn0dDjIyPpTtszkjuMzyt7ec=
@ -640,8 +638,8 @@ github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng=
github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ=
github.com/spf13/cobra v1.4.0/go.mod h1:Wo4iy3BUC+X2Fybo0PDqwJIv3dNRiZLHQymsfxlB84g=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho=
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
@ -809,8 +807,6 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=

View file

@ -354,15 +354,17 @@ func (w *AppsecSource) appsecHandler(rw http.ResponseWriter, r *http.Request) {
w.InChan <- parsedRequest
/*
response is a copy of w.AppSecRuntime.Response that is safe to use.
As OutOfBand might still be running, the original one can be modified
*/
response := <-parsedRequest.ResponseChannel
statusCode := http.StatusOK
if response.InBandInterrupt {
statusCode = http.StatusForbidden
AppsecBlockCounter.With(prometheus.Labels{"source": parsedRequest.RemoteAddrNormalized, "appsec_engine": parsedRequest.AppsecEngine}).Inc()
}
appsecResponse := w.AppsecRuntime.GenerateResponse(response, logger)
statusCode, appsecResponse := w.AppsecRuntime.GenerateResponse(response, logger)
logger.Debugf("Response: %+v", appsecResponse)
rw.WriteHeader(statusCode)

View file

@ -226,7 +226,8 @@ func (r *AppsecRunner) handleInBandInterrupt(request *appsec.ParsedRequest) {
if in := request.Tx.Interruption(); in != nil {
r.logger.Debugf("inband rules matched : %d", in.RuleID)
r.AppsecRuntime.Response.InBandInterrupt = true
r.AppsecRuntime.Response.HTTPResponseCode = r.AppsecRuntime.Config.BlockedHTTPCode
r.AppsecRuntime.Response.BouncerHTTPResponseCode = r.AppsecRuntime.Config.BouncerBlockedHTTPCode
r.AppsecRuntime.Response.UserHTTPResponseCode = r.AppsecRuntime.Config.UserBlockedHTTPCode
r.AppsecRuntime.Response.Action = r.AppsecRuntime.DefaultRemediation
if _, ok := r.AppsecRuntime.RemediationById[in.RuleID]; ok {
@ -252,7 +253,9 @@ func (r *AppsecRunner) handleInBandInterrupt(request *appsec.ParsedRequest) {
r.logger.Errorf("unable to generate appsec event : %s", err)
return
}
r.outChan <- *appsecOvlfw
if appsecOvlfw != nil {
r.outChan <- *appsecOvlfw
}
}
// Should the in band match trigger an event ?

View file

@ -1,6 +1,7 @@
package appsecacquisition
import (
"net/http"
"net/url"
"testing"
"time"
@ -21,16 +22,21 @@ Missing tests (wip):
*/
type appsecRuleTest struct {
name string
expected_load_ok bool
inband_rules []appsec_rule.CustomRule
outofband_rules []appsec_rule.CustomRule
on_load []appsec.Hook
pre_eval []appsec.Hook
post_eval []appsec.Hook
on_match []appsec.Hook
input_request appsec.ParsedRequest
output_asserts func(events []types.Event, responses []appsec.AppsecTempResponse)
name string
expected_load_ok bool
inband_rules []appsec_rule.CustomRule
outofband_rules []appsec_rule.CustomRule
on_load []appsec.Hook
pre_eval []appsec.Hook
post_eval []appsec.Hook
on_match []appsec.Hook
BouncerBlockedHTTPCode int
UserBlockedHTTPCode int
UserPassedHTTPCode int
DefaultRemediation string
DefaultPassAction string
input_request appsec.ParsedRequest
output_asserts func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int)
}
func TestAppsecOnMatchHooks(t *testing.T) {
@ -53,13 +59,14 @@ func TestAppsecOnMatchHooks(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Len(t, events, 2)
require.Equal(t, types.APPSEC, events[0].Type)
require.Equal(t, types.LOG, events[1].Type)
require.Len(t, responses, 1)
require.Equal(t, 403, responses[0].HTTPResponseCode)
require.Equal(t, "ban", responses[0].Action)
require.Equal(t, 403, responses[0].BouncerHTTPResponseCode)
require.Equal(t, 403, responses[0].UserHTTPResponseCode)
require.Equal(t, appsec.BanRemediation, responses[0].Action)
},
},
@ -84,17 +91,18 @@ func TestAppsecOnMatchHooks(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Len(t, events, 2)
require.Equal(t, types.APPSEC, events[0].Type)
require.Equal(t, types.LOG, events[1].Type)
require.Len(t, responses, 1)
require.Equal(t, 413, responses[0].HTTPResponseCode)
require.Equal(t, "ban", responses[0].Action)
require.Equal(t, 403, responses[0].BouncerHTTPResponseCode)
require.Equal(t, 413, responses[0].UserHTTPResponseCode)
require.Equal(t, appsec.BanRemediation, responses[0].Action)
},
},
{
name: "on_match: change action to another standard one (log)",
name: "on_match: change action to a non standard one (log)",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
@ -114,7 +122,7 @@ func TestAppsecOnMatchHooks(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Len(t, events, 2)
require.Equal(t, types.APPSEC, events[0].Type)
require.Equal(t, types.LOG, events[1].Type)
@ -143,16 +151,16 @@ func TestAppsecOnMatchHooks(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Len(t, events, 2)
require.Equal(t, types.APPSEC, events[0].Type)
require.Equal(t, types.LOG, events[1].Type)
require.Len(t, responses, 1)
require.Equal(t, "allow", responses[0].Action)
require.Equal(t, appsec.AllowRemediation, responses[0].Action)
},
},
{
name: "on_match: change action to another standard one (deny/ban/block)",
name: "on_match: change action to another standard one (ban)",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
@ -164,7 +172,7 @@ func TestAppsecOnMatchHooks(t *testing.T) {
},
},
on_match: []appsec.Hook{
{Filter: "IsInBand == true", Apply: []string{"SetRemediation('deny')"}},
{Filter: "IsInBand == true", Apply: []string{"SetRemediation('ban')"}},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
@ -172,10 +180,10 @@ func TestAppsecOnMatchHooks(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Len(t, responses, 1)
//note: SetAction normalizes deny, ban and block to ban
require.Equal(t, "ban", responses[0].Action)
require.Equal(t, appsec.BanRemediation, responses[0].Action)
},
},
{
@ -199,10 +207,10 @@ func TestAppsecOnMatchHooks(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Len(t, responses, 1)
//note: SetAction normalizes deny, ban and block to ban
require.Equal(t, "captcha", responses[0].Action)
require.Equal(t, appsec.CaptchaRemediation, responses[0].Action)
},
},
{
@ -226,7 +234,7 @@ func TestAppsecOnMatchHooks(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Len(t, events, 2)
require.Equal(t, types.APPSEC, events[0].Type)
require.Equal(t, types.LOG, events[1].Type)
@ -255,11 +263,11 @@ func TestAppsecOnMatchHooks(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Len(t, events, 1)
require.Equal(t, types.LOG, events[0].Type)
require.Len(t, responses, 1)
require.Equal(t, "ban", responses[0].Action)
require.Equal(t, appsec.BanRemediation, responses[0].Action)
},
},
{
@ -283,11 +291,11 @@ func TestAppsecOnMatchHooks(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Len(t, events, 1)
require.Equal(t, types.APPSEC, events[0].Type)
require.Len(t, responses, 1)
require.Equal(t, "ban", responses[0].Action)
require.Equal(t, appsec.BanRemediation, responses[0].Action)
},
},
}
@ -328,7 +336,7 @@ func TestAppsecPreEvalHooks(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Empty(t, events)
require.Len(t, responses, 1)
require.False(t, responses[0].InBandInterrupt)
@ -356,7 +364,7 @@ func TestAppsecPreEvalHooks(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Len(t, events, 2)
require.Equal(t, types.APPSEC, events[0].Type)
@ -391,7 +399,7 @@ func TestAppsecPreEvalHooks(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Empty(t, events)
require.Len(t, responses, 1)
require.False(t, responses[0].InBandInterrupt)
@ -419,7 +427,7 @@ func TestAppsecPreEvalHooks(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Empty(t, events)
require.Len(t, responses, 1)
require.False(t, responses[0].InBandInterrupt)
@ -447,7 +455,7 @@ func TestAppsecPreEvalHooks(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Empty(t, events)
require.Len(t, responses, 1)
require.False(t, responses[0].InBandInterrupt)
@ -472,7 +480,7 @@ func TestAppsecPreEvalHooks(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Len(t, events, 1)
require.Equal(t, types.LOG, events[0].Type)
require.True(t, events[0].Appsec.HasOutBandMatches)
@ -506,7 +514,7 @@ func TestAppsecPreEvalHooks(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Len(t, events, 2)
require.Len(t, responses, 1)
require.Equal(t, "foobar", responses[0].Action)
@ -533,7 +541,7 @@ func TestAppsecPreEvalHooks(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Len(t, events, 2)
require.Len(t, responses, 1)
require.Equal(t, "foobar", responses[0].Action)
@ -560,10 +568,12 @@ func TestAppsecPreEvalHooks(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Len(t, events, 2)
require.Len(t, responses, 1)
require.Equal(t, "foobar", responses[0].Action)
require.Equal(t, "foobar", appsecResponse.Action)
require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus)
},
},
}
@ -574,6 +584,473 @@ func TestAppsecPreEvalHooks(t *testing.T) {
})
}
}
func TestAppsecRemediationConfigHooks(t *testing.T) {
tests := []appsecRuleTest{
{
name: "Basic matching rule",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule1",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Equal(t, appsec.BanRemediation, responses[0].Action)
require.Equal(t, http.StatusForbidden, statusCode)
require.Equal(t, appsec.BanRemediation, appsecResponse.Action)
require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus)
},
},
{
name: "SetRemediation",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule1",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
on_match: []appsec.Hook{{Apply: []string{"SetRemediation('captcha')"}}}, //rule ID is generated at runtime. If you change rule, it will break the test (:
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Equal(t, appsec.CaptchaRemediation, responses[0].Action)
require.Equal(t, http.StatusForbidden, statusCode)
require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action)
require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus)
},
},
{
name: "SetRemediation",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule1",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
on_match: []appsec.Hook{{Apply: []string{"SetReturnCode(418)"}}}, //rule ID is generated at runtime. If you change rule, it will break the test (:
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Equal(t, appsec.BanRemediation, responses[0].Action)
require.Equal(t, http.StatusForbidden, statusCode)
require.Equal(t, appsec.BanRemediation, appsecResponse.Action)
require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus)
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
loadAppSecEngine(test, t)
})
}
}
func TestOnMatchRemediationHooks(t *testing.T) {
tests := []appsecRuleTest{
{
name: "set remediation to allow with on_match hook",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule42",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
on_match: []appsec.Hook{
{Filter: "IsInBand == true", Apply: []string{"SetRemediation('allow')"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Equal(t, appsec.AllowRemediation, appsecResponse.Action)
require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus)
},
},
{
name: "set remediation to captcha + custom user code with on_match hook",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule42",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
DefaultRemediation: appsec.AllowRemediation,
on_match: []appsec.Hook{
{Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')", "SetReturnCode(418)"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
spew.Dump(responses)
spew.Dump(appsecResponse)
log.Errorf("http status : %d", statusCode)
require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action)
require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus)
require.Equal(t, http.StatusForbidden, statusCode)
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
loadAppSecEngine(test, t)
})
}
}
func TestAppsecDefaultPassRemediation(t *testing.T) {
tests := []appsecRuleTest{
{
name: "Basic non-matching rule",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule1",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/",
Args: url.Values{"foo": []string{"tutu"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Equal(t, appsec.AllowRemediation, responses[0].Action)
require.Equal(t, http.StatusOK, statusCode)
require.Equal(t, appsec.AllowRemediation, appsecResponse.Action)
require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus)
},
},
{
name: "DefaultPassAction: pass",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule1",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/",
Args: url.Values{"foo": []string{"tutu"}},
},
DefaultPassAction: "allow",
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Equal(t, appsec.AllowRemediation, responses[0].Action)
require.Equal(t, http.StatusOK, statusCode)
require.Equal(t, appsec.AllowRemediation, appsecResponse.Action)
require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus)
},
},
{
name: "DefaultPassAction: captcha",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule1",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/",
Args: url.Values{"foo": []string{"tutu"}},
},
DefaultPassAction: "captcha",
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Equal(t, appsec.CaptchaRemediation, responses[0].Action)
require.Equal(t, http.StatusOK, statusCode) //@tko: body is captcha, but as it's 200, captcha won't be showed to user
require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action)
require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus)
},
},
{
name: "DefaultPassHTTPCode: 200",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule1",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/",
Args: url.Values{"foo": []string{"tutu"}},
},
UserPassedHTTPCode: 200,
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Equal(t, appsec.AllowRemediation, responses[0].Action)
require.Equal(t, http.StatusOK, statusCode)
require.Equal(t, appsec.AllowRemediation, appsecResponse.Action)
require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus)
},
},
{
name: "DefaultPassHTTPCode: 200",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule1",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/",
Args: url.Values{"foo": []string{"tutu"}},
},
UserPassedHTTPCode: 418,
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Equal(t, appsec.AllowRemediation, responses[0].Action)
require.Equal(t, http.StatusOK, statusCode)
require.Equal(t, appsec.AllowRemediation, appsecResponse.Action)
require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus)
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
loadAppSecEngine(test, t)
})
}
}
func TestAppsecDefaultRemediation(t *testing.T) {
tests := []appsecRuleTest{
{
name: "Basic matching rule",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule1",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Equal(t, appsec.BanRemediation, responses[0].Action)
require.Equal(t, http.StatusForbidden, statusCode)
require.Equal(t, appsec.BanRemediation, appsecResponse.Action)
require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus)
},
},
{
name: "default remediation to ban (default)",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule42",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
DefaultRemediation: "ban",
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Equal(t, appsec.BanRemediation, responses[0].Action)
require.Equal(t, http.StatusForbidden, statusCode)
require.Equal(t, appsec.BanRemediation, appsecResponse.Action)
require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus)
},
},
{
name: "default remediation to allow",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule42",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
DefaultRemediation: "allow",
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Equal(t, appsec.AllowRemediation, responses[0].Action)
require.Equal(t, http.StatusOK, statusCode)
require.Equal(t, appsec.AllowRemediation, appsecResponse.Action)
require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus)
},
},
{
name: "default remediation to captcha",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule42",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
DefaultRemediation: "captcha",
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Equal(t, appsec.CaptchaRemediation, responses[0].Action)
require.Equal(t, http.StatusForbidden, statusCode)
require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action)
require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus)
},
},
{
name: "custom user HTTP code",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule42",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
UserBlockedHTTPCode: 418,
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Equal(t, appsec.BanRemediation, responses[0].Action)
require.Equal(t, http.StatusForbidden, statusCode)
require.Equal(t, appsec.BanRemediation, appsecResponse.Action)
require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus)
},
},
{
name: "custom remediation + HTTP code",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule42",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
UserBlockedHTTPCode: 418,
DefaultRemediation: "foobar",
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Equal(t, "foobar", responses[0].Action)
require.Equal(t, http.StatusForbidden, statusCode)
require.Equal(t, "foobar", appsecResponse.Action)
require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus)
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
loadAppSecEngine(test, t)
})
}
}
func TestAppsecRuleMatches(t *testing.T) {
/*
@ -601,7 +1078,7 @@ func TestAppsecRuleMatches(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Len(t, events, 2)
require.Equal(t, types.APPSEC, events[0].Type)
@ -632,13 +1109,172 @@ func TestAppsecRuleMatches(t *testing.T) {
URI: "/urllll",
Args: url.Values{"foo": []string{"tutu"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse) {
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Empty(t, events)
require.Len(t, responses, 1)
require.False(t, responses[0].InBandInterrupt)
require.False(t, responses[0].OutOfBandInterrupt)
},
},
{
name: "default remediation to allow",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule42",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
DefaultRemediation: "allow",
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Equal(t, appsec.AllowRemediation, responses[0].Action)
require.Equal(t, http.StatusOK, statusCode)
require.Equal(t, appsec.AllowRemediation, appsecResponse.Action)
require.Equal(t, http.StatusOK, appsecResponse.HTTPStatus)
},
},
{
name: "default remediation to captcha",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule42",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
DefaultRemediation: "captcha",
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Equal(t, appsec.CaptchaRemediation, responses[0].Action)
require.Equal(t, http.StatusForbidden, statusCode)
require.Equal(t, appsec.CaptchaRemediation, appsecResponse.Action)
require.Equal(t, http.StatusForbidden, appsecResponse.HTTPStatus)
},
},
{
name: "no default remediation / custom user HTTP code",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule42",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"toto"}},
},
UserBlockedHTTPCode: 418,
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Equal(t, appsec.BanRemediation, responses[0].Action)
require.Equal(t, http.StatusForbidden, statusCode)
require.Equal(t, appsec.BanRemediation, appsecResponse.Action)
require.Equal(t, http.StatusTeapot, appsecResponse.HTTPStatus)
},
},
{
name: "no match but try to set remediation to captcha with on_match hook",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule42",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
on_match: []appsec.Hook{
{Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')"}},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"bla"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Empty(t, events)
require.Equal(t, http.StatusOK, statusCode)
require.Equal(t, appsec.AllowRemediation, appsecResponse.Action)
},
},
{
name: "no match but try to set user HTTP code with on_match hook",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule42",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
on_match: []appsec.Hook{
{Filter: "IsInBand == true", Apply: []string{"SetReturnCode(418)"}},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"bla"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Empty(t, events)
require.Equal(t, http.StatusOK, statusCode)
require.Equal(t, appsec.AllowRemediation, appsecResponse.Action)
},
},
{
name: "no match but try to set remediation with pre_eval hook",
expected_load_ok: true,
inband_rules: []appsec_rule.CustomRule{
{
Name: "rule42",
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: appsec_rule.Match{Type: "regex", Value: "^toto"},
Transform: []string{"lowercase"},
},
},
pre_eval: []appsec.Hook{
{Filter: "IsInBand == true", Apply: []string{"SetRemediationByName('rule42', 'captcha')"}},
},
input_request: appsec.ParsedRequest{
RemoteAddr: "1.2.3.4",
Method: "GET",
URI: "/urllll",
Args: url.Values{"foo": []string{"bla"}},
},
output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) {
require.Empty(t, events)
require.Equal(t, http.StatusOK, statusCode)
require.Equal(t, appsec.AllowRemediation, appsecResponse.Action)
},
},
}
for _, test := range tests {
@ -678,7 +1314,16 @@ func loadAppSecEngine(test appsecRuleTest, t *testing.T) {
outofbandRules = append(outofbandRules, strRule)
}
appsecCfg := appsec.AppsecConfig{Logger: logger, OnLoad: test.on_load, PreEval: test.pre_eval, PostEval: test.post_eval, OnMatch: test.on_match}
appsecCfg := appsec.AppsecConfig{Logger: logger,
OnLoad: test.on_load,
PreEval: test.pre_eval,
PostEval: test.post_eval,
OnMatch: test.on_match,
BouncerBlockedHTTPCode: test.BouncerBlockedHTTPCode,
UserBlockedHTTPCode: test.UserBlockedHTTPCode,
UserPassedHTTPCode: test.UserPassedHTTPCode,
DefaultRemediation: test.DefaultRemediation,
DefaultPassAction: test.DefaultPassAction}
AppsecRuntime, err := appsecCfg.Build()
if err != nil {
t.Fatalf("unable to build appsec runtime : %s", err)
@ -724,8 +1369,10 @@ func loadAppSecEngine(test appsecRuleTest, t *testing.T) {
runner.handleRequest(&input)
time.Sleep(50 * time.Millisecond)
http_status, appsecResponse := AppsecRuntime.GenerateResponse(OutputResponses[0], logger)
log.Infof("events : %s", spew.Sdump(OutputEvents))
log.Infof("responses : %s", spew.Sdump(OutputResponses))
test.output_asserts(OutputEvents, OutputResponses)
test.output_asserts(OutputEvents, OutputResponses, appsecResponse, http_status)
}

View file

@ -25,6 +25,7 @@ type LokiClient struct {
t *tomb.Tomb
fail_start time.Time
currentTickerInterval time.Duration
requestHeaders map[string]string
}
type Config struct {
@ -116,7 +117,7 @@ func (lc *LokiClient) queryRange(uri string, ctx context.Context, c chan *LokiQu
case <-lc.t.Dying():
return lc.t.Err()
case <-ticker.C:
resp, err := http.Get(uri)
resp, err := lc.Get(uri)
if err != nil {
if ok := lc.shouldRetry(); !ok {
return errors.Wrapf(err, "error querying range")
@ -127,6 +128,7 @@ func (lc *LokiClient) queryRange(uri string, ctx context.Context, c chan *LokiQu
}
if resp.StatusCode != http.StatusOK {
lc.Logger.Warnf("bad HTTP response code for query range: %d", resp.StatusCode)
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
if ok := lc.shouldRetry(); !ok {
@ -215,7 +217,7 @@ func (lc *LokiClient) Ready(ctx context.Context) error {
return lc.t.Err()
case <-tick.C:
lc.Logger.Debug("Checking if Loki is ready")
resp, err := http.Get(url)
resp, err := lc.Get(url)
if err != nil {
lc.Logger.Warnf("Error checking if Loki is ready: %s", err)
continue
@ -251,10 +253,9 @@ func (lc *LokiClient) Tail(ctx context.Context) (chan *LokiResponse, error) {
}
requestHeader := http.Header{}
for k, v := range lc.config.Headers {
for k, v := range lc.requestHeaders {
requestHeader.Add(k, v)
}
requestHeader.Set("User-Agent", "Crowdsec "+cwversion.VersionStr())
lc.Logger.Infof("Connecting to %s", u)
conn, _, err := dialer.Dial(u, requestHeader)
@ -293,16 +294,6 @@ func (lc *LokiClient) QueryRange(ctx context.Context, infinite bool) chan *LokiQ
lc.Logger.Debugf("Since: %s (%s)", lc.config.Since, time.Now().Add(-lc.config.Since))
requestHeader := http.Header{}
for k, v := range lc.config.Headers {
requestHeader.Add(k, v)
}
if lc.config.Username != "" || lc.config.Password != "" {
requestHeader.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(lc.config.Username+":"+lc.config.Password)))
}
requestHeader.Set("User-Agent", "Crowdsec "+cwversion.VersionStr())
lc.Logger.Infof("Connecting to %s", url)
lc.t.Go(func() error {
return lc.queryRange(url, ctx, c, infinite)
@ -310,6 +301,26 @@ func (lc *LokiClient) QueryRange(ctx context.Context, infinite bool) chan *LokiQ
return c
}
func NewLokiClient(config Config) *LokiClient {
return &LokiClient{Logger: log.WithField("component", "lokiclient"), config: config}
// Create a wrapper for http.Get to be able to set headers and auth
func (lc *LokiClient) Get(url string) (*http.Response, error) {
request, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
for k, v := range lc.requestHeaders {
request.Header.Add(k, v)
}
return http.DefaultClient.Do(request)
}
func NewLokiClient(config Config) *LokiClient {
headers := make(map[string]string)
for k, v := range config.Headers {
headers[k] = v
}
if config.Username != "" || config.Password != "" {
headers["Authorization"] = "Basic " + base64.StdEncoding.EncodeToString([]byte(config.Username+":"+config.Password))
}
headers["User-Agent"] = "Crowdsec " + cwversion.VersionStr()
return &LokiClient{Logger: log.WithField("component", "lokiclient"), config: config, requestHeaders: headers}
}

View file

@ -276,10 +276,17 @@ func feedLoki(logger *log.Entry, n int, title string) error {
if err != nil {
return err
}
resp, err := http.Post("http://127.0.0.1:3100/loki/api/v1/push", "application/json", bytes.NewBuffer(buff))
req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:3100/loki/api/v1/push", bytes.NewBuffer(buff))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Scope-OrgID", "1234")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
b, _ := io.ReadAll(resp.Body)
logger.Error(string(b))
@ -306,6 +313,8 @@ mode: cat
source: loki
url: http://127.0.0.1:3100
query: '{server="demo",key="%s"}'
headers:
x-scope-orgid: "1234"
since: 1h
`, title),
},
@ -362,26 +371,26 @@ func TestStreamingAcquisition(t *testing.T) {
}{
{
name: "Bad port",
config: `
mode: tail
config: `mode: tail
source: loki
url: http://127.0.0.1:3101
url: "http://127.0.0.1:3101"
headers:
x-scope-orgid: "1234"
query: >
{server="demo"}
`, // No Loki server here
{server="demo"}`, // No Loki server here
expectedErr: "",
streamErr: `loki is not ready: context deadline exceeded`,
expectedLines: 0,
},
{
name: "ok",
config: `
mode: tail
config: `mode: tail
source: loki
url: http://127.0.0.1:3100
url: "http://127.0.0.1:3100"
headers:
x-scope-orgid: "1234"
query: >
{server="demo"}
`,
{server="demo"}`,
expectedErr: "",
streamErr: "",
expectedLines: 20,
@ -456,6 +465,8 @@ func TestStopStreaming(t *testing.T) {
mode: tail
source: loki
url: http://127.0.0.1:3100
headers:
x-scope-orgid: "1234"
query: >
{server="demo"}
`

View file

@ -633,6 +633,13 @@ func (a *apic) PullTop(forcePull bool) error {
}
}
log.Debug("Acquiring lock for pullCAPI")
err = a.dbClient.AcquirePullCAPILock()
if a.dbClient.IsLocked(err) {
log.Info("PullCAPI is already running, skipping")
return nil
}
log.Infof("Starting community-blocklist update")
data, _, err := a.apiClient.Decisions.GetStreamV3(context.Background(), apiclient.DecisionsStreamOpts{Startup: a.startup})
@ -684,6 +691,11 @@ func (a *apic) PullTop(forcePull bool) error {
return fmt.Errorf("while updating blocklists: %w", err)
}
log.Debug("Releasing lock for pullCAPI")
if err := a.dbClient.ReleasePullCAPILock(); err != nil {
return fmt.Errorf("while releasing lock: %w", err)
}
return nil
}

View file

@ -26,15 +26,15 @@ func TestAPICSendMetrics(t *testing.T) {
}{
{
name: "basic",
duration: time.Millisecond * 60,
metricsInterval: time.Millisecond * 10,
duration: time.Millisecond * 120,
metricsInterval: time.Millisecond * 20,
expectedCalls: 5,
setUp: func(api *apic) {},
},
{
name: "with some metrics",
duration: time.Millisecond * 60,
metricsInterval: time.Millisecond * 10,
duration: time.Millisecond * 120,
metricsInterval: time.Millisecond * 20,
expectedCalls: 5,
setUp: func(api *apic) {
api.dbClient.Ent.Machine.Delete().ExecX(context.Background())

View file

@ -66,7 +66,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer {
validCert, extractedCN, err := a.TlsAuth.ValidateCert(c)
if !validCert {
logger.Errorf("invalid client certificate: %s", err)
logger.Error(err)
return nil
}

View file

@ -4,6 +4,7 @@ import (
"bytes"
"crypto"
"crypto/x509"
"encoding/pem"
"fmt"
"io"
"net/http"
@ -19,14 +20,13 @@ import (
type TLSAuth struct {
AllowedOUs []string
CrlPath string
revokationCache map[string]cacheEntry
revocationCache map[string]cacheEntry
cacheExpiration time.Duration
logger *log.Entry
}
type cacheEntry struct {
revoked bool
err error
timestamp time.Time
}
@ -89,10 +89,12 @@ func (ta *TLSAuth) isExpired(cert *x509.Certificate) bool {
return false
}
func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) {
if cert.OCSPServer == nil || (cert.OCSPServer != nil && len(cert.OCSPServer) == 0) {
// isOCSPRevoked checks if the client certificate is revoked by any of the OCSP servers present in the certificate.
// It returns a boolean indicating if the certificate is revoked and a boolean indicating if the OCSP check was successful and could be cached.
func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificate) (bool, bool) {
if cert.OCSPServer == nil || len(cert.OCSPServer) == 0 {
ta.logger.Infof("TLSAuth: no OCSP Server present in client certificate, skipping OCSP verification")
return false, nil
return false, true
}
for _, server := range cert.OCSPServer {
@ -104,9 +106,10 @@ func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificat
switch ocspResponse.Status {
case ocsp.Good:
return false, nil
return false, true
case ocsp.Revoked:
return true, fmt.Errorf("client certificate is revoked by server %s", server)
ta.logger.Errorf("TLSAuth: client certificate is revoked by server %s", server)
return true, true
case ocsp.Unknown:
log.Debugf("unknow OCSP status for server %s", server)
continue
@ -115,83 +118,82 @@ func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificat
log.Infof("Could not get any valid OCSP response, assuming the cert is revoked")
return true, nil
return true, false
}
func (ta *TLSAuth) isCRLRevoked(cert *x509.Certificate) (bool, error) {
// isCRLRevoked checks if the client certificate is revoked by the CRL present in the CrlPath.
// It returns a boolean indicating if the certificate is revoked and a boolean indicating if the CRL check was successful and could be cached.
func (ta *TLSAuth) isCRLRevoked(cert *x509.Certificate) (bool, bool) {
if ta.CrlPath == "" {
ta.logger.Warn("no crl_path, skipping CRL check")
return false, nil
ta.logger.Info("no crl_path, skipping CRL check")
return false, true
}
crlContent, err := os.ReadFile(ta.CrlPath)
if err != nil {
ta.logger.Warnf("could not read CRL file, skipping check: %s", err)
return false, nil
ta.logger.Errorf("could not read CRL file, skipping check: %s", err)
return false, false
}
crl, err := x509.ParseCRL(crlContent)
crlBinary, rest := pem.Decode(crlContent)
if len(rest) > 0 {
ta.logger.Warn("CRL file contains more than one PEM block, ignoring the rest")
}
crl, err := x509.ParseRevocationList(crlBinary.Bytes)
if err != nil {
ta.logger.Warnf("could not parse CRL file, skipping check: %s", err)
return false, nil
ta.logger.Errorf("could not parse CRL file, skipping check: %s", err)
return false, false
}
if crl.HasExpired(time.Now().UTC()) {
now := time.Now().UTC()
if now.After(crl.NextUpdate) {
ta.logger.Warn("CRL has expired, will still validate the cert against it.")
}
for _, revoked := range crl.TBSCertList.RevokedCertificates {
if now.Before(crl.ThisUpdate) {
ta.logger.Warn("CRL is not yet valid, will still validate the cert against it.")
}
for _, revoked := range crl.RevokedCertificateEntries {
if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 {
return true, fmt.Errorf("client certificate is revoked by CRL")
ta.logger.Warn("client certificate is revoked by CRL")
return true, true
}
}
return false, nil
return false, true
}
func (ta *TLSAuth) isRevoked(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) {
sn := cert.SerialNumber.String()
if cacheValue, ok := ta.revokationCache[sn]; ok {
if cacheValue, ok := ta.revocationCache[sn]; ok {
if time.Now().UTC().Sub(cacheValue.timestamp) < ta.cacheExpiration {
ta.logger.Debugf("TLSAuth: using cached value for cert %s: %t | %s", sn, cacheValue.revoked, cacheValue.err)
return cacheValue.revoked, cacheValue.err
} else {
ta.logger.Debugf("TLSAuth: cached value expired, removing from cache")
delete(ta.revokationCache, sn)
ta.logger.Debugf("TLSAuth: using cached value for cert %s: %t", sn, cacheValue.revoked)
return cacheValue.revoked, nil
}
ta.logger.Debugf("TLSAuth: cached value expired, removing from cache")
delete(ta.revocationCache, sn)
} else {
ta.logger.Tracef("TLSAuth: no cached value for cert %s", sn)
}
revoked, err := ta.isOCSPRevoked(cert, issuer)
if err != nil {
ta.revokationCache[sn] = cacheEntry{
revokedByOCSP, cacheOCSP := ta.isOCSPRevoked(cert, issuer)
revokedByCRL, cacheCRL := ta.isCRLRevoked(cert)
revoked := revokedByOCSP || revokedByCRL
if cacheOCSP && cacheCRL {
ta.revocationCache[sn] = cacheEntry{
revoked: revoked,
err: err,
timestamp: time.Now().UTC(),
}
return true, err
}
if revoked {
ta.revokationCache[sn] = cacheEntry{
revoked: revoked,
err: err,
timestamp: time.Now().UTC(),
}
return true, nil
}
revoked, err = ta.isCRLRevoked(cert)
ta.revokationCache[sn] = cacheEntry{
revoked: revoked,
err: err,
timestamp: time.Now().UTC(),
}
return revoked, err
return revoked, nil
}
func (ta *TLSAuth) isInvalid(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) {
@ -265,11 +267,11 @@ func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) {
revoked, err := ta.isInvalid(clientCert, c.Request.TLS.VerifiedChains[0][1])
if err != nil {
ta.logger.Errorf("TLSAuth: error checking if client certificate is revoked: %s", err)
return false, "", fmt.Errorf("could not check for client certification revokation status: %w", err)
return false, "", fmt.Errorf("could not check for client certification revocation status: %w", err)
}
if revoked {
return false, "", fmt.Errorf("client certificate is revoked")
return false, "", fmt.Errorf("client certificate for CN=%s OU=%s is revoked", clientCert.Subject.CommonName, clientCert.Subject.OrganizationalUnit)
}
ta.logger.Debugf("client OU %v is allowed vs required OU %v", clientCert.Subject.OrganizationalUnit, ta.AllowedOUs)
@ -282,7 +284,7 @@ func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) {
func NewTLSAuth(allowedOus []string, crlPath string, cacheExpiration time.Duration, logger *log.Entry) (*TLSAuth, error) {
ta := &TLSAuth{
revokationCache: map[string]cacheEntry{},
revocationCache: map[string]cacheEntry{},
cacheExpiration: cacheExpiration,
CrlPath: crlPath,
logger: logger,

View file

@ -2,6 +2,7 @@ package appsec
import (
"fmt"
"net/http"
"os"
"regexp"
@ -30,6 +31,12 @@ const (
hookOnMatch
)
const (
BanRemediation = "ban"
CaptchaRemediation = "captcha"
AllowRemediation = "allow"
)
func (h *Hook) Build(hookStage int) error {
ctx := map[string]interface{}{}
@ -62,12 +69,13 @@ func (h *Hook) Build(hookStage int) error {
}
type AppsecTempResponse struct {
InBandInterrupt bool
OutOfBandInterrupt bool
Action string //allow, deny, captcha, log
HTTPResponseCode int
SendEvent bool //do we send an internal event on rule match
SendAlert bool //do we send an alert on rule match
InBandInterrupt bool
OutOfBandInterrupt bool
Action string //allow, deny, captcha, log
UserHTTPResponseCode int //The response code to send to the user
BouncerHTTPResponseCode int //The response code to send to the remediation component
SendEvent bool //do we send an internal event on rule match
SendAlert bool //do we send an alert on rule match
}
type AppsecSubEngineOpts struct {
@ -110,31 +118,33 @@ type AppsecRuntimeConfig struct {
}
type AppsecConfig struct {
Name string `yaml:"name"`
OutOfBandRules []string `yaml:"outofband_rules"`
InBandRules []string `yaml:"inband_rules"`
DefaultRemediation string `yaml:"default_remediation"`
DefaultPassAction string `yaml:"default_pass_action"`
BlockedHTTPCode int `yaml:"blocked_http_code"`
PassedHTTPCode int `yaml:"passed_http_code"`
OnLoad []Hook `yaml:"on_load"`
PreEval []Hook `yaml:"pre_eval"`
PostEval []Hook `yaml:"post_eval"`
OnMatch []Hook `yaml:"on_match"`
VariablesTracking []string `yaml:"variables_tracking"`
InbandOptions AppsecSubEngineOpts `yaml:"inband_options"`
OutOfBandOptions AppsecSubEngineOpts `yaml:"outofband_options"`
Name string `yaml:"name"`
OutOfBandRules []string `yaml:"outofband_rules"`
InBandRules []string `yaml:"inband_rules"`
DefaultRemediation string `yaml:"default_remediation"`
DefaultPassAction string `yaml:"default_pass_action"`
BouncerBlockedHTTPCode int `yaml:"blocked_http_code"` //returned to the bouncer
BouncerPassedHTTPCode int `yaml:"passed_http_code"` //returned to the bouncer
UserBlockedHTTPCode int `yaml:"user_blocked_http_code"` //returned to the user
UserPassedHTTPCode int `yaml:"user_passed_http_code"` //returned to the user
OnLoad []Hook `yaml:"on_load"`
PreEval []Hook `yaml:"pre_eval"`
PostEval []Hook `yaml:"post_eval"`
OnMatch []Hook `yaml:"on_match"`
VariablesTracking []string `yaml:"variables_tracking"`
InbandOptions AppsecSubEngineOpts `yaml:"inband_options"`
OutOfBandOptions AppsecSubEngineOpts `yaml:"outofband_options"`
LogLevel *log.Level `yaml:"log_level"`
Logger *log.Entry `yaml:"-"`
}
func (w *AppsecRuntimeConfig) ClearResponse() {
w.Logger.Debugf("#-> %p", w)
w.Response = AppsecTempResponse{}
w.Logger.Debugf("-> %p", w.Config)
w.Response.Action = w.Config.DefaultPassAction
w.Response.HTTPResponseCode = w.Config.PassedHTTPCode
w.Response.BouncerHTTPResponseCode = w.Config.BouncerPassedHTTPCode
w.Response.UserHTTPResponseCode = w.Config.UserPassedHTTPCode
w.Response.SendEvent = true
w.Response.SendAlert = true
}
@ -191,24 +201,35 @@ func (wc *AppsecConfig) GetDataDir() string {
func (wc *AppsecConfig) Build() (*AppsecRuntimeConfig, error) {
ret := &AppsecRuntimeConfig{Logger: wc.Logger.WithField("component", "appsec_runtime_config")}
//set the defaults
switch wc.DefaultRemediation {
case "":
wc.DefaultRemediation = "ban"
case "ban", "captcha", "log":
//those are the officially supported remediation(s)
default:
wc.Logger.Warningf("default '%s' remediation of %s is none of [ban,captcha,log] ensure bouncer compatbility!", wc.DefaultRemediation, wc.Name)
if wc.BouncerBlockedHTTPCode == 0 {
wc.BouncerBlockedHTTPCode = http.StatusForbidden
}
if wc.BlockedHTTPCode == 0 {
wc.BlockedHTTPCode = 403
if wc.BouncerPassedHTTPCode == 0 {
wc.BouncerPassedHTTPCode = http.StatusOK
}
if wc.PassedHTTPCode == 0 {
wc.PassedHTTPCode = 200
if wc.UserBlockedHTTPCode == 0 {
wc.UserBlockedHTTPCode = http.StatusForbidden
}
if wc.UserPassedHTTPCode == 0 {
wc.UserPassedHTTPCode = http.StatusOK
}
if wc.DefaultPassAction == "" {
wc.DefaultPassAction = "allow"
wc.DefaultPassAction = AllowRemediation
}
if wc.DefaultRemediation == "" {
wc.DefaultRemediation = BanRemediation
}
//set the defaults
switch wc.DefaultRemediation {
case BanRemediation, CaptchaRemediation, AllowRemediation:
//those are the officially supported remediation(s)
default:
wc.Logger.Warningf("default '%s' remediation of %s is none of [%s,%s,%s] ensure bouncer compatbility!", wc.DefaultRemediation, wc.Name, BanRemediation, CaptchaRemediation, AllowRemediation)
}
ret.Name = wc.Name
ret.Config = wc
ret.DefaultRemediation = wc.DefaultRemediation
@ -553,27 +574,13 @@ func (w *AppsecRuntimeConfig) SetActionByName(name string, action string) error
func (w *AppsecRuntimeConfig) SetAction(action string) error {
//log.Infof("setting to %s", action)
w.Logger.Debugf("setting action to %s", action)
switch action {
case "allow":
w.Response.Action = action
w.Response.HTTPResponseCode = w.Config.PassedHTTPCode
//@tko how should we handle this ? it seems bouncer only understand bans, but it might be misleading ?
case "deny", "ban", "block":
w.Response.Action = "ban"
case "log":
w.Response.Action = action
w.Response.HTTPResponseCode = w.Config.PassedHTTPCode
case "captcha":
w.Response.Action = action
default:
w.Response.Action = action
}
w.Response.Action = action
return nil
}
func (w *AppsecRuntimeConfig) SetHTTPCode(code int) error {
w.Logger.Debugf("setting http code to %d", code)
w.Response.HTTPResponseCode = code
w.Response.UserHTTPResponseCode = code
return nil
}
@ -582,24 +589,23 @@ type BodyResponse struct {
HTTPStatus int `json:"http_status"`
}
func (w *AppsecRuntimeConfig) GenerateResponse(response AppsecTempResponse, logger *log.Entry) BodyResponse {
resp := BodyResponse{}
//if there is no interrupt, we should allow with default code
if !response.InBandInterrupt {
resp.Action = w.Config.DefaultPassAction
resp.HTTPStatus = w.Config.PassedHTTPCode
return resp
}
resp.Action = response.Action
if resp.Action == "" {
resp.Action = w.Config.DefaultRemediation
}
logger.Debugf("action is %s", resp.Action)
func (w *AppsecRuntimeConfig) GenerateResponse(response AppsecTempResponse, logger *log.Entry) (int, BodyResponse) {
var bouncerStatusCode int
resp.HTTPStatus = response.HTTPResponseCode
if resp.HTTPStatus == 0 {
resp.HTTPStatus = w.Config.BlockedHTTPCode
resp := BodyResponse{Action: response.Action}
if response.Action == AllowRemediation {
resp.HTTPStatus = w.Config.UserPassedHTTPCode
bouncerStatusCode = w.Config.BouncerPassedHTTPCode
} else { //ban, captcha and anything else
resp.HTTPStatus = response.UserHTTPResponseCode
if resp.HTTPStatus == 0 {
resp.HTTPStatus = w.Config.UserBlockedHTTPCode
}
bouncerStatusCode = response.BouncerHTTPResponseCode
if bouncerStatusCode == 0 {
bouncerStatusCode = w.Config.BouncerBlockedHTTPCode
}
}
logger.Debugf("http status is %d", resp.HTTPStatus)
return resp
return bouncerStatusCode, resp
}

View file

@ -178,6 +178,7 @@ func (l *LocalApiClientCfg) Load() error {
func (lapiCfg *LocalApiServerCfg) GetTrustedIPs() ([]net.IPNet, error) {
trustedIPs := make([]net.IPNet, 0)
for _, ip := range lapiCfg.TrustedIPs {
cidr := toValidCIDR(ip)
@ -265,7 +266,7 @@ func (c *Config) LoadAPIServer(inCli bool) error {
return fmt.Errorf("no listen_uri specified")
}
//inherit log level from common, then api->server
// inherit log level from common, then api->server
var logLevel log.Level
if c.API.Server.LogLevel != nil {
logLevel = *c.API.Server.LogLevel

View file

@ -25,7 +25,7 @@ var globalConfig = Config{}
// Config contains top-level defaults -> overridden by configuration file -> overridden by CLI flags
type Config struct {
//just a path to ourselves :p
// just a path to ourselves :p
FilePath *string `yaml:"-"`
Self []byte `yaml:"-"`
Common *CommonCfg `yaml:"common,omitempty"`
@ -44,10 +44,12 @@ type Config struct {
func NewConfig(configFile string, disableAgent bool, disableAPI bool, inCli bool) (*Config, string, error) {
patcher := yamlpatch.NewPatcher(configFile, ".local")
patcher.SetQuiet(inCli)
fcontent, err := patcher.MergedPatchContent()
if err != nil {
return nil, "", err
}
configData := csstring.StrictExpand(string(fcontent), os.LookupEnv)
cfg := Config{
FilePath: &configFile,

View file

@ -7,6 +7,7 @@ import (
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/alert"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/machine"
@ -67,6 +68,7 @@ type Alert struct {
// The values are being populated by the AlertQuery when eager-loading is set.
Edges AlertEdges `json:"edges"`
machine_alerts *int
selectValues sql.SelectValues
}
// AlertEdges holds the relations/edges for other nodes in the graph.
@ -142,7 +144,7 @@ func (*Alert) scanValues(columns []string) ([]any, error) {
case alert.ForeignKeys[0]: // machine_alerts
values[i] = new(sql.NullInt64)
default:
return nil, fmt.Errorf("unexpected column %q for type Alert", columns[i])
values[i] = new(sql.UnknownType)
}
}
return values, nil
@ -309,36 +311,44 @@ func (a *Alert) assignValues(columns []string, values []any) error {
a.machine_alerts = new(int)
*a.machine_alerts = int(value.Int64)
}
default:
a.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the Alert.
// This includes values selected through modifiers, order, etc.
func (a *Alert) Value(name string) (ent.Value, error) {
return a.selectValues.Get(name)
}
// QueryOwner queries the "owner" edge of the Alert entity.
func (a *Alert) QueryOwner() *MachineQuery {
return (&AlertClient{config: a.config}).QueryOwner(a)
return NewAlertClient(a.config).QueryOwner(a)
}
// QueryDecisions queries the "decisions" edge of the Alert entity.
func (a *Alert) QueryDecisions() *DecisionQuery {
return (&AlertClient{config: a.config}).QueryDecisions(a)
return NewAlertClient(a.config).QueryDecisions(a)
}
// QueryEvents queries the "events" edge of the Alert entity.
func (a *Alert) QueryEvents() *EventQuery {
return (&AlertClient{config: a.config}).QueryEvents(a)
return NewAlertClient(a.config).QueryEvents(a)
}
// QueryMetas queries the "metas" edge of the Alert entity.
func (a *Alert) QueryMetas() *MetaQuery {
return (&AlertClient{config: a.config}).QueryMetas(a)
return NewAlertClient(a.config).QueryMetas(a)
}
// Update returns a builder for updating this Alert.
// Note that you need to call Alert.Unwrap() before calling this method if this Alert
// was returned from a transaction, and the transaction was committed or rolled back.
func (a *Alert) Update() *AlertUpdateOne {
return (&AlertClient{config: a.config}).UpdateOne(a)
return NewAlertClient(a.config).UpdateOne(a)
}
// Unwrap unwraps the Alert entity that was returned from a transaction after it was closed,
@ -435,9 +445,3 @@ func (a *Alert) String() string {
// Alerts is a parsable slice of Alert.
type Alerts []*Alert
func (a Alerts) config(cfg config) {
for _i := range a {
a[_i].config = cfg
}
}

View file

@ -4,6 +4,9 @@ package alert
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
@ -168,3 +171,203 @@ var (
// DefaultSimulated holds the default value on creation for the "simulated" field.
DefaultSimulated bool
)
// OrderOption defines the ordering options for the Alert queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByScenario orders the results by the scenario field.
func ByScenario(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldScenario, opts...).ToFunc()
}
// ByBucketId orders the results by the bucketId field.
func ByBucketId(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBucketId, opts...).ToFunc()
}
// ByMessage orders the results by the message field.
func ByMessage(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMessage, opts...).ToFunc()
}
// ByEventsCountField orders the results by the eventsCount field.
func ByEventsCountField(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldEventsCount, opts...).ToFunc()
}
// ByStartedAt orders the results by the startedAt field.
func ByStartedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStartedAt, opts...).ToFunc()
}
// ByStoppedAt orders the results by the stoppedAt field.
func ByStoppedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStoppedAt, opts...).ToFunc()
}
// BySourceIp orders the results by the sourceIp field.
func BySourceIp(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSourceIp, opts...).ToFunc()
}
// BySourceRange orders the results by the sourceRange field.
func BySourceRange(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSourceRange, opts...).ToFunc()
}
// BySourceAsNumber orders the results by the sourceAsNumber field.
func BySourceAsNumber(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSourceAsNumber, opts...).ToFunc()
}
// BySourceAsName orders the results by the sourceAsName field.
func BySourceAsName(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSourceAsName, opts...).ToFunc()
}
// BySourceCountry orders the results by the sourceCountry field.
func BySourceCountry(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSourceCountry, opts...).ToFunc()
}
// BySourceLatitude orders the results by the sourceLatitude field.
func BySourceLatitude(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSourceLatitude, opts...).ToFunc()
}
// BySourceLongitude orders the results by the sourceLongitude field.
func BySourceLongitude(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSourceLongitude, opts...).ToFunc()
}
// BySourceScope orders the results by the sourceScope field.
func BySourceScope(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSourceScope, opts...).ToFunc()
}
// BySourceValue orders the results by the sourceValue field.
func BySourceValue(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSourceValue, opts...).ToFunc()
}
// ByCapacity orders the results by the capacity field.
func ByCapacity(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCapacity, opts...).ToFunc()
}
// ByLeakSpeed orders the results by the leakSpeed field.
func ByLeakSpeed(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldLeakSpeed, opts...).ToFunc()
}
// ByScenarioVersion orders the results by the scenarioVersion field.
func ByScenarioVersion(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldScenarioVersion, opts...).ToFunc()
}
// ByScenarioHash orders the results by the scenarioHash field.
func ByScenarioHash(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldScenarioHash, opts...).ToFunc()
}
// BySimulated orders the results by the simulated field.
func BySimulated(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSimulated, opts...).ToFunc()
}
// ByUUID orders the results by the uuid field.
func ByUUID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUUID, opts...).ToFunc()
}
// ByOwnerField orders the results by owner field.
func ByOwnerField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newOwnerStep(), sql.OrderByField(field, opts...))
}
}
// ByDecisionsCount orders the results by decisions count.
func ByDecisionsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newDecisionsStep(), opts...)
}
}
// ByDecisions orders the results by decisions terms.
func ByDecisions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newDecisionsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByEventsCount orders the results by events count.
func ByEventsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newEventsStep(), opts...)
}
}
// ByEvents orders the results by events terms.
func ByEvents(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newEventsStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
// ByMetasCount orders the results by metas count.
func ByMetasCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborsCount(s, newMetasStep(), opts...)
}
}
// ByMetas orders the results by metas terms.
func ByMetas(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newMetasStep(), append([]sql.OrderTerm{term}, terms...)...)
}
}
func newOwnerStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(OwnerInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn),
)
}
func newDecisionsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(DecisionsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, DecisionsTable, DecisionsColumn),
)
}
func newEventsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(EventsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, EventsTable, EventsColumn),
)
}
func newMetasStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(MetasInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.O2M, false, MetasTable, MetasColumn),
)
}

File diff suppressed because it is too large Load diff

View file

@ -409,50 +409,8 @@ func (ac *AlertCreate) Mutation() *AlertMutation {
// Save creates the Alert in the database.
func (ac *AlertCreate) Save(ctx context.Context) (*Alert, error) {
var (
err error
node *Alert
)
ac.defaults()
if len(ac.hooks) == 0 {
if err = ac.check(); err != nil {
return nil, err
}
node, err = ac.sqlSave(ctx)
} else {
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*AlertMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
if err = ac.check(); err != nil {
return nil, err
}
ac.mutation = mutation
if node, err = ac.sqlSave(ctx); err != nil {
return nil, err
}
mutation.id = &node.ID
mutation.done = true
return node, err
})
for i := len(ac.hooks) - 1; i >= 0; i-- {
if ac.hooks[i] == nil {
return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = ac.hooks[i](mut)
}
v, err := mut.Mutate(ctx, ac.mutation)
if err != nil {
return nil, err
}
nv, ok := v.(*Alert)
if !ok {
return nil, fmt.Errorf("unexpected node type %T returned from AlertMutation", v)
}
node = nv
}
return node, err
return withHooks(ctx, ac.sqlSave, ac.mutation, ac.hooks)
}
// SaveX calls Save and panics if Save returns an error.
@ -525,6 +483,9 @@ func (ac *AlertCreate) check() error {
}
func (ac *AlertCreate) sqlSave(ctx context.Context) (*Alert, error) {
if err := ac.check(); err != nil {
return nil, err
}
_node, _spec := ac.createSpec()
if err := sqlgraph.CreateNode(ctx, ac.driver, _spec); err != nil {
if sqlgraph.IsConstraintError(err) {
@ -534,202 +495,106 @@ func (ac *AlertCreate) sqlSave(ctx context.Context) (*Alert, error) {
}
id := _spec.ID.Value.(int64)
_node.ID = int(id)
ac.mutation.id = &_node.ID
ac.mutation.done = true
return _node, nil
}
func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) {
var (
_node = &Alert{config: ac.config}
_spec = &sqlgraph.CreateSpec{
Table: alert.Table,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: alert.FieldID,
},
}
_spec = sqlgraph.NewCreateSpec(alert.Table, sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt))
)
if value, ok := ac.mutation.CreatedAt(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: alert.FieldCreatedAt,
})
_spec.SetField(alert.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = &value
}
if value, ok := ac.mutation.UpdatedAt(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: alert.FieldUpdatedAt,
})
_spec.SetField(alert.FieldUpdatedAt, field.TypeTime, value)
_node.UpdatedAt = &value
}
if value, ok := ac.mutation.Scenario(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: alert.FieldScenario,
})
_spec.SetField(alert.FieldScenario, field.TypeString, value)
_node.Scenario = value
}
if value, ok := ac.mutation.BucketId(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: alert.FieldBucketId,
})
_spec.SetField(alert.FieldBucketId, field.TypeString, value)
_node.BucketId = value
}
if value, ok := ac.mutation.Message(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: alert.FieldMessage,
})
_spec.SetField(alert.FieldMessage, field.TypeString, value)
_node.Message = value
}
if value, ok := ac.mutation.EventsCount(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeInt32,
Value: value,
Column: alert.FieldEventsCount,
})
_spec.SetField(alert.FieldEventsCount, field.TypeInt32, value)
_node.EventsCount = value
}
if value, ok := ac.mutation.StartedAt(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: alert.FieldStartedAt,
})
_spec.SetField(alert.FieldStartedAt, field.TypeTime, value)
_node.StartedAt = value
}
if value, ok := ac.mutation.StoppedAt(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: alert.FieldStoppedAt,
})
_spec.SetField(alert.FieldStoppedAt, field.TypeTime, value)
_node.StoppedAt = value
}
if value, ok := ac.mutation.SourceIp(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: alert.FieldSourceIp,
})
_spec.SetField(alert.FieldSourceIp, field.TypeString, value)
_node.SourceIp = value
}
if value, ok := ac.mutation.SourceRange(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: alert.FieldSourceRange,
})
_spec.SetField(alert.FieldSourceRange, field.TypeString, value)
_node.SourceRange = value
}
if value, ok := ac.mutation.SourceAsNumber(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: alert.FieldSourceAsNumber,
})
_spec.SetField(alert.FieldSourceAsNumber, field.TypeString, value)
_node.SourceAsNumber = value
}
if value, ok := ac.mutation.SourceAsName(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: alert.FieldSourceAsName,
})
_spec.SetField(alert.FieldSourceAsName, field.TypeString, value)
_node.SourceAsName = value
}
if value, ok := ac.mutation.SourceCountry(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: alert.FieldSourceCountry,
})
_spec.SetField(alert.FieldSourceCountry, field.TypeString, value)
_node.SourceCountry = value
}
if value, ok := ac.mutation.SourceLatitude(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeFloat32,
Value: value,
Column: alert.FieldSourceLatitude,
})
_spec.SetField(alert.FieldSourceLatitude, field.TypeFloat32, value)
_node.SourceLatitude = value
}
if value, ok := ac.mutation.SourceLongitude(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeFloat32,
Value: value,
Column: alert.FieldSourceLongitude,
})
_spec.SetField(alert.FieldSourceLongitude, field.TypeFloat32, value)
_node.SourceLongitude = value
}
if value, ok := ac.mutation.SourceScope(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: alert.FieldSourceScope,
})
_spec.SetField(alert.FieldSourceScope, field.TypeString, value)
_node.SourceScope = value
}
if value, ok := ac.mutation.SourceValue(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: alert.FieldSourceValue,
})
_spec.SetField(alert.FieldSourceValue, field.TypeString, value)
_node.SourceValue = value
}
if value, ok := ac.mutation.Capacity(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeInt32,
Value: value,
Column: alert.FieldCapacity,
})
_spec.SetField(alert.FieldCapacity, field.TypeInt32, value)
_node.Capacity = value
}
if value, ok := ac.mutation.LeakSpeed(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: alert.FieldLeakSpeed,
})
_spec.SetField(alert.FieldLeakSpeed, field.TypeString, value)
_node.LeakSpeed = value
}
if value, ok := ac.mutation.ScenarioVersion(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: alert.FieldScenarioVersion,
})
_spec.SetField(alert.FieldScenarioVersion, field.TypeString, value)
_node.ScenarioVersion = value
}
if value, ok := ac.mutation.ScenarioHash(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: alert.FieldScenarioHash,
})
_spec.SetField(alert.FieldScenarioHash, field.TypeString, value)
_node.ScenarioHash = value
}
if value, ok := ac.mutation.Simulated(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeBool,
Value: value,
Column: alert.FieldSimulated,
})
_spec.SetField(alert.FieldSimulated, field.TypeBool, value)
_node.Simulated = value
}
if value, ok := ac.mutation.UUID(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: alert.FieldUUID,
})
_spec.SetField(alert.FieldUUID, field.TypeString, value)
_node.UUID = value
}
if nodes := ac.mutation.OwnerIDs(); len(nodes) > 0 {
@ -740,10 +605,7 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) {
Columns: []string{alert.OwnerColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: machine.FieldID,
},
IDSpec: sqlgraph.NewFieldSpec(machine.FieldID, field.TypeInt),
},
}
for _, k := range nodes {
@ -760,10 +622,7 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) {
Columns: []string{alert.DecisionsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: decision.FieldID,
},
IDSpec: sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt),
},
}
for _, k := range nodes {
@ -779,10 +638,7 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) {
Columns: []string{alert.EventsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: event.FieldID,
},
IDSpec: sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt),
},
}
for _, k := range nodes {
@ -798,10 +654,7 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) {
Columns: []string{alert.MetasColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: meta.FieldID,
},
IDSpec: sqlgraph.NewFieldSpec(meta.FieldID, field.TypeInt),
},
}
for _, k := range nodes {
@ -815,11 +668,15 @@ func (ac *AlertCreate) createSpec() (*Alert, *sqlgraph.CreateSpec) {
// AlertCreateBulk is the builder for creating many Alert entities in bulk.
type AlertCreateBulk struct {
config
err error
builders []*AlertCreate
}
// Save creates the Alert entities in the database.
func (acb *AlertCreateBulk) Save(ctx context.Context) ([]*Alert, error) {
if acb.err != nil {
return nil, acb.err
}
specs := make([]*sqlgraph.CreateSpec, len(acb.builders))
nodes := make([]*Alert, len(acb.builders))
mutators := make([]Mutator, len(acb.builders))
@ -836,8 +693,8 @@ func (acb *AlertCreateBulk) Save(ctx context.Context) ([]*Alert, error) {
return nil, err
}
builder.mutation = mutation
nodes[i], specs[i] = builder.createSpec()
var err error
nodes[i], specs[i] = builder.createSpec()
if i < len(mutators)-1 {
_, err = mutators[i+1].Mutate(root, acb.builders[i+1].mutation)
} else {

View file

@ -4,7 +4,6 @@ package ent
import (
"context"
"fmt"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
@ -28,34 +27,7 @@ func (ad *AlertDelete) Where(ps ...predicate.Alert) *AlertDelete {
// Exec executes the deletion query and returns how many vertices were deleted.
func (ad *AlertDelete) Exec(ctx context.Context) (int, error) {
var (
err error
affected int
)
if len(ad.hooks) == 0 {
affected, err = ad.sqlExec(ctx)
} else {
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*AlertMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
ad.mutation = mutation
affected, err = ad.sqlExec(ctx)
mutation.done = true
return affected, err
})
for i := len(ad.hooks) - 1; i >= 0; i-- {
if ad.hooks[i] == nil {
return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = ad.hooks[i](mut)
}
if _, err := mut.Mutate(ctx, ad.mutation); err != nil {
return 0, err
}
}
return affected, err
return withHooks(ctx, ad.sqlExec, ad.mutation, ad.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
@ -68,15 +40,7 @@ func (ad *AlertDelete) ExecX(ctx context.Context) int {
}
func (ad *AlertDelete) sqlExec(ctx context.Context) (int, error) {
_spec := &sqlgraph.DeleteSpec{
Node: &sqlgraph.NodeSpec{
Table: alert.Table,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: alert.FieldID,
},
},
}
_spec := sqlgraph.NewDeleteSpec(alert.Table, sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt))
if ps := ad.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
@ -88,6 +52,7 @@ func (ad *AlertDelete) sqlExec(ctx context.Context) (int, error) {
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
ad.mutation.done = true
return affected, err
}
@ -96,6 +61,12 @@ type AlertDeleteOne struct {
ad *AlertDelete
}
// Where appends a list predicates to the AlertDelete builder.
func (ado *AlertDeleteOne) Where(ps ...predicate.Alert) *AlertDeleteOne {
ado.ad.mutation.Where(ps...)
return ado
}
// Exec executes the deletion query.
func (ado *AlertDeleteOne) Exec(ctx context.Context) error {
n, err := ado.ad.Exec(ctx)
@ -111,5 +82,7 @@ func (ado *AlertDeleteOne) Exec(ctx context.Context) error {
// ExecX is like Exec, but panics if an error occurs.
func (ado *AlertDeleteOne) ExecX(ctx context.Context) {
ado.ad.ExecX(ctx)
if err := ado.Exec(ctx); err != nil {
panic(err)
}
}

View file

@ -22,11 +22,9 @@ import (
// AlertQuery is the builder for querying Alert entities.
type AlertQuery struct {
config
limit *int
offset *int
unique *bool
order []OrderFunc
fields []string
ctx *QueryContext
order []alert.OrderOption
inters []Interceptor
predicates []predicate.Alert
withOwner *MachineQuery
withDecisions *DecisionQuery
@ -44,34 +42,34 @@ func (aq *AlertQuery) Where(ps ...predicate.Alert) *AlertQuery {
return aq
}
// Limit adds a limit step to the query.
// Limit the number of records to be returned by this query.
func (aq *AlertQuery) Limit(limit int) *AlertQuery {
aq.limit = &limit
aq.ctx.Limit = &limit
return aq
}
// Offset adds an offset step to the query.
// Offset to start from.
func (aq *AlertQuery) Offset(offset int) *AlertQuery {
aq.offset = &offset
aq.ctx.Offset = &offset
return aq
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (aq *AlertQuery) Unique(unique bool) *AlertQuery {
aq.unique = &unique
aq.ctx.Unique = &unique
return aq
}
// Order adds an order step to the query.
func (aq *AlertQuery) Order(o ...OrderFunc) *AlertQuery {
// Order specifies how the records should be ordered.
func (aq *AlertQuery) Order(o ...alert.OrderOption) *AlertQuery {
aq.order = append(aq.order, o...)
return aq
}
// QueryOwner chains the current query on the "owner" edge.
func (aq *AlertQuery) QueryOwner() *MachineQuery {
query := &MachineQuery{config: aq.config}
query := (&MachineClient{config: aq.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := aq.prepareQuery(ctx); err != nil {
return nil, err
@ -93,7 +91,7 @@ func (aq *AlertQuery) QueryOwner() *MachineQuery {
// QueryDecisions chains the current query on the "decisions" edge.
func (aq *AlertQuery) QueryDecisions() *DecisionQuery {
query := &DecisionQuery{config: aq.config}
query := (&DecisionClient{config: aq.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := aq.prepareQuery(ctx); err != nil {
return nil, err
@ -115,7 +113,7 @@ func (aq *AlertQuery) QueryDecisions() *DecisionQuery {
// QueryEvents chains the current query on the "events" edge.
func (aq *AlertQuery) QueryEvents() *EventQuery {
query := &EventQuery{config: aq.config}
query := (&EventClient{config: aq.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := aq.prepareQuery(ctx); err != nil {
return nil, err
@ -137,7 +135,7 @@ func (aq *AlertQuery) QueryEvents() *EventQuery {
// QueryMetas chains the current query on the "metas" edge.
func (aq *AlertQuery) QueryMetas() *MetaQuery {
query := &MetaQuery{config: aq.config}
query := (&MetaClient{config: aq.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := aq.prepareQuery(ctx); err != nil {
return nil, err
@ -160,7 +158,7 @@ func (aq *AlertQuery) QueryMetas() *MetaQuery {
// First returns the first Alert entity from the query.
// Returns a *NotFoundError when no Alert was found.
func (aq *AlertQuery) First(ctx context.Context) (*Alert, error) {
nodes, err := aq.Limit(1).All(ctx)
nodes, err := aq.Limit(1).All(setContextOp(ctx, aq.ctx, "First"))
if err != nil {
return nil, err
}
@ -183,7 +181,7 @@ func (aq *AlertQuery) FirstX(ctx context.Context) *Alert {
// Returns a *NotFoundError when no Alert ID was found.
func (aq *AlertQuery) FirstID(ctx context.Context) (id int, err error) {
var ids []int
if ids, err = aq.Limit(1).IDs(ctx); err != nil {
if ids, err = aq.Limit(1).IDs(setContextOp(ctx, aq.ctx, "FirstID")); err != nil {
return
}
if len(ids) == 0 {
@ -206,7 +204,7 @@ func (aq *AlertQuery) FirstIDX(ctx context.Context) int {
// Returns a *NotSingularError when more than one Alert entity is found.
// Returns a *NotFoundError when no Alert entities are found.
func (aq *AlertQuery) Only(ctx context.Context) (*Alert, error) {
nodes, err := aq.Limit(2).All(ctx)
nodes, err := aq.Limit(2).All(setContextOp(ctx, aq.ctx, "Only"))
if err != nil {
return nil, err
}
@ -234,7 +232,7 @@ func (aq *AlertQuery) OnlyX(ctx context.Context) *Alert {
// Returns a *NotFoundError when no entities are found.
func (aq *AlertQuery) OnlyID(ctx context.Context) (id int, err error) {
var ids []int
if ids, err = aq.Limit(2).IDs(ctx); err != nil {
if ids, err = aq.Limit(2).IDs(setContextOp(ctx, aq.ctx, "OnlyID")); err != nil {
return
}
switch len(ids) {
@ -259,10 +257,12 @@ func (aq *AlertQuery) OnlyIDX(ctx context.Context) int {
// All executes the query and returns a list of Alerts.
func (aq *AlertQuery) All(ctx context.Context) ([]*Alert, error) {
ctx = setContextOp(ctx, aq.ctx, "All")
if err := aq.prepareQuery(ctx); err != nil {
return nil, err
}
return aq.sqlAll(ctx)
qr := querierAll[[]*Alert, *AlertQuery]()
return withInterceptors[[]*Alert](ctx, aq, qr, aq.inters)
}
// AllX is like All, but panics if an error occurs.
@ -275,9 +275,12 @@ func (aq *AlertQuery) AllX(ctx context.Context) []*Alert {
}
// IDs executes the query and returns a list of Alert IDs.
func (aq *AlertQuery) IDs(ctx context.Context) ([]int, error) {
var ids []int
if err := aq.Select(alert.FieldID).Scan(ctx, &ids); err != nil {
func (aq *AlertQuery) IDs(ctx context.Context) (ids []int, err error) {
if aq.ctx.Unique == nil && aq.path != nil {
aq.Unique(true)
}
ctx = setContextOp(ctx, aq.ctx, "IDs")
if err = aq.Select(alert.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
@ -294,10 +297,11 @@ func (aq *AlertQuery) IDsX(ctx context.Context) []int {
// Count returns the count of the given query.
func (aq *AlertQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, aq.ctx, "Count")
if err := aq.prepareQuery(ctx); err != nil {
return 0, err
}
return aq.sqlCount(ctx)
return withInterceptors[int](ctx, aq, querierCount[*AlertQuery](), aq.inters)
}
// CountX is like Count, but panics if an error occurs.
@ -311,10 +315,15 @@ func (aq *AlertQuery) CountX(ctx context.Context) int {
// Exist returns true if the query has elements in the graph.
func (aq *AlertQuery) Exist(ctx context.Context) (bool, error) {
if err := aq.prepareQuery(ctx); err != nil {
return false, err
ctx = setContextOp(ctx, aq.ctx, "Exist")
switch _, err := aq.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
return aq.sqlExist(ctx)
}
// ExistX is like Exist, but panics if an error occurs.
@ -334,25 +343,24 @@ func (aq *AlertQuery) Clone() *AlertQuery {
}
return &AlertQuery{
config: aq.config,
limit: aq.limit,
offset: aq.offset,
order: append([]OrderFunc{}, aq.order...),
ctx: aq.ctx.Clone(),
order: append([]alert.OrderOption{}, aq.order...),
inters: append([]Interceptor{}, aq.inters...),
predicates: append([]predicate.Alert{}, aq.predicates...),
withOwner: aq.withOwner.Clone(),
withDecisions: aq.withDecisions.Clone(),
withEvents: aq.withEvents.Clone(),
withMetas: aq.withMetas.Clone(),
// clone intermediate query.
sql: aq.sql.Clone(),
path: aq.path,
unique: aq.unique,
sql: aq.sql.Clone(),
path: aq.path,
}
}
// WithOwner tells the query-builder to eager-load the nodes that are connected to
// the "owner" edge. The optional arguments are used to configure the query builder of the edge.
func (aq *AlertQuery) WithOwner(opts ...func(*MachineQuery)) *AlertQuery {
query := &MachineQuery{config: aq.config}
query := (&MachineClient{config: aq.config}).Query()
for _, opt := range opts {
opt(query)
}
@ -363,7 +371,7 @@ func (aq *AlertQuery) WithOwner(opts ...func(*MachineQuery)) *AlertQuery {
// WithDecisions tells the query-builder to eager-load the nodes that are connected to
// the "decisions" edge. The optional arguments are used to configure the query builder of the edge.
func (aq *AlertQuery) WithDecisions(opts ...func(*DecisionQuery)) *AlertQuery {
query := &DecisionQuery{config: aq.config}
query := (&DecisionClient{config: aq.config}).Query()
for _, opt := range opts {
opt(query)
}
@ -374,7 +382,7 @@ func (aq *AlertQuery) WithDecisions(opts ...func(*DecisionQuery)) *AlertQuery {
// WithEvents tells the query-builder to eager-load the nodes that are connected to
// the "events" edge. The optional arguments are used to configure the query builder of the edge.
func (aq *AlertQuery) WithEvents(opts ...func(*EventQuery)) *AlertQuery {
query := &EventQuery{config: aq.config}
query := (&EventClient{config: aq.config}).Query()
for _, opt := range opts {
opt(query)
}
@ -385,7 +393,7 @@ func (aq *AlertQuery) WithEvents(opts ...func(*EventQuery)) *AlertQuery {
// WithMetas tells the query-builder to eager-load the nodes that are connected to
// the "metas" edge. The optional arguments are used to configure the query builder of the edge.
func (aq *AlertQuery) WithMetas(opts ...func(*MetaQuery)) *AlertQuery {
query := &MetaQuery{config: aq.config}
query := (&MetaClient{config: aq.config}).Query()
for _, opt := range opts {
opt(query)
}
@ -408,16 +416,11 @@ func (aq *AlertQuery) WithMetas(opts ...func(*MetaQuery)) *AlertQuery {
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (aq *AlertQuery) GroupBy(field string, fields ...string) *AlertGroupBy {
grbuild := &AlertGroupBy{config: aq.config}
grbuild.fields = append([]string{field}, fields...)
grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) {
if err := aq.prepareQuery(ctx); err != nil {
return nil, err
}
return aq.sqlQuery(ctx), nil
}
aq.ctx.Fields = append([]string{field}, fields...)
grbuild := &AlertGroupBy{build: aq}
grbuild.flds = &aq.ctx.Fields
grbuild.label = alert.Label
grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan
grbuild.scan = grbuild.Scan
return grbuild
}
@ -434,15 +437,30 @@ func (aq *AlertQuery) GroupBy(field string, fields ...string) *AlertGroupBy {
// Select(alert.FieldCreatedAt).
// Scan(ctx, &v)
func (aq *AlertQuery) Select(fields ...string) *AlertSelect {
aq.fields = append(aq.fields, fields...)
selbuild := &AlertSelect{AlertQuery: aq}
selbuild.label = alert.Label
selbuild.flds, selbuild.scan = &aq.fields, selbuild.Scan
return selbuild
aq.ctx.Fields = append(aq.ctx.Fields, fields...)
sbuild := &AlertSelect{AlertQuery: aq}
sbuild.label = alert.Label
sbuild.flds, sbuild.scan = &aq.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a AlertSelect configured with the given aggregations.
func (aq *AlertQuery) Aggregate(fns ...AggregateFunc) *AlertSelect {
return aq.Select().Aggregate(fns...)
}
func (aq *AlertQuery) prepareQuery(ctx context.Context) error {
for _, f := range aq.fields {
for _, inter := range aq.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, aq); err != nil {
return err
}
}
}
for _, f := range aq.ctx.Fields {
if !alert.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
@ -536,6 +554,9 @@ func (aq *AlertQuery) loadOwner(ctx context.Context, query *MachineQuery, nodes
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(machine.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
@ -562,8 +583,11 @@ func (aq *AlertQuery) loadDecisions(ctx context.Context, query *DecisionQuery, n
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(decision.FieldAlertDecisions)
}
query.Where(predicate.Decision(func(s *sql.Selector) {
s.Where(sql.InValues(alert.DecisionsColumn, fks...))
s.Where(sql.InValues(s.C(alert.DecisionsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
@ -573,7 +597,7 @@ func (aq *AlertQuery) loadDecisions(ctx context.Context, query *DecisionQuery, n
fk := n.AlertDecisions
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected foreign-key "alert_decisions" returned %v for node %v`, fk, n.ID)
return fmt.Errorf(`unexpected referenced foreign-key "alert_decisions" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
@ -589,8 +613,11 @@ func (aq *AlertQuery) loadEvents(ctx context.Context, query *EventQuery, nodes [
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(event.FieldAlertEvents)
}
query.Where(predicate.Event(func(s *sql.Selector) {
s.Where(sql.InValues(alert.EventsColumn, fks...))
s.Where(sql.InValues(s.C(alert.EventsColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
@ -600,7 +627,7 @@ func (aq *AlertQuery) loadEvents(ctx context.Context, query *EventQuery, nodes [
fk := n.AlertEvents
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected foreign-key "alert_events" returned %v for node %v`, fk, n.ID)
return fmt.Errorf(`unexpected referenced foreign-key "alert_events" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
@ -616,8 +643,11 @@ func (aq *AlertQuery) loadMetas(ctx context.Context, query *MetaQuery, nodes []*
init(nodes[i])
}
}
if len(query.ctx.Fields) > 0 {
query.ctx.AppendFieldOnce(meta.FieldAlertMetas)
}
query.Where(predicate.Meta(func(s *sql.Selector) {
s.Where(sql.InValues(alert.MetasColumn, fks...))
s.Where(sql.InValues(s.C(alert.MetasColumn), fks...))
}))
neighbors, err := query.All(ctx)
if err != nil {
@ -627,7 +657,7 @@ func (aq *AlertQuery) loadMetas(ctx context.Context, query *MetaQuery, nodes []*
fk := n.AlertMetas
node, ok := nodeids[fk]
if !ok {
return fmt.Errorf(`unexpected foreign-key "alert_metas" returned %v for node %v`, fk, n.ID)
return fmt.Errorf(`unexpected referenced foreign-key "alert_metas" returned %v for node %v`, fk, n.ID)
}
assign(node, n)
}
@ -636,41 +666,22 @@ func (aq *AlertQuery) loadMetas(ctx context.Context, query *MetaQuery, nodes []*
func (aq *AlertQuery) sqlCount(ctx context.Context) (int, error) {
_spec := aq.querySpec()
_spec.Node.Columns = aq.fields
if len(aq.fields) > 0 {
_spec.Unique = aq.unique != nil && *aq.unique
_spec.Node.Columns = aq.ctx.Fields
if len(aq.ctx.Fields) > 0 {
_spec.Unique = aq.ctx.Unique != nil && *aq.ctx.Unique
}
return sqlgraph.CountNodes(ctx, aq.driver, _spec)
}
func (aq *AlertQuery) sqlExist(ctx context.Context) (bool, error) {
switch _, err := aq.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
func (aq *AlertQuery) querySpec() *sqlgraph.QuerySpec {
_spec := &sqlgraph.QuerySpec{
Node: &sqlgraph.NodeSpec{
Table: alert.Table,
Columns: alert.Columns,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: alert.FieldID,
},
},
From: aq.sql,
Unique: true,
}
if unique := aq.unique; unique != nil {
_spec := sqlgraph.NewQuerySpec(alert.Table, alert.Columns, sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt))
_spec.From = aq.sql
if unique := aq.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if aq.path != nil {
_spec.Unique = true
}
if fields := aq.fields; len(fields) > 0 {
if fields := aq.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, alert.FieldID)
for i := range fields {
@ -686,10 +697,10 @@ func (aq *AlertQuery) querySpec() *sqlgraph.QuerySpec {
}
}
}
if limit := aq.limit; limit != nil {
if limit := aq.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := aq.offset; offset != nil {
if offset := aq.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := aq.order; len(ps) > 0 {
@ -705,7 +716,7 @@ func (aq *AlertQuery) querySpec() *sqlgraph.QuerySpec {
func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(aq.driver.Dialect())
t1 := builder.Table(alert.Table)
columns := aq.fields
columns := aq.ctx.Fields
if len(columns) == 0 {
columns = alert.Columns
}
@ -714,7 +725,7 @@ func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector {
selector = aq.sql
selector.Select(selector.Columns(columns...)...)
}
if aq.unique != nil && *aq.unique {
if aq.ctx.Unique != nil && *aq.ctx.Unique {
selector.Distinct()
}
for _, p := range aq.predicates {
@ -723,12 +734,12 @@ func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector {
for _, p := range aq.order {
p(selector)
}
if offset := aq.offset; offset != nil {
if offset := aq.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := aq.limit; limit != nil {
if limit := aq.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
@ -736,13 +747,8 @@ func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector {
// AlertGroupBy is the group-by builder for Alert entities.
type AlertGroupBy struct {
config
selector
fields []string
fns []AggregateFunc
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
build *AlertQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
@ -751,74 +757,77 @@ func (agb *AlertGroupBy) Aggregate(fns ...AggregateFunc) *AlertGroupBy {
return agb
}
// Scan applies the group-by query and scans the result into the given value.
// Scan applies the selector query and scans the result into the given value.
func (agb *AlertGroupBy) Scan(ctx context.Context, v any) error {
query, err := agb.path(ctx)
if err != nil {
ctx = setContextOp(ctx, agb.build.ctx, "GroupBy")
if err := agb.build.prepareQuery(ctx); err != nil {
return err
}
agb.sql = query
return agb.sqlScan(ctx, v)
return scanWithInterceptors[*AlertQuery, *AlertGroupBy](ctx, agb.build, agb, agb.build.inters, v)
}
func (agb *AlertGroupBy) sqlScan(ctx context.Context, v any) error {
for _, f := range agb.fields {
if !alert.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)}
}
}
selector := agb.sqlQuery()
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := agb.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
func (agb *AlertGroupBy) sqlQuery() *sql.Selector {
selector := agb.sql.Select()
func (agb *AlertGroupBy) sqlScan(ctx context.Context, root *AlertQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(agb.fns))
for _, fn := range agb.fns {
aggregation = append(aggregation, fn(selector))
}
// If no columns were selected in a custom aggregation function, the default
// selection is the fields used for "group-by", and the aggregation functions.
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(agb.fields)+len(agb.fns))
for _, f := range agb.fields {
columns := make([]string, 0, len(*agb.flds)+len(agb.fns))
for _, f := range *agb.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
return selector.GroupBy(selector.Columns(agb.fields...)...)
selector.GroupBy(selector.Columns(*agb.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := agb.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// AlertSelect is the builder for selecting fields of Alert entities.
type AlertSelect struct {
*AlertQuery
selector
// intermediate query (i.e. traversal path).
sql *sql.Selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (as *AlertSelect) Aggregate(fns ...AggregateFunc) *AlertSelect {
as.fns = append(as.fns, fns...)
return as
}
// Scan applies the selector query and scans the result into the given value.
func (as *AlertSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, as.ctx, "Select")
if err := as.prepareQuery(ctx); err != nil {
return err
}
as.sql = as.AlertQuery.sqlQuery(ctx)
return as.sqlScan(ctx, v)
return scanWithInterceptors[*AlertQuery, *AlertSelect](ctx, as.AlertQuery, as, as.inters, v)
}
func (as *AlertSelect) sqlScan(ctx context.Context, v any) error {
func (as *AlertSelect) sqlScan(ctx context.Context, root *AlertQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(as.fns))
for _, fn := range as.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*as.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := as.sql.Query()
query, args := selector.Query()
if err := as.driver.Query(ctx, query, args, rows); err != nil {
return err
}

File diff suppressed because it is too large Load diff

View file

@ -7,6 +7,7 @@ import (
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer"
)
@ -37,7 +38,8 @@ type Bouncer struct {
// LastPull holds the value of the "last_pull" field.
LastPull time.Time `json:"last_pull"`
// AuthType holds the value of the "auth_type" field.
AuthType string `json:"auth_type"`
AuthType string `json:"auth_type"`
selectValues sql.SelectValues
}
// scanValues returns the types for scanning values from sql.Rows.
@ -54,7 +56,7 @@ func (*Bouncer) scanValues(columns []string) ([]any, error) {
case bouncer.FieldCreatedAt, bouncer.FieldUpdatedAt, bouncer.FieldUntil, bouncer.FieldLastPull:
values[i] = new(sql.NullTime)
default:
return nil, fmt.Errorf("unexpected column %q for type Bouncer", columns[i])
values[i] = new(sql.UnknownType)
}
}
return values, nil
@ -142,16 +144,24 @@ func (b *Bouncer) assignValues(columns []string, values []any) error {
} else if value.Valid {
b.AuthType = value.String
}
default:
b.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the Bouncer.
// This includes values selected through modifiers, order, etc.
func (b *Bouncer) Value(name string) (ent.Value, error) {
return b.selectValues.Get(name)
}
// Update returns a builder for updating this Bouncer.
// Note that you need to call Bouncer.Unwrap() before calling this method if this Bouncer
// was returned from a transaction, and the transaction was committed or rolled back.
func (b *Bouncer) Update() *BouncerUpdateOne {
return (&BouncerClient{config: b.config}).UpdateOne(b)
return NewBouncerClient(b.config).UpdateOne(b)
}
// Unwrap unwraps the Bouncer entity that was returned from a transaction after it was closed,
@ -212,9 +222,3 @@ func (b *Bouncer) String() string {
// Bouncers is a parsable slice of Bouncer.
type Bouncers []*Bouncer
func (b Bouncers) config(cfg config) {
for _i := range b {
b[_i].config = cfg
}
}

View file

@ -4,6 +4,8 @@ package bouncer
import (
"time"
"entgo.io/ent/dialect/sql"
)
const (
@ -81,3 +83,66 @@ var (
// DefaultAuthType holds the default value on creation for the "auth_type" field.
DefaultAuthType string
)
// OrderOption defines the ordering options for the Bouncer queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByName orders the results by the name field.
func ByName(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldName, opts...).ToFunc()
}
// ByAPIKey orders the results by the api_key field.
func ByAPIKey(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAPIKey, opts...).ToFunc()
}
// ByRevoked orders the results by the revoked field.
func ByRevoked(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRevoked, opts...).ToFunc()
}
// ByIPAddress orders the results by the ip_address field.
func ByIPAddress(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldIPAddress, opts...).ToFunc()
}
// ByType orders the results by the type field.
func ByType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldType, opts...).ToFunc()
}
// ByVersion orders the results by the version field.
func ByVersion(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldVersion, opts...).ToFunc()
}
// ByUntil orders the results by the until field.
func ByUntil(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUntil, opts...).ToFunc()
}
// ByLastPull orders the results by the last_pull field.
func ByLastPull(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldLastPull, opts...).ToFunc()
}
// ByAuthType orders the results by the auth_type field.
func ByAuthType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAuthType, opts...).ToFunc()
}

File diff suppressed because it is too large Load diff

View file

@ -157,50 +157,8 @@ func (bc *BouncerCreate) Mutation() *BouncerMutation {
// Save creates the Bouncer in the database.
func (bc *BouncerCreate) Save(ctx context.Context) (*Bouncer, error) {
var (
err error
node *Bouncer
)
bc.defaults()
if len(bc.hooks) == 0 {
if err = bc.check(); err != nil {
return nil, err
}
node, err = bc.sqlSave(ctx)
} else {
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*BouncerMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
if err = bc.check(); err != nil {
return nil, err
}
bc.mutation = mutation
if node, err = bc.sqlSave(ctx); err != nil {
return nil, err
}
mutation.id = &node.ID
mutation.done = true
return node, err
})
for i := len(bc.hooks) - 1; i >= 0; i-- {
if bc.hooks[i] == nil {
return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = bc.hooks[i](mut)
}
v, err := mut.Mutate(ctx, bc.mutation)
if err != nil {
return nil, err
}
nv, ok := v.(*Bouncer)
if !ok {
return nil, fmt.Errorf("unexpected node type %T returned from BouncerMutation", v)
}
node = nv
}
return node, err
return withHooks(ctx, bc.sqlSave, bc.mutation, bc.hooks)
}
// SaveX calls Save and panics if Save returns an error.
@ -274,6 +232,9 @@ func (bc *BouncerCreate) check() error {
}
func (bc *BouncerCreate) sqlSave(ctx context.Context) (*Bouncer, error) {
if err := bc.check(); err != nil {
return nil, err
}
_node, _spec := bc.createSpec()
if err := sqlgraph.CreateNode(ctx, bc.driver, _spec); err != nil {
if sqlgraph.IsConstraintError(err) {
@ -283,106 +244,58 @@ func (bc *BouncerCreate) sqlSave(ctx context.Context) (*Bouncer, error) {
}
id := _spec.ID.Value.(int64)
_node.ID = int(id)
bc.mutation.id = &_node.ID
bc.mutation.done = true
return _node, nil
}
func (bc *BouncerCreate) createSpec() (*Bouncer, *sqlgraph.CreateSpec) {
var (
_node = &Bouncer{config: bc.config}
_spec = &sqlgraph.CreateSpec{
Table: bouncer.Table,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: bouncer.FieldID,
},
}
_spec = sqlgraph.NewCreateSpec(bouncer.Table, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt))
)
if value, ok := bc.mutation.CreatedAt(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: bouncer.FieldCreatedAt,
})
_spec.SetField(bouncer.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = &value
}
if value, ok := bc.mutation.UpdatedAt(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: bouncer.FieldUpdatedAt,
})
_spec.SetField(bouncer.FieldUpdatedAt, field.TypeTime, value)
_node.UpdatedAt = &value
}
if value, ok := bc.mutation.Name(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldName,
})
_spec.SetField(bouncer.FieldName, field.TypeString, value)
_node.Name = value
}
if value, ok := bc.mutation.APIKey(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldAPIKey,
})
_spec.SetField(bouncer.FieldAPIKey, field.TypeString, value)
_node.APIKey = value
}
if value, ok := bc.mutation.Revoked(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeBool,
Value: value,
Column: bouncer.FieldRevoked,
})
_spec.SetField(bouncer.FieldRevoked, field.TypeBool, value)
_node.Revoked = value
}
if value, ok := bc.mutation.IPAddress(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldIPAddress,
})
_spec.SetField(bouncer.FieldIPAddress, field.TypeString, value)
_node.IPAddress = value
}
if value, ok := bc.mutation.GetType(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldType,
})
_spec.SetField(bouncer.FieldType, field.TypeString, value)
_node.Type = value
}
if value, ok := bc.mutation.Version(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldVersion,
})
_spec.SetField(bouncer.FieldVersion, field.TypeString, value)
_node.Version = value
}
if value, ok := bc.mutation.Until(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: bouncer.FieldUntil,
})
_spec.SetField(bouncer.FieldUntil, field.TypeTime, value)
_node.Until = value
}
if value, ok := bc.mutation.LastPull(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: bouncer.FieldLastPull,
})
_spec.SetField(bouncer.FieldLastPull, field.TypeTime, value)
_node.LastPull = value
}
if value, ok := bc.mutation.AuthType(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldAuthType,
})
_spec.SetField(bouncer.FieldAuthType, field.TypeString, value)
_node.AuthType = value
}
return _node, _spec
@ -391,11 +304,15 @@ func (bc *BouncerCreate) createSpec() (*Bouncer, *sqlgraph.CreateSpec) {
// BouncerCreateBulk is the builder for creating many Bouncer entities in bulk.
type BouncerCreateBulk struct {
config
err error
builders []*BouncerCreate
}
// Save creates the Bouncer entities in the database.
func (bcb *BouncerCreateBulk) Save(ctx context.Context) ([]*Bouncer, error) {
if bcb.err != nil {
return nil, bcb.err
}
specs := make([]*sqlgraph.CreateSpec, len(bcb.builders))
nodes := make([]*Bouncer, len(bcb.builders))
mutators := make([]Mutator, len(bcb.builders))
@ -412,8 +329,8 @@ func (bcb *BouncerCreateBulk) Save(ctx context.Context) ([]*Bouncer, error) {
return nil, err
}
builder.mutation = mutation
nodes[i], specs[i] = builder.createSpec()
var err error
nodes[i], specs[i] = builder.createSpec()
if i < len(mutators)-1 {
_, err = mutators[i+1].Mutate(root, bcb.builders[i+1].mutation)
} else {

View file

@ -4,7 +4,6 @@ package ent
import (
"context"
"fmt"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
@ -28,34 +27,7 @@ func (bd *BouncerDelete) Where(ps ...predicate.Bouncer) *BouncerDelete {
// Exec executes the deletion query and returns how many vertices were deleted.
func (bd *BouncerDelete) Exec(ctx context.Context) (int, error) {
var (
err error
affected int
)
if len(bd.hooks) == 0 {
affected, err = bd.sqlExec(ctx)
} else {
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*BouncerMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
bd.mutation = mutation
affected, err = bd.sqlExec(ctx)
mutation.done = true
return affected, err
})
for i := len(bd.hooks) - 1; i >= 0; i-- {
if bd.hooks[i] == nil {
return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = bd.hooks[i](mut)
}
if _, err := mut.Mutate(ctx, bd.mutation); err != nil {
return 0, err
}
}
return affected, err
return withHooks(ctx, bd.sqlExec, bd.mutation, bd.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
@ -68,15 +40,7 @@ func (bd *BouncerDelete) ExecX(ctx context.Context) int {
}
func (bd *BouncerDelete) sqlExec(ctx context.Context) (int, error) {
_spec := &sqlgraph.DeleteSpec{
Node: &sqlgraph.NodeSpec{
Table: bouncer.Table,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: bouncer.FieldID,
},
},
}
_spec := sqlgraph.NewDeleteSpec(bouncer.Table, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt))
if ps := bd.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
@ -88,6 +52,7 @@ func (bd *BouncerDelete) sqlExec(ctx context.Context) (int, error) {
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
bd.mutation.done = true
return affected, err
}
@ -96,6 +61,12 @@ type BouncerDeleteOne struct {
bd *BouncerDelete
}
// Where appends a list predicates to the BouncerDelete builder.
func (bdo *BouncerDeleteOne) Where(ps ...predicate.Bouncer) *BouncerDeleteOne {
bdo.bd.mutation.Where(ps...)
return bdo
}
// Exec executes the deletion query.
func (bdo *BouncerDeleteOne) Exec(ctx context.Context) error {
n, err := bdo.bd.Exec(ctx)
@ -111,5 +82,7 @@ func (bdo *BouncerDeleteOne) Exec(ctx context.Context) error {
// ExecX is like Exec, but panics if an error occurs.
func (bdo *BouncerDeleteOne) ExecX(ctx context.Context) {
bdo.bd.ExecX(ctx)
if err := bdo.Exec(ctx); err != nil {
panic(err)
}
}

View file

@ -17,11 +17,9 @@ import (
// BouncerQuery is the builder for querying Bouncer entities.
type BouncerQuery struct {
config
limit *int
offset *int
unique *bool
order []OrderFunc
fields []string
ctx *QueryContext
order []bouncer.OrderOption
inters []Interceptor
predicates []predicate.Bouncer
// intermediate query (i.e. traversal path).
sql *sql.Selector
@ -34,27 +32,27 @@ func (bq *BouncerQuery) Where(ps ...predicate.Bouncer) *BouncerQuery {
return bq
}
// Limit adds a limit step to the query.
// Limit the number of records to be returned by this query.
func (bq *BouncerQuery) Limit(limit int) *BouncerQuery {
bq.limit = &limit
bq.ctx.Limit = &limit
return bq
}
// Offset adds an offset step to the query.
// Offset to start from.
func (bq *BouncerQuery) Offset(offset int) *BouncerQuery {
bq.offset = &offset
bq.ctx.Offset = &offset
return bq
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (bq *BouncerQuery) Unique(unique bool) *BouncerQuery {
bq.unique = &unique
bq.ctx.Unique = &unique
return bq
}
// Order adds an order step to the query.
func (bq *BouncerQuery) Order(o ...OrderFunc) *BouncerQuery {
// Order specifies how the records should be ordered.
func (bq *BouncerQuery) Order(o ...bouncer.OrderOption) *BouncerQuery {
bq.order = append(bq.order, o...)
return bq
}
@ -62,7 +60,7 @@ func (bq *BouncerQuery) Order(o ...OrderFunc) *BouncerQuery {
// First returns the first Bouncer entity from the query.
// Returns a *NotFoundError when no Bouncer was found.
func (bq *BouncerQuery) First(ctx context.Context) (*Bouncer, error) {
nodes, err := bq.Limit(1).All(ctx)
nodes, err := bq.Limit(1).All(setContextOp(ctx, bq.ctx, "First"))
if err != nil {
return nil, err
}
@ -85,7 +83,7 @@ func (bq *BouncerQuery) FirstX(ctx context.Context) *Bouncer {
// Returns a *NotFoundError when no Bouncer ID was found.
func (bq *BouncerQuery) FirstID(ctx context.Context) (id int, err error) {
var ids []int
if ids, err = bq.Limit(1).IDs(ctx); err != nil {
if ids, err = bq.Limit(1).IDs(setContextOp(ctx, bq.ctx, "FirstID")); err != nil {
return
}
if len(ids) == 0 {
@ -108,7 +106,7 @@ func (bq *BouncerQuery) FirstIDX(ctx context.Context) int {
// Returns a *NotSingularError when more than one Bouncer entity is found.
// Returns a *NotFoundError when no Bouncer entities are found.
func (bq *BouncerQuery) Only(ctx context.Context) (*Bouncer, error) {
nodes, err := bq.Limit(2).All(ctx)
nodes, err := bq.Limit(2).All(setContextOp(ctx, bq.ctx, "Only"))
if err != nil {
return nil, err
}
@ -136,7 +134,7 @@ func (bq *BouncerQuery) OnlyX(ctx context.Context) *Bouncer {
// Returns a *NotFoundError when no entities are found.
func (bq *BouncerQuery) OnlyID(ctx context.Context) (id int, err error) {
var ids []int
if ids, err = bq.Limit(2).IDs(ctx); err != nil {
if ids, err = bq.Limit(2).IDs(setContextOp(ctx, bq.ctx, "OnlyID")); err != nil {
return
}
switch len(ids) {
@ -161,10 +159,12 @@ func (bq *BouncerQuery) OnlyIDX(ctx context.Context) int {
// All executes the query and returns a list of Bouncers.
func (bq *BouncerQuery) All(ctx context.Context) ([]*Bouncer, error) {
ctx = setContextOp(ctx, bq.ctx, "All")
if err := bq.prepareQuery(ctx); err != nil {
return nil, err
}
return bq.sqlAll(ctx)
qr := querierAll[[]*Bouncer, *BouncerQuery]()
return withInterceptors[[]*Bouncer](ctx, bq, qr, bq.inters)
}
// AllX is like All, but panics if an error occurs.
@ -177,9 +177,12 @@ func (bq *BouncerQuery) AllX(ctx context.Context) []*Bouncer {
}
// IDs executes the query and returns a list of Bouncer IDs.
func (bq *BouncerQuery) IDs(ctx context.Context) ([]int, error) {
var ids []int
if err := bq.Select(bouncer.FieldID).Scan(ctx, &ids); err != nil {
func (bq *BouncerQuery) IDs(ctx context.Context) (ids []int, err error) {
if bq.ctx.Unique == nil && bq.path != nil {
bq.Unique(true)
}
ctx = setContextOp(ctx, bq.ctx, "IDs")
if err = bq.Select(bouncer.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
@ -196,10 +199,11 @@ func (bq *BouncerQuery) IDsX(ctx context.Context) []int {
// Count returns the count of the given query.
func (bq *BouncerQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, bq.ctx, "Count")
if err := bq.prepareQuery(ctx); err != nil {
return 0, err
}
return bq.sqlCount(ctx)
return withInterceptors[int](ctx, bq, querierCount[*BouncerQuery](), bq.inters)
}
// CountX is like Count, but panics if an error occurs.
@ -213,10 +217,15 @@ func (bq *BouncerQuery) CountX(ctx context.Context) int {
// Exist returns true if the query has elements in the graph.
func (bq *BouncerQuery) Exist(ctx context.Context) (bool, error) {
if err := bq.prepareQuery(ctx); err != nil {
return false, err
ctx = setContextOp(ctx, bq.ctx, "Exist")
switch _, err := bq.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
return bq.sqlExist(ctx)
}
// ExistX is like Exist, but panics if an error occurs.
@ -236,14 +245,13 @@ func (bq *BouncerQuery) Clone() *BouncerQuery {
}
return &BouncerQuery{
config: bq.config,
limit: bq.limit,
offset: bq.offset,
order: append([]OrderFunc{}, bq.order...),
ctx: bq.ctx.Clone(),
order: append([]bouncer.OrderOption{}, bq.order...),
inters: append([]Interceptor{}, bq.inters...),
predicates: append([]predicate.Bouncer{}, bq.predicates...),
// clone intermediate query.
sql: bq.sql.Clone(),
path: bq.path,
unique: bq.unique,
sql: bq.sql.Clone(),
path: bq.path,
}
}
@ -262,16 +270,11 @@ func (bq *BouncerQuery) Clone() *BouncerQuery {
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (bq *BouncerQuery) GroupBy(field string, fields ...string) *BouncerGroupBy {
grbuild := &BouncerGroupBy{config: bq.config}
grbuild.fields = append([]string{field}, fields...)
grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) {
if err := bq.prepareQuery(ctx); err != nil {
return nil, err
}
return bq.sqlQuery(ctx), nil
}
bq.ctx.Fields = append([]string{field}, fields...)
grbuild := &BouncerGroupBy{build: bq}
grbuild.flds = &bq.ctx.Fields
grbuild.label = bouncer.Label
grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan
grbuild.scan = grbuild.Scan
return grbuild
}
@ -288,15 +291,30 @@ func (bq *BouncerQuery) GroupBy(field string, fields ...string) *BouncerGroupBy
// Select(bouncer.FieldCreatedAt).
// Scan(ctx, &v)
func (bq *BouncerQuery) Select(fields ...string) *BouncerSelect {
bq.fields = append(bq.fields, fields...)
selbuild := &BouncerSelect{BouncerQuery: bq}
selbuild.label = bouncer.Label
selbuild.flds, selbuild.scan = &bq.fields, selbuild.Scan
return selbuild
bq.ctx.Fields = append(bq.ctx.Fields, fields...)
sbuild := &BouncerSelect{BouncerQuery: bq}
sbuild.label = bouncer.Label
sbuild.flds, sbuild.scan = &bq.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a BouncerSelect configured with the given aggregations.
func (bq *BouncerQuery) Aggregate(fns ...AggregateFunc) *BouncerSelect {
return bq.Select().Aggregate(fns...)
}
func (bq *BouncerQuery) prepareQuery(ctx context.Context) error {
for _, f := range bq.fields {
for _, inter := range bq.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, bq); err != nil {
return err
}
}
}
for _, f := range bq.ctx.Fields {
if !bouncer.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
@ -338,41 +356,22 @@ func (bq *BouncerQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Boun
func (bq *BouncerQuery) sqlCount(ctx context.Context) (int, error) {
_spec := bq.querySpec()
_spec.Node.Columns = bq.fields
if len(bq.fields) > 0 {
_spec.Unique = bq.unique != nil && *bq.unique
_spec.Node.Columns = bq.ctx.Fields
if len(bq.ctx.Fields) > 0 {
_spec.Unique = bq.ctx.Unique != nil && *bq.ctx.Unique
}
return sqlgraph.CountNodes(ctx, bq.driver, _spec)
}
func (bq *BouncerQuery) sqlExist(ctx context.Context) (bool, error) {
switch _, err := bq.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
func (bq *BouncerQuery) querySpec() *sqlgraph.QuerySpec {
_spec := &sqlgraph.QuerySpec{
Node: &sqlgraph.NodeSpec{
Table: bouncer.Table,
Columns: bouncer.Columns,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: bouncer.FieldID,
},
},
From: bq.sql,
Unique: true,
}
if unique := bq.unique; unique != nil {
_spec := sqlgraph.NewQuerySpec(bouncer.Table, bouncer.Columns, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt))
_spec.From = bq.sql
if unique := bq.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if bq.path != nil {
_spec.Unique = true
}
if fields := bq.fields; len(fields) > 0 {
if fields := bq.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, bouncer.FieldID)
for i := range fields {
@ -388,10 +387,10 @@ func (bq *BouncerQuery) querySpec() *sqlgraph.QuerySpec {
}
}
}
if limit := bq.limit; limit != nil {
if limit := bq.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := bq.offset; offset != nil {
if offset := bq.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := bq.order; len(ps) > 0 {
@ -407,7 +406,7 @@ func (bq *BouncerQuery) querySpec() *sqlgraph.QuerySpec {
func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(bq.driver.Dialect())
t1 := builder.Table(bouncer.Table)
columns := bq.fields
columns := bq.ctx.Fields
if len(columns) == 0 {
columns = bouncer.Columns
}
@ -416,7 +415,7 @@ func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector {
selector = bq.sql
selector.Select(selector.Columns(columns...)...)
}
if bq.unique != nil && *bq.unique {
if bq.ctx.Unique != nil && *bq.ctx.Unique {
selector.Distinct()
}
for _, p := range bq.predicates {
@ -425,12 +424,12 @@ func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector {
for _, p := range bq.order {
p(selector)
}
if offset := bq.offset; offset != nil {
if offset := bq.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := bq.limit; limit != nil {
if limit := bq.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
@ -438,13 +437,8 @@ func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector {
// BouncerGroupBy is the group-by builder for Bouncer entities.
type BouncerGroupBy struct {
config
selector
fields []string
fns []AggregateFunc
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
build *BouncerQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
@ -453,74 +447,77 @@ func (bgb *BouncerGroupBy) Aggregate(fns ...AggregateFunc) *BouncerGroupBy {
return bgb
}
// Scan applies the group-by query and scans the result into the given value.
// Scan applies the selector query and scans the result into the given value.
func (bgb *BouncerGroupBy) Scan(ctx context.Context, v any) error {
query, err := bgb.path(ctx)
if err != nil {
ctx = setContextOp(ctx, bgb.build.ctx, "GroupBy")
if err := bgb.build.prepareQuery(ctx); err != nil {
return err
}
bgb.sql = query
return bgb.sqlScan(ctx, v)
return scanWithInterceptors[*BouncerQuery, *BouncerGroupBy](ctx, bgb.build, bgb, bgb.build.inters, v)
}
func (bgb *BouncerGroupBy) sqlScan(ctx context.Context, v any) error {
for _, f := range bgb.fields {
if !bouncer.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)}
}
}
selector := bgb.sqlQuery()
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := bgb.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
func (bgb *BouncerGroupBy) sqlQuery() *sql.Selector {
selector := bgb.sql.Select()
func (bgb *BouncerGroupBy) sqlScan(ctx context.Context, root *BouncerQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(bgb.fns))
for _, fn := range bgb.fns {
aggregation = append(aggregation, fn(selector))
}
// If no columns were selected in a custom aggregation function, the default
// selection is the fields used for "group-by", and the aggregation functions.
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(bgb.fields)+len(bgb.fns))
for _, f := range bgb.fields {
columns := make([]string, 0, len(*bgb.flds)+len(bgb.fns))
for _, f := range *bgb.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
return selector.GroupBy(selector.Columns(bgb.fields...)...)
selector.GroupBy(selector.Columns(*bgb.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := bgb.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// BouncerSelect is the builder for selecting fields of Bouncer entities.
type BouncerSelect struct {
*BouncerQuery
selector
// intermediate query (i.e. traversal path).
sql *sql.Selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (bs *BouncerSelect) Aggregate(fns ...AggregateFunc) *BouncerSelect {
bs.fns = append(bs.fns, fns...)
return bs
}
// Scan applies the selector query and scans the result into the given value.
func (bs *BouncerSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, bs.ctx, "Select")
if err := bs.prepareQuery(ctx); err != nil {
return err
}
bs.sql = bs.BouncerQuery.sqlQuery(ctx)
return bs.sqlScan(ctx, v)
return scanWithInterceptors[*BouncerQuery, *BouncerSelect](ctx, bs.BouncerQuery, bs, bs.inters, v)
}
func (bs *BouncerSelect) sqlScan(ctx context.Context, v any) error {
func (bs *BouncerSelect) sqlScan(ctx context.Context, root *BouncerQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(bs.fns))
for _, fn := range bs.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*bs.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := bs.sql.Query()
query, args := selector.Query()
if err := bs.driver.Query(ctx, query, args, rows); err != nil {
return err
}

View file

@ -185,35 +185,8 @@ func (bu *BouncerUpdate) Mutation() *BouncerMutation {
// Save executes the query and returns the number of nodes affected by the update operation.
func (bu *BouncerUpdate) Save(ctx context.Context) (int, error) {
var (
err error
affected int
)
bu.defaults()
if len(bu.hooks) == 0 {
affected, err = bu.sqlSave(ctx)
} else {
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*BouncerMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
bu.mutation = mutation
affected, err = bu.sqlSave(ctx)
mutation.done = true
return affected, err
})
for i := len(bu.hooks) - 1; i >= 0; i-- {
if bu.hooks[i] == nil {
return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = bu.hooks[i](mut)
}
if _, err := mut.Mutate(ctx, bu.mutation); err != nil {
return 0, err
}
}
return affected, err
return withHooks(ctx, bu.sqlSave, bu.mutation, bu.hooks)
}
// SaveX is like Save, but panics if an error occurs.
@ -251,16 +224,7 @@ func (bu *BouncerUpdate) defaults() {
}
func (bu *BouncerUpdate) sqlSave(ctx context.Context) (n int, err error) {
_spec := &sqlgraph.UpdateSpec{
Node: &sqlgraph.NodeSpec{
Table: bouncer.Table,
Columns: bouncer.Columns,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: bouncer.FieldID,
},
},
}
_spec := sqlgraph.NewUpdateSpec(bouncer.Table, bouncer.Columns, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt))
if ps := bu.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
@ -269,117 +233,55 @@ func (bu *BouncerUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
}
if value, ok := bu.mutation.CreatedAt(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: bouncer.FieldCreatedAt,
})
_spec.SetField(bouncer.FieldCreatedAt, field.TypeTime, value)
}
if bu.mutation.CreatedAtCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: bouncer.FieldCreatedAt,
})
_spec.ClearField(bouncer.FieldCreatedAt, field.TypeTime)
}
if value, ok := bu.mutation.UpdatedAt(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: bouncer.FieldUpdatedAt,
})
_spec.SetField(bouncer.FieldUpdatedAt, field.TypeTime, value)
}
if bu.mutation.UpdatedAtCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: bouncer.FieldUpdatedAt,
})
_spec.ClearField(bouncer.FieldUpdatedAt, field.TypeTime)
}
if value, ok := bu.mutation.Name(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldName,
})
_spec.SetField(bouncer.FieldName, field.TypeString, value)
}
if value, ok := bu.mutation.APIKey(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldAPIKey,
})
_spec.SetField(bouncer.FieldAPIKey, field.TypeString, value)
}
if value, ok := bu.mutation.Revoked(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeBool,
Value: value,
Column: bouncer.FieldRevoked,
})
_spec.SetField(bouncer.FieldRevoked, field.TypeBool, value)
}
if value, ok := bu.mutation.IPAddress(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldIPAddress,
})
_spec.SetField(bouncer.FieldIPAddress, field.TypeString, value)
}
if bu.mutation.IPAddressCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeString,
Column: bouncer.FieldIPAddress,
})
_spec.ClearField(bouncer.FieldIPAddress, field.TypeString)
}
if value, ok := bu.mutation.GetType(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldType,
})
_spec.SetField(bouncer.FieldType, field.TypeString, value)
}
if bu.mutation.TypeCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeString,
Column: bouncer.FieldType,
})
_spec.ClearField(bouncer.FieldType, field.TypeString)
}
if value, ok := bu.mutation.Version(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldVersion,
})
_spec.SetField(bouncer.FieldVersion, field.TypeString, value)
}
if bu.mutation.VersionCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeString,
Column: bouncer.FieldVersion,
})
_spec.ClearField(bouncer.FieldVersion, field.TypeString)
}
if value, ok := bu.mutation.Until(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: bouncer.FieldUntil,
})
_spec.SetField(bouncer.FieldUntil, field.TypeTime, value)
}
if bu.mutation.UntilCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: bouncer.FieldUntil,
})
_spec.ClearField(bouncer.FieldUntil, field.TypeTime)
}
if value, ok := bu.mutation.LastPull(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: bouncer.FieldLastPull,
})
_spec.SetField(bouncer.FieldLastPull, field.TypeTime, value)
}
if value, ok := bu.mutation.AuthType(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldAuthType,
})
_spec.SetField(bouncer.FieldAuthType, field.TypeString, value)
}
if n, err = sqlgraph.UpdateNodes(ctx, bu.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
@ -389,6 +291,7 @@ func (bu *BouncerUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
return 0, err
}
bu.mutation.done = true
return n, nil
}
@ -555,6 +458,12 @@ func (buo *BouncerUpdateOne) Mutation() *BouncerMutation {
return buo.mutation
}
// Where appends a list predicates to the BouncerUpdate builder.
func (buo *BouncerUpdateOne) Where(ps ...predicate.Bouncer) *BouncerUpdateOne {
buo.mutation.Where(ps...)
return buo
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (buo *BouncerUpdateOne) Select(field string, fields ...string) *BouncerUpdateOne {
@ -564,41 +473,8 @@ func (buo *BouncerUpdateOne) Select(field string, fields ...string) *BouncerUpda
// Save executes the query and returns the updated Bouncer entity.
func (buo *BouncerUpdateOne) Save(ctx context.Context) (*Bouncer, error) {
var (
err error
node *Bouncer
)
buo.defaults()
if len(buo.hooks) == 0 {
node, err = buo.sqlSave(ctx)
} else {
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*BouncerMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
buo.mutation = mutation
node, err = buo.sqlSave(ctx)
mutation.done = true
return node, err
})
for i := len(buo.hooks) - 1; i >= 0; i-- {
if buo.hooks[i] == nil {
return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = buo.hooks[i](mut)
}
v, err := mut.Mutate(ctx, buo.mutation)
if err != nil {
return nil, err
}
nv, ok := v.(*Bouncer)
if !ok {
return nil, fmt.Errorf("unexpected node type %T returned from BouncerMutation", v)
}
node = nv
}
return node, err
return withHooks(ctx, buo.sqlSave, buo.mutation, buo.hooks)
}
// SaveX is like Save, but panics if an error occurs.
@ -636,16 +512,7 @@ func (buo *BouncerUpdateOne) defaults() {
}
func (buo *BouncerUpdateOne) sqlSave(ctx context.Context) (_node *Bouncer, err error) {
_spec := &sqlgraph.UpdateSpec{
Node: &sqlgraph.NodeSpec{
Table: bouncer.Table,
Columns: bouncer.Columns,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: bouncer.FieldID,
},
},
}
_spec := sqlgraph.NewUpdateSpec(bouncer.Table, bouncer.Columns, sqlgraph.NewFieldSpec(bouncer.FieldID, field.TypeInt))
id, ok := buo.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Bouncer.id" for update`)}
@ -671,117 +538,55 @@ func (buo *BouncerUpdateOne) sqlSave(ctx context.Context) (_node *Bouncer, err e
}
}
if value, ok := buo.mutation.CreatedAt(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: bouncer.FieldCreatedAt,
})
_spec.SetField(bouncer.FieldCreatedAt, field.TypeTime, value)
}
if buo.mutation.CreatedAtCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: bouncer.FieldCreatedAt,
})
_spec.ClearField(bouncer.FieldCreatedAt, field.TypeTime)
}
if value, ok := buo.mutation.UpdatedAt(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: bouncer.FieldUpdatedAt,
})
_spec.SetField(bouncer.FieldUpdatedAt, field.TypeTime, value)
}
if buo.mutation.UpdatedAtCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: bouncer.FieldUpdatedAt,
})
_spec.ClearField(bouncer.FieldUpdatedAt, field.TypeTime)
}
if value, ok := buo.mutation.Name(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldName,
})
_spec.SetField(bouncer.FieldName, field.TypeString, value)
}
if value, ok := buo.mutation.APIKey(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldAPIKey,
})
_spec.SetField(bouncer.FieldAPIKey, field.TypeString, value)
}
if value, ok := buo.mutation.Revoked(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeBool,
Value: value,
Column: bouncer.FieldRevoked,
})
_spec.SetField(bouncer.FieldRevoked, field.TypeBool, value)
}
if value, ok := buo.mutation.IPAddress(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldIPAddress,
})
_spec.SetField(bouncer.FieldIPAddress, field.TypeString, value)
}
if buo.mutation.IPAddressCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeString,
Column: bouncer.FieldIPAddress,
})
_spec.ClearField(bouncer.FieldIPAddress, field.TypeString)
}
if value, ok := buo.mutation.GetType(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldType,
})
_spec.SetField(bouncer.FieldType, field.TypeString, value)
}
if buo.mutation.TypeCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeString,
Column: bouncer.FieldType,
})
_spec.ClearField(bouncer.FieldType, field.TypeString)
}
if value, ok := buo.mutation.Version(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldVersion,
})
_spec.SetField(bouncer.FieldVersion, field.TypeString, value)
}
if buo.mutation.VersionCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeString,
Column: bouncer.FieldVersion,
})
_spec.ClearField(bouncer.FieldVersion, field.TypeString)
}
if value, ok := buo.mutation.Until(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: bouncer.FieldUntil,
})
_spec.SetField(bouncer.FieldUntil, field.TypeTime, value)
}
if buo.mutation.UntilCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: bouncer.FieldUntil,
})
_spec.ClearField(bouncer.FieldUntil, field.TypeTime)
}
if value, ok := buo.mutation.LastPull(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: bouncer.FieldLastPull,
})
_spec.SetField(bouncer.FieldLastPull, field.TypeTime, value)
}
if value, ok := buo.mutation.AuthType(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldAuthType,
})
_spec.SetField(bouncer.FieldAuthType, field.TypeString, value)
}
_node = &Bouncer{config: buo.config}
_spec.Assign = _node.assignValues
@ -794,5 +599,6 @@ func (buo *BouncerUpdateOne) sqlSave(ctx context.Context) (_node *Bouncer, err e
}
return nil, err
}
buo.mutation.done = true
return _node, nil
}

View file

@ -7,20 +7,22 @@ import (
"errors"
"fmt"
"log"
"reflect"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/migrate"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/alert"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/event"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/lock"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/machine"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/meta"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
// Client is the client that holds all ent builders.
@ -38,6 +40,8 @@ type Client struct {
Decision *DecisionClient
// Event is the client for interacting with the Event builders.
Event *EventClient
// Lock is the client for interacting with the Lock builders.
Lock *LockClient
// Machine is the client for interacting with the Machine builders.
Machine *MachineClient
// Meta is the client for interacting with the Meta builders.
@ -46,7 +50,7 @@ type Client struct {
// NewClient creates a new client configured with the given options.
func NewClient(opts ...Option) *Client {
cfg := config{log: log.Println, hooks: &hooks{}}
cfg := config{log: log.Println, hooks: &hooks{}, inters: &inters{}}
cfg.options(opts...)
client := &Client{config: cfg}
client.init()
@ -60,10 +64,60 @@ func (c *Client) init() {
c.ConfigItem = NewConfigItemClient(c.config)
c.Decision = NewDecisionClient(c.config)
c.Event = NewEventClient(c.config)
c.Lock = NewLockClient(c.config)
c.Machine = NewMachineClient(c.config)
c.Meta = NewMetaClient(c.config)
}
type (
// config is the configuration for the client and its builder.
config struct {
// driver used for executing database requests.
driver dialect.Driver
// debug enable a debug logging.
debug bool
// log used for logging on debug mode.
log func(...any)
// hooks to execute on mutations.
hooks *hooks
// interceptors to execute on queries.
inters *inters
}
// Option function to configure the client.
Option func(*config)
)
// options applies the options on the config object.
func (c *config) options(opts ...Option) {
for _, opt := range opts {
opt(c)
}
if c.debug {
c.driver = dialect.Debug(c.driver, c.log)
}
}
// Debug enables debug logging on the ent.Driver.
func Debug() Option {
return func(c *config) {
c.debug = true
}
}
// Log sets the logging function for debug mode.
func Log(fn func(...any)) Option {
return func(c *config) {
c.log = fn
}
}
// Driver configures the client driver.
func Driver(driver dialect.Driver) Option {
return func(c *config) {
c.driver = driver
}
}
// Open opens a database/sql.DB specified by the driver name and
// the data source name, and returns a new client attached to it.
// Optional parameters can be added for configuring the client.
@ -80,11 +134,14 @@ func Open(driverName, dataSourceName string, options ...Option) (*Client, error)
}
}
// ErrTxStarted is returned when trying to start a new transaction from a transactional client.
var ErrTxStarted = errors.New("ent: cannot start a transaction within a transaction")
// Tx returns a new transactional client. The provided context
// is used until the transaction is committed or rolled back.
func (c *Client) Tx(ctx context.Context) (*Tx, error) {
if _, ok := c.driver.(*txDriver); ok {
return nil, errors.New("ent: cannot start a transaction within a transaction")
return nil, ErrTxStarted
}
tx, err := newTx(ctx, c.driver)
if err != nil {
@ -100,6 +157,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
ConfigItem: NewConfigItemClient(cfg),
Decision: NewDecisionClient(cfg),
Event: NewEventClient(cfg),
Lock: NewLockClient(cfg),
Machine: NewMachineClient(cfg),
Meta: NewMetaClient(cfg),
}, nil
@ -126,6 +184,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
ConfigItem: NewConfigItemClient(cfg),
Decision: NewDecisionClient(cfg),
Event: NewEventClient(cfg),
Lock: NewLockClient(cfg),
Machine: NewMachineClient(cfg),
Meta: NewMetaClient(cfg),
}, nil
@ -156,13 +215,47 @@ func (c *Client) Close() error {
// Use adds the mutation hooks to all the entity clients.
// In order to add hooks to a specific client, call: `client.Node.Use(...)`.
func (c *Client) Use(hooks ...Hook) {
c.Alert.Use(hooks...)
c.Bouncer.Use(hooks...)
c.ConfigItem.Use(hooks...)
c.Decision.Use(hooks...)
c.Event.Use(hooks...)
c.Machine.Use(hooks...)
c.Meta.Use(hooks...)
for _, n := range []interface{ Use(...Hook) }{
c.Alert, c.Bouncer, c.ConfigItem, c.Decision, c.Event, c.Lock, c.Machine,
c.Meta,
} {
n.Use(hooks...)
}
}
// Intercept adds the query interceptors to all the entity clients.
// In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`.
func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.Alert, c.Bouncer, c.ConfigItem, c.Decision, c.Event, c.Lock, c.Machine,
c.Meta,
} {
n.Intercept(interceptors...)
}
}
// Mutate implements the ent.Mutator interface.
func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
switch m := m.(type) {
case *AlertMutation:
return c.Alert.mutate(ctx, m)
case *BouncerMutation:
return c.Bouncer.mutate(ctx, m)
case *ConfigItemMutation:
return c.ConfigItem.mutate(ctx, m)
case *DecisionMutation:
return c.Decision.mutate(ctx, m)
case *EventMutation:
return c.Event.mutate(ctx, m)
case *LockMutation:
return c.Lock.mutate(ctx, m)
case *MachineMutation:
return c.Machine.mutate(ctx, m)
case *MetaMutation:
return c.Meta.mutate(ctx, m)
default:
return nil, fmt.Errorf("ent: unknown mutation type %T", m)
}
}
// AlertClient is a client for the Alert schema.
@ -181,6 +274,12 @@ func (c *AlertClient) Use(hooks ...Hook) {
c.hooks.Alert = append(c.hooks.Alert, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `alert.Intercept(f(g(h())))`.
func (c *AlertClient) Intercept(interceptors ...Interceptor) {
c.inters.Alert = append(c.inters.Alert, interceptors...)
}
// Create returns a builder for creating a Alert entity.
func (c *AlertClient) Create() *AlertCreate {
mutation := newAlertMutation(c.config, OpCreate)
@ -192,6 +291,21 @@ func (c *AlertClient) CreateBulk(builders ...*AlertCreate) *AlertCreateBulk {
return &AlertCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *AlertClient) MapCreateBulk(slice any, setFunc func(*AlertCreate, int)) *AlertCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &AlertCreateBulk{err: fmt.Errorf("calling to AlertClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*AlertCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &AlertCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for Alert.
func (c *AlertClient) Update() *AlertUpdate {
mutation := newAlertMutation(c.config, OpUpdate)
@ -221,7 +335,7 @@ func (c *AlertClient) DeleteOne(a *Alert) *AlertDeleteOne {
return c.DeleteOneID(a.ID)
}
// DeleteOne returns a builder for deleting the given entity by its id.
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *AlertClient) DeleteOneID(id int) *AlertDeleteOne {
builder := c.Delete().Where(alert.ID(id))
builder.mutation.id = &id
@ -233,6 +347,8 @@ func (c *AlertClient) DeleteOneID(id int) *AlertDeleteOne {
func (c *AlertClient) Query() *AlertQuery {
return &AlertQuery{
config: c.config,
ctx: &QueryContext{Type: TypeAlert},
inters: c.Interceptors(),
}
}
@ -252,8 +368,8 @@ func (c *AlertClient) GetX(ctx context.Context, id int) *Alert {
// QueryOwner queries the owner edge of a Alert.
func (c *AlertClient) QueryOwner(a *Alert) *MachineQuery {
query := &MachineQuery{config: c.config}
query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) {
query := (&MachineClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := a.ID
step := sqlgraph.NewStep(
sqlgraph.From(alert.Table, alert.FieldID, id),
@ -268,8 +384,8 @@ func (c *AlertClient) QueryOwner(a *Alert) *MachineQuery {
// QueryDecisions queries the decisions edge of a Alert.
func (c *AlertClient) QueryDecisions(a *Alert) *DecisionQuery {
query := &DecisionQuery{config: c.config}
query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) {
query := (&DecisionClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := a.ID
step := sqlgraph.NewStep(
sqlgraph.From(alert.Table, alert.FieldID, id),
@ -284,8 +400,8 @@ func (c *AlertClient) QueryDecisions(a *Alert) *DecisionQuery {
// QueryEvents queries the events edge of a Alert.
func (c *AlertClient) QueryEvents(a *Alert) *EventQuery {
query := &EventQuery{config: c.config}
query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) {
query := (&EventClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := a.ID
step := sqlgraph.NewStep(
sqlgraph.From(alert.Table, alert.FieldID, id),
@ -300,8 +416,8 @@ func (c *AlertClient) QueryEvents(a *Alert) *EventQuery {
// QueryMetas queries the metas edge of a Alert.
func (c *AlertClient) QueryMetas(a *Alert) *MetaQuery {
query := &MetaQuery{config: c.config}
query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) {
query := (&MetaClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := a.ID
step := sqlgraph.NewStep(
sqlgraph.From(alert.Table, alert.FieldID, id),
@ -319,6 +435,26 @@ func (c *AlertClient) Hooks() []Hook {
return c.hooks.Alert
}
// Interceptors returns the client interceptors.
func (c *AlertClient) Interceptors() []Interceptor {
return c.inters.Alert
}
func (c *AlertClient) mutate(ctx context.Context, m *AlertMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&AlertCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&AlertUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&AlertUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&AlertDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown Alert mutation op: %q", m.Op())
}
}
// BouncerClient is a client for the Bouncer schema.
type BouncerClient struct {
config
@ -335,6 +471,12 @@ func (c *BouncerClient) Use(hooks ...Hook) {
c.hooks.Bouncer = append(c.hooks.Bouncer, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `bouncer.Intercept(f(g(h())))`.
func (c *BouncerClient) Intercept(interceptors ...Interceptor) {
c.inters.Bouncer = append(c.inters.Bouncer, interceptors...)
}
// Create returns a builder for creating a Bouncer entity.
func (c *BouncerClient) Create() *BouncerCreate {
mutation := newBouncerMutation(c.config, OpCreate)
@ -346,6 +488,21 @@ func (c *BouncerClient) CreateBulk(builders ...*BouncerCreate) *BouncerCreateBul
return &BouncerCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *BouncerClient) MapCreateBulk(slice any, setFunc func(*BouncerCreate, int)) *BouncerCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &BouncerCreateBulk{err: fmt.Errorf("calling to BouncerClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*BouncerCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &BouncerCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for Bouncer.
func (c *BouncerClient) Update() *BouncerUpdate {
mutation := newBouncerMutation(c.config, OpUpdate)
@ -375,7 +532,7 @@ func (c *BouncerClient) DeleteOne(b *Bouncer) *BouncerDeleteOne {
return c.DeleteOneID(b.ID)
}
// DeleteOne returns a builder for deleting the given entity by its id.
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *BouncerClient) DeleteOneID(id int) *BouncerDeleteOne {
builder := c.Delete().Where(bouncer.ID(id))
builder.mutation.id = &id
@ -387,6 +544,8 @@ func (c *BouncerClient) DeleteOneID(id int) *BouncerDeleteOne {
func (c *BouncerClient) Query() *BouncerQuery {
return &BouncerQuery{
config: c.config,
ctx: &QueryContext{Type: TypeBouncer},
inters: c.Interceptors(),
}
}
@ -409,6 +568,26 @@ func (c *BouncerClient) Hooks() []Hook {
return c.hooks.Bouncer
}
// Interceptors returns the client interceptors.
func (c *BouncerClient) Interceptors() []Interceptor {
return c.inters.Bouncer
}
func (c *BouncerClient) mutate(ctx context.Context, m *BouncerMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&BouncerCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&BouncerUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&BouncerUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&BouncerDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown Bouncer mutation op: %q", m.Op())
}
}
// ConfigItemClient is a client for the ConfigItem schema.
type ConfigItemClient struct {
config
@ -425,6 +604,12 @@ func (c *ConfigItemClient) Use(hooks ...Hook) {
c.hooks.ConfigItem = append(c.hooks.ConfigItem, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `configitem.Intercept(f(g(h())))`.
func (c *ConfigItemClient) Intercept(interceptors ...Interceptor) {
c.inters.ConfigItem = append(c.inters.ConfigItem, interceptors...)
}
// Create returns a builder for creating a ConfigItem entity.
func (c *ConfigItemClient) Create() *ConfigItemCreate {
mutation := newConfigItemMutation(c.config, OpCreate)
@ -436,6 +621,21 @@ func (c *ConfigItemClient) CreateBulk(builders ...*ConfigItemCreate) *ConfigItem
return &ConfigItemCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *ConfigItemClient) MapCreateBulk(slice any, setFunc func(*ConfigItemCreate, int)) *ConfigItemCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &ConfigItemCreateBulk{err: fmt.Errorf("calling to ConfigItemClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*ConfigItemCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &ConfigItemCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for ConfigItem.
func (c *ConfigItemClient) Update() *ConfigItemUpdate {
mutation := newConfigItemMutation(c.config, OpUpdate)
@ -465,7 +665,7 @@ func (c *ConfigItemClient) DeleteOne(ci *ConfigItem) *ConfigItemDeleteOne {
return c.DeleteOneID(ci.ID)
}
// DeleteOne returns a builder for deleting the given entity by its id.
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *ConfigItemClient) DeleteOneID(id int) *ConfigItemDeleteOne {
builder := c.Delete().Where(configitem.ID(id))
builder.mutation.id = &id
@ -477,6 +677,8 @@ func (c *ConfigItemClient) DeleteOneID(id int) *ConfigItemDeleteOne {
func (c *ConfigItemClient) Query() *ConfigItemQuery {
return &ConfigItemQuery{
config: c.config,
ctx: &QueryContext{Type: TypeConfigItem},
inters: c.Interceptors(),
}
}
@ -499,6 +701,26 @@ func (c *ConfigItemClient) Hooks() []Hook {
return c.hooks.ConfigItem
}
// Interceptors returns the client interceptors.
func (c *ConfigItemClient) Interceptors() []Interceptor {
return c.inters.ConfigItem
}
func (c *ConfigItemClient) mutate(ctx context.Context, m *ConfigItemMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&ConfigItemCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&ConfigItemUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&ConfigItemUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&ConfigItemDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown ConfigItem mutation op: %q", m.Op())
}
}
// DecisionClient is a client for the Decision schema.
type DecisionClient struct {
config
@ -515,6 +737,12 @@ func (c *DecisionClient) Use(hooks ...Hook) {
c.hooks.Decision = append(c.hooks.Decision, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `decision.Intercept(f(g(h())))`.
func (c *DecisionClient) Intercept(interceptors ...Interceptor) {
c.inters.Decision = append(c.inters.Decision, interceptors...)
}
// Create returns a builder for creating a Decision entity.
func (c *DecisionClient) Create() *DecisionCreate {
mutation := newDecisionMutation(c.config, OpCreate)
@ -526,6 +754,21 @@ func (c *DecisionClient) CreateBulk(builders ...*DecisionCreate) *DecisionCreate
return &DecisionCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *DecisionClient) MapCreateBulk(slice any, setFunc func(*DecisionCreate, int)) *DecisionCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &DecisionCreateBulk{err: fmt.Errorf("calling to DecisionClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*DecisionCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &DecisionCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for Decision.
func (c *DecisionClient) Update() *DecisionUpdate {
mutation := newDecisionMutation(c.config, OpUpdate)
@ -555,7 +798,7 @@ func (c *DecisionClient) DeleteOne(d *Decision) *DecisionDeleteOne {
return c.DeleteOneID(d.ID)
}
// DeleteOne returns a builder for deleting the given entity by its id.
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *DecisionClient) DeleteOneID(id int) *DecisionDeleteOne {
builder := c.Delete().Where(decision.ID(id))
builder.mutation.id = &id
@ -567,6 +810,8 @@ func (c *DecisionClient) DeleteOneID(id int) *DecisionDeleteOne {
func (c *DecisionClient) Query() *DecisionQuery {
return &DecisionQuery{
config: c.config,
ctx: &QueryContext{Type: TypeDecision},
inters: c.Interceptors(),
}
}
@ -586,8 +831,8 @@ func (c *DecisionClient) GetX(ctx context.Context, id int) *Decision {
// QueryOwner queries the owner edge of a Decision.
func (c *DecisionClient) QueryOwner(d *Decision) *AlertQuery {
query := &AlertQuery{config: c.config}
query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) {
query := (&AlertClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := d.ID
step := sqlgraph.NewStep(
sqlgraph.From(decision.Table, decision.FieldID, id),
@ -605,6 +850,26 @@ func (c *DecisionClient) Hooks() []Hook {
return c.hooks.Decision
}
// Interceptors returns the client interceptors.
func (c *DecisionClient) Interceptors() []Interceptor {
return c.inters.Decision
}
func (c *DecisionClient) mutate(ctx context.Context, m *DecisionMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&DecisionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&DecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&DecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&DecisionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown Decision mutation op: %q", m.Op())
}
}
// EventClient is a client for the Event schema.
type EventClient struct {
config
@ -621,6 +886,12 @@ func (c *EventClient) Use(hooks ...Hook) {
c.hooks.Event = append(c.hooks.Event, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `event.Intercept(f(g(h())))`.
func (c *EventClient) Intercept(interceptors ...Interceptor) {
c.inters.Event = append(c.inters.Event, interceptors...)
}
// Create returns a builder for creating a Event entity.
func (c *EventClient) Create() *EventCreate {
mutation := newEventMutation(c.config, OpCreate)
@ -632,6 +903,21 @@ func (c *EventClient) CreateBulk(builders ...*EventCreate) *EventCreateBulk {
return &EventCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *EventClient) MapCreateBulk(slice any, setFunc func(*EventCreate, int)) *EventCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &EventCreateBulk{err: fmt.Errorf("calling to EventClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*EventCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &EventCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for Event.
func (c *EventClient) Update() *EventUpdate {
mutation := newEventMutation(c.config, OpUpdate)
@ -661,7 +947,7 @@ func (c *EventClient) DeleteOne(e *Event) *EventDeleteOne {
return c.DeleteOneID(e.ID)
}
// DeleteOne returns a builder for deleting the given entity by its id.
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *EventClient) DeleteOneID(id int) *EventDeleteOne {
builder := c.Delete().Where(event.ID(id))
builder.mutation.id = &id
@ -673,6 +959,8 @@ func (c *EventClient) DeleteOneID(id int) *EventDeleteOne {
func (c *EventClient) Query() *EventQuery {
return &EventQuery{
config: c.config,
ctx: &QueryContext{Type: TypeEvent},
inters: c.Interceptors(),
}
}
@ -692,8 +980,8 @@ func (c *EventClient) GetX(ctx context.Context, id int) *Event {
// QueryOwner queries the owner edge of a Event.
func (c *EventClient) QueryOwner(e *Event) *AlertQuery {
query := &AlertQuery{config: c.config}
query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) {
query := (&AlertClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := e.ID
step := sqlgraph.NewStep(
sqlgraph.From(event.Table, event.FieldID, id),
@ -711,6 +999,159 @@ func (c *EventClient) Hooks() []Hook {
return c.hooks.Event
}
// Interceptors returns the client interceptors.
func (c *EventClient) Interceptors() []Interceptor {
return c.inters.Event
}
func (c *EventClient) mutate(ctx context.Context, m *EventMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&EventCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&EventUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&EventUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&EventDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown Event mutation op: %q", m.Op())
}
}
// LockClient is a client for the Lock schema.
type LockClient struct {
config
}
// NewLockClient returns a client for the Lock from the given config.
func NewLockClient(c config) *LockClient {
return &LockClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `lock.Hooks(f(g(h())))`.
func (c *LockClient) Use(hooks ...Hook) {
c.hooks.Lock = append(c.hooks.Lock, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `lock.Intercept(f(g(h())))`.
func (c *LockClient) Intercept(interceptors ...Interceptor) {
c.inters.Lock = append(c.inters.Lock, interceptors...)
}
// Create returns a builder for creating a Lock entity.
func (c *LockClient) Create() *LockCreate {
mutation := newLockMutation(c.config, OpCreate)
return &LockCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of Lock entities.
func (c *LockClient) CreateBulk(builders ...*LockCreate) *LockCreateBulk {
return &LockCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *LockClient) MapCreateBulk(slice any, setFunc func(*LockCreate, int)) *LockCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &LockCreateBulk{err: fmt.Errorf("calling to LockClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*LockCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &LockCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for Lock.
func (c *LockClient) Update() *LockUpdate {
mutation := newLockMutation(c.config, OpUpdate)
return &LockUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *LockClient) UpdateOne(l *Lock) *LockUpdateOne {
mutation := newLockMutation(c.config, OpUpdateOne, withLock(l))
return &LockUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *LockClient) UpdateOneID(id int) *LockUpdateOne {
mutation := newLockMutation(c.config, OpUpdateOne, withLockID(id))
return &LockUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for Lock.
func (c *LockClient) Delete() *LockDelete {
mutation := newLockMutation(c.config, OpDelete)
return &LockDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *LockClient) DeleteOne(l *Lock) *LockDeleteOne {
return c.DeleteOneID(l.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *LockClient) DeleteOneID(id int) *LockDeleteOne {
builder := c.Delete().Where(lock.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &LockDeleteOne{builder}
}
// Query returns a query builder for Lock.
func (c *LockClient) Query() *LockQuery {
return &LockQuery{
config: c.config,
ctx: &QueryContext{Type: TypeLock},
inters: c.Interceptors(),
}
}
// Get returns a Lock entity by its id.
func (c *LockClient) Get(ctx context.Context, id int) (*Lock, error) {
return c.Query().Where(lock.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *LockClient) GetX(ctx context.Context, id int) *Lock {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// Hooks returns the client hooks.
func (c *LockClient) Hooks() []Hook {
return c.hooks.Lock
}
// Interceptors returns the client interceptors.
func (c *LockClient) Interceptors() []Interceptor {
return c.inters.Lock
}
func (c *LockClient) mutate(ctx context.Context, m *LockMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&LockCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&LockUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&LockUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&LockDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown Lock mutation op: %q", m.Op())
}
}
// MachineClient is a client for the Machine schema.
type MachineClient struct {
config
@ -727,6 +1168,12 @@ func (c *MachineClient) Use(hooks ...Hook) {
c.hooks.Machine = append(c.hooks.Machine, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `machine.Intercept(f(g(h())))`.
func (c *MachineClient) Intercept(interceptors ...Interceptor) {
c.inters.Machine = append(c.inters.Machine, interceptors...)
}
// Create returns a builder for creating a Machine entity.
func (c *MachineClient) Create() *MachineCreate {
mutation := newMachineMutation(c.config, OpCreate)
@ -738,6 +1185,21 @@ func (c *MachineClient) CreateBulk(builders ...*MachineCreate) *MachineCreateBul
return &MachineCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *MachineClient) MapCreateBulk(slice any, setFunc func(*MachineCreate, int)) *MachineCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &MachineCreateBulk{err: fmt.Errorf("calling to MachineClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*MachineCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &MachineCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for Machine.
func (c *MachineClient) Update() *MachineUpdate {
mutation := newMachineMutation(c.config, OpUpdate)
@ -767,7 +1229,7 @@ func (c *MachineClient) DeleteOne(m *Machine) *MachineDeleteOne {
return c.DeleteOneID(m.ID)
}
// DeleteOne returns a builder for deleting the given entity by its id.
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *MachineClient) DeleteOneID(id int) *MachineDeleteOne {
builder := c.Delete().Where(machine.ID(id))
builder.mutation.id = &id
@ -779,6 +1241,8 @@ func (c *MachineClient) DeleteOneID(id int) *MachineDeleteOne {
func (c *MachineClient) Query() *MachineQuery {
return &MachineQuery{
config: c.config,
ctx: &QueryContext{Type: TypeMachine},
inters: c.Interceptors(),
}
}
@ -798,8 +1262,8 @@ func (c *MachineClient) GetX(ctx context.Context, id int) *Machine {
// QueryAlerts queries the alerts edge of a Machine.
func (c *MachineClient) QueryAlerts(m *Machine) *AlertQuery {
query := &AlertQuery{config: c.config}
query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) {
query := (&AlertClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := m.ID
step := sqlgraph.NewStep(
sqlgraph.From(machine.Table, machine.FieldID, id),
@ -817,6 +1281,26 @@ func (c *MachineClient) Hooks() []Hook {
return c.hooks.Machine
}
// Interceptors returns the client interceptors.
func (c *MachineClient) Interceptors() []Interceptor {
return c.inters.Machine
}
func (c *MachineClient) mutate(ctx context.Context, m *MachineMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&MachineCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&MachineUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&MachineUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&MachineDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown Machine mutation op: %q", m.Op())
}
}
// MetaClient is a client for the Meta schema.
type MetaClient struct {
config
@ -833,6 +1317,12 @@ func (c *MetaClient) Use(hooks ...Hook) {
c.hooks.Meta = append(c.hooks.Meta, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `meta.Intercept(f(g(h())))`.
func (c *MetaClient) Intercept(interceptors ...Interceptor) {
c.inters.Meta = append(c.inters.Meta, interceptors...)
}
// Create returns a builder for creating a Meta entity.
func (c *MetaClient) Create() *MetaCreate {
mutation := newMetaMutation(c.config, OpCreate)
@ -844,6 +1334,21 @@ func (c *MetaClient) CreateBulk(builders ...*MetaCreate) *MetaCreateBulk {
return &MetaCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *MetaClient) MapCreateBulk(slice any, setFunc func(*MetaCreate, int)) *MetaCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &MetaCreateBulk{err: fmt.Errorf("calling to MetaClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*MetaCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &MetaCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for Meta.
func (c *MetaClient) Update() *MetaUpdate {
mutation := newMetaMutation(c.config, OpUpdate)
@ -873,7 +1378,7 @@ func (c *MetaClient) DeleteOne(m *Meta) *MetaDeleteOne {
return c.DeleteOneID(m.ID)
}
// DeleteOne returns a builder for deleting the given entity by its id.
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *MetaClient) DeleteOneID(id int) *MetaDeleteOne {
builder := c.Delete().Where(meta.ID(id))
builder.mutation.id = &id
@ -885,6 +1390,8 @@ func (c *MetaClient) DeleteOneID(id int) *MetaDeleteOne {
func (c *MetaClient) Query() *MetaQuery {
return &MetaQuery{
config: c.config,
ctx: &QueryContext{Type: TypeMeta},
inters: c.Interceptors(),
}
}
@ -904,8 +1411,8 @@ func (c *MetaClient) GetX(ctx context.Context, id int) *Meta {
// QueryOwner queries the owner edge of a Meta.
func (c *MetaClient) QueryOwner(m *Meta) *AlertQuery {
query := &AlertQuery{config: c.config}
query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) {
query := (&AlertClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := m.ID
step := sqlgraph.NewStep(
sqlgraph.From(meta.Table, meta.FieldID, id),
@ -922,3 +1429,34 @@ func (c *MetaClient) QueryOwner(m *Meta) *AlertQuery {
func (c *MetaClient) Hooks() []Hook {
return c.hooks.Meta
}
// Interceptors returns the client interceptors.
func (c *MetaClient) Interceptors() []Interceptor {
return c.inters.Meta
}
func (c *MetaClient) mutate(ctx context.Context, m *MetaMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&MetaCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&MetaUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&MetaUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&MetaDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("ent: unknown Meta mutation op: %q", m.Op())
}
}
// hooks and interceptors per client, for fast access.
type (
hooks struct {
Alert, Bouncer, ConfigItem, Decision, Event, Lock, Machine, Meta []ent.Hook
}
inters struct {
Alert, Bouncer, ConfigItem, Decision, Event, Lock, Machine,
Meta []ent.Interceptor
}
)

View file

@ -1,65 +0,0 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"entgo.io/ent"
"entgo.io/ent/dialect"
)
// Option function to configure the client.
type Option func(*config)
// Config is the configuration for the client and its builder.
type config struct {
// driver used for executing database requests.
driver dialect.Driver
// debug enable a debug logging.
debug bool
// log used for logging on debug mode.
log func(...any)
// hooks to execute on mutations.
hooks *hooks
}
// hooks per client, for fast access.
type hooks struct {
Alert []ent.Hook
Bouncer []ent.Hook
ConfigItem []ent.Hook
Decision []ent.Hook
Event []ent.Hook
Machine []ent.Hook
Meta []ent.Hook
}
// Options applies the options on the config object.
func (c *config) options(opts ...Option) {
for _, opt := range opts {
opt(c)
}
if c.debug {
c.driver = dialect.Debug(c.driver, c.log)
}
}
// Debug enables debug logging on the ent.Driver.
func Debug() Option {
return func(c *config) {
c.debug = true
}
}
// Log sets the logging function for debug mode.
func Log(fn func(...any)) Option {
return func(c *config) {
c.log = fn
}
}
// Driver configures the client driver.
func Driver(driver dialect.Driver) Option {
return func(c *config) {
c.driver = driver
}
}

View file

@ -7,6 +7,7 @@ import (
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem"
)
@ -23,7 +24,8 @@ type ConfigItem struct {
// Name holds the value of the "name" field.
Name string `json:"name"`
// Value holds the value of the "value" field.
Value string `json:"value"`
Value string `json:"value"`
selectValues sql.SelectValues
}
// scanValues returns the types for scanning values from sql.Rows.
@ -38,7 +40,7 @@ func (*ConfigItem) scanValues(columns []string) ([]any, error) {
case configitem.FieldCreatedAt, configitem.FieldUpdatedAt:
values[i] = new(sql.NullTime)
default:
return nil, fmt.Errorf("unexpected column %q for type ConfigItem", columns[i])
values[i] = new(sql.UnknownType)
}
}
return values, nil
@ -84,16 +86,24 @@ func (ci *ConfigItem) assignValues(columns []string, values []any) error {
} else if value.Valid {
ci.Value = value.String
}
default:
ci.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// GetValue returns the ent.Value that was dynamically selected and assigned to the ConfigItem.
// This includes values selected through modifiers, order, etc.
func (ci *ConfigItem) GetValue(name string) (ent.Value, error) {
return ci.selectValues.Get(name)
}
// Update returns a builder for updating this ConfigItem.
// Note that you need to call ConfigItem.Unwrap() before calling this method if this ConfigItem
// was returned from a transaction, and the transaction was committed or rolled back.
func (ci *ConfigItem) Update() *ConfigItemUpdateOne {
return (&ConfigItemClient{config: ci.config}).UpdateOne(ci)
return NewConfigItemClient(ci.config).UpdateOne(ci)
}
// Unwrap unwraps the ConfigItem entity that was returned from a transaction after it was closed,
@ -133,9 +143,3 @@ func (ci *ConfigItem) String() string {
// ConfigItems is a parsable slice of ConfigItem.
type ConfigItems []*ConfigItem
func (ci ConfigItems) config(cfg config) {
for _i := range ci {
ci[_i].config = cfg
}
}

View file

@ -4,6 +4,8 @@ package configitem
import (
"time"
"entgo.io/ent/dialect/sql"
)
const (
@ -52,3 +54,31 @@ var (
// UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
UpdateDefaultUpdatedAt func() time.Time
)
// OrderOption defines the ordering options for the ConfigItem queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByName orders the results by the name field.
func ByName(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldName, opts...).ToFunc()
}
// ByValue orders the results by the value field.
func ByValue(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldValue, opts...).ToFunc()
}

View file

@ -11,485 +11,310 @@ import (
// ID filters vertices based on their ID field.
func ID(id int) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldID), id))
})
return predicate.ConfigItem(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldID), id))
})
return predicate.ConfigItem(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.NEQ(s.C(FieldID), id))
})
return predicate.ConfigItem(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
v := make([]any, len(ids))
for i := range v {
v[i] = ids[i]
}
s.Where(sql.In(s.C(FieldID), v...))
})
return predicate.ConfigItem(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
v := make([]any, len(ids))
for i := range v {
v[i] = ids[i]
}
s.Where(sql.NotIn(s.C(FieldID), v...))
})
return predicate.ConfigItem(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.GT(s.C(FieldID), id))
})
return predicate.ConfigItem(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.GTE(s.C(FieldID), id))
})
return predicate.ConfigItem(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.LT(s.C(FieldID), id))
})
return predicate.ConfigItem(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.LTE(s.C(FieldID), id))
})
return predicate.ConfigItem(sql.FieldLTE(FieldID, id))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldCreatedAt), v))
})
return predicate.ConfigItem(sql.FieldEQ(FieldCreatedAt, v))
}
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
func UpdatedAt(v time.Time) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldUpdatedAt), v))
})
return predicate.ConfigItem(sql.FieldEQ(FieldUpdatedAt, v))
}
// Name applies equality check predicate on the "name" field. It's identical to NameEQ.
func Name(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldName), v))
})
return predicate.ConfigItem(sql.FieldEQ(FieldName, v))
}
// Value applies equality check predicate on the "value" field. It's identical to ValueEQ.
func Value(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldValue), v))
})
return predicate.ConfigItem(sql.FieldEQ(FieldValue, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldCreatedAt), v))
})
return predicate.ConfigItem(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.NEQ(s.C(FieldCreatedAt), v))
})
return predicate.ConfigItem(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.ConfigItem {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.In(s.C(FieldCreatedAt), v...))
})
return predicate.ConfigItem(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.ConfigItem {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.NotIn(s.C(FieldCreatedAt), v...))
})
return predicate.ConfigItem(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.GT(s.C(FieldCreatedAt), v))
})
return predicate.ConfigItem(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.GTE(s.C(FieldCreatedAt), v))
})
return predicate.ConfigItem(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.LT(s.C(FieldCreatedAt), v))
})
return predicate.ConfigItem(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.LTE(s.C(FieldCreatedAt), v))
})
return predicate.ConfigItem(sql.FieldLTE(FieldCreatedAt, v))
}
// CreatedAtIsNil applies the IsNil predicate on the "created_at" field.
func CreatedAtIsNil() predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.IsNull(s.C(FieldCreatedAt)))
})
return predicate.ConfigItem(sql.FieldIsNull(FieldCreatedAt))
}
// CreatedAtNotNil applies the NotNil predicate on the "created_at" field.
func CreatedAtNotNil() predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.NotNull(s.C(FieldCreatedAt)))
})
return predicate.ConfigItem(sql.FieldNotNull(FieldCreatedAt))
}
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
func UpdatedAtEQ(v time.Time) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldUpdatedAt), v))
})
return predicate.ConfigItem(sql.FieldEQ(FieldUpdatedAt, v))
}
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
func UpdatedAtNEQ(v time.Time) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.NEQ(s.C(FieldUpdatedAt), v))
})
return predicate.ConfigItem(sql.FieldNEQ(FieldUpdatedAt, v))
}
// UpdatedAtIn applies the In predicate on the "updated_at" field.
func UpdatedAtIn(vs ...time.Time) predicate.ConfigItem {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.In(s.C(FieldUpdatedAt), v...))
})
return predicate.ConfigItem(sql.FieldIn(FieldUpdatedAt, vs...))
}
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
func UpdatedAtNotIn(vs ...time.Time) predicate.ConfigItem {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...))
})
return predicate.ConfigItem(sql.FieldNotIn(FieldUpdatedAt, vs...))
}
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
func UpdatedAtGT(v time.Time) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.GT(s.C(FieldUpdatedAt), v))
})
return predicate.ConfigItem(sql.FieldGT(FieldUpdatedAt, v))
}
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
func UpdatedAtGTE(v time.Time) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.GTE(s.C(FieldUpdatedAt), v))
})
return predicate.ConfigItem(sql.FieldGTE(FieldUpdatedAt, v))
}
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
func UpdatedAtLT(v time.Time) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.LT(s.C(FieldUpdatedAt), v))
})
return predicate.ConfigItem(sql.FieldLT(FieldUpdatedAt, v))
}
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
func UpdatedAtLTE(v time.Time) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.LTE(s.C(FieldUpdatedAt), v))
})
return predicate.ConfigItem(sql.FieldLTE(FieldUpdatedAt, v))
}
// UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field.
func UpdatedAtIsNil() predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.IsNull(s.C(FieldUpdatedAt)))
})
return predicate.ConfigItem(sql.FieldIsNull(FieldUpdatedAt))
}
// UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field.
func UpdatedAtNotNil() predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.NotNull(s.C(FieldUpdatedAt)))
})
return predicate.ConfigItem(sql.FieldNotNull(FieldUpdatedAt))
}
// NameEQ applies the EQ predicate on the "name" field.
func NameEQ(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldName), v))
})
return predicate.ConfigItem(sql.FieldEQ(FieldName, v))
}
// NameNEQ applies the NEQ predicate on the "name" field.
func NameNEQ(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.NEQ(s.C(FieldName), v))
})
return predicate.ConfigItem(sql.FieldNEQ(FieldName, v))
}
// NameIn applies the In predicate on the "name" field.
func NameIn(vs ...string) predicate.ConfigItem {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.In(s.C(FieldName), v...))
})
return predicate.ConfigItem(sql.FieldIn(FieldName, vs...))
}
// NameNotIn applies the NotIn predicate on the "name" field.
func NameNotIn(vs ...string) predicate.ConfigItem {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.NotIn(s.C(FieldName), v...))
})
return predicate.ConfigItem(sql.FieldNotIn(FieldName, vs...))
}
// NameGT applies the GT predicate on the "name" field.
func NameGT(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.GT(s.C(FieldName), v))
})
return predicate.ConfigItem(sql.FieldGT(FieldName, v))
}
// NameGTE applies the GTE predicate on the "name" field.
func NameGTE(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.GTE(s.C(FieldName), v))
})
return predicate.ConfigItem(sql.FieldGTE(FieldName, v))
}
// NameLT applies the LT predicate on the "name" field.
func NameLT(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.LT(s.C(FieldName), v))
})
return predicate.ConfigItem(sql.FieldLT(FieldName, v))
}
// NameLTE applies the LTE predicate on the "name" field.
func NameLTE(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.LTE(s.C(FieldName), v))
})
return predicate.ConfigItem(sql.FieldLTE(FieldName, v))
}
// NameContains applies the Contains predicate on the "name" field.
func NameContains(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.Contains(s.C(FieldName), v))
})
return predicate.ConfigItem(sql.FieldContains(FieldName, v))
}
// NameHasPrefix applies the HasPrefix predicate on the "name" field.
func NameHasPrefix(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.HasPrefix(s.C(FieldName), v))
})
return predicate.ConfigItem(sql.FieldHasPrefix(FieldName, v))
}
// NameHasSuffix applies the HasSuffix predicate on the "name" field.
func NameHasSuffix(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.HasSuffix(s.C(FieldName), v))
})
return predicate.ConfigItem(sql.FieldHasSuffix(FieldName, v))
}
// NameEqualFold applies the EqualFold predicate on the "name" field.
func NameEqualFold(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.EqualFold(s.C(FieldName), v))
})
return predicate.ConfigItem(sql.FieldEqualFold(FieldName, v))
}
// NameContainsFold applies the ContainsFold predicate on the "name" field.
func NameContainsFold(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.ContainsFold(s.C(FieldName), v))
})
return predicate.ConfigItem(sql.FieldContainsFold(FieldName, v))
}
// ValueEQ applies the EQ predicate on the "value" field.
func ValueEQ(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldValue), v))
})
return predicate.ConfigItem(sql.FieldEQ(FieldValue, v))
}
// ValueNEQ applies the NEQ predicate on the "value" field.
func ValueNEQ(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.NEQ(s.C(FieldValue), v))
})
return predicate.ConfigItem(sql.FieldNEQ(FieldValue, v))
}
// ValueIn applies the In predicate on the "value" field.
func ValueIn(vs ...string) predicate.ConfigItem {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.In(s.C(FieldValue), v...))
})
return predicate.ConfigItem(sql.FieldIn(FieldValue, vs...))
}
// ValueNotIn applies the NotIn predicate on the "value" field.
func ValueNotIn(vs ...string) predicate.ConfigItem {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.NotIn(s.C(FieldValue), v...))
})
return predicate.ConfigItem(sql.FieldNotIn(FieldValue, vs...))
}
// ValueGT applies the GT predicate on the "value" field.
func ValueGT(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.GT(s.C(FieldValue), v))
})
return predicate.ConfigItem(sql.FieldGT(FieldValue, v))
}
// ValueGTE applies the GTE predicate on the "value" field.
func ValueGTE(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.GTE(s.C(FieldValue), v))
})
return predicate.ConfigItem(sql.FieldGTE(FieldValue, v))
}
// ValueLT applies the LT predicate on the "value" field.
func ValueLT(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.LT(s.C(FieldValue), v))
})
return predicate.ConfigItem(sql.FieldLT(FieldValue, v))
}
// ValueLTE applies the LTE predicate on the "value" field.
func ValueLTE(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.LTE(s.C(FieldValue), v))
})
return predicate.ConfigItem(sql.FieldLTE(FieldValue, v))
}
// ValueContains applies the Contains predicate on the "value" field.
func ValueContains(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.Contains(s.C(FieldValue), v))
})
return predicate.ConfigItem(sql.FieldContains(FieldValue, v))
}
// ValueHasPrefix applies the HasPrefix predicate on the "value" field.
func ValueHasPrefix(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.HasPrefix(s.C(FieldValue), v))
})
return predicate.ConfigItem(sql.FieldHasPrefix(FieldValue, v))
}
// ValueHasSuffix applies the HasSuffix predicate on the "value" field.
func ValueHasSuffix(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.HasSuffix(s.C(FieldValue), v))
})
return predicate.ConfigItem(sql.FieldHasSuffix(FieldValue, v))
}
// ValueEqualFold applies the EqualFold predicate on the "value" field.
func ValueEqualFold(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.EqualFold(s.C(FieldValue), v))
})
return predicate.ConfigItem(sql.FieldEqualFold(FieldValue, v))
}
// ValueContainsFold applies the ContainsFold predicate on the "value" field.
func ValueContainsFold(v string) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s.Where(sql.ContainsFold(s.C(FieldValue), v))
})
return predicate.ConfigItem(sql.FieldContainsFold(FieldValue, v))
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.ConfigItem) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s1 := s.Clone().SetP(nil)
for _, p := range predicates {
p(s1)
}
s.Where(s1.P())
})
return predicate.ConfigItem(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.ConfigItem) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
s1 := s.Clone().SetP(nil)
for i, p := range predicates {
if i > 0 {
s1.Or()
}
p(s1)
}
s.Where(s1.P())
})
return predicate.ConfigItem(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.ConfigItem) predicate.ConfigItem {
return predicate.ConfigItem(func(s *sql.Selector) {
p(s.Not())
})
return predicate.ConfigItem(sql.NotPredicates(p))
}

View file

@ -67,50 +67,8 @@ func (cic *ConfigItemCreate) Mutation() *ConfigItemMutation {
// Save creates the ConfigItem in the database.
func (cic *ConfigItemCreate) Save(ctx context.Context) (*ConfigItem, error) {
var (
err error
node *ConfigItem
)
cic.defaults()
if len(cic.hooks) == 0 {
if err = cic.check(); err != nil {
return nil, err
}
node, err = cic.sqlSave(ctx)
} else {
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*ConfigItemMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
if err = cic.check(); err != nil {
return nil, err
}
cic.mutation = mutation
if node, err = cic.sqlSave(ctx); err != nil {
return nil, err
}
mutation.id = &node.ID
mutation.done = true
return node, err
})
for i := len(cic.hooks) - 1; i >= 0; i-- {
if cic.hooks[i] == nil {
return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = cic.hooks[i](mut)
}
v, err := mut.Mutate(ctx, cic.mutation)
if err != nil {
return nil, err
}
nv, ok := v.(*ConfigItem)
if !ok {
return nil, fmt.Errorf("unexpected node type %T returned from ConfigItemMutation", v)
}
node = nv
}
return node, err
return withHooks(ctx, cic.sqlSave, cic.mutation, cic.hooks)
}
// SaveX calls Save and panics if Save returns an error.
@ -159,6 +117,9 @@ func (cic *ConfigItemCreate) check() error {
}
func (cic *ConfigItemCreate) sqlSave(ctx context.Context) (*ConfigItem, error) {
if err := cic.check(); err != nil {
return nil, err
}
_node, _spec := cic.createSpec()
if err := sqlgraph.CreateNode(ctx, cic.driver, _spec); err != nil {
if sqlgraph.IsConstraintError(err) {
@ -168,50 +129,30 @@ func (cic *ConfigItemCreate) sqlSave(ctx context.Context) (*ConfigItem, error) {
}
id := _spec.ID.Value.(int64)
_node.ID = int(id)
cic.mutation.id = &_node.ID
cic.mutation.done = true
return _node, nil
}
func (cic *ConfigItemCreate) createSpec() (*ConfigItem, *sqlgraph.CreateSpec) {
var (
_node = &ConfigItem{config: cic.config}
_spec = &sqlgraph.CreateSpec{
Table: configitem.Table,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: configitem.FieldID,
},
}
_spec = sqlgraph.NewCreateSpec(configitem.Table, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt))
)
if value, ok := cic.mutation.CreatedAt(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: configitem.FieldCreatedAt,
})
_spec.SetField(configitem.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = &value
}
if value, ok := cic.mutation.UpdatedAt(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: configitem.FieldUpdatedAt,
})
_spec.SetField(configitem.FieldUpdatedAt, field.TypeTime, value)
_node.UpdatedAt = &value
}
if value, ok := cic.mutation.Name(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: configitem.FieldName,
})
_spec.SetField(configitem.FieldName, field.TypeString, value)
_node.Name = value
}
if value, ok := cic.mutation.Value(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: configitem.FieldValue,
})
_spec.SetField(configitem.FieldValue, field.TypeString, value)
_node.Value = value
}
return _node, _spec
@ -220,11 +161,15 @@ func (cic *ConfigItemCreate) createSpec() (*ConfigItem, *sqlgraph.CreateSpec) {
// ConfigItemCreateBulk is the builder for creating many ConfigItem entities in bulk.
type ConfigItemCreateBulk struct {
config
err error
builders []*ConfigItemCreate
}
// Save creates the ConfigItem entities in the database.
func (cicb *ConfigItemCreateBulk) Save(ctx context.Context) ([]*ConfigItem, error) {
if cicb.err != nil {
return nil, cicb.err
}
specs := make([]*sqlgraph.CreateSpec, len(cicb.builders))
nodes := make([]*ConfigItem, len(cicb.builders))
mutators := make([]Mutator, len(cicb.builders))
@ -241,8 +186,8 @@ func (cicb *ConfigItemCreateBulk) Save(ctx context.Context) ([]*ConfigItem, erro
return nil, err
}
builder.mutation = mutation
nodes[i], specs[i] = builder.createSpec()
var err error
nodes[i], specs[i] = builder.createSpec()
if i < len(mutators)-1 {
_, err = mutators[i+1].Mutate(root, cicb.builders[i+1].mutation)
} else {

View file

@ -4,7 +4,6 @@ package ent
import (
"context"
"fmt"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
@ -28,34 +27,7 @@ func (cid *ConfigItemDelete) Where(ps ...predicate.ConfigItem) *ConfigItemDelete
// Exec executes the deletion query and returns how many vertices were deleted.
func (cid *ConfigItemDelete) Exec(ctx context.Context) (int, error) {
var (
err error
affected int
)
if len(cid.hooks) == 0 {
affected, err = cid.sqlExec(ctx)
} else {
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*ConfigItemMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
cid.mutation = mutation
affected, err = cid.sqlExec(ctx)
mutation.done = true
return affected, err
})
for i := len(cid.hooks) - 1; i >= 0; i-- {
if cid.hooks[i] == nil {
return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = cid.hooks[i](mut)
}
if _, err := mut.Mutate(ctx, cid.mutation); err != nil {
return 0, err
}
}
return affected, err
return withHooks(ctx, cid.sqlExec, cid.mutation, cid.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
@ -68,15 +40,7 @@ func (cid *ConfigItemDelete) ExecX(ctx context.Context) int {
}
func (cid *ConfigItemDelete) sqlExec(ctx context.Context) (int, error) {
_spec := &sqlgraph.DeleteSpec{
Node: &sqlgraph.NodeSpec{
Table: configitem.Table,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: configitem.FieldID,
},
},
}
_spec := sqlgraph.NewDeleteSpec(configitem.Table, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt))
if ps := cid.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
@ -88,6 +52,7 @@ func (cid *ConfigItemDelete) sqlExec(ctx context.Context) (int, error) {
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
cid.mutation.done = true
return affected, err
}
@ -96,6 +61,12 @@ type ConfigItemDeleteOne struct {
cid *ConfigItemDelete
}
// Where appends a list predicates to the ConfigItemDelete builder.
func (cido *ConfigItemDeleteOne) Where(ps ...predicate.ConfigItem) *ConfigItemDeleteOne {
cido.cid.mutation.Where(ps...)
return cido
}
// Exec executes the deletion query.
func (cido *ConfigItemDeleteOne) Exec(ctx context.Context) error {
n, err := cido.cid.Exec(ctx)
@ -111,5 +82,7 @@ func (cido *ConfigItemDeleteOne) Exec(ctx context.Context) error {
// ExecX is like Exec, but panics if an error occurs.
func (cido *ConfigItemDeleteOne) ExecX(ctx context.Context) {
cido.cid.ExecX(ctx)
if err := cido.Exec(ctx); err != nil {
panic(err)
}
}

View file

@ -17,11 +17,9 @@ import (
// ConfigItemQuery is the builder for querying ConfigItem entities.
type ConfigItemQuery struct {
config
limit *int
offset *int
unique *bool
order []OrderFunc
fields []string
ctx *QueryContext
order []configitem.OrderOption
inters []Interceptor
predicates []predicate.ConfigItem
// intermediate query (i.e. traversal path).
sql *sql.Selector
@ -34,27 +32,27 @@ func (ciq *ConfigItemQuery) Where(ps ...predicate.ConfigItem) *ConfigItemQuery {
return ciq
}
// Limit adds a limit step to the query.
// Limit the number of records to be returned by this query.
func (ciq *ConfigItemQuery) Limit(limit int) *ConfigItemQuery {
ciq.limit = &limit
ciq.ctx.Limit = &limit
return ciq
}
// Offset adds an offset step to the query.
// Offset to start from.
func (ciq *ConfigItemQuery) Offset(offset int) *ConfigItemQuery {
ciq.offset = &offset
ciq.ctx.Offset = &offset
return ciq
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (ciq *ConfigItemQuery) Unique(unique bool) *ConfigItemQuery {
ciq.unique = &unique
ciq.ctx.Unique = &unique
return ciq
}
// Order adds an order step to the query.
func (ciq *ConfigItemQuery) Order(o ...OrderFunc) *ConfigItemQuery {
// Order specifies how the records should be ordered.
func (ciq *ConfigItemQuery) Order(o ...configitem.OrderOption) *ConfigItemQuery {
ciq.order = append(ciq.order, o...)
return ciq
}
@ -62,7 +60,7 @@ func (ciq *ConfigItemQuery) Order(o ...OrderFunc) *ConfigItemQuery {
// First returns the first ConfigItem entity from the query.
// Returns a *NotFoundError when no ConfigItem was found.
func (ciq *ConfigItemQuery) First(ctx context.Context) (*ConfigItem, error) {
nodes, err := ciq.Limit(1).All(ctx)
nodes, err := ciq.Limit(1).All(setContextOp(ctx, ciq.ctx, "First"))
if err != nil {
return nil, err
}
@ -85,7 +83,7 @@ func (ciq *ConfigItemQuery) FirstX(ctx context.Context) *ConfigItem {
// Returns a *NotFoundError when no ConfigItem ID was found.
func (ciq *ConfigItemQuery) FirstID(ctx context.Context) (id int, err error) {
var ids []int
if ids, err = ciq.Limit(1).IDs(ctx); err != nil {
if ids, err = ciq.Limit(1).IDs(setContextOp(ctx, ciq.ctx, "FirstID")); err != nil {
return
}
if len(ids) == 0 {
@ -108,7 +106,7 @@ func (ciq *ConfigItemQuery) FirstIDX(ctx context.Context) int {
// Returns a *NotSingularError when more than one ConfigItem entity is found.
// Returns a *NotFoundError when no ConfigItem entities are found.
func (ciq *ConfigItemQuery) Only(ctx context.Context) (*ConfigItem, error) {
nodes, err := ciq.Limit(2).All(ctx)
nodes, err := ciq.Limit(2).All(setContextOp(ctx, ciq.ctx, "Only"))
if err != nil {
return nil, err
}
@ -136,7 +134,7 @@ func (ciq *ConfigItemQuery) OnlyX(ctx context.Context) *ConfigItem {
// Returns a *NotFoundError when no entities are found.
func (ciq *ConfigItemQuery) OnlyID(ctx context.Context) (id int, err error) {
var ids []int
if ids, err = ciq.Limit(2).IDs(ctx); err != nil {
if ids, err = ciq.Limit(2).IDs(setContextOp(ctx, ciq.ctx, "OnlyID")); err != nil {
return
}
switch len(ids) {
@ -161,10 +159,12 @@ func (ciq *ConfigItemQuery) OnlyIDX(ctx context.Context) int {
// All executes the query and returns a list of ConfigItems.
func (ciq *ConfigItemQuery) All(ctx context.Context) ([]*ConfigItem, error) {
ctx = setContextOp(ctx, ciq.ctx, "All")
if err := ciq.prepareQuery(ctx); err != nil {
return nil, err
}
return ciq.sqlAll(ctx)
qr := querierAll[[]*ConfigItem, *ConfigItemQuery]()
return withInterceptors[[]*ConfigItem](ctx, ciq, qr, ciq.inters)
}
// AllX is like All, but panics if an error occurs.
@ -177,9 +177,12 @@ func (ciq *ConfigItemQuery) AllX(ctx context.Context) []*ConfigItem {
}
// IDs executes the query and returns a list of ConfigItem IDs.
func (ciq *ConfigItemQuery) IDs(ctx context.Context) ([]int, error) {
var ids []int
if err := ciq.Select(configitem.FieldID).Scan(ctx, &ids); err != nil {
func (ciq *ConfigItemQuery) IDs(ctx context.Context) (ids []int, err error) {
if ciq.ctx.Unique == nil && ciq.path != nil {
ciq.Unique(true)
}
ctx = setContextOp(ctx, ciq.ctx, "IDs")
if err = ciq.Select(configitem.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
@ -196,10 +199,11 @@ func (ciq *ConfigItemQuery) IDsX(ctx context.Context) []int {
// Count returns the count of the given query.
func (ciq *ConfigItemQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, ciq.ctx, "Count")
if err := ciq.prepareQuery(ctx); err != nil {
return 0, err
}
return ciq.sqlCount(ctx)
return withInterceptors[int](ctx, ciq, querierCount[*ConfigItemQuery](), ciq.inters)
}
// CountX is like Count, but panics if an error occurs.
@ -213,10 +217,15 @@ func (ciq *ConfigItemQuery) CountX(ctx context.Context) int {
// Exist returns true if the query has elements in the graph.
func (ciq *ConfigItemQuery) Exist(ctx context.Context) (bool, error) {
if err := ciq.prepareQuery(ctx); err != nil {
return false, err
ctx = setContextOp(ctx, ciq.ctx, "Exist")
switch _, err := ciq.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
return ciq.sqlExist(ctx)
}
// ExistX is like Exist, but panics if an error occurs.
@ -236,14 +245,13 @@ func (ciq *ConfigItemQuery) Clone() *ConfigItemQuery {
}
return &ConfigItemQuery{
config: ciq.config,
limit: ciq.limit,
offset: ciq.offset,
order: append([]OrderFunc{}, ciq.order...),
ctx: ciq.ctx.Clone(),
order: append([]configitem.OrderOption{}, ciq.order...),
inters: append([]Interceptor{}, ciq.inters...),
predicates: append([]predicate.ConfigItem{}, ciq.predicates...),
// clone intermediate query.
sql: ciq.sql.Clone(),
path: ciq.path,
unique: ciq.unique,
sql: ciq.sql.Clone(),
path: ciq.path,
}
}
@ -262,16 +270,11 @@ func (ciq *ConfigItemQuery) Clone() *ConfigItemQuery {
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (ciq *ConfigItemQuery) GroupBy(field string, fields ...string) *ConfigItemGroupBy {
grbuild := &ConfigItemGroupBy{config: ciq.config}
grbuild.fields = append([]string{field}, fields...)
grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) {
if err := ciq.prepareQuery(ctx); err != nil {
return nil, err
}
return ciq.sqlQuery(ctx), nil
}
ciq.ctx.Fields = append([]string{field}, fields...)
grbuild := &ConfigItemGroupBy{build: ciq}
grbuild.flds = &ciq.ctx.Fields
grbuild.label = configitem.Label
grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan
grbuild.scan = grbuild.Scan
return grbuild
}
@ -288,15 +291,30 @@ func (ciq *ConfigItemQuery) GroupBy(field string, fields ...string) *ConfigItemG
// Select(configitem.FieldCreatedAt).
// Scan(ctx, &v)
func (ciq *ConfigItemQuery) Select(fields ...string) *ConfigItemSelect {
ciq.fields = append(ciq.fields, fields...)
selbuild := &ConfigItemSelect{ConfigItemQuery: ciq}
selbuild.label = configitem.Label
selbuild.flds, selbuild.scan = &ciq.fields, selbuild.Scan
return selbuild
ciq.ctx.Fields = append(ciq.ctx.Fields, fields...)
sbuild := &ConfigItemSelect{ConfigItemQuery: ciq}
sbuild.label = configitem.Label
sbuild.flds, sbuild.scan = &ciq.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a ConfigItemSelect configured with the given aggregations.
func (ciq *ConfigItemQuery) Aggregate(fns ...AggregateFunc) *ConfigItemSelect {
return ciq.Select().Aggregate(fns...)
}
func (ciq *ConfigItemQuery) prepareQuery(ctx context.Context) error {
for _, f := range ciq.fields {
for _, inter := range ciq.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, ciq); err != nil {
return err
}
}
}
for _, f := range ciq.ctx.Fields {
if !configitem.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
@ -338,41 +356,22 @@ func (ciq *ConfigItemQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*
func (ciq *ConfigItemQuery) sqlCount(ctx context.Context) (int, error) {
_spec := ciq.querySpec()
_spec.Node.Columns = ciq.fields
if len(ciq.fields) > 0 {
_spec.Unique = ciq.unique != nil && *ciq.unique
_spec.Node.Columns = ciq.ctx.Fields
if len(ciq.ctx.Fields) > 0 {
_spec.Unique = ciq.ctx.Unique != nil && *ciq.ctx.Unique
}
return sqlgraph.CountNodes(ctx, ciq.driver, _spec)
}
func (ciq *ConfigItemQuery) sqlExist(ctx context.Context) (bool, error) {
switch _, err := ciq.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
func (ciq *ConfigItemQuery) querySpec() *sqlgraph.QuerySpec {
_spec := &sqlgraph.QuerySpec{
Node: &sqlgraph.NodeSpec{
Table: configitem.Table,
Columns: configitem.Columns,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: configitem.FieldID,
},
},
From: ciq.sql,
Unique: true,
}
if unique := ciq.unique; unique != nil {
_spec := sqlgraph.NewQuerySpec(configitem.Table, configitem.Columns, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt))
_spec.From = ciq.sql
if unique := ciq.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if ciq.path != nil {
_spec.Unique = true
}
if fields := ciq.fields; len(fields) > 0 {
if fields := ciq.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, configitem.FieldID)
for i := range fields {
@ -388,10 +387,10 @@ func (ciq *ConfigItemQuery) querySpec() *sqlgraph.QuerySpec {
}
}
}
if limit := ciq.limit; limit != nil {
if limit := ciq.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := ciq.offset; offset != nil {
if offset := ciq.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := ciq.order; len(ps) > 0 {
@ -407,7 +406,7 @@ func (ciq *ConfigItemQuery) querySpec() *sqlgraph.QuerySpec {
func (ciq *ConfigItemQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(ciq.driver.Dialect())
t1 := builder.Table(configitem.Table)
columns := ciq.fields
columns := ciq.ctx.Fields
if len(columns) == 0 {
columns = configitem.Columns
}
@ -416,7 +415,7 @@ func (ciq *ConfigItemQuery) sqlQuery(ctx context.Context) *sql.Selector {
selector = ciq.sql
selector.Select(selector.Columns(columns...)...)
}
if ciq.unique != nil && *ciq.unique {
if ciq.ctx.Unique != nil && *ciq.ctx.Unique {
selector.Distinct()
}
for _, p := range ciq.predicates {
@ -425,12 +424,12 @@ func (ciq *ConfigItemQuery) sqlQuery(ctx context.Context) *sql.Selector {
for _, p := range ciq.order {
p(selector)
}
if offset := ciq.offset; offset != nil {
if offset := ciq.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := ciq.limit; limit != nil {
if limit := ciq.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
@ -438,13 +437,8 @@ func (ciq *ConfigItemQuery) sqlQuery(ctx context.Context) *sql.Selector {
// ConfigItemGroupBy is the group-by builder for ConfigItem entities.
type ConfigItemGroupBy struct {
config
selector
fields []string
fns []AggregateFunc
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
build *ConfigItemQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
@ -453,74 +447,77 @@ func (cigb *ConfigItemGroupBy) Aggregate(fns ...AggregateFunc) *ConfigItemGroupB
return cigb
}
// Scan applies the group-by query and scans the result into the given value.
// Scan applies the selector query and scans the result into the given value.
func (cigb *ConfigItemGroupBy) Scan(ctx context.Context, v any) error {
query, err := cigb.path(ctx)
if err != nil {
ctx = setContextOp(ctx, cigb.build.ctx, "GroupBy")
if err := cigb.build.prepareQuery(ctx); err != nil {
return err
}
cigb.sql = query
return cigb.sqlScan(ctx, v)
return scanWithInterceptors[*ConfigItemQuery, *ConfigItemGroupBy](ctx, cigb.build, cigb, cigb.build.inters, v)
}
func (cigb *ConfigItemGroupBy) sqlScan(ctx context.Context, v any) error {
for _, f := range cigb.fields {
if !configitem.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)}
}
}
selector := cigb.sqlQuery()
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := cigb.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
func (cigb *ConfigItemGroupBy) sqlQuery() *sql.Selector {
selector := cigb.sql.Select()
func (cigb *ConfigItemGroupBy) sqlScan(ctx context.Context, root *ConfigItemQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(cigb.fns))
for _, fn := range cigb.fns {
aggregation = append(aggregation, fn(selector))
}
// If no columns were selected in a custom aggregation function, the default
// selection is the fields used for "group-by", and the aggregation functions.
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(cigb.fields)+len(cigb.fns))
for _, f := range cigb.fields {
columns := make([]string, 0, len(*cigb.flds)+len(cigb.fns))
for _, f := range *cigb.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
return selector.GroupBy(selector.Columns(cigb.fields...)...)
selector.GroupBy(selector.Columns(*cigb.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := cigb.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// ConfigItemSelect is the builder for selecting fields of ConfigItem entities.
type ConfigItemSelect struct {
*ConfigItemQuery
selector
// intermediate query (i.e. traversal path).
sql *sql.Selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (cis *ConfigItemSelect) Aggregate(fns ...AggregateFunc) *ConfigItemSelect {
cis.fns = append(cis.fns, fns...)
return cis
}
// Scan applies the selector query and scans the result into the given value.
func (cis *ConfigItemSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, cis.ctx, "Select")
if err := cis.prepareQuery(ctx); err != nil {
return err
}
cis.sql = cis.ConfigItemQuery.sqlQuery(ctx)
return cis.sqlScan(ctx, v)
return scanWithInterceptors[*ConfigItemQuery, *ConfigItemSelect](ctx, cis.ConfigItemQuery, cis, cis.inters, v)
}
func (cis *ConfigItemSelect) sqlScan(ctx context.Context, v any) error {
func (cis *ConfigItemSelect) sqlScan(ctx context.Context, root *ConfigItemQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(cis.fns))
for _, fn := range cis.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*cis.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := cis.sql.Query()
query, args := selector.Query()
if err := cis.driver.Query(ctx, query, args, rows); err != nil {
return err
}

View file

@ -71,35 +71,8 @@ func (ciu *ConfigItemUpdate) Mutation() *ConfigItemMutation {
// Save executes the query and returns the number of nodes affected by the update operation.
func (ciu *ConfigItemUpdate) Save(ctx context.Context) (int, error) {
var (
err error
affected int
)
ciu.defaults()
if len(ciu.hooks) == 0 {
affected, err = ciu.sqlSave(ctx)
} else {
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*ConfigItemMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
ciu.mutation = mutation
affected, err = ciu.sqlSave(ctx)
mutation.done = true
return affected, err
})
for i := len(ciu.hooks) - 1; i >= 0; i-- {
if ciu.hooks[i] == nil {
return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = ciu.hooks[i](mut)
}
if _, err := mut.Mutate(ctx, ciu.mutation); err != nil {
return 0, err
}
}
return affected, err
return withHooks(ctx, ciu.sqlSave, ciu.mutation, ciu.hooks)
}
// SaveX is like Save, but panics if an error occurs.
@ -137,16 +110,7 @@ func (ciu *ConfigItemUpdate) defaults() {
}
func (ciu *ConfigItemUpdate) sqlSave(ctx context.Context) (n int, err error) {
_spec := &sqlgraph.UpdateSpec{
Node: &sqlgraph.NodeSpec{
Table: configitem.Table,
Columns: configitem.Columns,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: configitem.FieldID,
},
},
}
_spec := sqlgraph.NewUpdateSpec(configitem.Table, configitem.Columns, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt))
if ps := ciu.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
@ -155,44 +119,22 @@ func (ciu *ConfigItemUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
}
if value, ok := ciu.mutation.CreatedAt(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: configitem.FieldCreatedAt,
})
_spec.SetField(configitem.FieldCreatedAt, field.TypeTime, value)
}
if ciu.mutation.CreatedAtCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: configitem.FieldCreatedAt,
})
_spec.ClearField(configitem.FieldCreatedAt, field.TypeTime)
}
if value, ok := ciu.mutation.UpdatedAt(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: configitem.FieldUpdatedAt,
})
_spec.SetField(configitem.FieldUpdatedAt, field.TypeTime, value)
}
if ciu.mutation.UpdatedAtCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: configitem.FieldUpdatedAt,
})
_spec.ClearField(configitem.FieldUpdatedAt, field.TypeTime)
}
if value, ok := ciu.mutation.Name(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: configitem.FieldName,
})
_spec.SetField(configitem.FieldName, field.TypeString, value)
}
if value, ok := ciu.mutation.Value(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: configitem.FieldValue,
})
_spec.SetField(configitem.FieldValue, field.TypeString, value)
}
if n, err = sqlgraph.UpdateNodes(ctx, ciu.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
@ -202,6 +144,7 @@ func (ciu *ConfigItemUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
return 0, err
}
ciu.mutation.done = true
return n, nil
}
@ -254,6 +197,12 @@ func (ciuo *ConfigItemUpdateOne) Mutation() *ConfigItemMutation {
return ciuo.mutation
}
// Where appends a list predicates to the ConfigItemUpdate builder.
func (ciuo *ConfigItemUpdateOne) Where(ps ...predicate.ConfigItem) *ConfigItemUpdateOne {
ciuo.mutation.Where(ps...)
return ciuo
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (ciuo *ConfigItemUpdateOne) Select(field string, fields ...string) *ConfigItemUpdateOne {
@ -263,41 +212,8 @@ func (ciuo *ConfigItemUpdateOne) Select(field string, fields ...string) *ConfigI
// Save executes the query and returns the updated ConfigItem entity.
func (ciuo *ConfigItemUpdateOne) Save(ctx context.Context) (*ConfigItem, error) {
var (
err error
node *ConfigItem
)
ciuo.defaults()
if len(ciuo.hooks) == 0 {
node, err = ciuo.sqlSave(ctx)
} else {
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*ConfigItemMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
ciuo.mutation = mutation
node, err = ciuo.sqlSave(ctx)
mutation.done = true
return node, err
})
for i := len(ciuo.hooks) - 1; i >= 0; i-- {
if ciuo.hooks[i] == nil {
return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = ciuo.hooks[i](mut)
}
v, err := mut.Mutate(ctx, ciuo.mutation)
if err != nil {
return nil, err
}
nv, ok := v.(*ConfigItem)
if !ok {
return nil, fmt.Errorf("unexpected node type %T returned from ConfigItemMutation", v)
}
node = nv
}
return node, err
return withHooks(ctx, ciuo.sqlSave, ciuo.mutation, ciuo.hooks)
}
// SaveX is like Save, but panics if an error occurs.
@ -335,16 +251,7 @@ func (ciuo *ConfigItemUpdateOne) defaults() {
}
func (ciuo *ConfigItemUpdateOne) sqlSave(ctx context.Context) (_node *ConfigItem, err error) {
_spec := &sqlgraph.UpdateSpec{
Node: &sqlgraph.NodeSpec{
Table: configitem.Table,
Columns: configitem.Columns,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: configitem.FieldID,
},
},
}
_spec := sqlgraph.NewUpdateSpec(configitem.Table, configitem.Columns, sqlgraph.NewFieldSpec(configitem.FieldID, field.TypeInt))
id, ok := ciuo.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ConfigItem.id" for update`)}
@ -370,44 +277,22 @@ func (ciuo *ConfigItemUpdateOne) sqlSave(ctx context.Context) (_node *ConfigItem
}
}
if value, ok := ciuo.mutation.CreatedAt(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: configitem.FieldCreatedAt,
})
_spec.SetField(configitem.FieldCreatedAt, field.TypeTime, value)
}
if ciuo.mutation.CreatedAtCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: configitem.FieldCreatedAt,
})
_spec.ClearField(configitem.FieldCreatedAt, field.TypeTime)
}
if value, ok := ciuo.mutation.UpdatedAt(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: configitem.FieldUpdatedAt,
})
_spec.SetField(configitem.FieldUpdatedAt, field.TypeTime, value)
}
if ciuo.mutation.UpdatedAtCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: configitem.FieldUpdatedAt,
})
_spec.ClearField(configitem.FieldUpdatedAt, field.TypeTime)
}
if value, ok := ciuo.mutation.Name(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: configitem.FieldName,
})
_spec.SetField(configitem.FieldName, field.TypeString, value)
}
if value, ok := ciuo.mutation.Value(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: configitem.FieldValue,
})
_spec.SetField(configitem.FieldValue, field.TypeString, value)
}
_node = &ConfigItem{config: ciuo.config}
_spec.Assign = _node.assignValues
@ -420,5 +305,6 @@ func (ciuo *ConfigItemUpdateOne) sqlSave(ctx context.Context) (_node *ConfigItem
}
return nil, err
}
ciuo.mutation.done = true
return _node, nil
}

View file

@ -1,33 +0,0 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"context"
)
type clientCtxKey struct{}
// FromContext returns a Client stored inside a context, or nil if there isn't one.
func FromContext(ctx context.Context) *Client {
c, _ := ctx.Value(clientCtxKey{}).(*Client)
return c
}
// NewContext returns a new context with the given Client attached.
func NewContext(parent context.Context, c *Client) context.Context {
return context.WithValue(parent, clientCtxKey{}, c)
}
type txCtxKey struct{}
// TxFromContext returns a Tx stored inside a context, or nil if there isn't one.
func TxFromContext(ctx context.Context) *Tx {
tx, _ := ctx.Value(txCtxKey{}).(*Tx)
return tx
}
// NewTxContext returns a new context with the given Tx attached.
func NewTxContext(parent context.Context, tx *Tx) context.Context {
return context.WithValue(parent, txCtxKey{}, tx)
}

View file

@ -7,6 +7,7 @@ import (
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/alert"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
@ -51,7 +52,8 @@ type Decision struct {
AlertDecisions int `json:"alert_decisions,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the DecisionQuery when eager-loading is set.
Edges DecisionEdges `json:"edges"`
Edges DecisionEdges `json:"edges"`
selectValues sql.SelectValues
}
// DecisionEdges holds the relations/edges for other nodes in the graph.
@ -90,7 +92,7 @@ func (*Decision) scanValues(columns []string) ([]any, error) {
case decision.FieldCreatedAt, decision.FieldUpdatedAt, decision.FieldUntil:
values[i] = new(sql.NullTime)
default:
return nil, fmt.Errorf("unexpected column %q for type Decision", columns[i])
values[i] = new(sql.UnknownType)
}
}
return values, nil
@ -209,21 +211,29 @@ func (d *Decision) assignValues(columns []string, values []any) error {
} else if value.Valid {
d.AlertDecisions = int(value.Int64)
}
default:
d.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// GetValue returns the ent.Value that was dynamically selected and assigned to the Decision.
// This includes values selected through modifiers, order, etc.
func (d *Decision) GetValue(name string) (ent.Value, error) {
return d.selectValues.Get(name)
}
// QueryOwner queries the "owner" edge of the Decision entity.
func (d *Decision) QueryOwner() *AlertQuery {
return (&DecisionClient{config: d.config}).QueryOwner(d)
return NewDecisionClient(d.config).QueryOwner(d)
}
// Update returns a builder for updating this Decision.
// Note that you need to call Decision.Unwrap() before calling this method if this Decision
// was returned from a transaction, and the transaction was committed or rolled back.
func (d *Decision) Update() *DecisionUpdateOne {
return (&DecisionClient{config: d.config}).UpdateOne(d)
return NewDecisionClient(d.config).UpdateOne(d)
}
// Unwrap unwraps the Decision entity that was returned from a transaction after it was closed,
@ -301,9 +311,3 @@ func (d *Decision) String() string {
// Decisions is a parsable slice of Decision.
type Decisions []*Decision
func (d Decisions) config(cfg config) {
for _i := range d {
d[_i].config = cfg
}
}

View file

@ -4,6 +4,9 @@ package decision
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
@ -99,3 +102,105 @@ var (
// DefaultSimulated holds the default value on creation for the "simulated" field.
DefaultSimulated bool
)
// OrderOption defines the ordering options for the Decision queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByUntil orders the results by the until field.
func ByUntil(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUntil, opts...).ToFunc()
}
// ByScenario orders the results by the scenario field.
func ByScenario(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldScenario, opts...).ToFunc()
}
// ByType orders the results by the type field.
func ByType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldType, opts...).ToFunc()
}
// ByStartIP orders the results by the start_ip field.
func ByStartIP(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStartIP, opts...).ToFunc()
}
// ByEndIP orders the results by the end_ip field.
func ByEndIP(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldEndIP, opts...).ToFunc()
}
// ByStartSuffix orders the results by the start_suffix field.
func ByStartSuffix(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStartSuffix, opts...).ToFunc()
}
// ByEndSuffix orders the results by the end_suffix field.
func ByEndSuffix(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldEndSuffix, opts...).ToFunc()
}
// ByIPSize orders the results by the ip_size field.
func ByIPSize(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldIPSize, opts...).ToFunc()
}
// ByScope orders the results by the scope field.
func ByScope(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldScope, opts...).ToFunc()
}
// ByValue orders the results by the value field.
func ByValue(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldValue, opts...).ToFunc()
}
// ByOrigin orders the results by the origin field.
func ByOrigin(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldOrigin, opts...).ToFunc()
}
// BySimulated orders the results by the simulated field.
func BySimulated(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSimulated, opts...).ToFunc()
}
// ByUUID orders the results by the uuid field.
func ByUUID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUUID, opts...).ToFunc()
}
// ByAlertDecisions orders the results by the alert_decisions field.
func ByAlertDecisions(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAlertDecisions, opts...).ToFunc()
}
// ByOwnerField orders the results by owner field.
func ByOwnerField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newOwnerStep(), sql.OrderByField(field, opts...))
}
}
func newOwnerStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(OwnerInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn),
)
}

File diff suppressed because it is too large Load diff

View file

@ -231,50 +231,8 @@ func (dc *DecisionCreate) Mutation() *DecisionMutation {
// Save creates the Decision in the database.
func (dc *DecisionCreate) Save(ctx context.Context) (*Decision, error) {
var (
err error
node *Decision
)
dc.defaults()
if len(dc.hooks) == 0 {
if err = dc.check(); err != nil {
return nil, err
}
node, err = dc.sqlSave(ctx)
} else {
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*DecisionMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
if err = dc.check(); err != nil {
return nil, err
}
dc.mutation = mutation
if node, err = dc.sqlSave(ctx); err != nil {
return nil, err
}
mutation.id = &node.ID
mutation.done = true
return node, err
})
for i := len(dc.hooks) - 1; i >= 0; i-- {
if dc.hooks[i] == nil {
return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = dc.hooks[i](mut)
}
v, err := mut.Mutate(ctx, dc.mutation)
if err != nil {
return nil, err
}
nv, ok := v.(*Decision)
if !ok {
return nil, fmt.Errorf("unexpected node type %T returned from DecisionMutation", v)
}
node = nv
}
return node, err
return withHooks(ctx, dc.sqlSave, dc.mutation, dc.hooks)
}
// SaveX calls Save and panics if Save returns an error.
@ -339,6 +297,9 @@ func (dc *DecisionCreate) check() error {
}
func (dc *DecisionCreate) sqlSave(ctx context.Context) (*Decision, error) {
if err := dc.check(); err != nil {
return nil, err
}
_node, _spec := dc.createSpec()
if err := sqlgraph.CreateNode(ctx, dc.driver, _spec); err != nil {
if sqlgraph.IsConstraintError(err) {
@ -348,138 +309,74 @@ func (dc *DecisionCreate) sqlSave(ctx context.Context) (*Decision, error) {
}
id := _spec.ID.Value.(int64)
_node.ID = int(id)
dc.mutation.id = &_node.ID
dc.mutation.done = true
return _node, nil
}
func (dc *DecisionCreate) createSpec() (*Decision, *sqlgraph.CreateSpec) {
var (
_node = &Decision{config: dc.config}
_spec = &sqlgraph.CreateSpec{
Table: decision.Table,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: decision.FieldID,
},
}
_spec = sqlgraph.NewCreateSpec(decision.Table, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt))
)
if value, ok := dc.mutation.CreatedAt(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: decision.FieldCreatedAt,
})
_spec.SetField(decision.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = &value
}
if value, ok := dc.mutation.UpdatedAt(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: decision.FieldUpdatedAt,
})
_spec.SetField(decision.FieldUpdatedAt, field.TypeTime, value)
_node.UpdatedAt = &value
}
if value, ok := dc.mutation.Until(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: decision.FieldUntil,
})
_spec.SetField(decision.FieldUntil, field.TypeTime, value)
_node.Until = &value
}
if value, ok := dc.mutation.Scenario(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: decision.FieldScenario,
})
_spec.SetField(decision.FieldScenario, field.TypeString, value)
_node.Scenario = value
}
if value, ok := dc.mutation.GetType(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: decision.FieldType,
})
_spec.SetField(decision.FieldType, field.TypeString, value)
_node.Type = value
}
if value, ok := dc.mutation.StartIP(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldStartIP,
})
_spec.SetField(decision.FieldStartIP, field.TypeInt64, value)
_node.StartIP = value
}
if value, ok := dc.mutation.EndIP(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldEndIP,
})
_spec.SetField(decision.FieldEndIP, field.TypeInt64, value)
_node.EndIP = value
}
if value, ok := dc.mutation.StartSuffix(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldStartSuffix,
})
_spec.SetField(decision.FieldStartSuffix, field.TypeInt64, value)
_node.StartSuffix = value
}
if value, ok := dc.mutation.EndSuffix(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldEndSuffix,
})
_spec.SetField(decision.FieldEndSuffix, field.TypeInt64, value)
_node.EndSuffix = value
}
if value, ok := dc.mutation.IPSize(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldIPSize,
})
_spec.SetField(decision.FieldIPSize, field.TypeInt64, value)
_node.IPSize = value
}
if value, ok := dc.mutation.Scope(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: decision.FieldScope,
})
_spec.SetField(decision.FieldScope, field.TypeString, value)
_node.Scope = value
}
if value, ok := dc.mutation.Value(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: decision.FieldValue,
})
_spec.SetField(decision.FieldValue, field.TypeString, value)
_node.Value = value
}
if value, ok := dc.mutation.Origin(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: decision.FieldOrigin,
})
_spec.SetField(decision.FieldOrigin, field.TypeString, value)
_node.Origin = value
}
if value, ok := dc.mutation.Simulated(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeBool,
Value: value,
Column: decision.FieldSimulated,
})
_spec.SetField(decision.FieldSimulated, field.TypeBool, value)
_node.Simulated = value
}
if value, ok := dc.mutation.UUID(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: decision.FieldUUID,
})
_spec.SetField(decision.FieldUUID, field.TypeString, value)
_node.UUID = value
}
if nodes := dc.mutation.OwnerIDs(); len(nodes) > 0 {
@ -490,10 +387,7 @@ func (dc *DecisionCreate) createSpec() (*Decision, *sqlgraph.CreateSpec) {
Columns: []string{decision.OwnerColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: alert.FieldID,
},
IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt),
},
}
for _, k := range nodes {
@ -508,11 +402,15 @@ func (dc *DecisionCreate) createSpec() (*Decision, *sqlgraph.CreateSpec) {
// DecisionCreateBulk is the builder for creating many Decision entities in bulk.
type DecisionCreateBulk struct {
config
err error
builders []*DecisionCreate
}
// Save creates the Decision entities in the database.
func (dcb *DecisionCreateBulk) Save(ctx context.Context) ([]*Decision, error) {
if dcb.err != nil {
return nil, dcb.err
}
specs := make([]*sqlgraph.CreateSpec, len(dcb.builders))
nodes := make([]*Decision, len(dcb.builders))
mutators := make([]Mutator, len(dcb.builders))
@ -529,8 +427,8 @@ func (dcb *DecisionCreateBulk) Save(ctx context.Context) ([]*Decision, error) {
return nil, err
}
builder.mutation = mutation
nodes[i], specs[i] = builder.createSpec()
var err error
nodes[i], specs[i] = builder.createSpec()
if i < len(mutators)-1 {
_, err = mutators[i+1].Mutate(root, dcb.builders[i+1].mutation)
} else {

View file

@ -4,7 +4,6 @@ package ent
import (
"context"
"fmt"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
@ -28,34 +27,7 @@ func (dd *DecisionDelete) Where(ps ...predicate.Decision) *DecisionDelete {
// Exec executes the deletion query and returns how many vertices were deleted.
func (dd *DecisionDelete) Exec(ctx context.Context) (int, error) {
var (
err error
affected int
)
if len(dd.hooks) == 0 {
affected, err = dd.sqlExec(ctx)
} else {
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*DecisionMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
dd.mutation = mutation
affected, err = dd.sqlExec(ctx)
mutation.done = true
return affected, err
})
for i := len(dd.hooks) - 1; i >= 0; i-- {
if dd.hooks[i] == nil {
return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = dd.hooks[i](mut)
}
if _, err := mut.Mutate(ctx, dd.mutation); err != nil {
return 0, err
}
}
return affected, err
return withHooks(ctx, dd.sqlExec, dd.mutation, dd.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
@ -68,15 +40,7 @@ func (dd *DecisionDelete) ExecX(ctx context.Context) int {
}
func (dd *DecisionDelete) sqlExec(ctx context.Context) (int, error) {
_spec := &sqlgraph.DeleteSpec{
Node: &sqlgraph.NodeSpec{
Table: decision.Table,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: decision.FieldID,
},
},
}
_spec := sqlgraph.NewDeleteSpec(decision.Table, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt))
if ps := dd.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
@ -88,6 +52,7 @@ func (dd *DecisionDelete) sqlExec(ctx context.Context) (int, error) {
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
dd.mutation.done = true
return affected, err
}
@ -96,6 +61,12 @@ type DecisionDeleteOne struct {
dd *DecisionDelete
}
// Where appends a list predicates to the DecisionDelete builder.
func (ddo *DecisionDeleteOne) Where(ps ...predicate.Decision) *DecisionDeleteOne {
ddo.dd.mutation.Where(ps...)
return ddo
}
// Exec executes the deletion query.
func (ddo *DecisionDeleteOne) Exec(ctx context.Context) error {
n, err := ddo.dd.Exec(ctx)
@ -111,5 +82,7 @@ func (ddo *DecisionDeleteOne) Exec(ctx context.Context) error {
// ExecX is like Exec, but panics if an error occurs.
func (ddo *DecisionDeleteOne) ExecX(ctx context.Context) {
ddo.dd.ExecX(ctx)
if err := ddo.Exec(ctx); err != nil {
panic(err)
}
}

View file

@ -18,11 +18,9 @@ import (
// DecisionQuery is the builder for querying Decision entities.
type DecisionQuery struct {
config
limit *int
offset *int
unique *bool
order []OrderFunc
fields []string
ctx *QueryContext
order []decision.OrderOption
inters []Interceptor
predicates []predicate.Decision
withOwner *AlertQuery
// intermediate query (i.e. traversal path).
@ -36,34 +34,34 @@ func (dq *DecisionQuery) Where(ps ...predicate.Decision) *DecisionQuery {
return dq
}
// Limit adds a limit step to the query.
// Limit the number of records to be returned by this query.
func (dq *DecisionQuery) Limit(limit int) *DecisionQuery {
dq.limit = &limit
dq.ctx.Limit = &limit
return dq
}
// Offset adds an offset step to the query.
// Offset to start from.
func (dq *DecisionQuery) Offset(offset int) *DecisionQuery {
dq.offset = &offset
dq.ctx.Offset = &offset
return dq
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (dq *DecisionQuery) Unique(unique bool) *DecisionQuery {
dq.unique = &unique
dq.ctx.Unique = &unique
return dq
}
// Order adds an order step to the query.
func (dq *DecisionQuery) Order(o ...OrderFunc) *DecisionQuery {
// Order specifies how the records should be ordered.
func (dq *DecisionQuery) Order(o ...decision.OrderOption) *DecisionQuery {
dq.order = append(dq.order, o...)
return dq
}
// QueryOwner chains the current query on the "owner" edge.
func (dq *DecisionQuery) QueryOwner() *AlertQuery {
query := &AlertQuery{config: dq.config}
query := (&AlertClient{config: dq.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := dq.prepareQuery(ctx); err != nil {
return nil, err
@ -86,7 +84,7 @@ func (dq *DecisionQuery) QueryOwner() *AlertQuery {
// First returns the first Decision entity from the query.
// Returns a *NotFoundError when no Decision was found.
func (dq *DecisionQuery) First(ctx context.Context) (*Decision, error) {
nodes, err := dq.Limit(1).All(ctx)
nodes, err := dq.Limit(1).All(setContextOp(ctx, dq.ctx, "First"))
if err != nil {
return nil, err
}
@ -109,7 +107,7 @@ func (dq *DecisionQuery) FirstX(ctx context.Context) *Decision {
// Returns a *NotFoundError when no Decision ID was found.
func (dq *DecisionQuery) FirstID(ctx context.Context) (id int, err error) {
var ids []int
if ids, err = dq.Limit(1).IDs(ctx); err != nil {
if ids, err = dq.Limit(1).IDs(setContextOp(ctx, dq.ctx, "FirstID")); err != nil {
return
}
if len(ids) == 0 {
@ -132,7 +130,7 @@ func (dq *DecisionQuery) FirstIDX(ctx context.Context) int {
// Returns a *NotSingularError when more than one Decision entity is found.
// Returns a *NotFoundError when no Decision entities are found.
func (dq *DecisionQuery) Only(ctx context.Context) (*Decision, error) {
nodes, err := dq.Limit(2).All(ctx)
nodes, err := dq.Limit(2).All(setContextOp(ctx, dq.ctx, "Only"))
if err != nil {
return nil, err
}
@ -160,7 +158,7 @@ func (dq *DecisionQuery) OnlyX(ctx context.Context) *Decision {
// Returns a *NotFoundError when no entities are found.
func (dq *DecisionQuery) OnlyID(ctx context.Context) (id int, err error) {
var ids []int
if ids, err = dq.Limit(2).IDs(ctx); err != nil {
if ids, err = dq.Limit(2).IDs(setContextOp(ctx, dq.ctx, "OnlyID")); err != nil {
return
}
switch len(ids) {
@ -185,10 +183,12 @@ func (dq *DecisionQuery) OnlyIDX(ctx context.Context) int {
// All executes the query and returns a list of Decisions.
func (dq *DecisionQuery) All(ctx context.Context) ([]*Decision, error) {
ctx = setContextOp(ctx, dq.ctx, "All")
if err := dq.prepareQuery(ctx); err != nil {
return nil, err
}
return dq.sqlAll(ctx)
qr := querierAll[[]*Decision, *DecisionQuery]()
return withInterceptors[[]*Decision](ctx, dq, qr, dq.inters)
}
// AllX is like All, but panics if an error occurs.
@ -201,9 +201,12 @@ func (dq *DecisionQuery) AllX(ctx context.Context) []*Decision {
}
// IDs executes the query and returns a list of Decision IDs.
func (dq *DecisionQuery) IDs(ctx context.Context) ([]int, error) {
var ids []int
if err := dq.Select(decision.FieldID).Scan(ctx, &ids); err != nil {
func (dq *DecisionQuery) IDs(ctx context.Context) (ids []int, err error) {
if dq.ctx.Unique == nil && dq.path != nil {
dq.Unique(true)
}
ctx = setContextOp(ctx, dq.ctx, "IDs")
if err = dq.Select(decision.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
@ -220,10 +223,11 @@ func (dq *DecisionQuery) IDsX(ctx context.Context) []int {
// Count returns the count of the given query.
func (dq *DecisionQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, dq.ctx, "Count")
if err := dq.prepareQuery(ctx); err != nil {
return 0, err
}
return dq.sqlCount(ctx)
return withInterceptors[int](ctx, dq, querierCount[*DecisionQuery](), dq.inters)
}
// CountX is like Count, but panics if an error occurs.
@ -237,10 +241,15 @@ func (dq *DecisionQuery) CountX(ctx context.Context) int {
// Exist returns true if the query has elements in the graph.
func (dq *DecisionQuery) Exist(ctx context.Context) (bool, error) {
if err := dq.prepareQuery(ctx); err != nil {
return false, err
ctx = setContextOp(ctx, dq.ctx, "Exist")
switch _, err := dq.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
return dq.sqlExist(ctx)
}
// ExistX is like Exist, but panics if an error occurs.
@ -260,22 +269,21 @@ func (dq *DecisionQuery) Clone() *DecisionQuery {
}
return &DecisionQuery{
config: dq.config,
limit: dq.limit,
offset: dq.offset,
order: append([]OrderFunc{}, dq.order...),
ctx: dq.ctx.Clone(),
order: append([]decision.OrderOption{}, dq.order...),
inters: append([]Interceptor{}, dq.inters...),
predicates: append([]predicate.Decision{}, dq.predicates...),
withOwner: dq.withOwner.Clone(),
// clone intermediate query.
sql: dq.sql.Clone(),
path: dq.path,
unique: dq.unique,
sql: dq.sql.Clone(),
path: dq.path,
}
}
// WithOwner tells the query-builder to eager-load the nodes that are connected to
// the "owner" edge. The optional arguments are used to configure the query builder of the edge.
func (dq *DecisionQuery) WithOwner(opts ...func(*AlertQuery)) *DecisionQuery {
query := &AlertQuery{config: dq.config}
query := (&AlertClient{config: dq.config}).Query()
for _, opt := range opts {
opt(query)
}
@ -298,16 +306,11 @@ func (dq *DecisionQuery) WithOwner(opts ...func(*AlertQuery)) *DecisionQuery {
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (dq *DecisionQuery) GroupBy(field string, fields ...string) *DecisionGroupBy {
grbuild := &DecisionGroupBy{config: dq.config}
grbuild.fields = append([]string{field}, fields...)
grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) {
if err := dq.prepareQuery(ctx); err != nil {
return nil, err
}
return dq.sqlQuery(ctx), nil
}
dq.ctx.Fields = append([]string{field}, fields...)
grbuild := &DecisionGroupBy{build: dq}
grbuild.flds = &dq.ctx.Fields
grbuild.label = decision.Label
grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan
grbuild.scan = grbuild.Scan
return grbuild
}
@ -324,15 +327,30 @@ func (dq *DecisionQuery) GroupBy(field string, fields ...string) *DecisionGroupB
// Select(decision.FieldCreatedAt).
// Scan(ctx, &v)
func (dq *DecisionQuery) Select(fields ...string) *DecisionSelect {
dq.fields = append(dq.fields, fields...)
selbuild := &DecisionSelect{DecisionQuery: dq}
selbuild.label = decision.Label
selbuild.flds, selbuild.scan = &dq.fields, selbuild.Scan
return selbuild
dq.ctx.Fields = append(dq.ctx.Fields, fields...)
sbuild := &DecisionSelect{DecisionQuery: dq}
sbuild.label = decision.Label
sbuild.flds, sbuild.scan = &dq.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a DecisionSelect configured with the given aggregations.
func (dq *DecisionQuery) Aggregate(fns ...AggregateFunc) *DecisionSelect {
return dq.Select().Aggregate(fns...)
}
func (dq *DecisionQuery) prepareQuery(ctx context.Context) error {
for _, f := range dq.fields {
for _, inter := range dq.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, dq); err != nil {
return err
}
}
}
for _, f := range dq.ctx.Fields {
if !decision.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
@ -392,6 +410,9 @@ func (dq *DecisionQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(alert.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
@ -411,41 +432,22 @@ func (dq *DecisionQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes
func (dq *DecisionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := dq.querySpec()
_spec.Node.Columns = dq.fields
if len(dq.fields) > 0 {
_spec.Unique = dq.unique != nil && *dq.unique
_spec.Node.Columns = dq.ctx.Fields
if len(dq.ctx.Fields) > 0 {
_spec.Unique = dq.ctx.Unique != nil && *dq.ctx.Unique
}
return sqlgraph.CountNodes(ctx, dq.driver, _spec)
}
func (dq *DecisionQuery) sqlExist(ctx context.Context) (bool, error) {
switch _, err := dq.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
func (dq *DecisionQuery) querySpec() *sqlgraph.QuerySpec {
_spec := &sqlgraph.QuerySpec{
Node: &sqlgraph.NodeSpec{
Table: decision.Table,
Columns: decision.Columns,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: decision.FieldID,
},
},
From: dq.sql,
Unique: true,
}
if unique := dq.unique; unique != nil {
_spec := sqlgraph.NewQuerySpec(decision.Table, decision.Columns, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt))
_spec.From = dq.sql
if unique := dq.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if dq.path != nil {
_spec.Unique = true
}
if fields := dq.fields; len(fields) > 0 {
if fields := dq.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, decision.FieldID)
for i := range fields {
@ -453,6 +455,9 @@ func (dq *DecisionQuery) querySpec() *sqlgraph.QuerySpec {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
if dq.withOwner != nil {
_spec.Node.AddColumnOnce(decision.FieldAlertDecisions)
}
}
if ps := dq.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
@ -461,10 +466,10 @@ func (dq *DecisionQuery) querySpec() *sqlgraph.QuerySpec {
}
}
}
if limit := dq.limit; limit != nil {
if limit := dq.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := dq.offset; offset != nil {
if offset := dq.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := dq.order; len(ps) > 0 {
@ -480,7 +485,7 @@ func (dq *DecisionQuery) querySpec() *sqlgraph.QuerySpec {
func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(dq.driver.Dialect())
t1 := builder.Table(decision.Table)
columns := dq.fields
columns := dq.ctx.Fields
if len(columns) == 0 {
columns = decision.Columns
}
@ -489,7 +494,7 @@ func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector {
selector = dq.sql
selector.Select(selector.Columns(columns...)...)
}
if dq.unique != nil && *dq.unique {
if dq.ctx.Unique != nil && *dq.ctx.Unique {
selector.Distinct()
}
for _, p := range dq.predicates {
@ -498,12 +503,12 @@ func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector {
for _, p := range dq.order {
p(selector)
}
if offset := dq.offset; offset != nil {
if offset := dq.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := dq.limit; limit != nil {
if limit := dq.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
@ -511,13 +516,8 @@ func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector {
// DecisionGroupBy is the group-by builder for Decision entities.
type DecisionGroupBy struct {
config
selector
fields []string
fns []AggregateFunc
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
build *DecisionQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
@ -526,74 +526,77 @@ func (dgb *DecisionGroupBy) Aggregate(fns ...AggregateFunc) *DecisionGroupBy {
return dgb
}
// Scan applies the group-by query and scans the result into the given value.
// Scan applies the selector query and scans the result into the given value.
func (dgb *DecisionGroupBy) Scan(ctx context.Context, v any) error {
query, err := dgb.path(ctx)
if err != nil {
ctx = setContextOp(ctx, dgb.build.ctx, "GroupBy")
if err := dgb.build.prepareQuery(ctx); err != nil {
return err
}
dgb.sql = query
return dgb.sqlScan(ctx, v)
return scanWithInterceptors[*DecisionQuery, *DecisionGroupBy](ctx, dgb.build, dgb, dgb.build.inters, v)
}
func (dgb *DecisionGroupBy) sqlScan(ctx context.Context, v any) error {
for _, f := range dgb.fields {
if !decision.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)}
}
}
selector := dgb.sqlQuery()
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := dgb.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
func (dgb *DecisionGroupBy) sqlQuery() *sql.Selector {
selector := dgb.sql.Select()
func (dgb *DecisionGroupBy) sqlScan(ctx context.Context, root *DecisionQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(dgb.fns))
for _, fn := range dgb.fns {
aggregation = append(aggregation, fn(selector))
}
// If no columns were selected in a custom aggregation function, the default
// selection is the fields used for "group-by", and the aggregation functions.
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(dgb.fields)+len(dgb.fns))
for _, f := range dgb.fields {
columns := make([]string, 0, len(*dgb.flds)+len(dgb.fns))
for _, f := range *dgb.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
return selector.GroupBy(selector.Columns(dgb.fields...)...)
selector.GroupBy(selector.Columns(*dgb.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := dgb.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// DecisionSelect is the builder for selecting fields of Decision entities.
type DecisionSelect struct {
*DecisionQuery
selector
// intermediate query (i.e. traversal path).
sql *sql.Selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (ds *DecisionSelect) Aggregate(fns ...AggregateFunc) *DecisionSelect {
ds.fns = append(ds.fns, fns...)
return ds
}
// Scan applies the selector query and scans the result into the given value.
func (ds *DecisionSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, ds.ctx, "Select")
if err := ds.prepareQuery(ctx); err != nil {
return err
}
ds.sql = ds.DecisionQuery.sqlQuery(ctx)
return ds.sqlScan(ctx, v)
return scanWithInterceptors[*DecisionQuery, *DecisionSelect](ctx, ds.DecisionQuery, ds, ds.inters, v)
}
func (ds *DecisionSelect) sqlScan(ctx context.Context, v any) error {
func (ds *DecisionSelect) sqlScan(ctx context.Context, root *DecisionQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(ds.fns))
for _, fn := range ds.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*ds.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := ds.sql.Query()
query, args := selector.Query()
if err := ds.driver.Query(ctx, query, args, rows); err != nil {
return err
}

View file

@ -324,35 +324,8 @@ func (du *DecisionUpdate) ClearOwner() *DecisionUpdate {
// Save executes the query and returns the number of nodes affected by the update operation.
func (du *DecisionUpdate) Save(ctx context.Context) (int, error) {
var (
err error
affected int
)
du.defaults()
if len(du.hooks) == 0 {
affected, err = du.sqlSave(ctx)
} else {
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*DecisionMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
du.mutation = mutation
affected, err = du.sqlSave(ctx)
mutation.done = true
return affected, err
})
for i := len(du.hooks) - 1; i >= 0; i-- {
if du.hooks[i] == nil {
return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = du.hooks[i](mut)
}
if _, err := mut.Mutate(ctx, du.mutation); err != nil {
return 0, err
}
}
return affected, err
return withHooks(ctx, du.sqlSave, du.mutation, du.hooks)
}
// SaveX is like Save, but panics if an error occurs.
@ -390,16 +363,7 @@ func (du *DecisionUpdate) defaults() {
}
func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) {
_spec := &sqlgraph.UpdateSpec{
Node: &sqlgraph.NodeSpec{
Table: decision.Table,
Columns: decision.Columns,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: decision.FieldID,
},
},
}
_spec := sqlgraph.NewUpdateSpec(decision.Table, decision.Columns, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt))
if ps := du.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
@ -408,198 +372,91 @@ func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
}
if value, ok := du.mutation.CreatedAt(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: decision.FieldCreatedAt,
})
_spec.SetField(decision.FieldCreatedAt, field.TypeTime, value)
}
if du.mutation.CreatedAtCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: decision.FieldCreatedAt,
})
_spec.ClearField(decision.FieldCreatedAt, field.TypeTime)
}
if value, ok := du.mutation.UpdatedAt(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: decision.FieldUpdatedAt,
})
_spec.SetField(decision.FieldUpdatedAt, field.TypeTime, value)
}
if du.mutation.UpdatedAtCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: decision.FieldUpdatedAt,
})
_spec.ClearField(decision.FieldUpdatedAt, field.TypeTime)
}
if value, ok := du.mutation.Until(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: decision.FieldUntil,
})
_spec.SetField(decision.FieldUntil, field.TypeTime, value)
}
if du.mutation.UntilCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: decision.FieldUntil,
})
_spec.ClearField(decision.FieldUntil, field.TypeTime)
}
if value, ok := du.mutation.Scenario(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: decision.FieldScenario,
})
_spec.SetField(decision.FieldScenario, field.TypeString, value)
}
if value, ok := du.mutation.GetType(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: decision.FieldType,
})
_spec.SetField(decision.FieldType, field.TypeString, value)
}
if value, ok := du.mutation.StartIP(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldStartIP,
})
_spec.SetField(decision.FieldStartIP, field.TypeInt64, value)
}
if value, ok := du.mutation.AddedStartIP(); ok {
_spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldStartIP,
})
_spec.AddField(decision.FieldStartIP, field.TypeInt64, value)
}
if du.mutation.StartIPCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Column: decision.FieldStartIP,
})
_spec.ClearField(decision.FieldStartIP, field.TypeInt64)
}
if value, ok := du.mutation.EndIP(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldEndIP,
})
_spec.SetField(decision.FieldEndIP, field.TypeInt64, value)
}
if value, ok := du.mutation.AddedEndIP(); ok {
_spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldEndIP,
})
_spec.AddField(decision.FieldEndIP, field.TypeInt64, value)
}
if du.mutation.EndIPCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Column: decision.FieldEndIP,
})
_spec.ClearField(decision.FieldEndIP, field.TypeInt64)
}
if value, ok := du.mutation.StartSuffix(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldStartSuffix,
})
_spec.SetField(decision.FieldStartSuffix, field.TypeInt64, value)
}
if value, ok := du.mutation.AddedStartSuffix(); ok {
_spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldStartSuffix,
})
_spec.AddField(decision.FieldStartSuffix, field.TypeInt64, value)
}
if du.mutation.StartSuffixCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Column: decision.FieldStartSuffix,
})
_spec.ClearField(decision.FieldStartSuffix, field.TypeInt64)
}
if value, ok := du.mutation.EndSuffix(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldEndSuffix,
})
_spec.SetField(decision.FieldEndSuffix, field.TypeInt64, value)
}
if value, ok := du.mutation.AddedEndSuffix(); ok {
_spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldEndSuffix,
})
_spec.AddField(decision.FieldEndSuffix, field.TypeInt64, value)
}
if du.mutation.EndSuffixCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Column: decision.FieldEndSuffix,
})
_spec.ClearField(decision.FieldEndSuffix, field.TypeInt64)
}
if value, ok := du.mutation.IPSize(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldIPSize,
})
_spec.SetField(decision.FieldIPSize, field.TypeInt64, value)
}
if value, ok := du.mutation.AddedIPSize(); ok {
_spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldIPSize,
})
_spec.AddField(decision.FieldIPSize, field.TypeInt64, value)
}
if du.mutation.IPSizeCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Column: decision.FieldIPSize,
})
_spec.ClearField(decision.FieldIPSize, field.TypeInt64)
}
if value, ok := du.mutation.Scope(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: decision.FieldScope,
})
_spec.SetField(decision.FieldScope, field.TypeString, value)
}
if value, ok := du.mutation.Value(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: decision.FieldValue,
})
_spec.SetField(decision.FieldValue, field.TypeString, value)
}
if value, ok := du.mutation.Origin(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: decision.FieldOrigin,
})
_spec.SetField(decision.FieldOrigin, field.TypeString, value)
}
if value, ok := du.mutation.Simulated(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeBool,
Value: value,
Column: decision.FieldSimulated,
})
_spec.SetField(decision.FieldSimulated, field.TypeBool, value)
}
if value, ok := du.mutation.UUID(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: decision.FieldUUID,
})
_spec.SetField(decision.FieldUUID, field.TypeString, value)
}
if du.mutation.UUIDCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeString,
Column: decision.FieldUUID,
})
_spec.ClearField(decision.FieldUUID, field.TypeString)
}
if du.mutation.OwnerCleared() {
edge := &sqlgraph.EdgeSpec{
@ -609,10 +466,7 @@ func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) {
Columns: []string{decision.OwnerColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: alert.FieldID,
},
IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
@ -625,10 +479,7 @@ func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) {
Columns: []string{decision.OwnerColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: alert.FieldID,
},
IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt),
},
}
for _, k := range nodes {
@ -644,6 +495,7 @@ func (du *DecisionUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
return 0, err
}
du.mutation.done = true
return n, nil
}
@ -948,6 +800,12 @@ func (duo *DecisionUpdateOne) ClearOwner() *DecisionUpdateOne {
return duo
}
// Where appends a list predicates to the DecisionUpdate builder.
func (duo *DecisionUpdateOne) Where(ps ...predicate.Decision) *DecisionUpdateOne {
duo.mutation.Where(ps...)
return duo
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (duo *DecisionUpdateOne) Select(field string, fields ...string) *DecisionUpdateOne {
@ -957,41 +815,8 @@ func (duo *DecisionUpdateOne) Select(field string, fields ...string) *DecisionUp
// Save executes the query and returns the updated Decision entity.
func (duo *DecisionUpdateOne) Save(ctx context.Context) (*Decision, error) {
var (
err error
node *Decision
)
duo.defaults()
if len(duo.hooks) == 0 {
node, err = duo.sqlSave(ctx)
} else {
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*DecisionMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
duo.mutation = mutation
node, err = duo.sqlSave(ctx)
mutation.done = true
return node, err
})
for i := len(duo.hooks) - 1; i >= 0; i-- {
if duo.hooks[i] == nil {
return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = duo.hooks[i](mut)
}
v, err := mut.Mutate(ctx, duo.mutation)
if err != nil {
return nil, err
}
nv, ok := v.(*Decision)
if !ok {
return nil, fmt.Errorf("unexpected node type %T returned from DecisionMutation", v)
}
node = nv
}
return node, err
return withHooks(ctx, duo.sqlSave, duo.mutation, duo.hooks)
}
// SaveX is like Save, but panics if an error occurs.
@ -1029,16 +854,7 @@ func (duo *DecisionUpdateOne) defaults() {
}
func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err error) {
_spec := &sqlgraph.UpdateSpec{
Node: &sqlgraph.NodeSpec{
Table: decision.Table,
Columns: decision.Columns,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: decision.FieldID,
},
},
}
_spec := sqlgraph.NewUpdateSpec(decision.Table, decision.Columns, sqlgraph.NewFieldSpec(decision.FieldID, field.TypeInt))
id, ok := duo.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Decision.id" for update`)}
@ -1064,198 +880,91 @@ func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err
}
}
if value, ok := duo.mutation.CreatedAt(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: decision.FieldCreatedAt,
})
_spec.SetField(decision.FieldCreatedAt, field.TypeTime, value)
}
if duo.mutation.CreatedAtCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: decision.FieldCreatedAt,
})
_spec.ClearField(decision.FieldCreatedAt, field.TypeTime)
}
if value, ok := duo.mutation.UpdatedAt(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: decision.FieldUpdatedAt,
})
_spec.SetField(decision.FieldUpdatedAt, field.TypeTime, value)
}
if duo.mutation.UpdatedAtCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: decision.FieldUpdatedAt,
})
_spec.ClearField(decision.FieldUpdatedAt, field.TypeTime)
}
if value, ok := duo.mutation.Until(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: decision.FieldUntil,
})
_spec.SetField(decision.FieldUntil, field.TypeTime, value)
}
if duo.mutation.UntilCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: decision.FieldUntil,
})
_spec.ClearField(decision.FieldUntil, field.TypeTime)
}
if value, ok := duo.mutation.Scenario(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: decision.FieldScenario,
})
_spec.SetField(decision.FieldScenario, field.TypeString, value)
}
if value, ok := duo.mutation.GetType(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: decision.FieldType,
})
_spec.SetField(decision.FieldType, field.TypeString, value)
}
if value, ok := duo.mutation.StartIP(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldStartIP,
})
_spec.SetField(decision.FieldStartIP, field.TypeInt64, value)
}
if value, ok := duo.mutation.AddedStartIP(); ok {
_spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldStartIP,
})
_spec.AddField(decision.FieldStartIP, field.TypeInt64, value)
}
if duo.mutation.StartIPCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Column: decision.FieldStartIP,
})
_spec.ClearField(decision.FieldStartIP, field.TypeInt64)
}
if value, ok := duo.mutation.EndIP(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldEndIP,
})
_spec.SetField(decision.FieldEndIP, field.TypeInt64, value)
}
if value, ok := duo.mutation.AddedEndIP(); ok {
_spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldEndIP,
})
_spec.AddField(decision.FieldEndIP, field.TypeInt64, value)
}
if duo.mutation.EndIPCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Column: decision.FieldEndIP,
})
_spec.ClearField(decision.FieldEndIP, field.TypeInt64)
}
if value, ok := duo.mutation.StartSuffix(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldStartSuffix,
})
_spec.SetField(decision.FieldStartSuffix, field.TypeInt64, value)
}
if value, ok := duo.mutation.AddedStartSuffix(); ok {
_spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldStartSuffix,
})
_spec.AddField(decision.FieldStartSuffix, field.TypeInt64, value)
}
if duo.mutation.StartSuffixCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Column: decision.FieldStartSuffix,
})
_spec.ClearField(decision.FieldStartSuffix, field.TypeInt64)
}
if value, ok := duo.mutation.EndSuffix(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldEndSuffix,
})
_spec.SetField(decision.FieldEndSuffix, field.TypeInt64, value)
}
if value, ok := duo.mutation.AddedEndSuffix(); ok {
_spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldEndSuffix,
})
_spec.AddField(decision.FieldEndSuffix, field.TypeInt64, value)
}
if duo.mutation.EndSuffixCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Column: decision.FieldEndSuffix,
})
_spec.ClearField(decision.FieldEndSuffix, field.TypeInt64)
}
if value, ok := duo.mutation.IPSize(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldIPSize,
})
_spec.SetField(decision.FieldIPSize, field.TypeInt64, value)
}
if value, ok := duo.mutation.AddedIPSize(); ok {
_spec.Fields.Add = append(_spec.Fields.Add, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Value: value,
Column: decision.FieldIPSize,
})
_spec.AddField(decision.FieldIPSize, field.TypeInt64, value)
}
if duo.mutation.IPSizeCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeInt64,
Column: decision.FieldIPSize,
})
_spec.ClearField(decision.FieldIPSize, field.TypeInt64)
}
if value, ok := duo.mutation.Scope(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: decision.FieldScope,
})
_spec.SetField(decision.FieldScope, field.TypeString, value)
}
if value, ok := duo.mutation.Value(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: decision.FieldValue,
})
_spec.SetField(decision.FieldValue, field.TypeString, value)
}
if value, ok := duo.mutation.Origin(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: decision.FieldOrigin,
})
_spec.SetField(decision.FieldOrigin, field.TypeString, value)
}
if value, ok := duo.mutation.Simulated(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeBool,
Value: value,
Column: decision.FieldSimulated,
})
_spec.SetField(decision.FieldSimulated, field.TypeBool, value)
}
if value, ok := duo.mutation.UUID(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: decision.FieldUUID,
})
_spec.SetField(decision.FieldUUID, field.TypeString, value)
}
if duo.mutation.UUIDCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeString,
Column: decision.FieldUUID,
})
_spec.ClearField(decision.FieldUUID, field.TypeString)
}
if duo.mutation.OwnerCleared() {
edge := &sqlgraph.EdgeSpec{
@ -1265,10 +974,7 @@ func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err
Columns: []string{decision.OwnerColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: alert.FieldID,
},
IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
@ -1281,10 +987,7 @@ func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err
Columns: []string{decision.OwnerColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: alert.FieldID,
},
IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt),
},
}
for _, k := range nodes {
@ -1303,5 +1006,6 @@ func (duo *DecisionUpdateOne) sqlSave(ctx context.Context) (_node *Decision, err
}
return nil, err
}
duo.mutation.done = true
return _node, nil
}

View file

@ -6,6 +6,8 @@ import (
"context"
"errors"
"fmt"
"reflect"
"sync"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
@ -15,56 +17,87 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/event"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/lock"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/machine"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/meta"
)
// ent aliases to avoid import conflicts in user's code.
type (
Op = ent.Op
Hook = ent.Hook
Value = ent.Value
Query = ent.Query
Policy = ent.Policy
Mutator = ent.Mutator
Mutation = ent.Mutation
MutateFunc = ent.MutateFunc
Op = ent.Op
Hook = ent.Hook
Value = ent.Value
Query = ent.Query
QueryContext = ent.QueryContext
Querier = ent.Querier
QuerierFunc = ent.QuerierFunc
Interceptor = ent.Interceptor
InterceptFunc = ent.InterceptFunc
Traverser = ent.Traverser
TraverseFunc = ent.TraverseFunc
Policy = ent.Policy
Mutator = ent.Mutator
Mutation = ent.Mutation
MutateFunc = ent.MutateFunc
)
type clientCtxKey struct{}
// FromContext returns a Client stored inside a context, or nil if there isn't one.
func FromContext(ctx context.Context) *Client {
c, _ := ctx.Value(clientCtxKey{}).(*Client)
return c
}
// NewContext returns a new context with the given Client attached.
func NewContext(parent context.Context, c *Client) context.Context {
return context.WithValue(parent, clientCtxKey{}, c)
}
type txCtxKey struct{}
// TxFromContext returns a Tx stored inside a context, or nil if there isn't one.
func TxFromContext(ctx context.Context) *Tx {
tx, _ := ctx.Value(txCtxKey{}).(*Tx)
return tx
}
// NewTxContext returns a new context with the given Tx attached.
func NewTxContext(parent context.Context, tx *Tx) context.Context {
return context.WithValue(parent, txCtxKey{}, tx)
}
// OrderFunc applies an ordering on the sql selector.
// Deprecated: Use Asc/Desc functions or the package builders instead.
type OrderFunc func(*sql.Selector)
// columnChecker returns a function indicates if the column exists in the given column.
func columnChecker(table string) func(string) error {
checks := map[string]func(string) bool{
alert.Table: alert.ValidColumn,
bouncer.Table: bouncer.ValidColumn,
configitem.Table: configitem.ValidColumn,
decision.Table: decision.ValidColumn,
event.Table: event.ValidColumn,
machine.Table: machine.ValidColumn,
meta.Table: meta.ValidColumn,
}
check, ok := checks[table]
if !ok {
return func(string) error {
return fmt.Errorf("unknown table %q", table)
}
}
return func(column string) error {
if !check(column) {
return fmt.Errorf("unknown column %q for table %q", column, table)
}
return nil
}
var (
initCheck sync.Once
columnCheck sql.ColumnCheck
)
// columnChecker checks if the column exists in the given table.
func checkColumn(table, column string) error {
initCheck.Do(func() {
columnCheck = sql.NewColumnCheck(map[string]func(string) bool{
alert.Table: alert.ValidColumn,
bouncer.Table: bouncer.ValidColumn,
configitem.Table: configitem.ValidColumn,
decision.Table: decision.ValidColumn,
event.Table: event.ValidColumn,
lock.Table: lock.ValidColumn,
machine.Table: machine.ValidColumn,
meta.Table: meta.ValidColumn,
})
})
return columnCheck(table, column)
}
// Asc applies the given fields in ASC order.
func Asc(fields ...string) OrderFunc {
func Asc(fields ...string) func(*sql.Selector) {
return func(s *sql.Selector) {
check := columnChecker(s.TableName())
for _, f := range fields {
if err := check(f); err != nil {
if err := checkColumn(s.TableName(), f); err != nil {
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)})
}
s.OrderBy(sql.Asc(s.C(f)))
@ -73,11 +106,10 @@ func Asc(fields ...string) OrderFunc {
}
// Desc applies the given fields in DESC order.
func Desc(fields ...string) OrderFunc {
func Desc(fields ...string) func(*sql.Selector) {
return func(s *sql.Selector) {
check := columnChecker(s.TableName())
for _, f := range fields {
if err := check(f); err != nil {
if err := checkColumn(s.TableName(), f); err != nil {
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)})
}
s.OrderBy(sql.Desc(s.C(f)))
@ -109,8 +141,7 @@ func Count() AggregateFunc {
// Max applies the "max" aggregation function on the given field of each group.
func Max(field string) AggregateFunc {
return func(s *sql.Selector) string {
check := columnChecker(s.TableName())
if err := check(field); err != nil {
if err := checkColumn(s.TableName(), field); err != nil {
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)})
return ""
}
@ -121,8 +152,7 @@ func Max(field string) AggregateFunc {
// Mean applies the "mean" aggregation function on the given field of each group.
func Mean(field string) AggregateFunc {
return func(s *sql.Selector) string {
check := columnChecker(s.TableName())
if err := check(field); err != nil {
if err := checkColumn(s.TableName(), field); err != nil {
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)})
return ""
}
@ -133,8 +163,7 @@ func Mean(field string) AggregateFunc {
// Min applies the "min" aggregation function on the given field of each group.
func Min(field string) AggregateFunc {
return func(s *sql.Selector) string {
check := columnChecker(s.TableName())
if err := check(field); err != nil {
if err := checkColumn(s.TableName(), field); err != nil {
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)})
return ""
}
@ -145,8 +174,7 @@ func Min(field string) AggregateFunc {
// Sum applies the "sum" aggregation function on the given field of each group.
func Sum(field string) AggregateFunc {
return func(s *sql.Selector) string {
check := columnChecker(s.TableName())
if err := check(field); err != nil {
if err := checkColumn(s.TableName(), field); err != nil {
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)})
return ""
}
@ -275,6 +303,7 @@ func IsConstraintError(err error) bool {
type selector struct {
label string
flds *[]string
fns []AggregateFunc
scan func(context.Context, any) error
}
@ -473,5 +502,121 @@ func (s *selector) BoolX(ctx context.Context) bool {
return v
}
// withHooks invokes the builder operation with the given hooks, if any.
func withHooks[V Value, M any, PM interface {
*M
Mutation
}](ctx context.Context, exec func(context.Context) (V, error), mutation PM, hooks []Hook) (value V, err error) {
if len(hooks) == 0 {
return exec(ctx)
}
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutationT, ok := any(m).(PM)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
// Set the mutation to the builder.
*mutation = *mutationT
return exec(ctx)
})
for i := len(hooks) - 1; i >= 0; i-- {
if hooks[i] == nil {
return value, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = hooks[i](mut)
}
v, err := mut.Mutate(ctx, mutation)
if err != nil {
return value, err
}
nv, ok := v.(V)
if !ok {
return value, fmt.Errorf("unexpected node type %T returned from %T", v, mutation)
}
return nv, nil
}
// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist.
func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context {
if ent.QueryFromContext(ctx) == nil {
qc.Op = op
ctx = ent.NewQueryContext(ctx, qc)
}
return ctx
}
func querierAll[V Value, Q interface {
sqlAll(context.Context, ...queryHook) (V, error)
}]() Querier {
return QuerierFunc(func(ctx context.Context, q Query) (Value, error) {
query, ok := q.(Q)
if !ok {
return nil, fmt.Errorf("unexpected query type %T", q)
}
return query.sqlAll(ctx)
})
}
func querierCount[Q interface {
sqlCount(context.Context) (int, error)
}]() Querier {
return QuerierFunc(func(ctx context.Context, q Query) (Value, error) {
query, ok := q.(Q)
if !ok {
return nil, fmt.Errorf("unexpected query type %T", q)
}
return query.sqlCount(ctx)
})
}
func withInterceptors[V Value](ctx context.Context, q Query, qr Querier, inters []Interceptor) (v V, err error) {
for i := len(inters) - 1; i >= 0; i-- {
qr = inters[i].Intercept(qr)
}
rv, err := qr.Query(ctx, q)
if err != nil {
return v, err
}
vt, ok := rv.(V)
if !ok {
return v, fmt.Errorf("unexpected type %T returned from %T. expected type: %T", vt, q, v)
}
return vt, nil
}
func scanWithInterceptors[Q1 ent.Query, Q2 interface {
sqlScan(context.Context, Q1, any) error
}](ctx context.Context, rootQuery Q1, selectOrGroup Q2, inters []Interceptor, v any) error {
rv := reflect.ValueOf(v)
var qr Querier = QuerierFunc(func(ctx context.Context, q Query) (Value, error) {
query, ok := q.(Q1)
if !ok {
return nil, fmt.Errorf("unexpected query type %T", q)
}
if err := selectOrGroup.sqlScan(ctx, query, v); err != nil {
return nil, err
}
if k := rv.Kind(); k == reflect.Pointer && rv.Elem().CanInterface() {
return rv.Elem().Interface(), nil
}
return v, nil
})
for i := len(inters) - 1; i >= 0; i-- {
qr = inters[i].Intercept(qr)
}
vv, err := qr.Query(ctx, rootQuery)
if err != nil {
return err
}
switch rv2 := reflect.ValueOf(vv); {
case rv.IsNil(), rv2.IsNil(), rv.Kind() != reflect.Pointer:
case rv.Type() == rv2.Type():
rv.Elem().Set(rv2.Elem())
case rv.Elem().Type() == rv2.Type():
rv.Elem().Set(rv2)
}
return nil
}
// queryHook describes an internal hook for the different sqlAll methods.
type queryHook func(context.Context, *sqlgraph.QuerySpec)

View file

@ -7,6 +7,7 @@ import (
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/alert"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/event"
@ -29,7 +30,8 @@ type Event struct {
AlertEvents int `json:"alert_events,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the EventQuery when eager-loading is set.
Edges EventEdges `json:"edges"`
Edges EventEdges `json:"edges"`
selectValues sql.SelectValues
}
// EventEdges holds the relations/edges for other nodes in the graph.
@ -66,7 +68,7 @@ func (*Event) scanValues(columns []string) ([]any, error) {
case event.FieldCreatedAt, event.FieldUpdatedAt, event.FieldTime:
values[i] = new(sql.NullTime)
default:
return nil, fmt.Errorf("unexpected column %q for type Event", columns[i])
values[i] = new(sql.UnknownType)
}
}
return values, nil
@ -118,21 +120,29 @@ func (e *Event) assignValues(columns []string, values []any) error {
} else if value.Valid {
e.AlertEvents = int(value.Int64)
}
default:
e.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the Event.
// This includes values selected through modifiers, order, etc.
func (e *Event) Value(name string) (ent.Value, error) {
return e.selectValues.Get(name)
}
// QueryOwner queries the "owner" edge of the Event entity.
func (e *Event) QueryOwner() *AlertQuery {
return (&EventClient{config: e.config}).QueryOwner(e)
return NewEventClient(e.config).QueryOwner(e)
}
// Update returns a builder for updating this Event.
// Note that you need to call Event.Unwrap() before calling this method if this Event
// was returned from a transaction, and the transaction was committed or rolled back.
func (e *Event) Update() *EventUpdateOne {
return (&EventClient{config: e.config}).UpdateOne(e)
return NewEventClient(e.config).UpdateOne(e)
}
// Unwrap unwraps the Event entity that was returned from a transaction after it was closed,
@ -175,9 +185,3 @@ func (e *Event) String() string {
// Events is a parsable slice of Event.
type Events []*Event
func (e Events) config(cfg config) {
for _i := range e {
e[_i].config = cfg
}
}

View file

@ -4,6 +4,9 @@ package event
import (
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
)
const (
@ -66,3 +69,50 @@ var (
// SerializedValidator is a validator for the "serialized" field. It is called by the builders before save.
SerializedValidator func(string) error
)
// OrderOption defines the ordering options for the Event queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByUpdatedAt orders the results by the updated_at field.
func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
}
// ByTime orders the results by the time field.
func ByTime(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTime, opts...).ToFunc()
}
// BySerialized orders the results by the serialized field.
func BySerialized(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSerialized, opts...).ToFunc()
}
// ByAlertEvents orders the results by the alert_events field.
func ByAlertEvents(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAlertEvents, opts...).ToFunc()
}
// ByOwnerField orders the results by owner field.
func ByOwnerField(field string, opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newOwnerStep(), sql.OrderByField(field, opts...))
}
}
func newOwnerStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(OwnerInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn),
)
}

View file

@ -12,477 +12,307 @@ import (
// ID filters vertices based on their ID field.
func ID(id int) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldID), id))
})
return predicate.Event(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id int) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldID), id))
})
return predicate.Event(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id int) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.NEQ(s.C(FieldID), id))
})
return predicate.Event(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...int) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
v := make([]any, len(ids))
for i := range v {
v[i] = ids[i]
}
s.Where(sql.In(s.C(FieldID), v...))
})
return predicate.Event(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...int) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
v := make([]any, len(ids))
for i := range v {
v[i] = ids[i]
}
s.Where(sql.NotIn(s.C(FieldID), v...))
})
return predicate.Event(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id int) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.GT(s.C(FieldID), id))
})
return predicate.Event(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id int) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.GTE(s.C(FieldID), id))
})
return predicate.Event(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id int) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.LT(s.C(FieldID), id))
})
return predicate.Event(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id int) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.LTE(s.C(FieldID), id))
})
return predicate.Event(sql.FieldLTE(FieldID, id))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldCreatedAt), v))
})
return predicate.Event(sql.FieldEQ(FieldCreatedAt, v))
}
// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
func UpdatedAt(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldUpdatedAt), v))
})
return predicate.Event(sql.FieldEQ(FieldUpdatedAt, v))
}
// Time applies equality check predicate on the "time" field. It's identical to TimeEQ.
func Time(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldTime), v))
})
return predicate.Event(sql.FieldEQ(FieldTime, v))
}
// Serialized applies equality check predicate on the "serialized" field. It's identical to SerializedEQ.
func Serialized(v string) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldSerialized), v))
})
return predicate.Event(sql.FieldEQ(FieldSerialized, v))
}
// AlertEvents applies equality check predicate on the "alert_events" field. It's identical to AlertEventsEQ.
func AlertEvents(v int) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldAlertEvents), v))
})
return predicate.Event(sql.FieldEQ(FieldAlertEvents, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldCreatedAt), v))
})
return predicate.Event(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.NEQ(s.C(FieldCreatedAt), v))
})
return predicate.Event(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.Event {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.In(s.C(FieldCreatedAt), v...))
})
return predicate.Event(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.Event {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.NotIn(s.C(FieldCreatedAt), v...))
})
return predicate.Event(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.GT(s.C(FieldCreatedAt), v))
})
return predicate.Event(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.GTE(s.C(FieldCreatedAt), v))
})
return predicate.Event(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.LT(s.C(FieldCreatedAt), v))
})
return predicate.Event(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.LTE(s.C(FieldCreatedAt), v))
})
return predicate.Event(sql.FieldLTE(FieldCreatedAt, v))
}
// CreatedAtIsNil applies the IsNil predicate on the "created_at" field.
func CreatedAtIsNil() predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.IsNull(s.C(FieldCreatedAt)))
})
return predicate.Event(sql.FieldIsNull(FieldCreatedAt))
}
// CreatedAtNotNil applies the NotNil predicate on the "created_at" field.
func CreatedAtNotNil() predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.NotNull(s.C(FieldCreatedAt)))
})
return predicate.Event(sql.FieldNotNull(FieldCreatedAt))
}
// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
func UpdatedAtEQ(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldUpdatedAt), v))
})
return predicate.Event(sql.FieldEQ(FieldUpdatedAt, v))
}
// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
func UpdatedAtNEQ(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.NEQ(s.C(FieldUpdatedAt), v))
})
return predicate.Event(sql.FieldNEQ(FieldUpdatedAt, v))
}
// UpdatedAtIn applies the In predicate on the "updated_at" field.
func UpdatedAtIn(vs ...time.Time) predicate.Event {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.In(s.C(FieldUpdatedAt), v...))
})
return predicate.Event(sql.FieldIn(FieldUpdatedAt, vs...))
}
// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
func UpdatedAtNotIn(vs ...time.Time) predicate.Event {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.NotIn(s.C(FieldUpdatedAt), v...))
})
return predicate.Event(sql.FieldNotIn(FieldUpdatedAt, vs...))
}
// UpdatedAtGT applies the GT predicate on the "updated_at" field.
func UpdatedAtGT(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.GT(s.C(FieldUpdatedAt), v))
})
return predicate.Event(sql.FieldGT(FieldUpdatedAt, v))
}
// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
func UpdatedAtGTE(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.GTE(s.C(FieldUpdatedAt), v))
})
return predicate.Event(sql.FieldGTE(FieldUpdatedAt, v))
}
// UpdatedAtLT applies the LT predicate on the "updated_at" field.
func UpdatedAtLT(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.LT(s.C(FieldUpdatedAt), v))
})
return predicate.Event(sql.FieldLT(FieldUpdatedAt, v))
}
// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
func UpdatedAtLTE(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.LTE(s.C(FieldUpdatedAt), v))
})
return predicate.Event(sql.FieldLTE(FieldUpdatedAt, v))
}
// UpdatedAtIsNil applies the IsNil predicate on the "updated_at" field.
func UpdatedAtIsNil() predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.IsNull(s.C(FieldUpdatedAt)))
})
return predicate.Event(sql.FieldIsNull(FieldUpdatedAt))
}
// UpdatedAtNotNil applies the NotNil predicate on the "updated_at" field.
func UpdatedAtNotNil() predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.NotNull(s.C(FieldUpdatedAt)))
})
return predicate.Event(sql.FieldNotNull(FieldUpdatedAt))
}
// TimeEQ applies the EQ predicate on the "time" field.
func TimeEQ(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldTime), v))
})
return predicate.Event(sql.FieldEQ(FieldTime, v))
}
// TimeNEQ applies the NEQ predicate on the "time" field.
func TimeNEQ(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.NEQ(s.C(FieldTime), v))
})
return predicate.Event(sql.FieldNEQ(FieldTime, v))
}
// TimeIn applies the In predicate on the "time" field.
func TimeIn(vs ...time.Time) predicate.Event {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.In(s.C(FieldTime), v...))
})
return predicate.Event(sql.FieldIn(FieldTime, vs...))
}
// TimeNotIn applies the NotIn predicate on the "time" field.
func TimeNotIn(vs ...time.Time) predicate.Event {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.NotIn(s.C(FieldTime), v...))
})
return predicate.Event(sql.FieldNotIn(FieldTime, vs...))
}
// TimeGT applies the GT predicate on the "time" field.
func TimeGT(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.GT(s.C(FieldTime), v))
})
return predicate.Event(sql.FieldGT(FieldTime, v))
}
// TimeGTE applies the GTE predicate on the "time" field.
func TimeGTE(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.GTE(s.C(FieldTime), v))
})
return predicate.Event(sql.FieldGTE(FieldTime, v))
}
// TimeLT applies the LT predicate on the "time" field.
func TimeLT(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.LT(s.C(FieldTime), v))
})
return predicate.Event(sql.FieldLT(FieldTime, v))
}
// TimeLTE applies the LTE predicate on the "time" field.
func TimeLTE(v time.Time) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.LTE(s.C(FieldTime), v))
})
return predicate.Event(sql.FieldLTE(FieldTime, v))
}
// SerializedEQ applies the EQ predicate on the "serialized" field.
func SerializedEQ(v string) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldSerialized), v))
})
return predicate.Event(sql.FieldEQ(FieldSerialized, v))
}
// SerializedNEQ applies the NEQ predicate on the "serialized" field.
func SerializedNEQ(v string) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.NEQ(s.C(FieldSerialized), v))
})
return predicate.Event(sql.FieldNEQ(FieldSerialized, v))
}
// SerializedIn applies the In predicate on the "serialized" field.
func SerializedIn(vs ...string) predicate.Event {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.In(s.C(FieldSerialized), v...))
})
return predicate.Event(sql.FieldIn(FieldSerialized, vs...))
}
// SerializedNotIn applies the NotIn predicate on the "serialized" field.
func SerializedNotIn(vs ...string) predicate.Event {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.NotIn(s.C(FieldSerialized), v...))
})
return predicate.Event(sql.FieldNotIn(FieldSerialized, vs...))
}
// SerializedGT applies the GT predicate on the "serialized" field.
func SerializedGT(v string) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.GT(s.C(FieldSerialized), v))
})
return predicate.Event(sql.FieldGT(FieldSerialized, v))
}
// SerializedGTE applies the GTE predicate on the "serialized" field.
func SerializedGTE(v string) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.GTE(s.C(FieldSerialized), v))
})
return predicate.Event(sql.FieldGTE(FieldSerialized, v))
}
// SerializedLT applies the LT predicate on the "serialized" field.
func SerializedLT(v string) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.LT(s.C(FieldSerialized), v))
})
return predicate.Event(sql.FieldLT(FieldSerialized, v))
}
// SerializedLTE applies the LTE predicate on the "serialized" field.
func SerializedLTE(v string) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.LTE(s.C(FieldSerialized), v))
})
return predicate.Event(sql.FieldLTE(FieldSerialized, v))
}
// SerializedContains applies the Contains predicate on the "serialized" field.
func SerializedContains(v string) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.Contains(s.C(FieldSerialized), v))
})
return predicate.Event(sql.FieldContains(FieldSerialized, v))
}
// SerializedHasPrefix applies the HasPrefix predicate on the "serialized" field.
func SerializedHasPrefix(v string) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.HasPrefix(s.C(FieldSerialized), v))
})
return predicate.Event(sql.FieldHasPrefix(FieldSerialized, v))
}
// SerializedHasSuffix applies the HasSuffix predicate on the "serialized" field.
func SerializedHasSuffix(v string) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.HasSuffix(s.C(FieldSerialized), v))
})
return predicate.Event(sql.FieldHasSuffix(FieldSerialized, v))
}
// SerializedEqualFold applies the EqualFold predicate on the "serialized" field.
func SerializedEqualFold(v string) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.EqualFold(s.C(FieldSerialized), v))
})
return predicate.Event(sql.FieldEqualFold(FieldSerialized, v))
}
// SerializedContainsFold applies the ContainsFold predicate on the "serialized" field.
func SerializedContainsFold(v string) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.ContainsFold(s.C(FieldSerialized), v))
})
return predicate.Event(sql.FieldContainsFold(FieldSerialized, v))
}
// AlertEventsEQ applies the EQ predicate on the "alert_events" field.
func AlertEventsEQ(v int) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldAlertEvents), v))
})
return predicate.Event(sql.FieldEQ(FieldAlertEvents, v))
}
// AlertEventsNEQ applies the NEQ predicate on the "alert_events" field.
func AlertEventsNEQ(v int) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.NEQ(s.C(FieldAlertEvents), v))
})
return predicate.Event(sql.FieldNEQ(FieldAlertEvents, v))
}
// AlertEventsIn applies the In predicate on the "alert_events" field.
func AlertEventsIn(vs ...int) predicate.Event {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.In(s.C(FieldAlertEvents), v...))
})
return predicate.Event(sql.FieldIn(FieldAlertEvents, vs...))
}
// AlertEventsNotIn applies the NotIn predicate on the "alert_events" field.
func AlertEventsNotIn(vs ...int) predicate.Event {
v := make([]any, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.NotIn(s.C(FieldAlertEvents), v...))
})
return predicate.Event(sql.FieldNotIn(FieldAlertEvents, vs...))
}
// AlertEventsIsNil applies the IsNil predicate on the "alert_events" field.
func AlertEventsIsNil() predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.IsNull(s.C(FieldAlertEvents)))
})
return predicate.Event(sql.FieldIsNull(FieldAlertEvents))
}
// AlertEventsNotNil applies the NotNil predicate on the "alert_events" field.
func AlertEventsNotNil() predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s.Where(sql.NotNull(s.C(FieldAlertEvents)))
})
return predicate.Event(sql.FieldNotNull(FieldAlertEvents))
}
// HasOwner applies the HasEdge predicate on the "owner" edge.
@ -490,7 +320,6 @@ func HasOwner() predicate.Event {
return predicate.Event(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(OwnerTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn),
)
sqlgraph.HasNeighbors(s, step)
@ -500,11 +329,7 @@ func HasOwner() predicate.Event {
// HasOwnerWith applies the HasEdge predicate on the "owner" edge with a given conditions (other predicates).
func HasOwnerWith(preds ...predicate.Alert) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
sqlgraph.To(OwnerInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, OwnerTable, OwnerColumn),
)
step := newOwnerStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds {
p(s)
@ -515,32 +340,15 @@ func HasOwnerWith(preds ...predicate.Alert) predicate.Event {
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.Event) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s1 := s.Clone().SetP(nil)
for _, p := range predicates {
p(s1)
}
s.Where(s1.P())
})
return predicate.Event(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.Event) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
s1 := s.Clone().SetP(nil)
for i, p := range predicates {
if i > 0 {
s1.Or()
}
p(s1)
}
s.Where(s1.P())
})
return predicate.Event(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.Event) predicate.Event {
return predicate.Event(func(s *sql.Selector) {
p(s.Not())
})
return predicate.Event(sql.NotPredicates(p))
}

View file

@ -101,50 +101,8 @@ func (ec *EventCreate) Mutation() *EventMutation {
// Save creates the Event in the database.
func (ec *EventCreate) Save(ctx context.Context) (*Event, error) {
var (
err error
node *Event
)
ec.defaults()
if len(ec.hooks) == 0 {
if err = ec.check(); err != nil {
return nil, err
}
node, err = ec.sqlSave(ctx)
} else {
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*EventMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
if err = ec.check(); err != nil {
return nil, err
}
ec.mutation = mutation
if node, err = ec.sqlSave(ctx); err != nil {
return nil, err
}
mutation.id = &node.ID
mutation.done = true
return node, err
})
for i := len(ec.hooks) - 1; i >= 0; i-- {
if ec.hooks[i] == nil {
return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = ec.hooks[i](mut)
}
v, err := mut.Mutate(ctx, ec.mutation)
if err != nil {
return nil, err
}
nv, ok := v.(*Event)
if !ok {
return nil, fmt.Errorf("unexpected node type %T returned from EventMutation", v)
}
node = nv
}
return node, err
return withHooks(ctx, ec.sqlSave, ec.mutation, ec.hooks)
}
// SaveX calls Save and panics if Save returns an error.
@ -198,6 +156,9 @@ func (ec *EventCreate) check() error {
}
func (ec *EventCreate) sqlSave(ctx context.Context) (*Event, error) {
if err := ec.check(); err != nil {
return nil, err
}
_node, _spec := ec.createSpec()
if err := sqlgraph.CreateNode(ctx, ec.driver, _spec); err != nil {
if sqlgraph.IsConstraintError(err) {
@ -207,50 +168,30 @@ func (ec *EventCreate) sqlSave(ctx context.Context) (*Event, error) {
}
id := _spec.ID.Value.(int64)
_node.ID = int(id)
ec.mutation.id = &_node.ID
ec.mutation.done = true
return _node, nil
}
func (ec *EventCreate) createSpec() (*Event, *sqlgraph.CreateSpec) {
var (
_node = &Event{config: ec.config}
_spec = &sqlgraph.CreateSpec{
Table: event.Table,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: event.FieldID,
},
}
_spec = sqlgraph.NewCreateSpec(event.Table, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt))
)
if value, ok := ec.mutation.CreatedAt(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: event.FieldCreatedAt,
})
_spec.SetField(event.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = &value
}
if value, ok := ec.mutation.UpdatedAt(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: event.FieldUpdatedAt,
})
_spec.SetField(event.FieldUpdatedAt, field.TypeTime, value)
_node.UpdatedAt = &value
}
if value, ok := ec.mutation.Time(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: event.FieldTime,
})
_spec.SetField(event.FieldTime, field.TypeTime, value)
_node.Time = value
}
if value, ok := ec.mutation.Serialized(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: event.FieldSerialized,
})
_spec.SetField(event.FieldSerialized, field.TypeString, value)
_node.Serialized = value
}
if nodes := ec.mutation.OwnerIDs(); len(nodes) > 0 {
@ -261,10 +202,7 @@ func (ec *EventCreate) createSpec() (*Event, *sqlgraph.CreateSpec) {
Columns: []string{event.OwnerColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: alert.FieldID,
},
IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt),
},
}
for _, k := range nodes {
@ -279,11 +217,15 @@ func (ec *EventCreate) createSpec() (*Event, *sqlgraph.CreateSpec) {
// EventCreateBulk is the builder for creating many Event entities in bulk.
type EventCreateBulk struct {
config
err error
builders []*EventCreate
}
// Save creates the Event entities in the database.
func (ecb *EventCreateBulk) Save(ctx context.Context) ([]*Event, error) {
if ecb.err != nil {
return nil, ecb.err
}
specs := make([]*sqlgraph.CreateSpec, len(ecb.builders))
nodes := make([]*Event, len(ecb.builders))
mutators := make([]Mutator, len(ecb.builders))
@ -300,8 +242,8 @@ func (ecb *EventCreateBulk) Save(ctx context.Context) ([]*Event, error) {
return nil, err
}
builder.mutation = mutation
nodes[i], specs[i] = builder.createSpec()
var err error
nodes[i], specs[i] = builder.createSpec()
if i < len(mutators)-1 {
_, err = mutators[i+1].Mutate(root, ecb.builders[i+1].mutation)
} else {

View file

@ -4,7 +4,6 @@ package ent
import (
"context"
"fmt"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
@ -28,34 +27,7 @@ func (ed *EventDelete) Where(ps ...predicate.Event) *EventDelete {
// Exec executes the deletion query and returns how many vertices were deleted.
func (ed *EventDelete) Exec(ctx context.Context) (int, error) {
var (
err error
affected int
)
if len(ed.hooks) == 0 {
affected, err = ed.sqlExec(ctx)
} else {
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*EventMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
ed.mutation = mutation
affected, err = ed.sqlExec(ctx)
mutation.done = true
return affected, err
})
for i := len(ed.hooks) - 1; i >= 0; i-- {
if ed.hooks[i] == nil {
return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = ed.hooks[i](mut)
}
if _, err := mut.Mutate(ctx, ed.mutation); err != nil {
return 0, err
}
}
return affected, err
return withHooks(ctx, ed.sqlExec, ed.mutation, ed.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
@ -68,15 +40,7 @@ func (ed *EventDelete) ExecX(ctx context.Context) int {
}
func (ed *EventDelete) sqlExec(ctx context.Context) (int, error) {
_spec := &sqlgraph.DeleteSpec{
Node: &sqlgraph.NodeSpec{
Table: event.Table,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: event.FieldID,
},
},
}
_spec := sqlgraph.NewDeleteSpec(event.Table, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt))
if ps := ed.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
@ -88,6 +52,7 @@ func (ed *EventDelete) sqlExec(ctx context.Context) (int, error) {
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
ed.mutation.done = true
return affected, err
}
@ -96,6 +61,12 @@ type EventDeleteOne struct {
ed *EventDelete
}
// Where appends a list predicates to the EventDelete builder.
func (edo *EventDeleteOne) Where(ps ...predicate.Event) *EventDeleteOne {
edo.ed.mutation.Where(ps...)
return edo
}
// Exec executes the deletion query.
func (edo *EventDeleteOne) Exec(ctx context.Context) error {
n, err := edo.ed.Exec(ctx)
@ -111,5 +82,7 @@ func (edo *EventDeleteOne) Exec(ctx context.Context) error {
// ExecX is like Exec, but panics if an error occurs.
func (edo *EventDeleteOne) ExecX(ctx context.Context) {
edo.ed.ExecX(ctx)
if err := edo.Exec(ctx); err != nil {
panic(err)
}
}

View file

@ -18,11 +18,9 @@ import (
// EventQuery is the builder for querying Event entities.
type EventQuery struct {
config
limit *int
offset *int
unique *bool
order []OrderFunc
fields []string
ctx *QueryContext
order []event.OrderOption
inters []Interceptor
predicates []predicate.Event
withOwner *AlertQuery
// intermediate query (i.e. traversal path).
@ -36,34 +34,34 @@ func (eq *EventQuery) Where(ps ...predicate.Event) *EventQuery {
return eq
}
// Limit adds a limit step to the query.
// Limit the number of records to be returned by this query.
func (eq *EventQuery) Limit(limit int) *EventQuery {
eq.limit = &limit
eq.ctx.Limit = &limit
return eq
}
// Offset adds an offset step to the query.
// Offset to start from.
func (eq *EventQuery) Offset(offset int) *EventQuery {
eq.offset = &offset
eq.ctx.Offset = &offset
return eq
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (eq *EventQuery) Unique(unique bool) *EventQuery {
eq.unique = &unique
eq.ctx.Unique = &unique
return eq
}
// Order adds an order step to the query.
func (eq *EventQuery) Order(o ...OrderFunc) *EventQuery {
// Order specifies how the records should be ordered.
func (eq *EventQuery) Order(o ...event.OrderOption) *EventQuery {
eq.order = append(eq.order, o...)
return eq
}
// QueryOwner chains the current query on the "owner" edge.
func (eq *EventQuery) QueryOwner() *AlertQuery {
query := &AlertQuery{config: eq.config}
query := (&AlertClient{config: eq.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := eq.prepareQuery(ctx); err != nil {
return nil, err
@ -86,7 +84,7 @@ func (eq *EventQuery) QueryOwner() *AlertQuery {
// First returns the first Event entity from the query.
// Returns a *NotFoundError when no Event was found.
func (eq *EventQuery) First(ctx context.Context) (*Event, error) {
nodes, err := eq.Limit(1).All(ctx)
nodes, err := eq.Limit(1).All(setContextOp(ctx, eq.ctx, "First"))
if err != nil {
return nil, err
}
@ -109,7 +107,7 @@ func (eq *EventQuery) FirstX(ctx context.Context) *Event {
// Returns a *NotFoundError when no Event ID was found.
func (eq *EventQuery) FirstID(ctx context.Context) (id int, err error) {
var ids []int
if ids, err = eq.Limit(1).IDs(ctx); err != nil {
if ids, err = eq.Limit(1).IDs(setContextOp(ctx, eq.ctx, "FirstID")); err != nil {
return
}
if len(ids) == 0 {
@ -132,7 +130,7 @@ func (eq *EventQuery) FirstIDX(ctx context.Context) int {
// Returns a *NotSingularError when more than one Event entity is found.
// Returns a *NotFoundError when no Event entities are found.
func (eq *EventQuery) Only(ctx context.Context) (*Event, error) {
nodes, err := eq.Limit(2).All(ctx)
nodes, err := eq.Limit(2).All(setContextOp(ctx, eq.ctx, "Only"))
if err != nil {
return nil, err
}
@ -160,7 +158,7 @@ func (eq *EventQuery) OnlyX(ctx context.Context) *Event {
// Returns a *NotFoundError when no entities are found.
func (eq *EventQuery) OnlyID(ctx context.Context) (id int, err error) {
var ids []int
if ids, err = eq.Limit(2).IDs(ctx); err != nil {
if ids, err = eq.Limit(2).IDs(setContextOp(ctx, eq.ctx, "OnlyID")); err != nil {
return
}
switch len(ids) {
@ -185,10 +183,12 @@ func (eq *EventQuery) OnlyIDX(ctx context.Context) int {
// All executes the query and returns a list of Events.
func (eq *EventQuery) All(ctx context.Context) ([]*Event, error) {
ctx = setContextOp(ctx, eq.ctx, "All")
if err := eq.prepareQuery(ctx); err != nil {
return nil, err
}
return eq.sqlAll(ctx)
qr := querierAll[[]*Event, *EventQuery]()
return withInterceptors[[]*Event](ctx, eq, qr, eq.inters)
}
// AllX is like All, but panics if an error occurs.
@ -201,9 +201,12 @@ func (eq *EventQuery) AllX(ctx context.Context) []*Event {
}
// IDs executes the query and returns a list of Event IDs.
func (eq *EventQuery) IDs(ctx context.Context) ([]int, error) {
var ids []int
if err := eq.Select(event.FieldID).Scan(ctx, &ids); err != nil {
func (eq *EventQuery) IDs(ctx context.Context) (ids []int, err error) {
if eq.ctx.Unique == nil && eq.path != nil {
eq.Unique(true)
}
ctx = setContextOp(ctx, eq.ctx, "IDs")
if err = eq.Select(event.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
@ -220,10 +223,11 @@ func (eq *EventQuery) IDsX(ctx context.Context) []int {
// Count returns the count of the given query.
func (eq *EventQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, eq.ctx, "Count")
if err := eq.prepareQuery(ctx); err != nil {
return 0, err
}
return eq.sqlCount(ctx)
return withInterceptors[int](ctx, eq, querierCount[*EventQuery](), eq.inters)
}
// CountX is like Count, but panics if an error occurs.
@ -237,10 +241,15 @@ func (eq *EventQuery) CountX(ctx context.Context) int {
// Exist returns true if the query has elements in the graph.
func (eq *EventQuery) Exist(ctx context.Context) (bool, error) {
if err := eq.prepareQuery(ctx); err != nil {
return false, err
ctx = setContextOp(ctx, eq.ctx, "Exist")
switch _, err := eq.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
return eq.sqlExist(ctx)
}
// ExistX is like Exist, but panics if an error occurs.
@ -260,22 +269,21 @@ func (eq *EventQuery) Clone() *EventQuery {
}
return &EventQuery{
config: eq.config,
limit: eq.limit,
offset: eq.offset,
order: append([]OrderFunc{}, eq.order...),
ctx: eq.ctx.Clone(),
order: append([]event.OrderOption{}, eq.order...),
inters: append([]Interceptor{}, eq.inters...),
predicates: append([]predicate.Event{}, eq.predicates...),
withOwner: eq.withOwner.Clone(),
// clone intermediate query.
sql: eq.sql.Clone(),
path: eq.path,
unique: eq.unique,
sql: eq.sql.Clone(),
path: eq.path,
}
}
// WithOwner tells the query-builder to eager-load the nodes that are connected to
// the "owner" edge. The optional arguments are used to configure the query builder of the edge.
func (eq *EventQuery) WithOwner(opts ...func(*AlertQuery)) *EventQuery {
query := &AlertQuery{config: eq.config}
query := (&AlertClient{config: eq.config}).Query()
for _, opt := range opts {
opt(query)
}
@ -298,16 +306,11 @@ func (eq *EventQuery) WithOwner(opts ...func(*AlertQuery)) *EventQuery {
// Aggregate(ent.Count()).
// Scan(ctx, &v)
func (eq *EventQuery) GroupBy(field string, fields ...string) *EventGroupBy {
grbuild := &EventGroupBy{config: eq.config}
grbuild.fields = append([]string{field}, fields...)
grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) {
if err := eq.prepareQuery(ctx); err != nil {
return nil, err
}
return eq.sqlQuery(ctx), nil
}
eq.ctx.Fields = append([]string{field}, fields...)
grbuild := &EventGroupBy{build: eq}
grbuild.flds = &eq.ctx.Fields
grbuild.label = event.Label
grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan
grbuild.scan = grbuild.Scan
return grbuild
}
@ -324,15 +327,30 @@ func (eq *EventQuery) GroupBy(field string, fields ...string) *EventGroupBy {
// Select(event.FieldCreatedAt).
// Scan(ctx, &v)
func (eq *EventQuery) Select(fields ...string) *EventSelect {
eq.fields = append(eq.fields, fields...)
selbuild := &EventSelect{EventQuery: eq}
selbuild.label = event.Label
selbuild.flds, selbuild.scan = &eq.fields, selbuild.Scan
return selbuild
eq.ctx.Fields = append(eq.ctx.Fields, fields...)
sbuild := &EventSelect{EventQuery: eq}
sbuild.label = event.Label
sbuild.flds, sbuild.scan = &eq.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a EventSelect configured with the given aggregations.
func (eq *EventQuery) Aggregate(fns ...AggregateFunc) *EventSelect {
return eq.Select().Aggregate(fns...)
}
func (eq *EventQuery) prepareQuery(ctx context.Context) error {
for _, f := range eq.fields {
for _, inter := range eq.inters {
if inter == nil {
return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, eq); err != nil {
return err
}
}
}
for _, f := range eq.ctx.Fields {
if !event.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
}
@ -392,6 +410,9 @@ func (eq *EventQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes []
}
nodeids[fk] = append(nodeids[fk], nodes[i])
}
if len(ids) == 0 {
return nil
}
query.Where(alert.IDIn(ids...))
neighbors, err := query.All(ctx)
if err != nil {
@ -411,41 +432,22 @@ func (eq *EventQuery) loadOwner(ctx context.Context, query *AlertQuery, nodes []
func (eq *EventQuery) sqlCount(ctx context.Context) (int, error) {
_spec := eq.querySpec()
_spec.Node.Columns = eq.fields
if len(eq.fields) > 0 {
_spec.Unique = eq.unique != nil && *eq.unique
_spec.Node.Columns = eq.ctx.Fields
if len(eq.ctx.Fields) > 0 {
_spec.Unique = eq.ctx.Unique != nil && *eq.ctx.Unique
}
return sqlgraph.CountNodes(ctx, eq.driver, _spec)
}
func (eq *EventQuery) sqlExist(ctx context.Context) (bool, error) {
switch _, err := eq.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("ent: check existence: %w", err)
default:
return true, nil
}
}
func (eq *EventQuery) querySpec() *sqlgraph.QuerySpec {
_spec := &sqlgraph.QuerySpec{
Node: &sqlgraph.NodeSpec{
Table: event.Table,
Columns: event.Columns,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: event.FieldID,
},
},
From: eq.sql,
Unique: true,
}
if unique := eq.unique; unique != nil {
_spec := sqlgraph.NewQuerySpec(event.Table, event.Columns, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt))
_spec.From = eq.sql
if unique := eq.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if eq.path != nil {
_spec.Unique = true
}
if fields := eq.fields; len(fields) > 0 {
if fields := eq.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, event.FieldID)
for i := range fields {
@ -453,6 +455,9 @@ func (eq *EventQuery) querySpec() *sqlgraph.QuerySpec {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
if eq.withOwner != nil {
_spec.Node.AddColumnOnce(event.FieldAlertEvents)
}
}
if ps := eq.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
@ -461,10 +466,10 @@ func (eq *EventQuery) querySpec() *sqlgraph.QuerySpec {
}
}
}
if limit := eq.limit; limit != nil {
if limit := eq.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := eq.offset; offset != nil {
if offset := eq.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := eq.order; len(ps) > 0 {
@ -480,7 +485,7 @@ func (eq *EventQuery) querySpec() *sqlgraph.QuerySpec {
func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(eq.driver.Dialect())
t1 := builder.Table(event.Table)
columns := eq.fields
columns := eq.ctx.Fields
if len(columns) == 0 {
columns = event.Columns
}
@ -489,7 +494,7 @@ func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector {
selector = eq.sql
selector.Select(selector.Columns(columns...)...)
}
if eq.unique != nil && *eq.unique {
if eq.ctx.Unique != nil && *eq.ctx.Unique {
selector.Distinct()
}
for _, p := range eq.predicates {
@ -498,12 +503,12 @@ func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector {
for _, p := range eq.order {
p(selector)
}
if offset := eq.offset; offset != nil {
if offset := eq.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := eq.limit; limit != nil {
if limit := eq.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
@ -511,13 +516,8 @@ func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector {
// EventGroupBy is the group-by builder for Event entities.
type EventGroupBy struct {
config
selector
fields []string
fns []AggregateFunc
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
build *EventQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
@ -526,74 +526,77 @@ func (egb *EventGroupBy) Aggregate(fns ...AggregateFunc) *EventGroupBy {
return egb
}
// Scan applies the group-by query and scans the result into the given value.
// Scan applies the selector query and scans the result into the given value.
func (egb *EventGroupBy) Scan(ctx context.Context, v any) error {
query, err := egb.path(ctx)
if err != nil {
ctx = setContextOp(ctx, egb.build.ctx, "GroupBy")
if err := egb.build.prepareQuery(ctx); err != nil {
return err
}
egb.sql = query
return egb.sqlScan(ctx, v)
return scanWithInterceptors[*EventQuery, *EventGroupBy](ctx, egb.build, egb, egb.build.inters, v)
}
func (egb *EventGroupBy) sqlScan(ctx context.Context, v any) error {
for _, f := range egb.fields {
if !event.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)}
}
}
selector := egb.sqlQuery()
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := egb.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
func (egb *EventGroupBy) sqlQuery() *sql.Selector {
selector := egb.sql.Select()
func (egb *EventGroupBy) sqlScan(ctx context.Context, root *EventQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(egb.fns))
for _, fn := range egb.fns {
aggregation = append(aggregation, fn(selector))
}
// If no columns were selected in a custom aggregation function, the default
// selection is the fields used for "group-by", and the aggregation functions.
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(egb.fields)+len(egb.fns))
for _, f := range egb.fields {
columns := make([]string, 0, len(*egb.flds)+len(egb.fns))
for _, f := range *egb.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
return selector.GroupBy(selector.Columns(egb.fields...)...)
selector.GroupBy(selector.Columns(*egb.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := egb.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// EventSelect is the builder for selecting fields of Event entities.
type EventSelect struct {
*EventQuery
selector
// intermediate query (i.e. traversal path).
sql *sql.Selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (es *EventSelect) Aggregate(fns ...AggregateFunc) *EventSelect {
es.fns = append(es.fns, fns...)
return es
}
// Scan applies the selector query and scans the result into the given value.
func (es *EventSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, es.ctx, "Select")
if err := es.prepareQuery(ctx); err != nil {
return err
}
es.sql = es.EventQuery.sqlQuery(ctx)
return es.sqlScan(ctx, v)
return scanWithInterceptors[*EventQuery, *EventSelect](ctx, es.EventQuery, es, es.inters, v)
}
func (es *EventSelect) sqlScan(ctx context.Context, v any) error {
func (es *EventSelect) sqlScan(ctx context.Context, root *EventQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(es.fns))
for _, fn := range es.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*es.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := es.sql.Query()
query, args := selector.Query()
if err := es.driver.Query(ctx, query, args, rows); err != nil {
return err
}

View file

@ -117,41 +117,8 @@ func (eu *EventUpdate) ClearOwner() *EventUpdate {
// Save executes the query and returns the number of nodes affected by the update operation.
func (eu *EventUpdate) Save(ctx context.Context) (int, error) {
var (
err error
affected int
)
eu.defaults()
if len(eu.hooks) == 0 {
if err = eu.check(); err != nil {
return 0, err
}
affected, err = eu.sqlSave(ctx)
} else {
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*EventMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
if err = eu.check(); err != nil {
return 0, err
}
eu.mutation = mutation
affected, err = eu.sqlSave(ctx)
mutation.done = true
return affected, err
})
for i := len(eu.hooks) - 1; i >= 0; i-- {
if eu.hooks[i] == nil {
return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = eu.hooks[i](mut)
}
if _, err := mut.Mutate(ctx, eu.mutation); err != nil {
return 0, err
}
}
return affected, err
return withHooks(ctx, eu.sqlSave, eu.mutation, eu.hooks)
}
// SaveX is like Save, but panics if an error occurs.
@ -199,16 +166,10 @@ func (eu *EventUpdate) check() error {
}
func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) {
_spec := &sqlgraph.UpdateSpec{
Node: &sqlgraph.NodeSpec{
Table: event.Table,
Columns: event.Columns,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: event.FieldID,
},
},
if err := eu.check(); err != nil {
return n, err
}
_spec := sqlgraph.NewUpdateSpec(event.Table, event.Columns, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt))
if ps := eu.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
@ -217,44 +178,22 @@ func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
}
if value, ok := eu.mutation.CreatedAt(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: event.FieldCreatedAt,
})
_spec.SetField(event.FieldCreatedAt, field.TypeTime, value)
}
if eu.mutation.CreatedAtCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: event.FieldCreatedAt,
})
_spec.ClearField(event.FieldCreatedAt, field.TypeTime)
}
if value, ok := eu.mutation.UpdatedAt(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: event.FieldUpdatedAt,
})
_spec.SetField(event.FieldUpdatedAt, field.TypeTime, value)
}
if eu.mutation.UpdatedAtCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: event.FieldUpdatedAt,
})
_spec.ClearField(event.FieldUpdatedAt, field.TypeTime)
}
if value, ok := eu.mutation.Time(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: event.FieldTime,
})
_spec.SetField(event.FieldTime, field.TypeTime, value)
}
if value, ok := eu.mutation.Serialized(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: event.FieldSerialized,
})
_spec.SetField(event.FieldSerialized, field.TypeString, value)
}
if eu.mutation.OwnerCleared() {
edge := &sqlgraph.EdgeSpec{
@ -264,10 +203,7 @@ func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) {
Columns: []string{event.OwnerColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: alert.FieldID,
},
IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
@ -280,10 +216,7 @@ func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) {
Columns: []string{event.OwnerColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: alert.FieldID,
},
IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt),
},
}
for _, k := range nodes {
@ -299,6 +232,7 @@ func (eu *EventUpdate) sqlSave(ctx context.Context) (n int, err error) {
}
return 0, err
}
eu.mutation.done = true
return n, nil
}
@ -396,6 +330,12 @@ func (euo *EventUpdateOne) ClearOwner() *EventUpdateOne {
return euo
}
// Where appends a list predicates to the EventUpdate builder.
func (euo *EventUpdateOne) Where(ps ...predicate.Event) *EventUpdateOne {
euo.mutation.Where(ps...)
return euo
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (euo *EventUpdateOne) Select(field string, fields ...string) *EventUpdateOne {
@ -405,47 +345,8 @@ func (euo *EventUpdateOne) Select(field string, fields ...string) *EventUpdateOn
// Save executes the query and returns the updated Event entity.
func (euo *EventUpdateOne) Save(ctx context.Context) (*Event, error) {
var (
err error
node *Event
)
euo.defaults()
if len(euo.hooks) == 0 {
if err = euo.check(); err != nil {
return nil, err
}
node, err = euo.sqlSave(ctx)
} else {
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*EventMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
if err = euo.check(); err != nil {
return nil, err
}
euo.mutation = mutation
node, err = euo.sqlSave(ctx)
mutation.done = true
return node, err
})
for i := len(euo.hooks) - 1; i >= 0; i-- {
if euo.hooks[i] == nil {
return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)")
}
mut = euo.hooks[i](mut)
}
v, err := mut.Mutate(ctx, euo.mutation)
if err != nil {
return nil, err
}
nv, ok := v.(*Event)
if !ok {
return nil, fmt.Errorf("unexpected node type %T returned from EventMutation", v)
}
node = nv
}
return node, err
return withHooks(ctx, euo.sqlSave, euo.mutation, euo.hooks)
}
// SaveX is like Save, but panics if an error occurs.
@ -493,16 +394,10 @@ func (euo *EventUpdateOne) check() error {
}
func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error) {
_spec := &sqlgraph.UpdateSpec{
Node: &sqlgraph.NodeSpec{
Table: event.Table,
Columns: event.Columns,
ID: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: event.FieldID,
},
},
if err := euo.check(); err != nil {
return _node, err
}
_spec := sqlgraph.NewUpdateSpec(event.Table, event.Columns, sqlgraph.NewFieldSpec(event.FieldID, field.TypeInt))
id, ok := euo.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Event.id" for update`)}
@ -528,44 +423,22 @@ func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error
}
}
if value, ok := euo.mutation.CreatedAt(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: event.FieldCreatedAt,
})
_spec.SetField(event.FieldCreatedAt, field.TypeTime, value)
}
if euo.mutation.CreatedAtCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: event.FieldCreatedAt,
})
_spec.ClearField(event.FieldCreatedAt, field.TypeTime)
}
if value, ok := euo.mutation.UpdatedAt(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: event.FieldUpdatedAt,
})
_spec.SetField(event.FieldUpdatedAt, field.TypeTime, value)
}
if euo.mutation.UpdatedAtCleared() {
_spec.Fields.Clear = append(_spec.Fields.Clear, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Column: event.FieldUpdatedAt,
})
_spec.ClearField(event.FieldUpdatedAt, field.TypeTime)
}
if value, ok := euo.mutation.Time(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeTime,
Value: value,
Column: event.FieldTime,
})
_spec.SetField(event.FieldTime, field.TypeTime, value)
}
if value, ok := euo.mutation.Serialized(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: event.FieldSerialized,
})
_spec.SetField(event.FieldSerialized, field.TypeString, value)
}
if euo.mutation.OwnerCleared() {
edge := &sqlgraph.EdgeSpec{
@ -575,10 +448,7 @@ func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error
Columns: []string{event.OwnerColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: alert.FieldID,
},
IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt),
},
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
@ -591,10 +461,7 @@ func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error
Columns: []string{event.OwnerColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: &sqlgraph.FieldSpec{
Type: field.TypeInt,
Column: alert.FieldID,
},
IDSpec: sqlgraph.NewFieldSpec(alert.FieldID, field.TypeInt),
},
}
for _, k := range nodes {
@ -613,5 +480,6 @@ func (euo *EventUpdateOne) sqlSave(ctx context.Context) (_node *Event, err error
}
return nil, err
}
euo.mutation.done = true
return _node, nil
}

View file

@ -15,11 +15,10 @@ type AlertFunc func(context.Context, *ent.AlertMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f AlertFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
mv, ok := m.(*ent.AlertMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AlertMutation", m)
if mv, ok := m.(*ent.AlertMutation); ok {
return f(ctx, mv)
}
return f(ctx, mv)
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AlertMutation", m)
}
// The BouncerFunc type is an adapter to allow the use of ordinary
@ -28,11 +27,10 @@ type BouncerFunc func(context.Context, *ent.BouncerMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f BouncerFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
mv, ok := m.(*ent.BouncerMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.BouncerMutation", m)
if mv, ok := m.(*ent.BouncerMutation); ok {
return f(ctx, mv)
}
return f(ctx, mv)
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.BouncerMutation", m)
}
// The ConfigItemFunc type is an adapter to allow the use of ordinary
@ -41,11 +39,10 @@ type ConfigItemFunc func(context.Context, *ent.ConfigItemMutation) (ent.Value, e
// Mutate calls f(ctx, m).
func (f ConfigItemFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
mv, ok := m.(*ent.ConfigItemMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ConfigItemMutation", m)
if mv, ok := m.(*ent.ConfigItemMutation); ok {
return f(ctx, mv)
}
return f(ctx, mv)
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ConfigItemMutation", m)
}
// The DecisionFunc type is an adapter to allow the use of ordinary
@ -54,11 +51,10 @@ type DecisionFunc func(context.Context, *ent.DecisionMutation) (ent.Value, error
// Mutate calls f(ctx, m).
func (f DecisionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
mv, ok := m.(*ent.DecisionMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.DecisionMutation", m)
if mv, ok := m.(*ent.DecisionMutation); ok {
return f(ctx, mv)
}
return f(ctx, mv)
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.DecisionMutation", m)
}
// The EventFunc type is an adapter to allow the use of ordinary
@ -67,11 +63,22 @@ type EventFunc func(context.Context, *ent.EventMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f EventFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
mv, ok := m.(*ent.EventMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.EventMutation", m)
if mv, ok := m.(*ent.EventMutation); ok {
return f(ctx, mv)
}
return f(ctx, mv)
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.EventMutation", m)
}
// The LockFunc type is an adapter to allow the use of ordinary
// function as Lock mutator.
type LockFunc func(context.Context, *ent.LockMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f LockFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
if mv, ok := m.(*ent.LockMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.LockMutation", m)
}
// The MachineFunc type is an adapter to allow the use of ordinary
@ -80,11 +87,10 @@ type MachineFunc func(context.Context, *ent.MachineMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f MachineFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
mv, ok := m.(*ent.MachineMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MachineMutation", m)
if mv, ok := m.(*ent.MachineMutation); ok {
return f(ctx, mv)
}
return f(ctx, mv)
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MachineMutation", m)
}
// The MetaFunc type is an adapter to allow the use of ordinary
@ -93,11 +99,10 @@ type MetaFunc func(context.Context, *ent.MetaMutation) (ent.Value, error)
// Mutate calls f(ctx, m).
func (f MetaFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
mv, ok := m.(*ent.MetaMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MetaMutation", m)
if mv, ok := m.(*ent.MetaMutation); ok {
return f(ctx, mv)
}
return f(ctx, mv)
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.MetaMutation", m)
}
// Condition is a hook condition function.

117
pkg/database/ent/lock.go Normal file
View file

@ -0,0 +1,117 @@
// Code generated by ent, DO NOT EDIT.
package ent
import (
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/lock"
)
// Lock is the model entity for the Lock schema.
type Lock struct {
config `json:"-"`
// ID of the ent.
ID int `json:"id,omitempty"`
// Name holds the value of the "name" field.
Name string `json:"name"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at"`
selectValues sql.SelectValues
}
// scanValues returns the types for scanning values from sql.Rows.
func (*Lock) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case lock.FieldID:
values[i] = new(sql.NullInt64)
case lock.FieldName:
values[i] = new(sql.NullString)
case lock.FieldCreatedAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the Lock fields.
func (l *Lock) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case lock.FieldID:
value, ok := values[i].(*sql.NullInt64)
if !ok {
return fmt.Errorf("unexpected type %T for field id", value)
}
l.ID = int(value.Int64)
case lock.FieldName:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field name", values[i])
} else if value.Valid {
l.Name = value.String
}
case lock.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
l.CreatedAt = value.Time
}
default:
l.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the Lock.
// This includes values selected through modifiers, order, etc.
func (l *Lock) Value(name string) (ent.Value, error) {
return l.selectValues.Get(name)
}
// Update returns a builder for updating this Lock.
// Note that you need to call Lock.Unwrap() before calling this method if this Lock
// was returned from a transaction, and the transaction was committed or rolled back.
func (l *Lock) Update() *LockUpdateOne {
return NewLockClient(l.config).UpdateOne(l)
}
// Unwrap unwraps the Lock entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (l *Lock) Unwrap() *Lock {
_tx, ok := l.config.driver.(*txDriver)
if !ok {
panic("ent: Lock is not a transactional entity")
}
l.config.driver = _tx.drv
return l
}
// String implements the fmt.Stringer.
func (l *Lock) String() string {
var builder strings.Builder
builder.WriteString("Lock(")
builder.WriteString(fmt.Sprintf("id=%v, ", l.ID))
builder.WriteString("name=")
builder.WriteString(l.Name)
builder.WriteString(", ")
builder.WriteString("created_at=")
builder.WriteString(l.CreatedAt.Format(time.ANSIC))
builder.WriteByte(')')
return builder.String()
}
// Locks is a parsable slice of Lock.
type Locks []*Lock

Some files were not shown because too many files have changed in this diff Show more