Jelajahi Sumber

Merge pull request #44356 from corhere/libnetwork-namespace-correctness

libnetwork: fix restoring thread network namespaces
Sebastiaan van Stijn 2 tahun lalu
induk
melakukan
6c829007cc

+ 0 - 16
libnetwork/drivers/bridge/bridge.go

@@ -22,7 +22,6 @@ import (
 	"github.com/docker/docker/libnetwork/netutils"
 	"github.com/docker/docker/libnetwork/ns"
 	"github.com/docker/docker/libnetwork/options"
-	"github.com/docker/docker/libnetwork/osl"
 	"github.com/docker/docker/libnetwork/portmapper"
 	"github.com/docker/docker/libnetwork/types"
 	"github.com/sirupsen/logrus"
@@ -671,8 +670,6 @@ func (d *driver) checkConflict(config *networkConfiguration) error {
 }
 
 func (d *driver) createNetwork(config *networkConfiguration) (err error) {
-	defer osl.InitOSContext()()
-
 	// Initialize handle when needed
 	d.Lock()
 	if d.nlh == nil {
@@ -811,7 +808,6 @@ func (d *driver) DeleteNetwork(nid string) error {
 func (d *driver) deleteNetwork(nid string) error {
 	var err error
 
-	defer osl.InitOSContext()()
 	// Get network handler and remove it from driver
 	d.Lock()
 	n, ok := d.networks[nid]
@@ -934,8 +930,6 @@ func setHairpinMode(nlh *netlink.Handle, link netlink.Link, enable bool) error {
 }
 
 func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo, epOptions map[string]interface{}) error {
-	defer osl.InitOSContext()()
-
 	if ifInfo == nil {
 		return errors.New("invalid interface info passed")
 	}
@@ -1121,8 +1115,6 @@ func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo,
 func (d *driver) DeleteEndpoint(nid, eid string) error {
 	var err error
 
-	defer osl.InitOSContext()()
-
 	// Get the network handler and make sure it exists
 	d.Lock()
 	n, ok := d.networks[nid]
@@ -1242,8 +1234,6 @@ func (d *driver) EndpointOperInfo(nid, eid string) (map[string]interface{}, erro
 
 // Join method is invoked when a Sandbox is attached to an endpoint.
 func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo, options map[string]interface{}) error {
-	defer osl.InitOSContext()()
-
 	network, err := d.getNetwork(nid)
 	if err != nil {
 		return err
@@ -1288,8 +1278,6 @@ func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo,
 
 // Leave method is invoked when a Sandbox detaches from an endpoint.
 func (d *driver) Leave(nid, eid string) error {
-	defer osl.InitOSContext()()
-
 	network, err := d.getNetwork(nid)
 	if err != nil {
 		return types.InternalMaskableErrorf("%s", err)
@@ -1314,8 +1302,6 @@ func (d *driver) Leave(nid, eid string) error {
 }
 
 func (d *driver) ProgramExternalConnectivity(nid, eid string, options map[string]interface{}) error {
-	defer osl.InitOSContext()()
-
 	network, err := d.getNetwork(nid)
 	if err != nil {
 		return err
@@ -1368,8 +1354,6 @@ func (d *driver) ProgramExternalConnectivity(nid, eid string, options map[string
 }
 
 func (d *driver) RevokeExternalConnectivity(nid, eid string) error {
-	defer osl.InitOSContext()()
-
 	network, err := d.getNetwork(nid)
 	if err != nil {
 		return err

+ 0 - 3
libnetwork/drivers/ipvlan/ipvlan_endpoint.go

@@ -9,7 +9,6 @@ import (
 	"github.com/docker/docker/libnetwork/driverapi"
 	"github.com/docker/docker/libnetwork/netlabel"
 	"github.com/docker/docker/libnetwork/ns"
-	"github.com/docker/docker/libnetwork/osl"
 	"github.com/docker/docker/libnetwork/types"
 	"github.com/sirupsen/logrus"
 )
@@ -17,7 +16,6 @@ import (
 // CreateEndpoint assigns the mac, ip and endpoint id for the new container
 func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo,
 	epOptions map[string]interface{}) error {
-	defer osl.InitOSContext()()
 
 	if err := validateID(nid, eid); err != nil {
 		return err
@@ -66,7 +64,6 @@ func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo,
 
 // DeleteEndpoint remove the endpoint and associated netlink interface
 func (d *driver) DeleteEndpoint(nid, eid string) error {
-	defer osl.InitOSContext()()
 	if err := validateID(nid, eid); err != nil {
 		return err
 	}

+ 0 - 3
libnetwork/drivers/ipvlan/ipvlan_joinleave.go

@@ -10,7 +10,6 @@ import (
 	"github.com/docker/docker/libnetwork/driverapi"
 	"github.com/docker/docker/libnetwork/netutils"
 	"github.com/docker/docker/libnetwork/ns"
-	"github.com/docker/docker/libnetwork/osl"
 	"github.com/docker/docker/libnetwork/types"
 	"github.com/sirupsen/logrus"
 )
@@ -28,7 +27,6 @@ const (
 
 // Join method is invoked when a Sandbox is attached to an endpoint.
 func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo, options map[string]interface{}) error {
-	defer osl.InitOSContext()()
 	n, err := d.getNetwork(nid)
 	if err != nil {
 		return err
@@ -139,7 +137,6 @@ func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo,
 
 // Leave method is invoked when a Sandbox detaches from an endpoint.
 func (d *driver) Leave(nid, eid string) error {
-	defer osl.InitOSContext()()
 	network, err := d.getNetwork(nid)
 	if err != nil {
 		return err

+ 0 - 3
libnetwork/drivers/ipvlan/ipvlan_network.go

@@ -10,7 +10,6 @@ import (
 	"github.com/docker/docker/libnetwork/netlabel"
 	"github.com/docker/docker/libnetwork/ns"
 	"github.com/docker/docker/libnetwork/options"
-	"github.com/docker/docker/libnetwork/osl"
 	"github.com/docker/docker/libnetwork/types"
 	"github.com/docker/docker/pkg/parsers/kernel"
 	"github.com/docker/docker/pkg/stringid"
@@ -19,7 +18,6 @@ import (
 
 // CreateNetwork the network for the specified driver type
 func (d *driver) CreateNetwork(nid string, option map[string]interface{}, nInfo driverapi.NetworkInfo, ipV4Data, ipV6Data []driverapi.IPAMData) error {
-	defer osl.InitOSContext()()
 	kv, err := kernel.GetKernelVersion()
 	if err != nil {
 		return fmt.Errorf("failed to check kernel version for ipvlan driver support: %v", err)
@@ -118,7 +116,6 @@ func (d *driver) createNetwork(config *configuration) (bool, error) {
 
 // DeleteNetwork deletes the network for the specified driver type
 func (d *driver) DeleteNetwork(nid string) error {
-	defer osl.InitOSContext()()
 	n := d.network(nid)
 	if n == nil {
 		return fmt.Errorf("network id %s not found", nid)

+ 0 - 3
libnetwork/drivers/macvlan/macvlan_endpoint.go

@@ -10,7 +10,6 @@ import (
 	"github.com/docker/docker/libnetwork/netlabel"
 	"github.com/docker/docker/libnetwork/netutils"
 	"github.com/docker/docker/libnetwork/ns"
-	"github.com/docker/docker/libnetwork/osl"
 	"github.com/docker/docker/libnetwork/types"
 	"github.com/sirupsen/logrus"
 )
@@ -18,7 +17,6 @@ import (
 // CreateEndpoint assigns the mac, ip and endpoint id for the new container
 func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo,
 	epOptions map[string]interface{}) error {
-	defer osl.InitOSContext()()
 
 	if err := validateID(nid, eid); err != nil {
 		return err
@@ -71,7 +69,6 @@ func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo,
 
 // DeleteEndpoint removes the endpoint and associated netlink interface
 func (d *driver) DeleteEndpoint(nid, eid string) error {
-	defer osl.InitOSContext()()
 	if err := validateID(nid, eid); err != nil {
 		return err
 	}

+ 0 - 3
libnetwork/drivers/macvlan/macvlan_joinleave.go

@@ -10,13 +10,11 @@ import (
 	"github.com/docker/docker/libnetwork/driverapi"
 	"github.com/docker/docker/libnetwork/netutils"
 	"github.com/docker/docker/libnetwork/ns"
-	"github.com/docker/docker/libnetwork/osl"
 	"github.com/sirupsen/logrus"
 )
 
 // Join method is invoked when a Sandbox is attached to an endpoint.
 func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo, options map[string]interface{}) error {
-	defer osl.InitOSContext()()
 	n, err := d.getNetwork(nid)
 	if err != nil {
 		return err
@@ -100,7 +98,6 @@ func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo,
 
 // Leave method is invoked when a Sandbox detaches from an endpoint.
 func (d *driver) Leave(nid, eid string) error {
-	defer osl.InitOSContext()()
 	network, err := d.getNetwork(nid)
 	if err != nil {
 		return err

+ 0 - 4
libnetwork/drivers/macvlan/macvlan_network.go

@@ -10,7 +10,6 @@ import (
 	"github.com/docker/docker/libnetwork/netlabel"
 	"github.com/docker/docker/libnetwork/ns"
 	"github.com/docker/docker/libnetwork/options"
-	"github.com/docker/docker/libnetwork/osl"
 	"github.com/docker/docker/libnetwork/types"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/sirupsen/logrus"
@@ -18,8 +17,6 @@ import (
 
 // CreateNetwork the network for the specified driver type
 func (d *driver) CreateNetwork(nid string, option map[string]interface{}, nInfo driverapi.NetworkInfo, ipV4Data, ipV6Data []driverapi.IPAMData) error {
-	defer osl.InitOSContext()()
-
 	// reject a null v4 network
 	if len(ipV4Data) == 0 || ipV4Data[0].Pool.String() == "0.0.0.0/0" {
 		return fmt.Errorf("ipv4 pool is empty")
@@ -109,7 +106,6 @@ func (d *driver) createNetwork(config *configuration) (bool, error) {
 
 // DeleteNetwork deletes the network for the specified driver type
 func (d *driver) DeleteNetwork(nid string) error {
-	defer osl.InitOSContext()()
 	n := d.network(nid)
 	if n == nil {
 		return fmt.Errorf("network id %s not found", nid)

+ 28 - 4
libnetwork/drivers/overlay/ov_network.go

@@ -76,6 +76,14 @@ type network struct {
 
 func init() {
 	reexec.Register("set-default-vlan", setDefaultVlan)
+
+	// Lock main() to the initial thread to exclude the goroutines executing
+	// func (*network).watchMiss() from being scheduled onto that thread.
+	// Changes to the network namespace of the initial thread alter
+	// /proc/self/ns/net, which would break any code which (incorrectly)
+	// assumes that that file is a handle to the network namespace for the
+	// thread it is currently executing on.
+	runtime.LockOSThread()
 }
 
 func setDefaultVlan() {
@@ -779,20 +787,36 @@ func (n *network) initSandbox(restore bool) error {
 func (n *network) watchMiss(nlSock *nl.NetlinkSocket, nsPath string) {
 	// With the new version of the netlink library the deserialize function makes
 	// requests about the interface of the netlink message. This can succeed only
-	// if this go routine is in the target namespace. For this reason following we
-	// lock the thread on that namespace
-	runtime.LockOSThread()
-	defer runtime.UnlockOSThread()
+	// if this go routine is in the target namespace.
+	origNs, err := netns.Get()
+	if err != nil {
+		logrus.WithError(err).Error("failed to get the initial network namespace")
+		return
+	}
+	defer origNs.Close()
 	newNs, err := netns.GetFromPath(nsPath)
 	if err != nil {
 		logrus.WithError(err).Errorf("failed to get the namespace %s", nsPath)
 		return
 	}
 	defer newNs.Close()
+
+	runtime.LockOSThread()
 	if err = netns.Set(newNs); err != nil {
 		logrus.WithError(err).Errorf("failed to enter the namespace %s", nsPath)
+		runtime.UnlockOSThread()
 		return
 	}
+	defer func() {
+		if err := netns.Set(origNs); err != nil {
+			logrus.WithError(err).Error("failed to restore the thread's initial network namespace")
+			// The error is only fatal for the current thread. Keep this
+			// goroutine locked to the thread to make the runtime replace it
+			// with a clean thread once this goroutine terminates.
+		} else {
+			runtime.UnlockOSThread()
+		}
+	}()
 	for {
 		msgs, _, err := nlSock.Receive()
 		if err != nil {

+ 0 - 10
libnetwork/drivers/overlay/ov_utils.go

@@ -11,7 +11,6 @@ import (
 	"github.com/docker/docker/libnetwork/drivers/overlay/overlayutils"
 	"github.com/docker/docker/libnetwork/netutils"
 	"github.com/docker/docker/libnetwork/ns"
-	"github.com/docker/docker/libnetwork/osl"
 	"github.com/sirupsen/logrus"
 	"github.com/vishvananda/netlink"
 	"github.com/vishvananda/netns"
@@ -32,7 +31,6 @@ func validateID(nid, eid string) error {
 }
 
 func createVethPair() (string, string, error) {
-	defer osl.InitOSContext()()
 	nlh := ns.NlHandle()
 
 	// Generate a name for what will be the host side pipe interface
@@ -59,8 +57,6 @@ func createVethPair() (string, string, error) {
 }
 
 func createVxlan(name string, vni uint32, mtu int) error {
-	defer osl.InitOSContext()()
-
 	vxlan := &netlink.Vxlan{
 		LinkAttrs: netlink.LinkAttrs{Name: name, MTU: mtu},
 		VxlanId:   int(vni),
@@ -79,8 +75,6 @@ func createVxlan(name string, vni uint32, mtu int) error {
 }
 
 func deleteInterfaceBySubnet(brPrefix string, s *subnet) error {
-	defer osl.InitOSContext()()
-
 	nlh := ns.NlHandle()
 	links, err := nlh.LinkList()
 	if err != nil {
@@ -109,8 +103,6 @@ func deleteInterfaceBySubnet(brPrefix string, s *subnet) error {
 }
 
 func deleteInterface(name string) error {
-	defer osl.InitOSContext()()
-
 	link, err := ns.NlHandle().LinkByName(name)
 	if err != nil {
 		return fmt.Errorf("failed to find interface with name %s: %v", name, err)
@@ -124,8 +116,6 @@ func deleteInterface(name string) error {
 }
 
 func deleteVxlanByVNI(path string, vni uint32) error {
-	defer osl.InitOSContext()()
-
 	nlh := ns.NlHandle()
 	if path != "" {
 		ns, err := netns.GetFromPath(path)

+ 5 - 14
libnetwork/libnetwork_linux_test.go

@@ -917,7 +917,6 @@ func parallelJoin(t *testing.T, rc libnetwork.Sandbox, ep libnetwork.Endpoint, t
 	sb := sboxes[thrNumber-1]
 	err = ep.Join(sb)
 
-	runtime.LockOSThread()
 	if err != nil {
 		if _, ok := err.(types.ForbiddenError); !ok {
 			t.Fatalf("thread %d: %v", thrNumber, err)
@@ -934,7 +933,6 @@ func parallelLeave(t *testing.T, rc libnetwork.Sandbox, ep libnetwork.Endpoint,
 	sb := sboxes[thrNumber-1]
 
 	err = ep.Leave(sb)
-	runtime.LockOSThread()
 	if err != nil {
 		if _, ok := err.(types.ForbiddenError); !ok {
 			t.Fatalf("thread %d: %v", thrNumber, err)
@@ -966,13 +964,9 @@ func runParallelTests(t *testing.T, thrNumber int) {
 	}
 
 	runtime.LockOSThread()
-	defer runtime.UnlockOSThread()
-
 	if thrNumber == first {
 		createGlobalInstance(t)
-	}
-
-	if thrNumber != first {
+	} else {
 		<-start
 
 		thrdone := make(chan struct{})
@@ -985,18 +979,15 @@ func runParallelTests(t *testing.T, thrNumber int) {
 
 		err = netns.Set(testns)
 		if err != nil {
+			runtime.UnlockOSThread()
 			t.Fatal(err)
 		}
 	}
 	defer func() {
 		if err := netns.Set(origins); err != nil {
-			// NOTE(@cpuguy83): This...
-			// I touched this code because the linter found that we weren't checking the error...
-			// It returns an error because "origins" is a closed file handle *unless* createGlobalInstance is called.
-			// Which... this test is run in parallel and `createGlobalInstance` modifies `origins` without synchronization.
-			// I'm not sure what exactly the *intent* of this code was, but it looks very broken.
-			// Anyway that's why I'm only logging the error and not failing the test.
-			t.Log(err)
+			t.Fatalf("Error restoring the current thread's netns: %v", err)
+		} else {
+			runtime.UnlockOSThread()
 		}
 	}()
 

+ 0 - 3
libnetwork/netutils/utils_linux.go

@@ -12,7 +12,6 @@ import (
 
 	"github.com/docker/docker/libnetwork/ipamutils"
 	"github.com/docker/docker/libnetwork/ns"
-	"github.com/docker/docker/libnetwork/osl"
 	"github.com/docker/docker/libnetwork/resolvconf"
 	"github.com/docker/docker/libnetwork/types"
 	"github.com/pkg/errors"
@@ -73,8 +72,6 @@ func GenerateIfaceName(nlh *netlink.Handle, prefix string, len int) (string, err
 func ElectInterfaceAddresses(name string) ([]*net.IPNet, []*net.IPNet, error) {
 	var v4Nets, v6Nets []*net.IPNet
 
-	defer osl.InitOSContext()()
-
 	link, _ := ns.NlHandle().LinkByName(name)
 	if link != nil {
 		v4addr, err := ns.NlHandle().AddrList(link, netlink.FAMILY_V4)

+ 0 - 18
libnetwork/ns/init_linux.go

@@ -2,7 +2,6 @@ package ns
 
 import (
 	"fmt"
-	"os"
 	"os/exec"
 	"strings"
 	"sync"
@@ -39,19 +38,6 @@ func Init() {
 	}
 }
 
-// SetNamespace sets the initial namespace handler
-func SetNamespace() error {
-	initOnce.Do(Init)
-	if err := netns.Set(initNs); err != nil {
-		linkInfo, linkErr := getLink()
-		if linkErr != nil {
-			linkInfo = linkErr.Error()
-		}
-		return fmt.Errorf("failed to set to initial namespace, %v, initns fd %d: %v", linkInfo, initNs, err)
-	}
-	return nil
-}
-
 // ParseHandlerInt transforms the namespace handler into an integer
 func ParseHandlerInt() int {
 	return int(getHandler())
@@ -63,10 +49,6 @@ func getHandler() netns.NsHandle {
 	return initNs
 }
 
-func getLink() (string, error) {
-	return os.Readlink(fmt.Sprintf("/proc/%d/task/%d/ns/net", os.Getpid(), syscall.Gettid()))
-}
-
 // NlHandle returns the netlink handler
 func NlHandle() *netlink.Handle {
 	initOnce.Do(Init)

+ 38 - 32
libnetwork/osl/namespace_linux.go

@@ -28,6 +28,14 @@ const defaultPrefix = "/var/run/docker"
 
 func init() {
 	reexec.Register("set-ipv6", reexecSetIPv6)
+
+	// Lock main() to the initial thread to exclude the goroutines spawned
+	// by func (*networkNamespace) InvokeFunc() from being scheduled onto
+	// that thread. Changes to the network namespace of the initial thread
+	// alter /proc/self/ns/net, which would break any code which
+	// (incorrectly) assumes that that file is a handle to the network
+	// namespace for the thread it is currently executing on.
+	runtime.LockOSThread()
 }
 
 var (
@@ -412,43 +420,41 @@ func (n *networkNamespace) DisableARPForVIP(srcName string) (Err error) {
 }
 
 func (n *networkNamespace) InvokeFunc(f func()) error {
-	return nsInvoke(n.nsPath(), func(nsFD int) error { return nil }, func(callerFD int) error {
-		f()
-		return nil
-	})
-}
-
-// InitOSContext initializes OS context while configuring network resources
-func InitOSContext() func() {
-	runtime.LockOSThread()
-	if err := ns.SetNamespace(); err != nil {
-		logrus.Error(err)
-	}
-	return runtime.UnlockOSThread
-}
-
-func nsInvoke(path string, prefunc func(nsFD int) error, postfunc func(callerFD int) error) error {
-	defer InitOSContext()()
-
-	newNs, err := netns.GetFromPath(path)
+	origNS, err := netns.Get()
 	if err != nil {
-		return fmt.Errorf("failed get network namespace %q: %v", path, err)
-	}
-	defer newNs.Close()
-
-	// Invoked before the namespace switch happens but after the namespace file
-	// handle is obtained.
-	if err := prefunc(int(newNs)); err != nil {
-		return fmt.Errorf("failed in prefunc: %v", err)
+		return fmt.Errorf("failed to get original network namespace: %w", err)
 	}
+	defer origNS.Close()
 
-	if err = netns.Set(newNs); err != nil {
-		return err
+	path := n.nsPath()
+	newNS, err := netns.GetFromPath(path)
+	if err != nil {
+		return fmt.Errorf("failed get network namespace %q: %w", path, err)
 	}
-	defer ns.SetNamespace()
+	defer newNS.Close()
 
-	// Invoked after the namespace switch.
-	return postfunc(ns.ParseHandlerInt())
+	done := make(chan error, 1)
+	go func() {
+		runtime.LockOSThread()
+		if err := netns.Set(newNS); err != nil {
+			runtime.UnlockOSThread()
+			done <- err
+			return
+		}
+		defer func() {
+			close(done)
+			if err := netns.Set(origNS); err != nil {
+				logrus.WithError(err).Warn("failed to restore thread's network namespace")
+				// Recover from the error by leaving this goroutine locked to
+				// the thread. The runtime will terminate the thread and replace
+				// it with a clean one when this goroutine returns.
+			} else {
+				runtime.UnlockOSThread()
+			}
+		}()
+		f()
+	}()
+	return <-done
 }
 
 func (n *networkNamespace) nsPath() string {

+ 0 - 5
libnetwork/osl/namespace_windows.go

@@ -21,11 +21,6 @@ func GetSandboxForExternalKey(path string, key string) (Sandbox, error) {
 func GC() {
 }
 
-// InitOSContext initializes OS context while configuring network resources
-func InitOSContext() func() {
-	return func() {}
-}
-
 // SetBasePath sets the base url prefix for the ns path
 func SetBasePath(path string) {
 }

+ 0 - 5
libnetwork/osl/sandbox_freebsd.go

@@ -27,11 +27,6 @@ func GetSandboxForExternalKey(path string, key string) (Sandbox, error) {
 func GC() {
 }
 
-// InitOSContext initializes OS context while configuring network resources
-func InitOSContext() func() {
-	return func() {}
-}
-
 // SetBasePath sets the base url prefix for the ns path
 func SetBasePath(path string) {
 }

+ 0 - 7
libnetwork/osl/sandbox_linux_test.go

@@ -7,7 +7,6 @@ import (
 	"net"
 	"os"
 	"path/filepath"
-	"runtime"
 	"strings"
 	"syscall"
 	"testing"
@@ -197,7 +196,6 @@ func TestDisableIPv6DAD(t *testing.T) {
 	if err != nil {
 		t.Fatalf("Failed to create a new sandbox: %v", err)
 	}
-	runtime.LockOSThread()
 	defer destroyTest(t, s)
 
 	n, ok := s.(*networkNamespace)
@@ -257,7 +255,6 @@ func TestSetInterfaceIP(t *testing.T) {
 	if err != nil {
 		t.Fatalf("Failed to create a new sandbox: %v", err)
 	}
-	runtime.LockOSThread()
 	defer destroyTest(t, s)
 
 	n, ok := s.(*networkNamespace)
@@ -332,7 +329,6 @@ func TestLiveRestore(t *testing.T) {
 	if err != nil {
 		t.Fatalf("Failed to create a new sandbox: %v", err)
 	}
-	runtime.LockOSThread()
 	defer destroyTest(t, s)
 
 	n, ok := s.(*networkNamespace)
@@ -482,7 +478,6 @@ func TestSandboxCreateTwice(t *testing.T) {
 	if err != nil {
 		t.Fatalf("Failed to create a new sandbox: %v", err)
 	}
-	runtime.LockOSThread()
 
 	// Create another sandbox with the same key to see if we handle it
 	// gracefully.
@@ -490,7 +485,6 @@ func TestSandboxCreateTwice(t *testing.T) {
 	if err != nil {
 		t.Fatalf("Failed to create a new sandbox: %v", err)
 	}
-	runtime.LockOSThread()
 
 	err = s.Destroy()
 	if err != nil {
@@ -532,7 +526,6 @@ func TestAddRemoveInterface(t *testing.T) {
 	if err != nil {
 		t.Fatalf("Failed to create a new sandbox: %v", err)
 	}
-	runtime.LockOSThread()
 
 	if s.Key() != key {
 		t.Fatalf("s.Key() returned %s. Expected %s", s.Key(), key)

+ 24 - 10
libnetwork/testutils/context_unix.go

@@ -5,10 +5,10 @@ package testutils
 
 import (
 	"runtime"
-	"syscall"
 	"testing"
 
 	"github.com/docker/docker/libnetwork/ns"
+	"github.com/vishvananda/netns"
 )
 
 // SetupTestOSContext joins a new network namespace, and returns its associated
@@ -18,26 +18,40 @@ import (
 //
 //	defer SetupTestOSContext(t)()
 func SetupTestOSContext(t *testing.T) func() {
-	runtime.LockOSThread()
-	if err := syscall.Unshare(syscall.CLONE_NEWNET); err != nil {
-		t.Fatalf("Failed to enter netns: %v", err)
+	origNS, err := netns.Get()
+	if err != nil {
+		t.Fatalf("Failed to open initial netns: %v", err)
+	}
+	restore := func() {
+		if err := netns.Set(origNS); err != nil {
+			t.Logf("Warning: failed to restore thread netns (%v)", err)
+		} else {
+			runtime.UnlockOSThread()
+		}
+
+		if err := origNS.Close(); err != nil {
+			t.Logf("Warning: netns closing failed (%v)", err)
+		}
 	}
 
-	fd, err := syscall.Open("/proc/self/ns/net", syscall.O_RDONLY, 0)
+	runtime.LockOSThread()
+	newNS, err := netns.New()
 	if err != nil {
-		t.Fatal("Failed to open netns file")
+		// netns.New() is not atomic: it could have encountered an error
+		// after unsharing the current thread's network namespace.
+		restore()
+		t.Fatalf("Failed to enter netns: %v", err)
 	}
 
 	// Since we are switching to a new test namespace make
 	// sure to re-initialize initNs context
 	ns.Init()
 
-	runtime.LockOSThread()
-
 	return func() {
-		if err := syscall.Close(fd); err != nil {
+		if err := newNS.Close(); err != nil {
 			t.Logf("Warning: netns closing failed (%v)", err)
 		}
-		runtime.UnlockOSThread()
+		restore()
+		ns.Init()
 	}
 }