diff --git a/libnetwork/libnetwork_linux_test.go b/libnetwork/libnetwork_linux_test.go index c67e4255eb..90b3b0fbde 100644 --- a/libnetwork/libnetwork_linux_test.go +++ b/libnetwork/libnetwork_linux_test.go @@ -3,13 +3,10 @@ package libnetwork_test import ( "bytes" "encoding/json" - "flag" "fmt" "net" "os" "os/exec" - "runtime" - "strconv" "strings" "sync" "testing" @@ -22,23 +19,17 @@ import ( "github.com/docker/docker/libnetwork/testutils" "github.com/docker/docker/libnetwork/types" "github.com/docker/docker/pkg/reexec" + "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" "github.com/vishvananda/netns" + "golang.org/x/sync/errgroup" ) const ( bridgeNetType = "bridge" ) -// Shared state for createGlobalInstance() and runParallelTests(). -var ( - origins = netns.None() - testns = netns.None() - - controller libnetwork.NetworkController -) - func makeTesthostNetwork(t *testing.T, c libnetwork.NetworkController) libnetwork.Network { t.Helper() n, err := createTestNetwork(c, "host", "testhost", options.Generic{}, nil, nil) @@ -48,60 +39,6 @@ func makeTesthostNetwork(t *testing.T, c libnetwork.NetworkController) libnetwor return n } -func createGlobalInstance(t *testing.T) { - var err error - defer close(start) - - origins, err = netns.Get() - if err != nil { - t.Fatal(err) - } - - testns, err = netns.New() - if err != nil { - t.Fatal(err) - } - - controller = newController(t) - t.Cleanup(controller.Stop) - - netOption := options.Generic{ - netlabel.GenericData: options.Generic{ - "BridgeName": "network", - }, - } - - net1 := makeTesthostNetwork(t, controller) - net2, err := createTestNetwork(controller, "bridge", "network2", netOption, nil, nil) - if err != nil { - t.Fatal(err) - } - - _, err = net1.CreateEndpoint("pep1") - if err != nil { - t.Fatal(err) - } - - _, err = net2.CreateEndpoint("pep2") - if err != nil { - t.Fatal(err) - } - - _, err = net2.CreateEndpoint("pep3") - if err != nil { - t.Fatal(err) - } - - if sboxes[first-1], err = controller.NewSandbox(fmt.Sprintf("%drace", first), libnetwork.OptionUseDefaultSandbox()); err != nil { - t.Fatal(err) - } - for thd := first + 1; thd <= last; thd++ { - if sboxes[thd-1], err = controller.NewSandbox(fmt.Sprintf("%drace", thd)); err != nil { - t.Fatal(err) - } - } -} - func TestHost(t *testing.T) { defer testutils.SetupTestOSContext(t)() controller := newController(t) @@ -906,160 +843,133 @@ func TestResolvConf(t *testing.T) { } } -func parallelJoin(t *testing.T, rc libnetwork.Sandbox, ep libnetwork.Endpoint, thrNumber int) { - debugf("J%d.", thrNumber) - var err error - - sb := sboxes[thrNumber-1] - err = ep.Join(sb) - - if err != nil { - if _, ok := err.(types.ForbiddenError); !ok { - t.Fatalf("thread %d: %v", thrNumber, err) - } - debugf("JE%d(%v).", thrNumber, err) - } - debugf("JD%d.", thrNumber) +type parallelTester struct { + osctx *testutils.OSContext + controller libnetwork.NetworkController + net1, net2 libnetwork.Network + iterCnt int } -func parallelLeave(t *testing.T, rc libnetwork.Sandbox, ep libnetwork.Endpoint, thrNumber int) { - debugf("L%d.", thrNumber) - var err error - - sb := sboxes[thrNumber-1] - - err = ep.Leave(sb) - if err != nil { - if _, ok := err.(types.ForbiddenError); !ok { - t.Fatalf("thread %d: %v", thrNumber, err) - } - debugf("LE%d(%v).", thrNumber, err) - } - debugf("LD%d.", thrNumber) -} - -func runParallelTests(t *testing.T, thrNumber int) { +func (pt parallelTester) Do(t *testing.T, thrNumber int) error { var ( ep libnetwork.Endpoint sb libnetwork.Sandbox err error ) - t.Parallel() - - pTest := flag.Lookup("test.parallel") - if pTest == nil { - t.Skip("Skipped because test.parallel flag not set;") - } - numParallel, err := strconv.Atoi(pTest.Value.String()) + teardown, err := pt.osctx.Set() if err != nil { - t.Fatal(err) - } - if numParallel < numThreads { - t.Skip("Skipped because t.parallel was less than ", numThreads) - } - - runtime.LockOSThread() - if thrNumber == first { - createGlobalInstance(t) - } else { - <-start - - thrdone := make(chan struct{}) - done <- thrdone - defer close(thrdone) - - if thrNumber == last { - defer close(done) - } - - err = netns.Set(testns) - if err != nil { - runtime.UnlockOSThread() - t.Fatal(err) - } - } - defer func() { - if err := netns.Set(origins); err != nil { - t.Fatalf("Error restoring the current thread's netns: %v", err) - } else { - runtime.UnlockOSThread() - } - }() - - net1, err := controller.NetworkByName("testhost") - if err != nil { - t.Fatal(err) - } - if net1 == nil { - t.Fatal("Could not find testhost") - } - - net2, err := controller.NetworkByName("network2") - if err != nil { - t.Fatal(err) - } - if net2 == nil { - t.Fatal("Could not find network2") + return err } + defer teardown(t) epName := fmt.Sprintf("pep%d", thrNumber) - if thrNumber == first { - ep, err = net1.EndpointByName(epName) + if thrNumber == 1 { + ep, err = pt.net1.EndpointByName(epName) } else { - ep, err = net2.EndpointByName(epName) + ep, err = pt.net2.EndpointByName(epName) } if err != nil { - t.Fatal(err) + return errors.WithStack(err) } if ep == nil { - t.Fatal("Got nil ep with no error") + return errors.New("got nil ep with no error") } cid := fmt.Sprintf("%drace", thrNumber) - controller.WalkSandboxes(libnetwork.SandboxContainerWalker(&sb, cid)) + pt.controller.WalkSandboxes(libnetwork.SandboxContainerWalker(&sb, cid)) if sb == nil { - t.Fatalf("Got nil sandbox for container: %s", cid) + return errors.Errorf("got nil sandbox for container: %s", cid) } - for i := 0; i < iterCnt; i++ { - parallelJoin(t, sb, ep, thrNumber) - parallelLeave(t, sb, ep, thrNumber) + for i := 0; i < pt.iterCnt; i++ { + if err := ep.Join(sb); err != nil { + if _, ok := err.(types.ForbiddenError); !ok { + return errors.Wrapf(err, "thread %d", thrNumber) + } + } + if err := ep.Leave(sb); err != nil { + if _, ok := err.(types.ForbiddenError); !ok { + return errors.Wrapf(err, "thread %d", thrNumber) + } + } } - debugf("\n") + if err := errors.WithStack(sb.Delete()); err != nil { + return err + } + return errors.WithStack(ep.Delete(false)) +} - err = sb.Delete() +func TestParallel(t *testing.T) { + const ( + first = 1 + last = 3 + numThreads = last - first + 1 + iterCnt = 25 + ) + + osctx := testutils.SetupTestOSContextEx(t) + defer osctx.Cleanup(t) + controller := newController(t) + + netOption := options.Generic{ + netlabel.GenericData: options.Generic{ + "BridgeName": "network", + }, + } + + net1 := makeTesthostNetwork(t, controller) + defer net1.Delete() + net2, err := createTestNetwork(controller, "bridge", "network2", netOption, nil, nil) if err != nil { t.Fatal(err) } - if thrNumber == first { - for thrdone := range done { - <-thrdone - } + defer net2.Delete() - if testns != origins { - testns.Close() - } - if err := net2.Delete(); err != nil { - t.Fatal(err) - } - } else { - err = ep.Delete(false) - if err != nil { + _, err = net1.CreateEndpoint("pep1") + if err != nil { + t.Fatal(err) + } + + _, err = net2.CreateEndpoint("pep2") + if err != nil { + t.Fatal(err) + } + + _, err = net2.CreateEndpoint("pep3") + if err != nil { + t.Fatal(err) + } + + sboxes := make([]libnetwork.Sandbox, numThreads) + if sboxes[first-1], err = controller.NewSandbox(fmt.Sprintf("%drace", first), libnetwork.OptionUseDefaultSandbox()); err != nil { + t.Fatal(err) + } + for thd := first + 1; thd <= last; thd++ { + if sboxes[thd-1], err = controller.NewSandbox(fmt.Sprintf("%drace", thd)); err != nil { t.Fatal(err) } } -} -func TestParallel1(t *testing.T) { - runParallelTests(t, 1) -} + pt := parallelTester{ + osctx: osctx, + controller: controller, + net1: net1, + net2: net2, + iterCnt: iterCnt, + } -func TestParallel2(t *testing.T) { - runParallelTests(t, 2) + var eg errgroup.Group + for i := first; i <= last; i++ { + i := i + eg.Go(func() error { return pt.Do(t, i) }) + } + if err := eg.Wait(); err != nil { + t.Fatalf("%+v", err) + } } func TestBridge(t *testing.T) { @@ -1150,10 +1060,6 @@ func isV6Listenable() bool { return v6ListenableCached } -func TestParallel3(t *testing.T) { - runParallelTests(t, 3) -} - func TestNullIpam(t *testing.T) { defer testutils.SetupTestOSContext(t)() controller := newController(t) diff --git a/libnetwork/libnetwork_test.go b/libnetwork/libnetwork_test.go index 6272f5d29e..7a8778de62 100644 --- a/libnetwork/libnetwork_test.go +++ b/libnetwork/libnetwork_test.go @@ -1332,23 +1332,3 @@ func TestValidRemoteDriver(t *testing.T) { } }() } - -var ( - start = make(chan struct{}) - done = make(chan chan struct{}, numThreads-1) - sboxes = make([]libnetwork.Sandbox, numThreads) -) - -const ( - iterCnt = 25 - numThreads = 3 - first = 1 - last = numThreads - debug = false -) - -func debugf(format string, a ...interface{}) { - if debug { - fmt.Printf(format, a...) - } -}