Quellcode durchsuchen

libnetwork/testutils: restore netns on teardown

testutils.SetupTestOSContext() sets the calling thread's network
namespace but neglected to restore it on teardown. This was not a
problem in practice as it called runtime.LockOSThread() twice but
runtime.UnlockOSThread() only once, so the tampered threads would be
terminated by the runtime when the test case returned and replaced with
a clean thread. Correct the utility so it restores the thread's network
namespace during teardown and unlocks the goroutine from the thread on
success.

Remove unnecessary runtime.LockOSThread() calls peppering test cases
which leverage testutils.SetupTestOSContext().

Signed-off-by: Cory Snider <csnider@mirantis.com>
Cory Snider vor 2 Jahren
Ursprung
Commit
afa41b16ea
2 geänderte Dateien mit 24 neuen und 17 gelöschten Zeilen
  1. 0 7
      libnetwork/osl/sandbox_linux_test.go
  2. 24 10
      libnetwork/testutils/context_unix.go

+ 0 - 7
libnetwork/osl/sandbox_linux_test.go

@@ -7,7 +7,6 @@ import (
 	"net"
 	"net"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
-	"runtime"
 	"strings"
 	"strings"
 	"syscall"
 	"syscall"
 	"testing"
 	"testing"
@@ -197,7 +196,6 @@ func TestDisableIPv6DAD(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatalf("Failed to create a new sandbox: %v", err)
 		t.Fatalf("Failed to create a new sandbox: %v", err)
 	}
 	}
-	runtime.LockOSThread()
 	defer destroyTest(t, s)
 	defer destroyTest(t, s)
 
 
 	n, ok := s.(*networkNamespace)
 	n, ok := s.(*networkNamespace)
@@ -257,7 +255,6 @@ func TestSetInterfaceIP(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatalf("Failed to create a new sandbox: %v", err)
 		t.Fatalf("Failed to create a new sandbox: %v", err)
 	}
 	}
-	runtime.LockOSThread()
 	defer destroyTest(t, s)
 	defer destroyTest(t, s)
 
 
 	n, ok := s.(*networkNamespace)
 	n, ok := s.(*networkNamespace)
@@ -332,7 +329,6 @@ func TestLiveRestore(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatalf("Failed to create a new sandbox: %v", err)
 		t.Fatalf("Failed to create a new sandbox: %v", err)
 	}
 	}
-	runtime.LockOSThread()
 	defer destroyTest(t, s)
 	defer destroyTest(t, s)
 
 
 	n, ok := s.(*networkNamespace)
 	n, ok := s.(*networkNamespace)
@@ -482,7 +478,6 @@ func TestSandboxCreateTwice(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatalf("Failed to create a new sandbox: %v", err)
 		t.Fatalf("Failed to create a new sandbox: %v", err)
 	}
 	}
-	runtime.LockOSThread()
 
 
 	// Create another sandbox with the same key to see if we handle it
 	// Create another sandbox with the same key to see if we handle it
 	// gracefully.
 	// gracefully.
@@ -490,7 +485,6 @@ func TestSandboxCreateTwice(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatalf("Failed to create a new sandbox: %v", err)
 		t.Fatalf("Failed to create a new sandbox: %v", err)
 	}
 	}
-	runtime.LockOSThread()
 
 
 	err = s.Destroy()
 	err = s.Destroy()
 	if err != nil {
 	if err != nil {
@@ -532,7 +526,6 @@ func TestAddRemoveInterface(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatalf("Failed to create a new sandbox: %v", err)
 		t.Fatalf("Failed to create a new sandbox: %v", err)
 	}
 	}
-	runtime.LockOSThread()
 
 
 	if s.Key() != key {
 	if s.Key() != key {
 		t.Fatalf("s.Key() returned %s. Expected %s", s.Key(), key)
 		t.Fatalf("s.Key() returned %s. Expected %s", s.Key(), key)

+ 24 - 10
libnetwork/testutils/context_unix.go

@@ -5,10 +5,10 @@ package testutils
 
 
 import (
 import (
 	"runtime"
 	"runtime"
-	"syscall"
 	"testing"
 	"testing"
 
 
 	"github.com/docker/docker/libnetwork/ns"
 	"github.com/docker/docker/libnetwork/ns"
+	"github.com/vishvananda/netns"
 )
 )
 
 
 // SetupTestOSContext joins a new network namespace, and returns its associated
 // SetupTestOSContext joins a new network namespace, and returns its associated
@@ -18,26 +18,40 @@ import (
 //
 //
 //	defer SetupTestOSContext(t)()
 //	defer SetupTestOSContext(t)()
 func SetupTestOSContext(t *testing.T) func() {
 func SetupTestOSContext(t *testing.T) func() {
-	runtime.LockOSThread()
-	if err := syscall.Unshare(syscall.CLONE_NEWNET); err != nil {
-		t.Fatalf("Failed to enter netns: %v", err)
+	origNS, err := netns.Get()
+	if err != nil {
+		t.Fatalf("Failed to open initial netns: %v", err)
+	}
+	restore := func() {
+		if err := netns.Set(origNS); err != nil {
+			t.Logf("Warning: failed to restore thread netns (%v)", err)
+		} else {
+			runtime.UnlockOSThread()
+		}
+
+		if err := origNS.Close(); err != nil {
+			t.Logf("Warning: netns closing failed (%v)", err)
+		}
 	}
 	}
 
 
-	fd, err := syscall.Open("/proc/self/ns/net", syscall.O_RDONLY, 0)
+	runtime.LockOSThread()
+	newNS, err := netns.New()
 	if err != nil {
 	if err != nil {
-		t.Fatal("Failed to open netns file")
+		// netns.New() is not atomic: it could have encountered an error
+		// after unsharing the current thread's network namespace.
+		restore()
+		t.Fatalf("Failed to enter netns: %v", err)
 	}
 	}
 
 
 	// Since we are switching to a new test namespace make
 	// Since we are switching to a new test namespace make
 	// sure to re-initialize initNs context
 	// sure to re-initialize initNs context
 	ns.Init()
 	ns.Init()
 
 
-	runtime.LockOSThread()
-
 	return func() {
 	return func() {
-		if err := syscall.Close(fd); err != nil {
+		if err := newNS.Close(); err != nil {
 			t.Logf("Warning: netns closing failed (%v)", err)
 			t.Logf("Warning: netns closing failed (%v)", err)
 		}
 		}
-		runtime.UnlockOSThread()
+		restore()
+		ns.Init()
 	}
 	}
 }
 }