diff --git a/libnetwork/drivers/bridge/bridge_test.go b/libnetwork/drivers/bridge/bridge_test.go index ab6bdf5846..e3600f9d5f 100644 --- a/libnetwork/drivers/bridge/bridge_test.go +++ b/libnetwork/drivers/bridge/bridge_test.go @@ -1075,7 +1075,8 @@ func TestCreateWithExistingBridge(t *testing.T) { } func TestCreateParallel(t *testing.T) { - defer testutils.SetupTestOSContext(t)() + c := testutils.SetupTestOSContextEx(t) + defer c.Cleanup(t) d := newDriver() @@ -1085,7 +1086,8 @@ func TestCreateParallel(t *testing.T) { ch := make(chan error, 100) for i := 0; i < 100; i++ { - go func(name string, ch chan<- error) { + name := "net" + strconv.Itoa(i) + c.Go(t, func() { config := &networkConfiguration{BridgeName: name} genericOption := make(map[string]interface{}) genericOption[netlabel.GenericData] = config @@ -1098,7 +1100,7 @@ func TestCreateParallel(t *testing.T) { return } ch <- nil - }("net"+strconv.Itoa(i), ch) + }) } // wait for the go routines var success int diff --git a/libnetwork/resolver_test.go b/libnetwork/resolver_test.go index b2ba88914c..f83551def0 100644 --- a/libnetwork/resolver_test.go +++ b/libnetwork/resolver_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/docker/docker/libnetwork/testutils" "github.com/miekg/dns" "github.com/sirupsen/logrus" "gotest.tools/v3/skip" @@ -214,6 +215,9 @@ func waitForLocalDNSServer(t *testing.T) { func TestDNSProxyServFail(t *testing.T) { skip.If(t, runtime.GOOS == "windows", "test only works on linux") + osctx := testutils.SetupTestOSContextEx(t) + defer osctx.Cleanup(t) + c, err := New() if err != nil { t.Fatal(err) @@ -247,9 +251,9 @@ func TestDNSProxyServFail(t *testing.T) { // use TCP for predictable results. Connection tests (to figure out DNS server initialization) don't work with UDP server := &dns.Server{Addr: "127.0.0.1:53", Net: "tcp"} srvErrCh := make(chan error, 1) - go func() { + osctx.Go(t, func() { srvErrCh <- server.ListenAndServe() - }() + }) defer func() { server.Shutdown() //nolint:errcheck if err := <-srvErrCh; err != nil { diff --git a/libnetwork/testutils/context.go b/libnetwork/testutils/context.go new file mode 100644 index 0000000000..16f25bbd5f --- /dev/null +++ b/libnetwork/testutils/context.go @@ -0,0 +1,41 @@ +package testutils + +import "testing" + +// Logger is used to log non-fatal messages during tests. +type Logger interface { + Logf(format string, args ...any) +} + +var _ Logger = (*testing.T)(nil) + +// SetupTestOSContext joins the current goroutine to a new network namespace, +// and returns its associated teardown function. +// +// Example usage: +// +// defer SetupTestOSContext(t)() +func SetupTestOSContext(t *testing.T) func() { + c := SetupTestOSContextEx(t) + return func() { c.Cleanup(t) } +} + +// Go starts running fn in a new goroutine inside the test OS context. +func (c *OSContext) Go(t *testing.T, fn func()) { + t.Helper() + errCh := make(chan error, 1) + go func() { + teardown, err := c.Set() + if err != nil { + errCh <- err + return + } + defer teardown(t) + close(errCh) + fn() + }() + + if err := <-errCh; err != nil { + t.Fatalf("%+v", err) + } +} diff --git a/libnetwork/testutils/context_unix.go b/libnetwork/testutils/context_unix.go index 1317e17af9..c60cd915c3 100644 --- a/libnetwork/testutils/context_unix.go +++ b/libnetwork/testutils/context_unix.go @@ -4,42 +4,52 @@ package testutils import ( + "fmt" "runtime" + "strconv" "testing" "github.com/docker/docker/libnetwork/ns" + "github.com/pkg/errors" "github.com/vishvananda/netns" + "golang.org/x/sys/unix" ) -// SetupTestOSContext joins a new network namespace, and returns its associated -// teardown function. +// OSContext is a handle to a test OS context. +type OSContext struct { + origNS, newNS netns.NsHandle + + tid int + caller string // The file:line where SetupTestOSContextEx was called, for interpolating into error messages. +} + +// SetupTestOSContextEx joins the current goroutine to a new network namespace. +// +// Compared to [SetupTestOSContext], this function allows goroutines to be +// spawned which are associated with the same OS context via the returned +// OSContext value. // // Example usage: // -// defer SetupTestOSContext(t)() -func SetupTestOSContext(t *testing.T) func() { +// c := SetupTestOSContext(t) +// defer c.Cleanup(t) +func SetupTestOSContextEx(t *testing.T) *OSContext { + runtime.LockOSThread() origNS, err := netns.Get() if err != nil { + runtime.UnlockOSThread() t.Fatalf("Failed to open initial netns: %v", err) } - restore := func() { - if err := netns.Set(origNS); err != nil { - t.Logf("Warning: failed to restore thread netns (%v)", err) - } else { - runtime.UnlockOSThread() - } - if err := origNS.Close(); err != nil { - t.Logf("Warning: netns closing failed (%v)", err) - } + c := OSContext{ + tid: unix.Gettid(), + origNS: origNS, } - - runtime.LockOSThread() - newNS, err := netns.New() + c.newNS, err = netns.New() if err != nil { // netns.New() is not atomic: it could have encountered an error // after unsharing the current thread's network namespace. - restore() + c.restore(t) t.Fatalf("Failed to enter netns: %v", err) } @@ -50,19 +60,110 @@ func SetupTestOSContext(t *testing.T) func() { nl := ns.NlHandle() lo, err := nl.LinkByName("lo") if err != nil { - restore() + c.restore(t) t.Fatalf("Failed to get handle to loopback interface 'lo' in new netns: %v", err) } if err := nl.LinkSetUp(lo); err != nil { - restore() + c.restore(t) t.Fatalf("Failed to enable loopback interface in new netns: %v", err) } - return func() { - if err := newNS.Close(); err != nil { - t.Logf("Warning: netns closing failed (%v)", err) - } - restore() - ns.Init() + _, file, line, ok := runtime.Caller(0) + if ok { + c.caller = file + ":" + strconv.Itoa(line) + } + + return &c +} + +// Cleanup tears down the OS context. It must be called from the same goroutine +// as the [SetupTestOSContextEx] call which returned c. +// +// Explicit cleanup is required as (*testing.T).Cleanup() makes no guarantees +// about which goroutine the cleanup functions are invoked on. +func (c *OSContext) Cleanup(t *testing.T) { + t.Helper() + if unix.Gettid() != c.tid { + t.Fatalf("c.Cleanup() must be called from the same goroutine as SetupTestOSContextEx() (%s)", c.caller) + } + if err := c.newNS.Close(); err != nil { + t.Logf("Warning: netns closing failed (%v)", err) + } + c.restore(t) + ns.Init() +} + +func (c *OSContext) restore(t *testing.T) { + t.Helper() + if err := netns.Set(c.origNS); err != nil { + t.Logf("Warning: failed to restore thread netns (%v)", err) + } else { + runtime.UnlockOSThread() + } + + if err := c.origNS.Close(); err != nil { + t.Logf("Warning: netns closing failed (%v)", err) } } + +// Set sets the OS context of the calling goroutine to c and returns a teardown +// function to restore the calling goroutine's OS context and release resources. +// The teardown function accepts an optional Logger argument. +// +// This is a lower-level interface which is less ergonomic than c.Go() but more +// composable with other goroutine-spawning utilities such as [sync.WaitGroup] +// or [golang.org/x/sync/errgroup.Group]. +// +// Example usage: +// +// func TestFoo(t *testing.T) { +// osctx := testutils.SetupTestOSContextEx(t) +// defer osctx.Cleanup(t) +// var eg errgroup.Group +// eg.Go(func() error { +// teardown, err := osctx.Set() +// if err != nil { +// return err +// } +// defer teardown(t) +// // ... +// }) +// if err := eg.Wait(); err != nil { +// t.Fatalf("%+v", err) +// } +// } +func (c *OSContext) Set() (func(Logger), error) { + runtime.LockOSThread() + orig, err := netns.Get() + if err != nil { + runtime.UnlockOSThread() + return nil, errors.Wrap(err, "failed to open initial netns for goroutine") + } + if err := errors.WithStack(netns.Set(c.newNS)); err != nil { + runtime.UnlockOSThread() + return nil, errors.Wrap(err, "failed to set goroutine network namespace") + } + + tid := unix.Gettid() + _, file, line, callerOK := runtime.Caller(0) + + return func(log Logger) { + if unix.Gettid() != tid { + msg := "teardown function must be called from the same goroutine as c.Set()" + if callerOK { + msg += fmt.Sprintf(" (%s:%d)", file, line) + } + panic(msg) + } + + if err := netns.Set(orig); err != nil && log != nil { + log.Logf("Warning: failed to restore goroutine thread netns (%v)", err) + } else { + runtime.UnlockOSThread() + } + + if err := orig.Close(); err != nil && log != nil { + log.Logf("Warning: netns closing failed (%v)", err) + } + }, nil +} diff --git a/libnetwork/testutils/context_windows.go b/libnetwork/testutils/context_windows.go index 4fa3372962..3ddb01560c 100644 --- a/libnetwork/testutils/context_windows.go +++ b/libnetwork/testutils/context_windows.go @@ -2,12 +2,14 @@ package testutils import "testing" -// SetupTestOSContext joins a new network namespace, and returns its associated -// teardown function. -// -// Example usage: -// -// defer SetupTestOSContext(t)() -func SetupTestOSContext(t *testing.T) func() { - return func() {} +type OSContext struct{} + +func SetupTestOSContextEx(*testing.T) *OSContext { + return nil +} + +func (*OSContext) Cleanup(t *testing.T) {} + +func (*OSContext) Set() (func(Logger), error) { + return func(Logger) {}, nil }