Merge pull request #44356 from corhere/libnetwork-namespace-correctness

libnetwork: fix restoring thread network namespaces
This commit is contained in:
Sebastiaan van Stijn 2022-11-03 22:33:29 +01:00 committed by GitHub
commit 6c829007cc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 95 additions and 143 deletions

View file

@ -22,7 +22,6 @@ import (
"github.com/docker/docker/libnetwork/netutils" "github.com/docker/docker/libnetwork/netutils"
"github.com/docker/docker/libnetwork/ns" "github.com/docker/docker/libnetwork/ns"
"github.com/docker/docker/libnetwork/options" "github.com/docker/docker/libnetwork/options"
"github.com/docker/docker/libnetwork/osl"
"github.com/docker/docker/libnetwork/portmapper" "github.com/docker/docker/libnetwork/portmapper"
"github.com/docker/docker/libnetwork/types" "github.com/docker/docker/libnetwork/types"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -671,8 +670,6 @@ func (d *driver) checkConflict(config *networkConfiguration) error {
} }
func (d *driver) createNetwork(config *networkConfiguration) (err error) { func (d *driver) createNetwork(config *networkConfiguration) (err error) {
defer osl.InitOSContext()()
// Initialize handle when needed // Initialize handle when needed
d.Lock() d.Lock()
if d.nlh == nil { if d.nlh == nil {
@ -811,7 +808,6 @@ func (d *driver) DeleteNetwork(nid string) error {
func (d *driver) deleteNetwork(nid string) error { func (d *driver) deleteNetwork(nid string) error {
var err error var err error
defer osl.InitOSContext()()
// Get network handler and remove it from driver // Get network handler and remove it from driver
d.Lock() d.Lock()
n, ok := d.networks[nid] n, ok := d.networks[nid]
@ -934,8 +930,6 @@ func setHairpinMode(nlh *netlink.Handle, link netlink.Link, enable bool) error {
} }
func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo, epOptions map[string]interface{}) error { func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo, epOptions map[string]interface{}) error {
defer osl.InitOSContext()()
if ifInfo == nil { if ifInfo == nil {
return errors.New("invalid interface info passed") return errors.New("invalid interface info passed")
} }
@ -1121,8 +1115,6 @@ func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo,
func (d *driver) DeleteEndpoint(nid, eid string) error { func (d *driver) DeleteEndpoint(nid, eid string) error {
var err error var err error
defer osl.InitOSContext()()
// Get the network handler and make sure it exists // Get the network handler and make sure it exists
d.Lock() d.Lock()
n, ok := d.networks[nid] n, ok := d.networks[nid]
@ -1242,8 +1234,6 @@ func (d *driver) EndpointOperInfo(nid, eid string) (map[string]interface{}, erro
// Join method is invoked when a Sandbox is attached to an endpoint. // Join method is invoked when a Sandbox is attached to an endpoint.
func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo, options map[string]interface{}) error { func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo, options map[string]interface{}) error {
defer osl.InitOSContext()()
network, err := d.getNetwork(nid) network, err := d.getNetwork(nid)
if err != nil { if err != nil {
return err return err
@ -1288,8 +1278,6 @@ func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo,
// Leave method is invoked when a Sandbox detaches from an endpoint. // Leave method is invoked when a Sandbox detaches from an endpoint.
func (d *driver) Leave(nid, eid string) error { func (d *driver) Leave(nid, eid string) error {
defer osl.InitOSContext()()
network, err := d.getNetwork(nid) network, err := d.getNetwork(nid)
if err != nil { if err != nil {
return types.InternalMaskableErrorf("%s", err) return types.InternalMaskableErrorf("%s", err)
@ -1314,8 +1302,6 @@ func (d *driver) Leave(nid, eid string) error {
} }
func (d *driver) ProgramExternalConnectivity(nid, eid string, options map[string]interface{}) error { func (d *driver) ProgramExternalConnectivity(nid, eid string, options map[string]interface{}) error {
defer osl.InitOSContext()()
network, err := d.getNetwork(nid) network, err := d.getNetwork(nid)
if err != nil { if err != nil {
return err return err
@ -1368,8 +1354,6 @@ func (d *driver) ProgramExternalConnectivity(nid, eid string, options map[string
} }
func (d *driver) RevokeExternalConnectivity(nid, eid string) error { func (d *driver) RevokeExternalConnectivity(nid, eid string) error {
defer osl.InitOSContext()()
network, err := d.getNetwork(nid) network, err := d.getNetwork(nid)
if err != nil { if err != nil {
return err return err

View file

@ -9,7 +9,6 @@ import (
"github.com/docker/docker/libnetwork/driverapi" "github.com/docker/docker/libnetwork/driverapi"
"github.com/docker/docker/libnetwork/netlabel" "github.com/docker/docker/libnetwork/netlabel"
"github.com/docker/docker/libnetwork/ns" "github.com/docker/docker/libnetwork/ns"
"github.com/docker/docker/libnetwork/osl"
"github.com/docker/docker/libnetwork/types" "github.com/docker/docker/libnetwork/types"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -17,7 +16,6 @@ import (
// CreateEndpoint assigns the mac, ip and endpoint id for the new container // CreateEndpoint assigns the mac, ip and endpoint id for the new container
func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo, func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo,
epOptions map[string]interface{}) error { epOptions map[string]interface{}) error {
defer osl.InitOSContext()()
if err := validateID(nid, eid); err != nil { if err := validateID(nid, eid); err != nil {
return err return err
@ -66,7 +64,6 @@ func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo,
// DeleteEndpoint remove the endpoint and associated netlink interface // DeleteEndpoint remove the endpoint and associated netlink interface
func (d *driver) DeleteEndpoint(nid, eid string) error { func (d *driver) DeleteEndpoint(nid, eid string) error {
defer osl.InitOSContext()()
if err := validateID(nid, eid); err != nil { if err := validateID(nid, eid); err != nil {
return err return err
} }

View file

@ -10,7 +10,6 @@ import (
"github.com/docker/docker/libnetwork/driverapi" "github.com/docker/docker/libnetwork/driverapi"
"github.com/docker/docker/libnetwork/netutils" "github.com/docker/docker/libnetwork/netutils"
"github.com/docker/docker/libnetwork/ns" "github.com/docker/docker/libnetwork/ns"
"github.com/docker/docker/libnetwork/osl"
"github.com/docker/docker/libnetwork/types" "github.com/docker/docker/libnetwork/types"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -28,7 +27,6 @@ const (
// Join method is invoked when a Sandbox is attached to an endpoint. // Join method is invoked when a Sandbox is attached to an endpoint.
func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo, options map[string]interface{}) error { func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo, options map[string]interface{}) error {
defer osl.InitOSContext()()
n, err := d.getNetwork(nid) n, err := d.getNetwork(nid)
if err != nil { if err != nil {
return err return err
@ -139,7 +137,6 @@ func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo,
// Leave method is invoked when a Sandbox detaches from an endpoint. // Leave method is invoked when a Sandbox detaches from an endpoint.
func (d *driver) Leave(nid, eid string) error { func (d *driver) Leave(nid, eid string) error {
defer osl.InitOSContext()()
network, err := d.getNetwork(nid) network, err := d.getNetwork(nid)
if err != nil { if err != nil {
return err return err

View file

@ -10,7 +10,6 @@ import (
"github.com/docker/docker/libnetwork/netlabel" "github.com/docker/docker/libnetwork/netlabel"
"github.com/docker/docker/libnetwork/ns" "github.com/docker/docker/libnetwork/ns"
"github.com/docker/docker/libnetwork/options" "github.com/docker/docker/libnetwork/options"
"github.com/docker/docker/libnetwork/osl"
"github.com/docker/docker/libnetwork/types" "github.com/docker/docker/libnetwork/types"
"github.com/docker/docker/pkg/parsers/kernel" "github.com/docker/docker/pkg/parsers/kernel"
"github.com/docker/docker/pkg/stringid" "github.com/docker/docker/pkg/stringid"
@ -19,7 +18,6 @@ import (
// CreateNetwork the network for the specified driver type // CreateNetwork the network for the specified driver type
func (d *driver) CreateNetwork(nid string, option map[string]interface{}, nInfo driverapi.NetworkInfo, ipV4Data, ipV6Data []driverapi.IPAMData) error { func (d *driver) CreateNetwork(nid string, option map[string]interface{}, nInfo driverapi.NetworkInfo, ipV4Data, ipV6Data []driverapi.IPAMData) error {
defer osl.InitOSContext()()
kv, err := kernel.GetKernelVersion() kv, err := kernel.GetKernelVersion()
if err != nil { if err != nil {
return fmt.Errorf("failed to check kernel version for ipvlan driver support: %v", err) return fmt.Errorf("failed to check kernel version for ipvlan driver support: %v", err)
@ -118,7 +116,6 @@ func (d *driver) createNetwork(config *configuration) (bool, error) {
// DeleteNetwork deletes the network for the specified driver type // DeleteNetwork deletes the network for the specified driver type
func (d *driver) DeleteNetwork(nid string) error { func (d *driver) DeleteNetwork(nid string) error {
defer osl.InitOSContext()()
n := d.network(nid) n := d.network(nid)
if n == nil { if n == nil {
return fmt.Errorf("network id %s not found", nid) return fmt.Errorf("network id %s not found", nid)

View file

@ -10,7 +10,6 @@ import (
"github.com/docker/docker/libnetwork/netlabel" "github.com/docker/docker/libnetwork/netlabel"
"github.com/docker/docker/libnetwork/netutils" "github.com/docker/docker/libnetwork/netutils"
"github.com/docker/docker/libnetwork/ns" "github.com/docker/docker/libnetwork/ns"
"github.com/docker/docker/libnetwork/osl"
"github.com/docker/docker/libnetwork/types" "github.com/docker/docker/libnetwork/types"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -18,7 +17,6 @@ import (
// CreateEndpoint assigns the mac, ip and endpoint id for the new container // CreateEndpoint assigns the mac, ip and endpoint id for the new container
func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo, func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo,
epOptions map[string]interface{}) error { epOptions map[string]interface{}) error {
defer osl.InitOSContext()()
if err := validateID(nid, eid); err != nil { if err := validateID(nid, eid); err != nil {
return err return err
@ -71,7 +69,6 @@ func (d *driver) CreateEndpoint(nid, eid string, ifInfo driverapi.InterfaceInfo,
// DeleteEndpoint removes the endpoint and associated netlink interface // DeleteEndpoint removes the endpoint and associated netlink interface
func (d *driver) DeleteEndpoint(nid, eid string) error { func (d *driver) DeleteEndpoint(nid, eid string) error {
defer osl.InitOSContext()()
if err := validateID(nid, eid); err != nil { if err := validateID(nid, eid); err != nil {
return err return err
} }

View file

@ -10,13 +10,11 @@ import (
"github.com/docker/docker/libnetwork/driverapi" "github.com/docker/docker/libnetwork/driverapi"
"github.com/docker/docker/libnetwork/netutils" "github.com/docker/docker/libnetwork/netutils"
"github.com/docker/docker/libnetwork/ns" "github.com/docker/docker/libnetwork/ns"
"github.com/docker/docker/libnetwork/osl"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// Join method is invoked when a Sandbox is attached to an endpoint. // Join method is invoked when a Sandbox is attached to an endpoint.
func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo, options map[string]interface{}) error { func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo, options map[string]interface{}) error {
defer osl.InitOSContext()()
n, err := d.getNetwork(nid) n, err := d.getNetwork(nid)
if err != nil { if err != nil {
return err return err
@ -100,7 +98,6 @@ func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo,
// Leave method is invoked when a Sandbox detaches from an endpoint. // Leave method is invoked when a Sandbox detaches from an endpoint.
func (d *driver) Leave(nid, eid string) error { func (d *driver) Leave(nid, eid string) error {
defer osl.InitOSContext()()
network, err := d.getNetwork(nid) network, err := d.getNetwork(nid)
if err != nil { if err != nil {
return err return err

View file

@ -10,7 +10,6 @@ import (
"github.com/docker/docker/libnetwork/netlabel" "github.com/docker/docker/libnetwork/netlabel"
"github.com/docker/docker/libnetwork/ns" "github.com/docker/docker/libnetwork/ns"
"github.com/docker/docker/libnetwork/options" "github.com/docker/docker/libnetwork/options"
"github.com/docker/docker/libnetwork/osl"
"github.com/docker/docker/libnetwork/types" "github.com/docker/docker/libnetwork/types"
"github.com/docker/docker/pkg/stringid" "github.com/docker/docker/pkg/stringid"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -18,8 +17,6 @@ import (
// CreateNetwork the network for the specified driver type // CreateNetwork the network for the specified driver type
func (d *driver) CreateNetwork(nid string, option map[string]interface{}, nInfo driverapi.NetworkInfo, ipV4Data, ipV6Data []driverapi.IPAMData) error { func (d *driver) CreateNetwork(nid string, option map[string]interface{}, nInfo driverapi.NetworkInfo, ipV4Data, ipV6Data []driverapi.IPAMData) error {
defer osl.InitOSContext()()
// reject a null v4 network // reject a null v4 network
if len(ipV4Data) == 0 || ipV4Data[0].Pool.String() == "0.0.0.0/0" { if len(ipV4Data) == 0 || ipV4Data[0].Pool.String() == "0.0.0.0/0" {
return fmt.Errorf("ipv4 pool is empty") return fmt.Errorf("ipv4 pool is empty")
@ -109,7 +106,6 @@ func (d *driver) createNetwork(config *configuration) (bool, error) {
// DeleteNetwork deletes the network for the specified driver type // DeleteNetwork deletes the network for the specified driver type
func (d *driver) DeleteNetwork(nid string) error { func (d *driver) DeleteNetwork(nid string) error {
defer osl.InitOSContext()()
n := d.network(nid) n := d.network(nid)
if n == nil { if n == nil {
return fmt.Errorf("network id %s not found", nid) return fmt.Errorf("network id %s not found", nid)

View file

@ -76,6 +76,14 @@ type network struct {
func init() { func init() {
reexec.Register("set-default-vlan", setDefaultVlan) reexec.Register("set-default-vlan", setDefaultVlan)
// Lock main() to the initial thread to exclude the goroutines executing
// func (*network).watchMiss() from being scheduled onto that thread.
// Changes to the network namespace of the initial thread alter
// /proc/self/ns/net, which would break any code which (incorrectly)
// assumes that that file is a handle to the network namespace for the
// thread it is currently executing on.
runtime.LockOSThread()
} }
func setDefaultVlan() { func setDefaultVlan() {
@ -779,20 +787,36 @@ func (n *network) initSandbox(restore bool) error {
func (n *network) watchMiss(nlSock *nl.NetlinkSocket, nsPath string) { func (n *network) watchMiss(nlSock *nl.NetlinkSocket, nsPath string) {
// With the new version of the netlink library the deserialize function makes // With the new version of the netlink library the deserialize function makes
// requests about the interface of the netlink message. This can succeed only // requests about the interface of the netlink message. This can succeed only
// if this go routine is in the target namespace. For this reason following we // if this go routine is in the target namespace.
// lock the thread on that namespace origNs, err := netns.Get()
runtime.LockOSThread() if err != nil {
defer runtime.UnlockOSThread() logrus.WithError(err).Error("failed to get the initial network namespace")
return
}
defer origNs.Close()
newNs, err := netns.GetFromPath(nsPath) newNs, err := netns.GetFromPath(nsPath)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("failed to get the namespace %s", nsPath) logrus.WithError(err).Errorf("failed to get the namespace %s", nsPath)
return return
} }
defer newNs.Close() defer newNs.Close()
runtime.LockOSThread()
if err = netns.Set(newNs); err != nil { if err = netns.Set(newNs); err != nil {
logrus.WithError(err).Errorf("failed to enter the namespace %s", nsPath) logrus.WithError(err).Errorf("failed to enter the namespace %s", nsPath)
runtime.UnlockOSThread()
return return
} }
defer func() {
if err := netns.Set(origNs); err != nil {
logrus.WithError(err).Error("failed to restore the thread's initial network namespace")
// The error is only fatal for the current thread. Keep this
// goroutine locked to the thread to make the runtime replace it
// with a clean thread once this goroutine terminates.
} else {
runtime.UnlockOSThread()
}
}()
for { for {
msgs, _, err := nlSock.Receive() msgs, _, err := nlSock.Receive()
if err != nil { if err != nil {

View file

@ -11,7 +11,6 @@ import (
"github.com/docker/docker/libnetwork/drivers/overlay/overlayutils" "github.com/docker/docker/libnetwork/drivers/overlay/overlayutils"
"github.com/docker/docker/libnetwork/netutils" "github.com/docker/docker/libnetwork/netutils"
"github.com/docker/docker/libnetwork/ns" "github.com/docker/docker/libnetwork/ns"
"github.com/docker/docker/libnetwork/osl"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"github.com/vishvananda/netns" "github.com/vishvananda/netns"
@ -32,7 +31,6 @@ func validateID(nid, eid string) error {
} }
func createVethPair() (string, string, error) { func createVethPair() (string, string, error) {
defer osl.InitOSContext()()
nlh := ns.NlHandle() nlh := ns.NlHandle()
// Generate a name for what will be the host side pipe interface // Generate a name for what will be the host side pipe interface
@ -59,8 +57,6 @@ func createVethPair() (string, string, error) {
} }
func createVxlan(name string, vni uint32, mtu int) error { func createVxlan(name string, vni uint32, mtu int) error {
defer osl.InitOSContext()()
vxlan := &netlink.Vxlan{ vxlan := &netlink.Vxlan{
LinkAttrs: netlink.LinkAttrs{Name: name, MTU: mtu}, LinkAttrs: netlink.LinkAttrs{Name: name, MTU: mtu},
VxlanId: int(vni), VxlanId: int(vni),
@ -79,8 +75,6 @@ func createVxlan(name string, vni uint32, mtu int) error {
} }
func deleteInterfaceBySubnet(brPrefix string, s *subnet) error { func deleteInterfaceBySubnet(brPrefix string, s *subnet) error {
defer osl.InitOSContext()()
nlh := ns.NlHandle() nlh := ns.NlHandle()
links, err := nlh.LinkList() links, err := nlh.LinkList()
if err != nil { if err != nil {
@ -109,8 +103,6 @@ func deleteInterfaceBySubnet(brPrefix string, s *subnet) error {
} }
func deleteInterface(name string) error { func deleteInterface(name string) error {
defer osl.InitOSContext()()
link, err := ns.NlHandle().LinkByName(name) link, err := ns.NlHandle().LinkByName(name)
if err != nil { if err != nil {
return fmt.Errorf("failed to find interface with name %s: %v", name, err) return fmt.Errorf("failed to find interface with name %s: %v", name, err)
@ -124,8 +116,6 @@ func deleteInterface(name string) error {
} }
func deleteVxlanByVNI(path string, vni uint32) error { func deleteVxlanByVNI(path string, vni uint32) error {
defer osl.InitOSContext()()
nlh := ns.NlHandle() nlh := ns.NlHandle()
if path != "" { if path != "" {
ns, err := netns.GetFromPath(path) ns, err := netns.GetFromPath(path)

View file

@ -917,7 +917,6 @@ func parallelJoin(t *testing.T, rc libnetwork.Sandbox, ep libnetwork.Endpoint, t
sb := sboxes[thrNumber-1] sb := sboxes[thrNumber-1]
err = ep.Join(sb) err = ep.Join(sb)
runtime.LockOSThread()
if err != nil { if err != nil {
if _, ok := err.(types.ForbiddenError); !ok { if _, ok := err.(types.ForbiddenError); !ok {
t.Fatalf("thread %d: %v", thrNumber, err) t.Fatalf("thread %d: %v", thrNumber, err)
@ -934,7 +933,6 @@ func parallelLeave(t *testing.T, rc libnetwork.Sandbox, ep libnetwork.Endpoint,
sb := sboxes[thrNumber-1] sb := sboxes[thrNumber-1]
err = ep.Leave(sb) err = ep.Leave(sb)
runtime.LockOSThread()
if err != nil { if err != nil {
if _, ok := err.(types.ForbiddenError); !ok { if _, ok := err.(types.ForbiddenError); !ok {
t.Fatalf("thread %d: %v", thrNumber, err) t.Fatalf("thread %d: %v", thrNumber, err)
@ -966,13 +964,9 @@ func runParallelTests(t *testing.T, thrNumber int) {
} }
runtime.LockOSThread() runtime.LockOSThread()
defer runtime.UnlockOSThread()
if thrNumber == first { if thrNumber == first {
createGlobalInstance(t) createGlobalInstance(t)
} } else {
if thrNumber != first {
<-start <-start
thrdone := make(chan struct{}) thrdone := make(chan struct{})
@ -985,18 +979,15 @@ func runParallelTests(t *testing.T, thrNumber int) {
err = netns.Set(testns) err = netns.Set(testns)
if err != nil { if err != nil {
runtime.UnlockOSThread()
t.Fatal(err) t.Fatal(err)
} }
} }
defer func() { defer func() {
if err := netns.Set(origins); err != nil { if err := netns.Set(origins); err != nil {
// NOTE(@cpuguy83): This... t.Fatalf("Error restoring the current thread's netns: %v", err)
// I touched this code because the linter found that we weren't checking the error... } else {
// It returns an error because "origins" is a closed file handle *unless* createGlobalInstance is called. runtime.UnlockOSThread()
// Which... this test is run in parallel and `createGlobalInstance` modifies `origins` without synchronization.
// I'm not sure what exactly the *intent* of this code was, but it looks very broken.
// Anyway that's why I'm only logging the error and not failing the test.
t.Log(err)
} }
}() }()

View file

@ -12,7 +12,6 @@ import (
"github.com/docker/docker/libnetwork/ipamutils" "github.com/docker/docker/libnetwork/ipamutils"
"github.com/docker/docker/libnetwork/ns" "github.com/docker/docker/libnetwork/ns"
"github.com/docker/docker/libnetwork/osl"
"github.com/docker/docker/libnetwork/resolvconf" "github.com/docker/docker/libnetwork/resolvconf"
"github.com/docker/docker/libnetwork/types" "github.com/docker/docker/libnetwork/types"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -73,8 +72,6 @@ func GenerateIfaceName(nlh *netlink.Handle, prefix string, len int) (string, err
func ElectInterfaceAddresses(name string) ([]*net.IPNet, []*net.IPNet, error) { func ElectInterfaceAddresses(name string) ([]*net.IPNet, []*net.IPNet, error) {
var v4Nets, v6Nets []*net.IPNet var v4Nets, v6Nets []*net.IPNet
defer osl.InitOSContext()()
link, _ := ns.NlHandle().LinkByName(name) link, _ := ns.NlHandle().LinkByName(name)
if link != nil { if link != nil {
v4addr, err := ns.NlHandle().AddrList(link, netlink.FAMILY_V4) v4addr, err := ns.NlHandle().AddrList(link, netlink.FAMILY_V4)

View file

@ -2,7 +2,6 @@ package ns
import ( import (
"fmt" "fmt"
"os"
"os/exec" "os/exec"
"strings" "strings"
"sync" "sync"
@ -39,19 +38,6 @@ func Init() {
} }
} }
// SetNamespace sets the initial namespace handler
func SetNamespace() error {
initOnce.Do(Init)
if err := netns.Set(initNs); err != nil {
linkInfo, linkErr := getLink()
if linkErr != nil {
linkInfo = linkErr.Error()
}
return fmt.Errorf("failed to set to initial namespace, %v, initns fd %d: %v", linkInfo, initNs, err)
}
return nil
}
// ParseHandlerInt transforms the namespace handler into an integer // ParseHandlerInt transforms the namespace handler into an integer
func ParseHandlerInt() int { func ParseHandlerInt() int {
return int(getHandler()) return int(getHandler())
@ -63,10 +49,6 @@ func getHandler() netns.NsHandle {
return initNs return initNs
} }
func getLink() (string, error) {
return os.Readlink(fmt.Sprintf("/proc/%d/task/%d/ns/net", os.Getpid(), syscall.Gettid()))
}
// NlHandle returns the netlink handler // NlHandle returns the netlink handler
func NlHandle() *netlink.Handle { func NlHandle() *netlink.Handle {
initOnce.Do(Init) initOnce.Do(Init)

View file

@ -28,6 +28,14 @@ const defaultPrefix = "/var/run/docker"
func init() { func init() {
reexec.Register("set-ipv6", reexecSetIPv6) reexec.Register("set-ipv6", reexecSetIPv6)
// Lock main() to the initial thread to exclude the goroutines spawned
// by func (*networkNamespace) InvokeFunc() from being scheduled onto
// that thread. Changes to the network namespace of the initial thread
// alter /proc/self/ns/net, which would break any code which
// (incorrectly) assumes that that file is a handle to the network
// namespace for the thread it is currently executing on.
runtime.LockOSThread()
} }
var ( var (
@ -412,43 +420,41 @@ func (n *networkNamespace) DisableARPForVIP(srcName string) (Err error) {
} }
func (n *networkNamespace) InvokeFunc(f func()) error { func (n *networkNamespace) InvokeFunc(f func()) error {
return nsInvoke(n.nsPath(), func(nsFD int) error { return nil }, func(callerFD int) error { origNS, err := netns.Get()
f()
return nil
})
}
// InitOSContext initializes OS context while configuring network resources
func InitOSContext() func() {
runtime.LockOSThread()
if err := ns.SetNamespace(); err != nil {
logrus.Error(err)
}
return runtime.UnlockOSThread
}
func nsInvoke(path string, prefunc func(nsFD int) error, postfunc func(callerFD int) error) error {
defer InitOSContext()()
newNs, err := netns.GetFromPath(path)
if err != nil { if err != nil {
return fmt.Errorf("failed get network namespace %q: %v", path, err) return fmt.Errorf("failed to get original network namespace: %w", err)
} }
defer newNs.Close() defer origNS.Close()
// Invoked before the namespace switch happens but after the namespace file path := n.nsPath()
// handle is obtained. newNS, err := netns.GetFromPath(path)
if err := prefunc(int(newNs)); err != nil { if err != nil {
return fmt.Errorf("failed in prefunc: %v", err) return fmt.Errorf("failed get network namespace %q: %w", path, err)
} }
defer newNS.Close()
if err = netns.Set(newNs); err != nil { done := make(chan error, 1)
return err go func() {
runtime.LockOSThread()
if err := netns.Set(newNS); err != nil {
runtime.UnlockOSThread()
done <- err
return
} }
defer ns.SetNamespace() defer func() {
close(done)
// Invoked after the namespace switch. if err := netns.Set(origNS); err != nil {
return postfunc(ns.ParseHandlerInt()) logrus.WithError(err).Warn("failed to restore thread's network namespace")
// Recover from the error by leaving this goroutine locked to
// the thread. The runtime will terminate the thread and replace
// it with a clean one when this goroutine returns.
} else {
runtime.UnlockOSThread()
}
}()
f()
}()
return <-done
} }
func (n *networkNamespace) nsPath() string { func (n *networkNamespace) nsPath() string {

View file

@ -21,11 +21,6 @@ func GetSandboxForExternalKey(path string, key string) (Sandbox, error) {
func GC() { func GC() {
} }
// InitOSContext initializes OS context while configuring network resources
func InitOSContext() func() {
return func() {}
}
// SetBasePath sets the base url prefix for the ns path // SetBasePath sets the base url prefix for the ns path
func SetBasePath(path string) { func SetBasePath(path string) {
} }

View file

@ -27,11 +27,6 @@ func GetSandboxForExternalKey(path string, key string) (Sandbox, error) {
func GC() { func GC() {
} }
// InitOSContext initializes OS context while configuring network resources
func InitOSContext() func() {
return func() {}
}
// SetBasePath sets the base url prefix for the ns path // SetBasePath sets the base url prefix for the ns path
func SetBasePath(path string) { func SetBasePath(path string) {
} }

View file

@ -7,7 +7,6 @@ import (
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"syscall" "syscall"
"testing" "testing"
@ -197,7 +196,6 @@ func TestDisableIPv6DAD(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to create a new sandbox: %v", err) t.Fatalf("Failed to create a new sandbox: %v", err)
} }
runtime.LockOSThread()
defer destroyTest(t, s) defer destroyTest(t, s)
n, ok := s.(*networkNamespace) n, ok := s.(*networkNamespace)
@ -257,7 +255,6 @@ func TestSetInterfaceIP(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to create a new sandbox: %v", err) t.Fatalf("Failed to create a new sandbox: %v", err)
} }
runtime.LockOSThread()
defer destroyTest(t, s) defer destroyTest(t, s)
n, ok := s.(*networkNamespace) n, ok := s.(*networkNamespace)
@ -332,7 +329,6 @@ func TestLiveRestore(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to create a new sandbox: %v", err) t.Fatalf("Failed to create a new sandbox: %v", err)
} }
runtime.LockOSThread()
defer destroyTest(t, s) defer destroyTest(t, s)
n, ok := s.(*networkNamespace) n, ok := s.(*networkNamespace)
@ -482,7 +478,6 @@ func TestSandboxCreateTwice(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to create a new sandbox: %v", err) t.Fatalf("Failed to create a new sandbox: %v", err)
} }
runtime.LockOSThread()
// Create another sandbox with the same key to see if we handle it // Create another sandbox with the same key to see if we handle it
// gracefully. // gracefully.
@ -490,7 +485,6 @@ func TestSandboxCreateTwice(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to create a new sandbox: %v", err) t.Fatalf("Failed to create a new sandbox: %v", err)
} }
runtime.LockOSThread()
err = s.Destroy() err = s.Destroy()
if err != nil { if err != nil {
@ -532,7 +526,6 @@ func TestAddRemoveInterface(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to create a new sandbox: %v", err) t.Fatalf("Failed to create a new sandbox: %v", err)
} }
runtime.LockOSThread()
if s.Key() != key { if s.Key() != key {
t.Fatalf("s.Key() returned %s. Expected %s", s.Key(), key) t.Fatalf("s.Key() returned %s. Expected %s", s.Key(), key)

View file

@ -5,10 +5,10 @@ package testutils
import ( import (
"runtime" "runtime"
"syscall"
"testing" "testing"
"github.com/docker/docker/libnetwork/ns" "github.com/docker/docker/libnetwork/ns"
"github.com/vishvananda/netns"
) )
// SetupTestOSContext joins a new network namespace, and returns its associated // SetupTestOSContext joins a new network namespace, and returns its associated
@ -18,26 +18,40 @@ import (
// //
// defer SetupTestOSContext(t)() // defer SetupTestOSContext(t)()
func SetupTestOSContext(t *testing.T) func() { func SetupTestOSContext(t *testing.T) func() {
runtime.LockOSThread() origNS, err := netns.Get()
if err := syscall.Unshare(syscall.CLONE_NEWNET); err != nil { if err != nil {
t.Fatalf("Failed to enter netns: %v", err) t.Fatalf("Failed to open initial netns: %v", err)
}
restore := func() {
if err := netns.Set(origNS); err != nil {
t.Logf("Warning: failed to restore thread netns (%v)", err)
} else {
runtime.UnlockOSThread()
} }
fd, err := syscall.Open("/proc/self/ns/net", syscall.O_RDONLY, 0) if err := origNS.Close(); err != nil {
t.Logf("Warning: netns closing failed (%v)", err)
}
}
runtime.LockOSThread()
newNS, err := netns.New()
if err != nil { if err != nil {
t.Fatal("Failed to open netns file") // netns.New() is not atomic: it could have encountered an error
// after unsharing the current thread's network namespace.
restore()
t.Fatalf("Failed to enter netns: %v", err)
} }
// Since we are switching to a new test namespace make // Since we are switching to a new test namespace make
// sure to re-initialize initNs context // sure to re-initialize initNs context
ns.Init() ns.Init()
runtime.LockOSThread()
return func() { return func() {
if err := syscall.Close(fd); err != nil { if err := newNS.Close(); err != nil {
t.Logf("Warning: netns closing failed (%v)", err) t.Logf("Warning: netns closing failed (%v)", err)
} }
runtime.UnlockOSThread() restore()
ns.Init()
} }
} }