Compare commits
4 commits
master
...
fix-decisi
Author | SHA1 | Date | |
---|---|---|---|
|
319957e134 | ||
|
42d5dd8e03 | ||
|
ecece8f4a5 | ||
|
aabd258f2b |
9 changed files with 291 additions and 36 deletions
|
@ -47,6 +47,7 @@ func LoadTestConfig() csconfig.Config {
|
|||
Type: "sqlite",
|
||||
DbPath: filepath.Join(tempDir, "ent"),
|
||||
Flush: &flushConfig,
|
||||
//LogLevel: &log.AllLevels[log.DebugLevel],
|
||||
}
|
||||
apiServerConfig := csconfig.LocalApiServerCfg{
|
||||
ListenURI: "http://127.0.0.1:8080",
|
||||
|
|
|
@ -281,7 +281,6 @@ func TestStreamStartDecisionDedup(t *testing.T) {
|
|||
assert.Equal(t, int64(2), decisions["new"][0].ID)
|
||||
assert.Equal(t, "test", *decisions["new"][0].Origin)
|
||||
assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value)
|
||||
|
||||
// We delete another decision, yet don't receive it in stream, since there's another decision on same IP
|
||||
w = lapi.RecordResponse("DELETE", "/v1/decisions/2", emptyBody, PASSWORD)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
@ -1012,6 +1011,104 @@ func TestStreamDecision(t *testing.T) {
|
|||
NewChecks: []DecisionCheck{},
|
||||
},
|
||||
},
|
||||
"test startup with scenarios containing": {
|
||||
{
|
||||
TestName: "get stream",
|
||||
Method: "GET",
|
||||
Route: "/v1/decisions/stream?startup=true&scenarios_containing=ssh_bf",
|
||||
CheckCodeOnly: false,
|
||||
Code: 200,
|
||||
LenNew: 2,
|
||||
LenDeleted: 0,
|
||||
AuthType: APIKEY,
|
||||
DelChecks: []DecisionCheck{},
|
||||
NewChecks: []DecisionCheck{
|
||||
{
|
||||
ID: int64(2),
|
||||
Origin: "another_origin",
|
||||
Scenario: "crowdsecurity/ssh_bf",
|
||||
Value: "127.0.0.1",
|
||||
Duration: "2h59",
|
||||
Type: "ban",
|
||||
},
|
||||
{
|
||||
ID: int64(5),
|
||||
Origin: "test",
|
||||
Scenario: "crowdsecurity/ssh_bf",
|
||||
Value: "127.0.0.2",
|
||||
Duration: "2h59",
|
||||
Type: "ban",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
TestName: "delete decisions 3 (127.0.0.1)",
|
||||
Method: "DELETE",
|
||||
Route: "/v1/decisions/3",
|
||||
CheckCodeOnly: true,
|
||||
Code: 200,
|
||||
LenNew: 0,
|
||||
LenDeleted: 0,
|
||||
AuthType: PASSWORD,
|
||||
DelChecks: []DecisionCheck{},
|
||||
NewChecks: []DecisionCheck{},
|
||||
},
|
||||
{
|
||||
TestName: "check that 127.0.0.1 is not in deleted IP",
|
||||
Method: "GET",
|
||||
Route: "/v1/decisions/stream?startup=true&scenarios_containing=ssh_bf",
|
||||
CheckCodeOnly: false,
|
||||
Code: 200,
|
||||
LenNew: 2,
|
||||
LenDeleted: 0,
|
||||
AuthType: APIKEY,
|
||||
DelChecks: []DecisionCheck{},
|
||||
NewChecks: []DecisionCheck{},
|
||||
},
|
||||
{
|
||||
TestName: "delete decisions 2 (127.0.0.1)",
|
||||
Method: "DELETE",
|
||||
Route: "/v1/decisions/2",
|
||||
CheckCodeOnly: true,
|
||||
Code: 200,
|
||||
LenNew: 0,
|
||||
LenDeleted: 0,
|
||||
AuthType: PASSWORD,
|
||||
DelChecks: []DecisionCheck{},
|
||||
NewChecks: []DecisionCheck{},
|
||||
},
|
||||
{
|
||||
TestName: "check that 127.0.0.1 is deleted (decision for ssh_bf was with ID 2)",
|
||||
Method: "GET",
|
||||
Route: "/v1/decisions/stream?startup=true&scenarios_containing=ssh_bf",
|
||||
CheckCodeOnly: false,
|
||||
Code: 200,
|
||||
LenNew: 1,
|
||||
LenDeleted: 1,
|
||||
AuthType: APIKEY,
|
||||
DelChecks: []DecisionCheck{
|
||||
{
|
||||
ID: int64(2),
|
||||
Origin: "another_origin",
|
||||
Scenario: "crowdsecurity/ssh_bf",
|
||||
Value: "127.0.0.1",
|
||||
Duration: "-",
|
||||
|
||||
Type: "ban",
|
||||
},
|
||||
},
|
||||
NewChecks: []DecisionCheck{
|
||||
{
|
||||
ID: int64(5),
|
||||
Origin: "test",
|
||||
Scenario: "crowdsecurity/ssh_bf",
|
||||
Value: "127.0.0.2",
|
||||
Duration: "2h59",
|
||||
Type: "ban",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"test with scenarios containing": {
|
||||
{
|
||||
TestName: "get stream",
|
||||
|
|
|
@ -43,7 +43,7 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string]
|
|||
} else {
|
||||
query = query.Where(decision.SimulatedEQ(false))
|
||||
}
|
||||
t := sql.Table(decision.Table)
|
||||
t := sql.Table(decision.Table).As("t1")
|
||||
joinPredicate := make([]*sql.Predicate, 0)
|
||||
for param, value := range filter {
|
||||
switch param {
|
||||
|
@ -199,33 +199,55 @@ func (c *Client) QueryDecisionCountByScenario(filters map[string][]string) ([]*D
|
|||
func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) {
|
||||
now := time.Now().UTC()
|
||||
query := c.Ent.Decision.Query().Where(
|
||||
decision.UntilLT(time.Now().UTC()),
|
||||
decision.UntilLTE(now),
|
||||
)
|
||||
query, predicates, err := BuildDecisionRequestWithFilter(query, filters)
|
||||
if err != nil {
|
||||
c.Log.Warningf("QueryExpiredDecisionsWithFilters : %s", err)
|
||||
return []*ent.Decision{}, errors.Wrap(QueryFail, "get expired decisions with filters")
|
||||
}
|
||||
query = query.Where(func(s *sql.Selector) {
|
||||
|
||||
/*query = query.Where(func(s *sql.Selector) {
|
||||
t := sql.Table(decision.Table).As("t1")
|
||||
|
||||
subQuery := sql.Select(t.C(decision.FieldValue)).From(t).Where(sql.GT(t.C(decision.FieldUntil), now))
|
||||
for _, predicate := range predicates {
|
||||
subQuery.Where(predicate)
|
||||
subquery := sql.Select(s.C(decision.FieldValue)).From(t)
|
||||
for _, pred := range predicates {
|
||||
subquery.Where(pred)
|
||||
}
|
||||
subQuery.Where(sql.And(
|
||||
sql.ColumnsEQ(t.C(decision.FieldType), s.C(decision.FieldType)),
|
||||
sql.ColumnsEQ(t.C(decision.FieldScope), s.C(decision.FieldScope)),
|
||||
))
|
||||
s.Where(
|
||||
sql.NotIn(
|
||||
s.C(decision.FieldValue),
|
||||
subQuery,
|
||||
|
||||
subquery = subquery.Where(
|
||||
sql.And(
|
||||
sql.ColumnsEQ(s.C(decision.FieldScope), t.C(decision.FieldScope)),
|
||||
sql.ColumnsEQ(s.C(decision.FieldType), t.C(decision.FieldType)),
|
||||
sql.ColumnsEQ(s.C(decision.FieldValue), t.C(decision.FieldValue)),
|
||||
sql.GT(t.C(decision.FieldUntil), now),
|
||||
),
|
||||
)
|
||||
s.Where(sql.NotExists(subquery))
|
||||
})
|
||||
|
||||
data, err := query.Order(ent.Asc(decision.FieldValue), ent.Desc(decision.FieldUntil)).All(c.CTX)
|
||||
data, err := query.Order(ent.Asc(decision.FieldValue), ent.Desc(decision.FieldUntil)).All(c.CTX)*/
|
||||
|
||||
query.Modify(func(s *sql.Selector) {
|
||||
t := sql.Table(decision.Table).As("t1")
|
||||
p := []*sql.Predicate{
|
||||
sql.ColumnsEQ(s.C(decision.FieldScope), t.C(decision.FieldScope)),
|
||||
sql.ColumnsEQ(s.C(decision.FieldType), t.C(decision.FieldType)),
|
||||
sql.ColumnsEQ(s.C(decision.FieldValue), t.C(decision.FieldValue)),
|
||||
sql.GTE(t.C(decision.FieldUntil), now),
|
||||
}
|
||||
p = append(p, predicates...)
|
||||
s.LeftJoin(t).
|
||||
OnP(
|
||||
sql.And(
|
||||
p...,
|
||||
)).
|
||||
GroupBy(s.C(decision.FieldValue)).
|
||||
Where(sql.IsNull(t.C(decision.FieldValue)))
|
||||
})
|
||||
|
||||
data, err := query.All(c.CTX)
|
||||
|
||||
if err != nil {
|
||||
c.Log.Warningf("QueryExpiredDecisionsWithFilters : %s", err)
|
||||
return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions")
|
||||
|
@ -242,37 +264,40 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(since time.Time, filters
|
|||
now := time.Now().UTC()
|
||||
|
||||
query := c.Ent.Decision.Query().Where(
|
||||
decision.UntilGT(since),
|
||||
decision.UntilGTE(since),
|
||||
decision.UntilLTE(now),
|
||||
)
|
||||
query, _, err := BuildDecisionRequestWithFilter(query, filters)
|
||||
query, predicates, err := BuildDecisionRequestWithFilter(query, filters)
|
||||
if err != nil {
|
||||
c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err)
|
||||
return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions with filters")
|
||||
}
|
||||
|
||||
data, err := query.Order(ent.Asc(decision.FieldValue), ent.Asc(decision.FieldUntil)).All(c.CTX)
|
||||
query.Modify(func(s *sql.Selector) {
|
||||
t := sql.Table(decision.Table).As("t1")
|
||||
p := []*sql.Predicate{
|
||||
sql.ColumnsEQ(s.C(decision.FieldScope), t.C(decision.FieldScope)),
|
||||
sql.ColumnsEQ(s.C(decision.FieldType), t.C(decision.FieldType)),
|
||||
sql.ColumnsEQ(s.C(decision.FieldValue), t.C(decision.FieldValue)),
|
||||
sql.GTE(t.C(decision.FieldUntil), now),
|
||||
}
|
||||
p = append(p, predicates...)
|
||||
s.LeftJoin(t).
|
||||
OnP(
|
||||
sql.And(
|
||||
p...,
|
||||
)).
|
||||
GroupBy(s.C(decision.FieldValue)).
|
||||
Where(sql.IsNull(t.C(decision.FieldValue)))
|
||||
})
|
||||
|
||||
data, err := query.All(c.CTX)
|
||||
if err != nil {
|
||||
c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err)
|
||||
return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions with filters")
|
||||
}
|
||||
|
||||
ret := make([]*ent.Decision, 0)
|
||||
deletedDecisions := make(map[string]*ent.Decision)
|
||||
for _, d := range data {
|
||||
key := fmt.Sprintf("%s:%s:%s", d.Scope, d.Type, d.Value)
|
||||
if d.Until.Before(now) {
|
||||
deletedDecisions[key] = d
|
||||
}
|
||||
if d.Until.After(now) {
|
||||
delete(deletedDecisions, key)
|
||||
}
|
||||
}
|
||||
|
||||
for _, d := range deletedDecisions {
|
||||
ret = append(ret, d)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (c *Client) QueryNewDecisionsSinceWithFilters(since time.Time, filters map[string][]string) ([]*ent.Decision, error) {
|
||||
|
|
|
@ -35,6 +35,7 @@ type AlertQuery struct {
|
|||
withEvents *EventQuery
|
||||
withMetas *MetaQuery
|
||||
withFKs bool
|
||||
modifiers []func(s *sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
|
@ -487,6 +488,9 @@ func (aq *AlertQuery) sqlAll(ctx context.Context) ([]*Alert, error) {
|
|||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(aq.modifiers) > 0 {
|
||||
_spec.Modifiers = aq.modifiers
|
||||
}
|
||||
if err := sqlgraph.QueryNodes(ctx, aq.driver, _spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -615,6 +619,9 @@ func (aq *AlertQuery) sqlAll(ctx context.Context) ([]*Alert, error) {
|
|||
|
||||
func (aq *AlertQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := aq.querySpec()
|
||||
if len(aq.modifiers) > 0 {
|
||||
_spec.Modifiers = aq.modifiers
|
||||
}
|
||||
_spec.Node.Columns = aq.fields
|
||||
if len(aq.fields) > 0 {
|
||||
_spec.Unique = aq.unique != nil && *aq.unique
|
||||
|
@ -693,6 +700,9 @@ func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
|||
if aq.unique != nil && *aq.unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range aq.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range aq.predicates {
|
||||
p(selector)
|
||||
}
|
||||
|
@ -710,6 +720,12 @@ func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
|||
return selector
|
||||
}
|
||||
|
||||
// Modify adds a query modifier for attaching custom logic to queries.
|
||||
func (aq *AlertQuery) Modify(modifiers ...func(s *sql.Selector)) *AlertSelect {
|
||||
aq.modifiers = append(aq.modifiers, modifiers...)
|
||||
return aq.Select()
|
||||
}
|
||||
|
||||
// AlertGroupBy is the group-by builder for Alert entities.
|
||||
type AlertGroupBy struct {
|
||||
config
|
||||
|
@ -1197,3 +1213,9 @@ func (as *AlertSelect) sqlScan(ctx context.Context, v interface{}) error {
|
|||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
|
||||
// Modify adds a query modifier for attaching custom logic to queries.
|
||||
func (as *AlertSelect) Modify(modifiers ...func(s *sql.Selector)) *AlertSelect {
|
||||
as.modifiers = append(as.modifiers, modifiers...)
|
||||
return as
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ type BouncerQuery struct {
|
|||
order []OrderFunc
|
||||
fields []string
|
||||
predicates []predicate.Bouncer
|
||||
modifiers []func(s *sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
|
@ -326,6 +327,9 @@ func (bq *BouncerQuery) sqlAll(ctx context.Context) ([]*Bouncer, error) {
|
|||
node := nodes[len(nodes)-1]
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(bq.modifiers) > 0 {
|
||||
_spec.Modifiers = bq.modifiers
|
||||
}
|
||||
if err := sqlgraph.QueryNodes(ctx, bq.driver, _spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -337,6 +341,9 @@ func (bq *BouncerQuery) sqlAll(ctx context.Context) ([]*Bouncer, error) {
|
|||
|
||||
func (bq *BouncerQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := bq.querySpec()
|
||||
if len(bq.modifiers) > 0 {
|
||||
_spec.Modifiers = bq.modifiers
|
||||
}
|
||||
_spec.Node.Columns = bq.fields
|
||||
if len(bq.fields) > 0 {
|
||||
_spec.Unique = bq.unique != nil && *bq.unique
|
||||
|
@ -415,6 +422,9 @@ func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
|||
if bq.unique != nil && *bq.unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range bq.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range bq.predicates {
|
||||
p(selector)
|
||||
}
|
||||
|
@ -432,6 +442,12 @@ func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
|||
return selector
|
||||
}
|
||||
|
||||
// Modify adds a query modifier for attaching custom logic to queries.
|
||||
func (bq *BouncerQuery) Modify(modifiers ...func(s *sql.Selector)) *BouncerSelect {
|
||||
bq.modifiers = append(bq.modifiers, modifiers...)
|
||||
return bq.Select()
|
||||
}
|
||||
|
||||
// BouncerGroupBy is the group-by builder for Bouncer entities.
|
||||
type BouncerGroupBy struct {
|
||||
config
|
||||
|
@ -919,3 +935,9 @@ func (bs *BouncerSelect) sqlScan(ctx context.Context, v interface{}) error {
|
|||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
|
||||
// Modify adds a query modifier for attaching custom logic to queries.
|
||||
func (bs *BouncerSelect) Modify(modifiers ...func(s *sql.Selector)) *BouncerSelect {
|
||||
bs.modifiers = append(bs.modifiers, modifiers...)
|
||||
return bs
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ type DecisionQuery struct {
|
|||
// eager-loading edges.
|
||||
withOwner *AlertQuery
|
||||
withFKs bool
|
||||
modifiers []func(s *sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
|
@ -375,6 +376,9 @@ func (dq *DecisionQuery) sqlAll(ctx context.Context) ([]*Decision, error) {
|
|||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(dq.modifiers) > 0 {
|
||||
_spec.Modifiers = dq.modifiers
|
||||
}
|
||||
if err := sqlgraph.QueryNodes(ctx, dq.driver, _spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -416,6 +420,9 @@ func (dq *DecisionQuery) sqlAll(ctx context.Context) ([]*Decision, error) {
|
|||
|
||||
func (dq *DecisionQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := dq.querySpec()
|
||||
if len(dq.modifiers) > 0 {
|
||||
_spec.Modifiers = dq.modifiers
|
||||
}
|
||||
_spec.Node.Columns = dq.fields
|
||||
if len(dq.fields) > 0 {
|
||||
_spec.Unique = dq.unique != nil && *dq.unique
|
||||
|
@ -494,6 +501,9 @@ func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
|||
if dq.unique != nil && *dq.unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range dq.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range dq.predicates {
|
||||
p(selector)
|
||||
}
|
||||
|
@ -511,6 +521,12 @@ func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
|||
return selector
|
||||
}
|
||||
|
||||
// Modify adds a query modifier for attaching custom logic to queries.
|
||||
func (dq *DecisionQuery) Modify(modifiers ...func(s *sql.Selector)) *DecisionSelect {
|
||||
dq.modifiers = append(dq.modifiers, modifiers...)
|
||||
return dq.Select()
|
||||
}
|
||||
|
||||
// DecisionGroupBy is the group-by builder for Decision entities.
|
||||
type DecisionGroupBy struct {
|
||||
config
|
||||
|
@ -998,3 +1014,9 @@ func (ds *DecisionSelect) sqlScan(ctx context.Context, v interface{}) error {
|
|||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
|
||||
// Modify adds a query modifier for attaching custom logic to queries.
|
||||
func (ds *DecisionSelect) Modify(modifiers ...func(s *sql.Selector)) *DecisionSelect {
|
||||
ds.modifiers = append(ds.modifiers, modifiers...)
|
||||
return ds
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ type EventQuery struct {
|
|||
// eager-loading edges.
|
||||
withOwner *AlertQuery
|
||||
withFKs bool
|
||||
modifiers []func(s *sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
|
@ -375,6 +376,9 @@ func (eq *EventQuery) sqlAll(ctx context.Context) ([]*Event, error) {
|
|||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(eq.modifiers) > 0 {
|
||||
_spec.Modifiers = eq.modifiers
|
||||
}
|
||||
if err := sqlgraph.QueryNodes(ctx, eq.driver, _spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -416,6 +420,9 @@ func (eq *EventQuery) sqlAll(ctx context.Context) ([]*Event, error) {
|
|||
|
||||
func (eq *EventQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := eq.querySpec()
|
||||
if len(eq.modifiers) > 0 {
|
||||
_spec.Modifiers = eq.modifiers
|
||||
}
|
||||
_spec.Node.Columns = eq.fields
|
||||
if len(eq.fields) > 0 {
|
||||
_spec.Unique = eq.unique != nil && *eq.unique
|
||||
|
@ -494,6 +501,9 @@ func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
|||
if eq.unique != nil && *eq.unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range eq.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range eq.predicates {
|
||||
p(selector)
|
||||
}
|
||||
|
@ -511,6 +521,12 @@ func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
|||
return selector
|
||||
}
|
||||
|
||||
// Modify adds a query modifier for attaching custom logic to queries.
|
||||
func (eq *EventQuery) Modify(modifiers ...func(s *sql.Selector)) *EventSelect {
|
||||
eq.modifiers = append(eq.modifiers, modifiers...)
|
||||
return eq.Select()
|
||||
}
|
||||
|
||||
// EventGroupBy is the group-by builder for Event entities.
|
||||
type EventGroupBy struct {
|
||||
config
|
||||
|
@ -998,3 +1014,9 @@ func (es *EventSelect) sqlScan(ctx context.Context, v interface{}) error {
|
|||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
|
||||
// Modify adds a query modifier for attaching custom logic to queries.
|
||||
func (es *EventSelect) Modify(modifiers ...func(s *sql.Selector)) *EventSelect {
|
||||
es.modifiers = append(es.modifiers, modifiers...)
|
||||
return es
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ type MachineQuery struct {
|
|||
predicates []predicate.Machine
|
||||
// eager-loading edges.
|
||||
withAlerts *AlertQuery
|
||||
modifiers []func(s *sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
|
@ -368,6 +369,9 @@ func (mq *MachineQuery) sqlAll(ctx context.Context) ([]*Machine, error) {
|
|||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(mq.modifiers) > 0 {
|
||||
_spec.Modifiers = mq.modifiers
|
||||
}
|
||||
if err := sqlgraph.QueryNodes(ctx, mq.driver, _spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -409,6 +413,9 @@ func (mq *MachineQuery) sqlAll(ctx context.Context) ([]*Machine, error) {
|
|||
|
||||
func (mq *MachineQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := mq.querySpec()
|
||||
if len(mq.modifiers) > 0 {
|
||||
_spec.Modifiers = mq.modifiers
|
||||
}
|
||||
_spec.Node.Columns = mq.fields
|
||||
if len(mq.fields) > 0 {
|
||||
_spec.Unique = mq.unique != nil && *mq.unique
|
||||
|
@ -487,6 +494,9 @@ func (mq *MachineQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
|||
if mq.unique != nil && *mq.unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range mq.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range mq.predicates {
|
||||
p(selector)
|
||||
}
|
||||
|
@ -504,6 +514,12 @@ func (mq *MachineQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
|||
return selector
|
||||
}
|
||||
|
||||
// Modify adds a query modifier for attaching custom logic to queries.
|
||||
func (mq *MachineQuery) Modify(modifiers ...func(s *sql.Selector)) *MachineSelect {
|
||||
mq.modifiers = append(mq.modifiers, modifiers...)
|
||||
return mq.Select()
|
||||
}
|
||||
|
||||
// MachineGroupBy is the group-by builder for Machine entities.
|
||||
type MachineGroupBy struct {
|
||||
config
|
||||
|
@ -991,3 +1007,9 @@ func (ms *MachineSelect) sqlScan(ctx context.Context, v interface{}) error {
|
|||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
|
||||
// Modify adds a query modifier for attaching custom logic to queries.
|
||||
func (ms *MachineSelect) Modify(modifiers ...func(s *sql.Selector)) *MachineSelect {
|
||||
ms.modifiers = append(ms.modifiers, modifiers...)
|
||||
return ms
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ type MetaQuery struct {
|
|||
// eager-loading edges.
|
||||
withOwner *AlertQuery
|
||||
withFKs bool
|
||||
modifiers []func(s *sql.Selector)
|
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector
|
||||
path func(context.Context) (*sql.Selector, error)
|
||||
|
@ -375,6 +376,9 @@ func (mq *MetaQuery) sqlAll(ctx context.Context) ([]*Meta, error) {
|
|||
node.Edges.loadedTypes = loadedTypes
|
||||
return node.assignValues(columns, values)
|
||||
}
|
||||
if len(mq.modifiers) > 0 {
|
||||
_spec.Modifiers = mq.modifiers
|
||||
}
|
||||
if err := sqlgraph.QueryNodes(ctx, mq.driver, _spec); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -416,6 +420,9 @@ func (mq *MetaQuery) sqlAll(ctx context.Context) ([]*Meta, error) {
|
|||
|
||||
func (mq *MetaQuery) sqlCount(ctx context.Context) (int, error) {
|
||||
_spec := mq.querySpec()
|
||||
if len(mq.modifiers) > 0 {
|
||||
_spec.Modifiers = mq.modifiers
|
||||
}
|
||||
_spec.Node.Columns = mq.fields
|
||||
if len(mq.fields) > 0 {
|
||||
_spec.Unique = mq.unique != nil && *mq.unique
|
||||
|
@ -494,6 +501,9 @@ func (mq *MetaQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
|||
if mq.unique != nil && *mq.unique {
|
||||
selector.Distinct()
|
||||
}
|
||||
for _, m := range mq.modifiers {
|
||||
m(selector)
|
||||
}
|
||||
for _, p := range mq.predicates {
|
||||
p(selector)
|
||||
}
|
||||
|
@ -511,6 +521,12 @@ func (mq *MetaQuery) sqlQuery(ctx context.Context) *sql.Selector {
|
|||
return selector
|
||||
}
|
||||
|
||||
// Modify adds a query modifier for attaching custom logic to queries.
|
||||
func (mq *MetaQuery) Modify(modifiers ...func(s *sql.Selector)) *MetaSelect {
|
||||
mq.modifiers = append(mq.modifiers, modifiers...)
|
||||
return mq.Select()
|
||||
}
|
||||
|
||||
// MetaGroupBy is the group-by builder for Meta entities.
|
||||
type MetaGroupBy struct {
|
||||
config
|
||||
|
@ -998,3 +1014,9 @@ func (ms *MetaSelect) sqlScan(ctx context.Context, v interface{}) error {
|
|||
defer rows.Close()
|
||||
return sql.ScanSlice(rows, v)
|
||||
}
|
||||
|
||||
// Modify adds a query modifier for attaching custom logic to queries.
|
||||
func (ms *MetaSelect) Modify(modifiers ...func(s *sql.Selector)) *MetaSelect {
|
||||
ms.modifiers = append(ms.modifiers, modifiers...)
|
||||
return ms
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue