浏览代码

api/types/filters: fix errors not being matched by errors.Is()

I found that the errors returned weren't matched with `errors.Is()` when
wrapped.

Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
Sebastiaan van Stijn 1 年之前
父节点
当前提交
490fee7d45
共有 2 个文件被更改,包括 36 次插入24 次删除
  1. 5 5
      api/types/filters/parse.go
  2. 31 19
      api/types/filters/parse_test.go

+ 5 - 5
api/types/filters/parse.go

@@ -98,7 +98,7 @@ func FromJSON(p string) (Args, error) {
 	// Fallback to parsing arguments in the legacy slice format
 	deprecated := map[string][]string{}
 	if legacyErr := json.Unmarshal(raw, &deprecated); legacyErr != nil {
-		return args, invalidFilter{}
+		return args, &invalidFilter{}
 	}
 
 	args.fields = deprecatedArgs(deprecated)
@@ -206,7 +206,7 @@ func (args Args) GetBoolOrDefault(key string, defaultValue bool) (bool, error) {
 	}
 
 	if len(fieldValues) == 0 {
-		return defaultValue, invalidFilter{key, nil}
+		return defaultValue, &invalidFilter{key, nil}
 	}
 
 	isFalse := fieldValues["0"] || fieldValues["false"]
@@ -216,7 +216,7 @@ func (args Args) GetBoolOrDefault(key string, defaultValue bool) (bool, error) {
 	invalid := !isFalse && !isTrue
 
 	if conflicting || invalid {
-		return defaultValue, invalidFilter{key, args.Get(key)}
+		return defaultValue, &invalidFilter{key, args.Get(key)}
 	} else if isFalse {
 		return false, nil
 	} else if isTrue {
@@ -224,7 +224,7 @@ func (args Args) GetBoolOrDefault(key string, defaultValue bool) (bool, error) {
 	}
 
 	// This code shouldn't be reached.
-	return defaultValue, unreachableCode{Filter: key, Value: args.Get(key)}
+	return defaultValue, &unreachableCode{Filter: key, Value: args.Get(key)}
 }
 
 // ExactMatch returns true if the source matches exactly one of the values.
@@ -282,7 +282,7 @@ func (args Args) Contains(field string) bool {
 func (args Args) Validate(accepted map[string]bool) error {
 	for name := range args.fields {
 		if !accepted[name] {
-			return invalidFilter{name, nil}
+			return &invalidFilter{name, nil}
 		}
 	}
 	return nil

+ 31 - 19
api/types/filters/parse_test.go

@@ -3,6 +3,7 @@ package filters // import "github.com/docker/docker/api/types/filters"
 import (
 	"encoding/json"
 	"errors"
+	"fmt"
 	"sort"
 	"testing"
 
@@ -95,15 +96,19 @@ func TestFromJSON(t *testing.T) {
 		if err == nil {
 			t.Fatalf("Expected an error with %v, got nothing", invalid)
 		}
-		var invalidFilterError invalidFilter
+		var invalidFilterError *invalidFilter
 		if !errors.As(err, &invalidFilterError) {
 			t.Fatalf("Expected an invalidFilter error, got %T", err)
 		}
+		wrappedErr := fmt.Errorf("something went wrong: %w", err)
+		if !errors.Is(wrappedErr, err) {
+			t.Errorf("Expected a wrapped error to be detected as invalidFilter")
+		}
 	}
 
 	for expectedArgs, matchers := range valid {
-		for _, json := range matchers {
-			args, err := FromJSON(json)
+		for _, jsonString := range matchers {
+			args, err := FromJSON(jsonString)
 			if err != nil {
 				t.Fatal(err)
 			}
@@ -388,9 +393,13 @@ func TestValidate(t *testing.T) {
 	if err == nil {
 		t.Fatal("Expected to return an error, got nil")
 	}
-	var invalidFilterError invalidFilter
+	var invalidFilterError *invalidFilter
 	if !errors.As(err, &invalidFilterError) {
-		t.Fatalf("Expected an invalidFilter error, got %T", err)
+		t.Errorf("Expected an invalidFilter error, got %T", err)
+	}
+	wrappedErr := fmt.Errorf("something went wrong: %w", err)
+	if !errors.Is(wrappedErr, err) {
+		t.Errorf("Expected a wrapped error to be detected as invalidFilter")
 	}
 }
 
@@ -451,7 +460,7 @@ func TestClone(t *testing.T) {
 }
 
 func TestGetBoolOrDefault(t *testing.T) {
-	for _, tC := range []struct {
+	for _, tc := range []struct {
 		name          string
 		args          map[string][]string
 		defValue      bool
@@ -482,7 +491,7 @@ func TestGetBoolOrDefault(t *testing.T) {
 				"dangling": {"potato"},
 			},
 			defValue:      true,
-			expectedErr:   invalidFilter{Filter: "dangling", Value: []string{"potato"}},
+			expectedErr:   &invalidFilter{Filter: "dangling", Value: []string{"potato"}},
 			expectedValue: true,
 		},
 		{
@@ -491,7 +500,7 @@ func TestGetBoolOrDefault(t *testing.T) {
 				"dangling": {"banana", "potato"},
 			},
 			defValue:      true,
-			expectedErr:   invalidFilter{Filter: "dangling", Value: []string{"banana", "potato"}},
+			expectedErr:   &invalidFilter{Filter: "dangling", Value: []string{"banana", "potato"}},
 			expectedValue: true,
 		},
 		{
@@ -500,7 +509,7 @@ func TestGetBoolOrDefault(t *testing.T) {
 				"dangling": {"false", "true"},
 			},
 			defValue:      false,
-			expectedErr:   invalidFilter{Filter: "dangling", Value: []string{"false", "true"}},
+			expectedErr:   &invalidFilter{Filter: "dangling", Value: []string{"false", "true"}},
 			expectedValue: false,
 		},
 		{
@@ -509,7 +518,7 @@ func TestGetBoolOrDefault(t *testing.T) {
 				"dangling": {"false", "true", "1"},
 			},
 			defValue:      true,
-			expectedErr:   invalidFilter{Filter: "dangling", Value: []string{"false", "true", "1"}},
+			expectedErr:   &invalidFilter{Filter: "dangling", Value: []string{"false", "true", "1"}},
 			expectedValue: true,
 		},
 		{
@@ -531,35 +540,38 @@ func TestGetBoolOrDefault(t *testing.T) {
 			expectedValue: false,
 		},
 	} {
-		tC := tC
-		t.Run(tC.name, func(t *testing.T) {
+		tc := tc
+		t.Run(tc.name, func(t *testing.T) {
 			a := NewArgs()
 
-			for key, values := range tC.args {
+			for key, values := range tc.args {
 				for _, value := range values {
 					a.Add(key, value)
 				}
 			}
 
-			value, err := a.GetBoolOrDefault("dangling", tC.defValue)
+			value, err := a.GetBoolOrDefault("dangling", tc.defValue)
 
-			if tC.expectedErr == nil {
+			if tc.expectedErr == nil {
 				assert.Check(t, is.Nil(err))
 			} else {
-				assert.Check(t, is.ErrorType(err, tC.expectedErr))
+				assert.Check(t, is.ErrorType(err, tc.expectedErr))
 
 				// Check if error is the same.
-				expected := tC.expectedErr.(invalidFilter)
-				actual := err.(invalidFilter)
+				expected := tc.expectedErr.(*invalidFilter)
+				actual := err.(*invalidFilter)
 
 				assert.Check(t, is.Equal(expected.Filter, actual.Filter))
 
 				sort.Strings(expected.Value)
 				sort.Strings(actual.Value)
 				assert.Check(t, is.DeepEqual(expected.Value, actual.Value))
+
+				wrappedErr := fmt.Errorf("something went wrong: %w", err)
+				assert.Check(t, errors.Is(wrappedErr, err), "Expected a wrapped error to be detected as invalidFilter")
 			}
 
-			assert.Check(t, is.Equal(tC.expectedValue, value))
+			assert.Check(t, is.Equal(tc.expectedValue, value))
 		})
 	}
 }