libnetwork_test: overhaul TestParallel

TestParallel has been written in an unusual style which relies on the
testing package's intra-test parallelism feature and lots of global
state to test one thing using three cooperating parallel tests. It is
complicated to reason about and quite brittle. For example, the command

    go test -run TestParallel1 ./libnetwork

would deadlock, waiting until the test timeout for TestParallel2 and
TestParallel3 to run. And the test would be skipped if the
'-test.parallel' flag was less than three, either explicitly or
implicitly (default: GOMAXPROCS).

Overhaul TestParallel to address the aforementioned deficiencies and
get rid of mutable global state.

Signed-off-by: Cory Snider <csnider@mirantis.com>
This commit is contained in:
Cory Snider 2022-11-08 17:02:32 -05:00
parent 32ace57479
commit d0096bba21
2 changed files with 92 additions and 206 deletions

View file

@ -3,13 +3,10 @@ package libnetwork_test
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"net"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"sync"
"testing"
@ -22,23 +19,17 @@ import (
"github.com/docker/docker/libnetwork/testutils"
"github.com/docker/docker/libnetwork/types"
"github.com/docker/docker/pkg/reexec"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netns"
"golang.org/x/sync/errgroup"
)
const (
bridgeNetType = "bridge"
)
// Shared state for createGlobalInstance() and runParallelTests().
var (
origins = netns.None()
testns = netns.None()
controller libnetwork.NetworkController
)
func makeTesthostNetwork(t *testing.T, c libnetwork.NetworkController) libnetwork.Network {
t.Helper()
n, err := createTestNetwork(c, "host", "testhost", options.Generic{}, nil, nil)
@ -48,60 +39,6 @@ func makeTesthostNetwork(t *testing.T, c libnetwork.NetworkController) libnetwor
return n
}
func createGlobalInstance(t *testing.T) {
var err error
defer close(start)
origins, err = netns.Get()
if err != nil {
t.Fatal(err)
}
testns, err = netns.New()
if err != nil {
t.Fatal(err)
}
controller = newController(t)
t.Cleanup(controller.Stop)
netOption := options.Generic{
netlabel.GenericData: options.Generic{
"BridgeName": "network",
},
}
net1 := makeTesthostNetwork(t, controller)
net2, err := createTestNetwork(controller, "bridge", "network2", netOption, nil, nil)
if err != nil {
t.Fatal(err)
}
_, err = net1.CreateEndpoint("pep1")
if err != nil {
t.Fatal(err)
}
_, err = net2.CreateEndpoint("pep2")
if err != nil {
t.Fatal(err)
}
_, err = net2.CreateEndpoint("pep3")
if err != nil {
t.Fatal(err)
}
if sboxes[first-1], err = controller.NewSandbox(fmt.Sprintf("%drace", first), libnetwork.OptionUseDefaultSandbox()); err != nil {
t.Fatal(err)
}
for thd := first + 1; thd <= last; thd++ {
if sboxes[thd-1], err = controller.NewSandbox(fmt.Sprintf("%drace", thd)); err != nil {
t.Fatal(err)
}
}
}
func TestHost(t *testing.T) {
defer testutils.SetupTestOSContext(t)()
controller := newController(t)
@ -906,160 +843,133 @@ func TestResolvConf(t *testing.T) {
}
}
func parallelJoin(t *testing.T, rc libnetwork.Sandbox, ep libnetwork.Endpoint, thrNumber int) {
debugf("J%d.", thrNumber)
var err error
sb := sboxes[thrNumber-1]
err = ep.Join(sb)
if err != nil {
if _, ok := err.(types.ForbiddenError); !ok {
t.Fatalf("thread %d: %v", thrNumber, err)
}
debugf("JE%d(%v).", thrNumber, err)
}
debugf("JD%d.", thrNumber)
type parallelTester struct {
osctx *testutils.OSContext
controller libnetwork.NetworkController
net1, net2 libnetwork.Network
iterCnt int
}
func parallelLeave(t *testing.T, rc libnetwork.Sandbox, ep libnetwork.Endpoint, thrNumber int) {
debugf("L%d.", thrNumber)
var err error
sb := sboxes[thrNumber-1]
err = ep.Leave(sb)
if err != nil {
if _, ok := err.(types.ForbiddenError); !ok {
t.Fatalf("thread %d: %v", thrNumber, err)
}
debugf("LE%d(%v).", thrNumber, err)
}
debugf("LD%d.", thrNumber)
}
func runParallelTests(t *testing.T, thrNumber int) {
func (pt parallelTester) Do(t *testing.T, thrNumber int) error {
var (
ep libnetwork.Endpoint
sb libnetwork.Sandbox
err error
)
t.Parallel()
pTest := flag.Lookup("test.parallel")
if pTest == nil {
t.Skip("Skipped because test.parallel flag not set;")
}
numParallel, err := strconv.Atoi(pTest.Value.String())
teardown, err := pt.osctx.Set()
if err != nil {
t.Fatal(err)
}
if numParallel < numThreads {
t.Skip("Skipped because t.parallel was less than ", numThreads)
}
runtime.LockOSThread()
if thrNumber == first {
createGlobalInstance(t)
} else {
<-start
thrdone := make(chan struct{})
done <- thrdone
defer close(thrdone)
if thrNumber == last {
defer close(done)
}
err = netns.Set(testns)
if err != nil {
runtime.UnlockOSThread()
t.Fatal(err)
}
}
defer func() {
if err := netns.Set(origins); err != nil {
t.Fatalf("Error restoring the current thread's netns: %v", err)
} else {
runtime.UnlockOSThread()
}
}()
net1, err := controller.NetworkByName("testhost")
if err != nil {
t.Fatal(err)
}
if net1 == nil {
t.Fatal("Could not find testhost")
}
net2, err := controller.NetworkByName("network2")
if err != nil {
t.Fatal(err)
}
if net2 == nil {
t.Fatal("Could not find network2")
return err
}
defer teardown(t)
epName := fmt.Sprintf("pep%d", thrNumber)
if thrNumber == first {
ep, err = net1.EndpointByName(epName)
if thrNumber == 1 {
ep, err = pt.net1.EndpointByName(epName)
} else {
ep, err = net2.EndpointByName(epName)
ep, err = pt.net2.EndpointByName(epName)
}
if err != nil {
t.Fatal(err)
return errors.WithStack(err)
}
if ep == nil {
t.Fatal("Got nil ep with no error")
return errors.New("got nil ep with no error")
}
cid := fmt.Sprintf("%drace", thrNumber)
controller.WalkSandboxes(libnetwork.SandboxContainerWalker(&sb, cid))
pt.controller.WalkSandboxes(libnetwork.SandboxContainerWalker(&sb, cid))
if sb == nil {
t.Fatalf("Got nil sandbox for container: %s", cid)
return errors.Errorf("got nil sandbox for container: %s", cid)
}
for i := 0; i < iterCnt; i++ {
parallelJoin(t, sb, ep, thrNumber)
parallelLeave(t, sb, ep, thrNumber)
for i := 0; i < pt.iterCnt; i++ {
if err := ep.Join(sb); err != nil {
if _, ok := err.(types.ForbiddenError); !ok {
return errors.Wrapf(err, "thread %d", thrNumber)
}
}
if err := ep.Leave(sb); err != nil {
if _, ok := err.(types.ForbiddenError); !ok {
return errors.Wrapf(err, "thread %d", thrNumber)
}
}
}
debugf("\n")
if err := errors.WithStack(sb.Delete()); err != nil {
return err
}
return errors.WithStack(ep.Delete(false))
}
err = sb.Delete()
func TestParallel(t *testing.T) {
const (
first = 1
last = 3
numThreads = last - first + 1
iterCnt = 25
)
osctx := testutils.SetupTestOSContextEx(t)
defer osctx.Cleanup(t)
controller := newController(t)
netOption := options.Generic{
netlabel.GenericData: options.Generic{
"BridgeName": "network",
},
}
net1 := makeTesthostNetwork(t, controller)
defer net1.Delete()
net2, err := createTestNetwork(controller, "bridge", "network2", netOption, nil, nil)
if err != nil {
t.Fatal(err)
}
if thrNumber == first {
for thrdone := range done {
<-thrdone
}
defer net2.Delete()
if testns != origins {
testns.Close()
}
if err := net2.Delete(); err != nil {
t.Fatal(err)
}
} else {
err = ep.Delete(false)
if err != nil {
_, err = net1.CreateEndpoint("pep1")
if err != nil {
t.Fatal(err)
}
_, err = net2.CreateEndpoint("pep2")
if err != nil {
t.Fatal(err)
}
_, err = net2.CreateEndpoint("pep3")
if err != nil {
t.Fatal(err)
}
sboxes := make([]libnetwork.Sandbox, numThreads)
if sboxes[first-1], err = controller.NewSandbox(fmt.Sprintf("%drace", first), libnetwork.OptionUseDefaultSandbox()); err != nil {
t.Fatal(err)
}
for thd := first + 1; thd <= last; thd++ {
if sboxes[thd-1], err = controller.NewSandbox(fmt.Sprintf("%drace", thd)); err != nil {
t.Fatal(err)
}
}
}
func TestParallel1(t *testing.T) {
runParallelTests(t, 1)
}
pt := parallelTester{
osctx: osctx,
controller: controller,
net1: net1,
net2: net2,
iterCnt: iterCnt,
}
func TestParallel2(t *testing.T) {
runParallelTests(t, 2)
var eg errgroup.Group
for i := first; i <= last; i++ {
i := i
eg.Go(func() error { return pt.Do(t, i) })
}
if err := eg.Wait(); err != nil {
t.Fatalf("%+v", err)
}
}
func TestBridge(t *testing.T) {
@ -1150,10 +1060,6 @@ func isV6Listenable() bool {
return v6ListenableCached
}
func TestParallel3(t *testing.T) {
runParallelTests(t, 3)
}
func TestNullIpam(t *testing.T) {
defer testutils.SetupTestOSContext(t)()
controller := newController(t)

View file

@ -1332,23 +1332,3 @@ func TestValidRemoteDriver(t *testing.T) {
}
}()
}
var (
start = make(chan struct{})
done = make(chan chan struct{}, numThreads-1)
sboxes = make([]libnetwork.Sandbox, numThreads)
)
const (
iterCnt = 25
numThreads = 3
first = 1
last = numThreads
debug = false
)
func debugf(format string, a ...interface{}) {
if debug {
fmt.Printf(format, a...)
}
}