Bläddra i källkod

Merge pull request #47030 from knight42/feat/check-wrapped-err

Make errdefs.Is<SomeError> helper functions work with wrapped errors
Sebastiaan van Stijn 1 år sedan
förälder
incheckning
7082aecd54
2 ändrade filer med 72 tillägg och 0 borttagningar
  1. 66 0
      errdefs/helpers_test.go
  2. 6 0
      errdefs/is.go

+ 66 - 0
errdefs/helpers_test.go

@@ -2,6 +2,7 @@ package errdefs
 
 import (
 	"errors"
+	"fmt"
 	"testing"
 )
 
@@ -25,6 +26,11 @@ func TestNotFound(t *testing.T) {
 	if !errors.Is(e, errTest) {
 		t.Fatalf("expected not found error to match errTest")
 	}
+
+	wrapped := fmt.Errorf("foo: %w", e)
+	if !IsNotFound(wrapped) {
+		t.Fatalf("expected not found error, got: %T", wrapped)
+	}
 }
 
 func TestConflict(t *testing.T) {
@@ -41,6 +47,11 @@ func TestConflict(t *testing.T) {
 	if !errors.Is(e, errTest) {
 		t.Fatalf("expected conflict error to match errTest")
 	}
+
+	wrapped := fmt.Errorf("foo: %w", e)
+	if !IsConflict(wrapped) {
+		t.Fatalf("expected conflict error, got: %T", wrapped)
+	}
 }
 
 func TestForbidden(t *testing.T) {
@@ -57,6 +68,11 @@ func TestForbidden(t *testing.T) {
 	if !errors.Is(e, errTest) {
 		t.Fatalf("expected forbidden error to match errTest")
 	}
+
+	wrapped := fmt.Errorf("foo: %w", e)
+	if !IsForbidden(wrapped) {
+		t.Fatalf("expected forbidden error, got: %T", wrapped)
+	}
 }
 
 func TestInvalidParameter(t *testing.T) {
@@ -73,6 +89,11 @@ func TestInvalidParameter(t *testing.T) {
 	if !errors.Is(e, errTest) {
 		t.Fatalf("expected invalid argument error to match errTest")
 	}
+
+	wrapped := fmt.Errorf("foo: %w", e)
+	if !IsInvalidParameter(wrapped) {
+		t.Fatalf("expected invalid argument error, got: %T", wrapped)
+	}
 }
 
 func TestNotImplemented(t *testing.T) {
@@ -89,6 +110,11 @@ func TestNotImplemented(t *testing.T) {
 	if !errors.Is(e, errTest) {
 		t.Fatalf("expected not implemented error to match errTest")
 	}
+
+	wrapped := fmt.Errorf("foo: %w", e)
+	if !IsNotImplemented(wrapped) {
+		t.Fatalf("expected not implemented error, got: %T", wrapped)
+	}
 }
 
 func TestNotModified(t *testing.T) {
@@ -105,6 +131,11 @@ func TestNotModified(t *testing.T) {
 	if !errors.Is(e, errTest) {
 		t.Fatalf("expected not modified error to match errTest")
 	}
+
+	wrapped := fmt.Errorf("foo: %w", e)
+	if !IsNotModified(wrapped) {
+		t.Fatalf("expected not modified error, got: %T", wrapped)
+	}
 }
 
 func TestUnauthorized(t *testing.T) {
@@ -121,6 +152,11 @@ func TestUnauthorized(t *testing.T) {
 	if !errors.Is(e, errTest) {
 		t.Fatalf("expected unauthorized error to match errTest")
 	}
+
+	wrapped := fmt.Errorf("foo: %w", e)
+	if !IsUnauthorized(wrapped) {
+		t.Fatalf("expected unauthorized error, got: %T", wrapped)
+	}
 }
 
 func TestUnknown(t *testing.T) {
@@ -137,6 +173,11 @@ func TestUnknown(t *testing.T) {
 	if !errors.Is(e, errTest) {
 		t.Fatalf("expected unknown error to match errTest")
 	}
+
+	wrapped := fmt.Errorf("foo: %w", e)
+	if !IsUnknown(wrapped) {
+		t.Fatalf("expected unknown error, got: %T", wrapped)
+	}
 }
 
 func TestCancelled(t *testing.T) {
@@ -153,6 +194,11 @@ func TestCancelled(t *testing.T) {
 	if !errors.Is(e, errTest) {
 		t.Fatalf("expected cancelled error to match errTest")
 	}
+
+	wrapped := fmt.Errorf("foo: %w", e)
+	if !IsCancelled(wrapped) {
+		t.Fatalf("expected cancelled error, got: %T", wrapped)
+	}
 }
 
 func TestDeadline(t *testing.T) {
@@ -169,6 +215,11 @@ func TestDeadline(t *testing.T) {
 	if !errors.Is(e, errTest) {
 		t.Fatalf("expected deadline error to match errTest")
 	}
+
+	wrapped := fmt.Errorf("foo: %w", e)
+	if !IsDeadline(wrapped) {
+		t.Fatalf("expected deadline error, got: %T", wrapped)
+	}
 }
 
 func TestDataLoss(t *testing.T) {
@@ -185,6 +236,11 @@ func TestDataLoss(t *testing.T) {
 	if !errors.Is(e, errTest) {
 		t.Fatalf("expected data loss error to match errTest")
 	}
+
+	wrapped := fmt.Errorf("foo: %w", e)
+	if !IsDataLoss(wrapped) {
+		t.Fatalf("expected data loss error, got: %T", wrapped)
+	}
 }
 
 func TestUnavailable(t *testing.T) {
@@ -201,6 +257,11 @@ func TestUnavailable(t *testing.T) {
 	if !errors.Is(e, errTest) {
 		t.Fatalf("expected unavaillable error to match errTest")
 	}
+
+	wrapped := fmt.Errorf("foo: %w", e)
+	if !IsUnavailable(wrapped) {
+		t.Fatalf("expected unavaillable error, got: %T", wrapped)
+	}
 }
 
 func TestSystem(t *testing.T) {
@@ -217,4 +278,9 @@ func TestSystem(t *testing.T) {
 	if !errors.Is(e, errTest) {
 		t.Fatalf("expected system error to match errTest")
 	}
+
+	wrapped := fmt.Errorf("foo: %w", e)
+	if !IsSystem(wrapped) {
+		t.Fatalf("expected system error, got: %T", wrapped)
+	}
 }

+ 6 - 0
errdefs/is.go

@@ -9,6 +9,10 @@ type causer interface {
 	Cause() error
 }
 
+type wrapErr interface {
+	Unwrap() error
+}
+
 func getImplementer(err error) error {
 	switch e := err.(type) {
 	case
@@ -28,6 +32,8 @@ func getImplementer(err error) error {
 		return err
 	case causer:
 		return getImplementer(e.Cause())
+	case wrapErr:
+		return getImplementer(e.Unwrap())
 	default:
 		return err
 	}