Browse Source

libnetwork_test: overhaul TestParallel

TestParallel has been written in an unusual style which relies on the
testing package's intra-test parallelism feature and lots of global
state to test one thing using three cooperating parallel tests. It is
complicated to reason about and quite brittle. For example, the command

    go test -run TestParallel1 ./libnetwork

would deadlock, waiting until the test timeout for TestParallel2 and
TestParallel3 to run. And the test would be skipped if the
'-test.parallel' flag was less than three, either explicitly or
implicitly (default: GOMAXPROCS).

Overhaul TestParallel to address the aforementioned deficiencies and
get rid of mutable global state.

Signed-off-by: Cory Snider <csnider@mirantis.com>
Cory Snider 2 years ago
parent
commit
d0096bba21
2 changed files with 87 additions and 201 deletions
  1. 87 181
      libnetwork/libnetwork_linux_test.go
  2. 0 20
      libnetwork/libnetwork_test.go

+ 87 - 181
libnetwork/libnetwork_linux_test.go

@@ -3,13 +3,10 @@ package libnetwork_test
 import (
 	"bytes"
 	"encoding/json"
-	"flag"
 	"fmt"
 	"net"
 	"os"
 	"os/exec"
-	"runtime"
-	"strconv"
 	"strings"
 	"sync"
 	"testing"
@@ -22,23 +19,17 @@ import (
 	"github.com/docker/docker/libnetwork/testutils"
 	"github.com/docker/docker/libnetwork/types"
 	"github.com/docker/docker/pkg/reexec"
+	"github.com/pkg/errors"
 	"github.com/sirupsen/logrus"
 	"github.com/vishvananda/netlink"
 	"github.com/vishvananda/netns"
+	"golang.org/x/sync/errgroup"
 )
 
 const (
 	bridgeNetType = "bridge"
 )
 
-// Shared state for createGlobalInstance() and runParallelTests().
-var (
-	origins = netns.None()
-	testns  = netns.None()
-
-	controller libnetwork.NetworkController
-)
-
 func makeTesthostNetwork(t *testing.T, c libnetwork.NetworkController) libnetwork.Network {
 	t.Helper()
 	n, err := createTestNetwork(c, "host", "testhost", options.Generic{}, nil, nil)
@@ -48,60 +39,6 @@ func makeTesthostNetwork(t *testing.T, c libnetwork.NetworkController) libnetwor
 	return n
 }
 
-func createGlobalInstance(t *testing.T) {
-	var err error
-	defer close(start)
-
-	origins, err = netns.Get()
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	testns, err = netns.New()
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	controller = newController(t)
-	t.Cleanup(controller.Stop)
-
-	netOption := options.Generic{
-		netlabel.GenericData: options.Generic{
-			"BridgeName": "network",
-		},
-	}
-
-	net1 := makeTesthostNetwork(t, controller)
-	net2, err := createTestNetwork(controller, "bridge", "network2", netOption, nil, nil)
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	_, err = net1.CreateEndpoint("pep1")
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	_, err = net2.CreateEndpoint("pep2")
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	_, err = net2.CreateEndpoint("pep3")
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	if sboxes[first-1], err = controller.NewSandbox(fmt.Sprintf("%drace", first), libnetwork.OptionUseDefaultSandbox()); err != nil {
-		t.Fatal(err)
-	}
-	for thd := first + 1; thd <= last; thd++ {
-		if sboxes[thd-1], err = controller.NewSandbox(fmt.Sprintf("%drace", thd)); err != nil {
-			t.Fatal(err)
-		}
-	}
-}
-
 func TestHost(t *testing.T) {
 	defer testutils.SetupTestOSContext(t)()
 	controller := newController(t)
@@ -906,160 +843,133 @@ func TestResolvConf(t *testing.T) {
 	}
 }
 
-func parallelJoin(t *testing.T, rc libnetwork.Sandbox, ep libnetwork.Endpoint, thrNumber int) {
-	debugf("J%d.", thrNumber)
-	var err error
-
-	sb := sboxes[thrNumber-1]
-	err = ep.Join(sb)
-
-	if err != nil {
-		if _, ok := err.(types.ForbiddenError); !ok {
-			t.Fatalf("thread %d: %v", thrNumber, err)
-		}
-		debugf("JE%d(%v).", thrNumber, err)
-	}
-	debugf("JD%d.", thrNumber)
-}
-
-func parallelLeave(t *testing.T, rc libnetwork.Sandbox, ep libnetwork.Endpoint, thrNumber int) {
-	debugf("L%d.", thrNumber)
-	var err error
-
-	sb := sboxes[thrNumber-1]
-
-	err = ep.Leave(sb)
-	if err != nil {
-		if _, ok := err.(types.ForbiddenError); !ok {
-			t.Fatalf("thread %d: %v", thrNumber, err)
-		}
-		debugf("LE%d(%v).", thrNumber, err)
-	}
-	debugf("LD%d.", thrNumber)
+type parallelTester struct {
+	osctx      *testutils.OSContext
+	controller libnetwork.NetworkController
+	net1, net2 libnetwork.Network
+	iterCnt    int
 }
 
-func runParallelTests(t *testing.T, thrNumber int) {
+func (pt parallelTester) Do(t *testing.T, thrNumber int) error {
 	var (
 		ep  libnetwork.Endpoint
 		sb  libnetwork.Sandbox
 		err error
 	)
 
-	t.Parallel()
+	teardown, err := pt.osctx.Set()
+	if err != nil {
+		return err
+	}
+	defer teardown(t)
 
-	pTest := flag.Lookup("test.parallel")
-	if pTest == nil {
-		t.Skip("Skipped because test.parallel flag not set;")
+	epName := fmt.Sprintf("pep%d", thrNumber)
+
+	if thrNumber == 1 {
+		ep, err = pt.net1.EndpointByName(epName)
+	} else {
+		ep, err = pt.net2.EndpointByName(epName)
 	}
-	numParallel, err := strconv.Atoi(pTest.Value.String())
+
 	if err != nil {
-		t.Fatal(err)
+		return errors.WithStack(err)
 	}
-	if numParallel < numThreads {
-		t.Skip("Skipped because t.parallel was less than ", numThreads)
+	if ep == nil {
+		return errors.New("got nil ep with no error")
 	}
 
-	runtime.LockOSThread()
-	if thrNumber == first {
-		createGlobalInstance(t)
-	} else {
-		<-start
-
-		thrdone := make(chan struct{})
-		done <- thrdone
-		defer close(thrdone)
+	cid := fmt.Sprintf("%drace", thrNumber)
+	pt.controller.WalkSandboxes(libnetwork.SandboxContainerWalker(&sb, cid))
+	if sb == nil {
+		return errors.Errorf("got nil sandbox for container: %s", cid)
+	}
 
-		if thrNumber == last {
-			defer close(done)
+	for i := 0; i < pt.iterCnt; i++ {
+		if err := ep.Join(sb); err != nil {
+			if _, ok := err.(types.ForbiddenError); !ok {
+				return errors.Wrapf(err, "thread %d", thrNumber)
+			}
 		}
-
-		err = netns.Set(testns)
-		if err != nil {
-			runtime.UnlockOSThread()
-			t.Fatal(err)
+		if err := ep.Leave(sb); err != nil {
+			if _, ok := err.(types.ForbiddenError); !ok {
+				return errors.Wrapf(err, "thread %d", thrNumber)
+			}
 		}
 	}
-	defer func() {
-		if err := netns.Set(origins); err != nil {
-			t.Fatalf("Error restoring the current thread's netns: %v", err)
-		} else {
-			runtime.UnlockOSThread()
-		}
-	}()
 
-	net1, err := controller.NetworkByName("testhost")
-	if err != nil {
-		t.Fatal(err)
-	}
-	if net1 == nil {
-		t.Fatal("Could not find testhost")
+	if err := errors.WithStack(sb.Delete()); err != nil {
+		return err
 	}
+	return errors.WithStack(ep.Delete(false))
+}
 
-	net2, err := controller.NetworkByName("network2")
-	if err != nil {
-		t.Fatal(err)
-	}
-	if net2 == nil {
-		t.Fatal("Could not find network2")
-	}
+func TestParallel(t *testing.T) {
+	const (
+		first      = 1
+		last       = 3
+		numThreads = last - first + 1
+		iterCnt    = 25
+	)
 
-	epName := fmt.Sprintf("pep%d", thrNumber)
+	osctx := testutils.SetupTestOSContextEx(t)
+	defer osctx.Cleanup(t)
+	controller := newController(t)
 
-	if thrNumber == first {
-		ep, err = net1.EndpointByName(epName)
-	} else {
-		ep, err = net2.EndpointByName(epName)
+	netOption := options.Generic{
+		netlabel.GenericData: options.Generic{
+			"BridgeName": "network",
+		},
 	}
 
+	net1 := makeTesthostNetwork(t, controller)
+	defer net1.Delete()
+	net2, err := createTestNetwork(controller, "bridge", "network2", netOption, nil, nil)
 	if err != nil {
 		t.Fatal(err)
 	}
-	if ep == nil {
-		t.Fatal("Got nil ep with no error")
-	}
+	defer net2.Delete()
 
-	cid := fmt.Sprintf("%drace", thrNumber)
-	controller.WalkSandboxes(libnetwork.SandboxContainerWalker(&sb, cid))
-	if sb == nil {
-		t.Fatalf("Got nil sandbox for container: %s", cid)
+	_, err = net1.CreateEndpoint("pep1")
+	if err != nil {
+		t.Fatal(err)
 	}
 
-	for i := 0; i < iterCnt; i++ {
-		parallelJoin(t, sb, ep, thrNumber)
-		parallelLeave(t, sb, ep, thrNumber)
+	_, err = net2.CreateEndpoint("pep2")
+	if err != nil {
+		t.Fatal(err)
 	}
 
-	debugf("\n")
-
-	err = sb.Delete()
+	_, err = net2.CreateEndpoint("pep3")
 	if err != nil {
 		t.Fatal(err)
 	}
-	if thrNumber == first {
-		for thrdone := range done {
-			<-thrdone
-		}
 
-		if testns != origins {
-			testns.Close()
-		}
-		if err := net2.Delete(); err != nil {
-			t.Fatal(err)
-		}
-	} else {
-		err = ep.Delete(false)
-		if err != nil {
+	sboxes := make([]libnetwork.Sandbox, numThreads)
+	if sboxes[first-1], err = controller.NewSandbox(fmt.Sprintf("%drace", first), libnetwork.OptionUseDefaultSandbox()); err != nil {
+		t.Fatal(err)
+	}
+	for thd := first + 1; thd <= last; thd++ {
+		if sboxes[thd-1], err = controller.NewSandbox(fmt.Sprintf("%drace", thd)); err != nil {
 			t.Fatal(err)
 		}
 	}
-}
 
-func TestParallel1(t *testing.T) {
-	runParallelTests(t, 1)
-}
+	pt := parallelTester{
+		osctx:      osctx,
+		controller: controller,
+		net1:       net1,
+		net2:       net2,
+		iterCnt:    iterCnt,
+	}
 
-func TestParallel2(t *testing.T) {
-	runParallelTests(t, 2)
+	var eg errgroup.Group
+	for i := first; i <= last; i++ {
+		i := i
+		eg.Go(func() error { return pt.Do(t, i) })
+	}
+	if err := eg.Wait(); err != nil {
+		t.Fatalf("%+v", err)
+	}
 }
 
 func TestBridge(t *testing.T) {
@@ -1150,10 +1060,6 @@ func isV6Listenable() bool {
 	return v6ListenableCached
 }
 
-func TestParallel3(t *testing.T) {
-	runParallelTests(t, 3)
-}
-
 func TestNullIpam(t *testing.T) {
 	defer testutils.SetupTestOSContext(t)()
 	controller := newController(t)

+ 0 - 20
libnetwork/libnetwork_test.go

@@ -1332,23 +1332,3 @@ func TestValidRemoteDriver(t *testing.T) {
 		}
 	}()
 }
-
-var (
-	start  = make(chan struct{})
-	done   = make(chan chan struct{}, numThreads-1)
-	sboxes = make([]libnetwork.Sandbox, numThreads)
-)
-
-const (
-	iterCnt    = 25
-	numThreads = 3
-	first      = 1
-	last       = numThreads
-	debug      = false
-)
-
-func debugf(format string, a ...interface{}) {
-	if debug {
-		fmt.Printf(format, a...)
-	}
-}