api/types/filters: fix errors not being matched by errors.Is()

I found that the errors returned weren't matched with `errors.Is()` when
wrapped.

Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
This commit is contained in:
Sebastiaan van Stijn 2023-07-28 16:27:36 +02:00
parent b2bde4a7d8
commit 490fee7d45
No known key found for this signature in database
GPG key ID: 76698F39D527CE8C
2 changed files with 36 additions and 24 deletions

View file

@ -98,7 +98,7 @@ func FromJSON(p string) (Args, error) {
// Fallback to parsing arguments in the legacy slice format // Fallback to parsing arguments in the legacy slice format
deprecated := map[string][]string{} deprecated := map[string][]string{}
if legacyErr := json.Unmarshal(raw, &deprecated); legacyErr != nil { if legacyErr := json.Unmarshal(raw, &deprecated); legacyErr != nil {
return args, invalidFilter{} return args, &invalidFilter{}
} }
args.fields = deprecatedArgs(deprecated) args.fields = deprecatedArgs(deprecated)
@ -206,7 +206,7 @@ func (args Args) GetBoolOrDefault(key string, defaultValue bool) (bool, error) {
} }
if len(fieldValues) == 0 { if len(fieldValues) == 0 {
return defaultValue, invalidFilter{key, nil} return defaultValue, &invalidFilter{key, nil}
} }
isFalse := fieldValues["0"] || fieldValues["false"] isFalse := fieldValues["0"] || fieldValues["false"]
@ -216,7 +216,7 @@ func (args Args) GetBoolOrDefault(key string, defaultValue bool) (bool, error) {
invalid := !isFalse && !isTrue invalid := !isFalse && !isTrue
if conflicting || invalid { if conflicting || invalid {
return defaultValue, invalidFilter{key, args.Get(key)} return defaultValue, &invalidFilter{key, args.Get(key)}
} else if isFalse { } else if isFalse {
return false, nil return false, nil
} else if isTrue { } else if isTrue {
@ -224,7 +224,7 @@ func (args Args) GetBoolOrDefault(key string, defaultValue bool) (bool, error) {
} }
// This code shouldn't be reached. // This code shouldn't be reached.
return defaultValue, unreachableCode{Filter: key, Value: args.Get(key)} return defaultValue, &unreachableCode{Filter: key, Value: args.Get(key)}
} }
// ExactMatch returns true if the source matches exactly one of the values. // ExactMatch returns true if the source matches exactly one of the values.
@ -282,7 +282,7 @@ func (args Args) Contains(field string) bool {
func (args Args) Validate(accepted map[string]bool) error { func (args Args) Validate(accepted map[string]bool) error {
for name := range args.fields { for name := range args.fields {
if !accepted[name] { if !accepted[name] {
return invalidFilter{name, nil} return &invalidFilter{name, nil}
} }
} }
return nil return nil

View file

@ -3,6 +3,7 @@ package filters // import "github.com/docker/docker/api/types/filters"
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"sort" "sort"
"testing" "testing"
@ -95,15 +96,19 @@ func TestFromJSON(t *testing.T) {
if err == nil { if err == nil {
t.Fatalf("Expected an error with %v, got nothing", invalid) t.Fatalf("Expected an error with %v, got nothing", invalid)
} }
var invalidFilterError invalidFilter var invalidFilterError *invalidFilter
if !errors.As(err, &invalidFilterError) { if !errors.As(err, &invalidFilterError) {
t.Fatalf("Expected an invalidFilter error, got %T", err) t.Fatalf("Expected an invalidFilter error, got %T", err)
} }
wrappedErr := fmt.Errorf("something went wrong: %w", err)
if !errors.Is(wrappedErr, err) {
t.Errorf("Expected a wrapped error to be detected as invalidFilter")
}
} }
for expectedArgs, matchers := range valid { for expectedArgs, matchers := range valid {
for _, json := range matchers { for _, jsonString := range matchers {
args, err := FromJSON(json) args, err := FromJSON(jsonString)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -388,9 +393,13 @@ func TestValidate(t *testing.T) {
if err == nil { if err == nil {
t.Fatal("Expected to return an error, got nil") t.Fatal("Expected to return an error, got nil")
} }
var invalidFilterError invalidFilter var invalidFilterError *invalidFilter
if !errors.As(err, &invalidFilterError) { if !errors.As(err, &invalidFilterError) {
t.Fatalf("Expected an invalidFilter error, got %T", err) t.Errorf("Expected an invalidFilter error, got %T", err)
}
wrappedErr := fmt.Errorf("something went wrong: %w", err)
if !errors.Is(wrappedErr, err) {
t.Errorf("Expected a wrapped error to be detected as invalidFilter")
} }
} }
@ -451,7 +460,7 @@ func TestClone(t *testing.T) {
} }
func TestGetBoolOrDefault(t *testing.T) { func TestGetBoolOrDefault(t *testing.T) {
for _, tC := range []struct { for _, tc := range []struct {
name string name string
args map[string][]string args map[string][]string
defValue bool defValue bool
@ -482,7 +491,7 @@ func TestGetBoolOrDefault(t *testing.T) {
"dangling": {"potato"}, "dangling": {"potato"},
}, },
defValue: true, defValue: true,
expectedErr: invalidFilter{Filter: "dangling", Value: []string{"potato"}}, expectedErr: &invalidFilter{Filter: "dangling", Value: []string{"potato"}},
expectedValue: true, expectedValue: true,
}, },
{ {
@ -491,7 +500,7 @@ func TestGetBoolOrDefault(t *testing.T) {
"dangling": {"banana", "potato"}, "dangling": {"banana", "potato"},
}, },
defValue: true, defValue: true,
expectedErr: invalidFilter{Filter: "dangling", Value: []string{"banana", "potato"}}, expectedErr: &invalidFilter{Filter: "dangling", Value: []string{"banana", "potato"}},
expectedValue: true, expectedValue: true,
}, },
{ {
@ -500,7 +509,7 @@ func TestGetBoolOrDefault(t *testing.T) {
"dangling": {"false", "true"}, "dangling": {"false", "true"},
}, },
defValue: false, defValue: false,
expectedErr: invalidFilter{Filter: "dangling", Value: []string{"false", "true"}}, expectedErr: &invalidFilter{Filter: "dangling", Value: []string{"false", "true"}},
expectedValue: false, expectedValue: false,
}, },
{ {
@ -509,7 +518,7 @@ func TestGetBoolOrDefault(t *testing.T) {
"dangling": {"false", "true", "1"}, "dangling": {"false", "true", "1"},
}, },
defValue: true, defValue: true,
expectedErr: invalidFilter{Filter: "dangling", Value: []string{"false", "true", "1"}}, expectedErr: &invalidFilter{Filter: "dangling", Value: []string{"false", "true", "1"}},
expectedValue: true, expectedValue: true,
}, },
{ {
@ -531,35 +540,38 @@ func TestGetBoolOrDefault(t *testing.T) {
expectedValue: false, expectedValue: false,
}, },
} { } {
tC := tC tc := tc
t.Run(tC.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
a := NewArgs() a := NewArgs()
for key, values := range tC.args { for key, values := range tc.args {
for _, value := range values { for _, value := range values {
a.Add(key, value) a.Add(key, value)
} }
} }
value, err := a.GetBoolOrDefault("dangling", tC.defValue) value, err := a.GetBoolOrDefault("dangling", tc.defValue)
if tC.expectedErr == nil { if tc.expectedErr == nil {
assert.Check(t, is.Nil(err)) assert.Check(t, is.Nil(err))
} else { } else {
assert.Check(t, is.ErrorType(err, tC.expectedErr)) assert.Check(t, is.ErrorType(err, tc.expectedErr))
// Check if error is the same. // Check if error is the same.
expected := tC.expectedErr.(invalidFilter) expected := tc.expectedErr.(*invalidFilter)
actual := err.(invalidFilter) actual := err.(*invalidFilter)
assert.Check(t, is.Equal(expected.Filter, actual.Filter)) assert.Check(t, is.Equal(expected.Filter, actual.Filter))
sort.Strings(expected.Value) sort.Strings(expected.Value)
sort.Strings(actual.Value) sort.Strings(actual.Value)
assert.Check(t, is.DeepEqual(expected.Value, actual.Value)) assert.Check(t, is.DeepEqual(expected.Value, actual.Value))
wrappedErr := fmt.Errorf("something went wrong: %w", err)
assert.Check(t, errors.Is(wrappedErr, err), "Expected a wrapped error to be detected as invalidFilter")
} }
assert.Check(t, is.Equal(tC.expectedValue, value)) assert.Check(t, is.Equal(tc.expectedValue, value))
}) })
} }
} }