Ver Fonte

Merge pull request #140 from mrjana/cnm_integ

Make endpoint Join and Leave multi-thread safe
Madhu Venugopal há 10 anos atrás
pai
commit
f9ef08c30f

+ 1 - 1
libnetwork/Makefile

@@ -35,7 +35,7 @@ run-tests:
 	@echo "mode: count" > coverage.coverprofile
 	@for dir in $$(find . -maxdepth 10 -not -path './.git*' -not -path '*/_*' -type d); do \
 	    if ls $$dir/*.go &> /dev/null; then \
-            	$(shell which godep) go test -test.v -covermode=count -coverprofile=$$dir/profile.tmp $$dir ; \
+            	$(shell which godep) go test -test.parallel 3 -test.v -covermode=count -coverprofile=$$dir/profile.tmp $$dir ; \
 		if [ $$? -ne 0 ]; then exit $$?; fi ;\
 	        if [ -f $$dir/profile.tmp ]; then \
 		        cat $$dir/profile.tmp | tail -n +2 >> coverage.coverprofile ; \

+ 218 - 73
libnetwork/endpoint.go

@@ -4,6 +4,7 @@ import (
 	"io/ioutil"
 	"os"
 	"path/filepath"
+	"sync"
 
 	"github.com/Sirupsen/logrus"
 	"github.com/docker/libnetwork/driverapi"
@@ -96,33 +97,47 @@ type containerInfo struct {
 }
 
 type endpoint struct {
-	name         string
-	id           types.UUID
-	network      *network
-	sandboxInfo  *sandbox.Info
-	sandBox      sandbox.Sandbox
-	joinInfo     *driverapi.JoinInfo
-	container    *containerInfo
-	exposedPorts []netutils.TransportPort
-	generic      map[string]interface{}
-	context      map[string]interface{}
+	name          string
+	id            types.UUID
+	network       *network
+	sandboxInfo   *sandbox.Info
+	sandBox       sandbox.Sandbox
+	joinInfo      *driverapi.JoinInfo
+	container     *containerInfo
+	exposedPorts  []netutils.TransportPort
+	generic       map[string]interface{}
+	context       map[string]interface{}
+	joinLeaveDone chan struct{}
+	sync.Mutex
 }
 
 const defaultPrefix = "/var/lib/docker/network/files"
 
 func (ep *endpoint) ID() string {
+	ep.Lock()
+	defer ep.Unlock()
+
 	return string(ep.id)
 }
 
 func (ep *endpoint) Name() string {
+	ep.Lock()
+	defer ep.Unlock()
+
 	return ep.name
 }
 
 func (ep *endpoint) Network() string {
+	ep.Lock()
+	defer ep.Unlock()
+
 	return ep.network.name
 }
 
 func (ep *endpoint) SandboxInfo() *sandbox.Info {
+	ep.Lock()
+	defer ep.Unlock()
+
 	if ep.sandboxInfo == nil {
 		return nil
 	}
@@ -130,10 +145,23 @@ func (ep *endpoint) SandboxInfo() *sandbox.Info {
 }
 
 func (ep *endpoint) Info() (map[string]interface{}, error) {
-	return ep.network.driver.EndpointInfo(ep.network.id, ep.id)
+	ep.Lock()
+	network := ep.network
+	epid := ep.id
+	ep.Unlock()
+
+	network.Lock()
+	driver := network.driver
+	nid := network.id
+	network.Unlock()
+
+	return driver.EndpointInfo(nid, epid)
 }
 
 func (ep *endpoint) processOptions(options ...EndpointOption) {
+	ep.Lock()
+	defer ep.Unlock()
+
 	for _, opt := range options {
 		if opt != nil {
 			opt(ep)
@@ -167,6 +195,38 @@ func createFile(path string) error {
 	return err
 }
 
+// joinLeaveStart waits to ensure there are no joins or leaves in progress and
+// marks this join/leave in progress without race
+func (ep *endpoint) joinLeaveStart() {
+	ep.Lock()
+	defer ep.Unlock()
+
+	for ep.joinLeaveDone != nil {
+		joinLeaveDone := ep.joinLeaveDone
+		ep.Unlock()
+
+		select {
+		case <-joinLeaveDone:
+		}
+
+		ep.Lock()
+	}
+
+	ep.joinLeaveDone = make(chan struct{})
+}
+
+// joinLeaveEnd marks the end of this join/leave operation and
+// signals the same without race to other join and leave waiters
+func (ep *endpoint) joinLeaveEnd() {
+	ep.Lock()
+	defer ep.Unlock()
+
+	if ep.joinLeaveDone != nil {
+		close(ep.joinLeaveDone)
+		ep.joinLeaveDone = nil
+	}
+}
+
 func (ep *endpoint) Join(containerID string, options ...EndpointOption) (*ContainerData, error) {
 	var err error
 
@@ -174,44 +234,58 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) (*Contai
 		return nil, InvalidContainerIDError(containerID)
 	}
 
+	ep.joinLeaveStart()
+	defer ep.joinLeaveEnd()
+
+	ep.Lock()
 	if ep.container != nil {
+		ep.Unlock()
 		return nil, ErrInvalidJoin
 	}
 
 	ep.container = &containerInfo{
+		id: containerID,
 		config: containerConfig{
 			hostsPathConfig: hostsPathConfig{
 				extraHosts:    []extraHost{},
 				parentUpdates: []parentUpdate{},
 			},
 		}}
+
+	container := ep.container
+	network := ep.network
+	epid := ep.id
+
+	ep.Unlock()
 	defer func() {
+		ep.Lock()
 		if err != nil {
 			ep.container = nil
 		}
+		ep.Unlock()
 	}()
 
-	ep.processOptions(options...)
-
-	if ep.container.config.hostsPath == "" {
-		ep.container.config.hostsPath = defaultPrefix + "/" + containerID + "/hosts"
-	}
+	network.Lock()
+	driver := network.driver
+	nid := network.id
+	ctrlr := network.ctrlr
+	network.Unlock()
 
-	if ep.container.config.resolvConfPath == "" {
-		ep.container.config.resolvConfPath = defaultPrefix + "/" + containerID + "/resolv.conf"
-	}
+	ep.processOptions(options...)
 
 	sboxKey := sandbox.GenerateKey(containerID)
-	if ep.container.config.useDefaultSandBox {
+	if container.config.useDefaultSandBox {
 		sboxKey = sandbox.GenerateKey("default")
 	}
 
-	joinInfo, err := ep.network.driver.Join(ep.network.id, ep.id,
-		sboxKey, ep.container.config.generic)
+	joinInfo, err := driver.Join(nid, epid, sboxKey, container.config.generic)
 	if err != nil {
 		return nil, err
 	}
+
+	ep.Lock()
 	ep.joinInfo = joinInfo
+	ep.Unlock()
 
 	err = ep.buildHostsFiles()
 	if err != nil {
@@ -228,14 +302,13 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) (*Contai
 		return nil, err
 	}
 
-	sb, err := ep.network.ctrlr.sandboxAdd(sboxKey,
-		!ep.container.config.useDefaultSandBox)
+	sb, err := ctrlr.sandboxAdd(sboxKey, !container.config.useDefaultSandBox)
 	if err != nil {
 		return nil, err
 	}
 	defer func() {
 		if err != nil {
-			ep.network.ctrlr.sandboxRm(sboxKey)
+			ctrlr.sandboxRm(sboxKey)
 		}
 	}()
 
@@ -259,27 +332,50 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) (*Contai
 		}
 	}
 
-	ep.container.id = containerID
-	ep.container.data.SandboxKey = sb.Key()
+	container.data.SandboxKey = sb.Key()
+	cData := container.data
 
-	cData := ep.container.data
 	return &cData, nil
 }
 
 func (ep *endpoint) Leave(containerID string, options ...EndpointOption) error {
-	if ep.container == nil || ep.container.id == "" ||
-		containerID == "" || ep.container.id != containerID {
-		return InvalidContainerIDError(containerID)
-	}
+	var err error
+
+	ep.joinLeaveStart()
+	defer ep.joinLeaveEnd()
 
 	ep.processOptions(options...)
 
+	ep.Lock()
+	container := ep.container
 	n := ep.network
-	err := n.driver.Leave(n.id, ep.id, ep.context)
+	context := ep.context
+
+	if container == nil || container.id == "" ||
+		containerID == "" || container.id != containerID {
+		if container == nil {
+			err = ErrNoContainer
+		} else {
+			err = InvalidContainerIDError(containerID)
+		}
+
+		ep.Unlock()
+		return err
+	}
+	ep.container = nil
+	ep.context = nil
+	ep.Unlock()
+
+	n.Lock()
+	driver := n.driver
+	ctrlr := n.ctrlr
+	n.Unlock()
+
+	err = driver.Leave(n.id, ep.id, context)
 
 	sinfo := ep.SandboxInfo()
 	if sinfo != nil {
-		sb := ep.network.ctrlr.sandboxGet(ep.container.data.SandboxKey)
+		sb := ctrlr.sandboxGet(container.data.SandboxKey)
 		for _, i := range sinfo.Interfaces {
 			err = sb.RemoveInterface(i)
 			if err != nil {
@@ -288,93 +384,129 @@ func (ep *endpoint) Leave(containerID string, options ...EndpointOption) error {
 		}
 	}
 
-	ep.network.ctrlr.sandboxRm(ep.container.data.SandboxKey)
-	ep.container = nil
-	ep.context = nil
+	ctrlr.sandboxRm(container.data.SandboxKey)
+
 	return err
 }
 
 func (ep *endpoint) Delete() error {
 	var err error
+
+	ep.Lock()
+	epid := ep.id
+	name := ep.name
 	if ep.container != nil {
-		return &ActiveContainerError{name: ep.name, id: string(ep.id)}
+		ep.Unlock()
+		return &ActiveContainerError{name: name, id: string(epid)}
 	}
 
 	n := ep.network
+	ep.Unlock()
+
 	n.Lock()
-	_, ok := n.endpoints[ep.id]
+	_, ok := n.endpoints[epid]
 	if !ok {
 		n.Unlock()
-		return &UnknownEndpointError{name: ep.name, id: string(ep.id)}
+		return &UnknownEndpointError{name: name, id: string(epid)}
 	}
 
-	delete(n.endpoints, ep.id)
+	nid := n.id
+	driver := n.driver
+	delete(n.endpoints, epid)
 	n.Unlock()
 	defer func() {
 		if err != nil {
 			n.Lock()
-			n.endpoints[ep.id] = ep
+			n.endpoints[epid] = ep
 			n.Unlock()
 		}
 	}()
 
-	err = n.driver.DeleteEndpoint(n.id, ep.id)
+	err = driver.DeleteEndpoint(nid, epid)
 	return err
 }
 
 func (ep *endpoint) buildHostsFiles() error {
 	var extraContent []etchosts.Record
 
-	dir, _ := filepath.Split(ep.container.config.hostsPath)
+	ep.Lock()
+	container := ep.container
+	joinInfo := ep.joinInfo
+	ep.Unlock()
+
+	if container == nil {
+		return ErrNoContainer
+	}
+
+	if container.config.hostsPath == "" {
+		container.config.hostsPath = defaultPrefix + "/" + container.id + "/hosts"
+	}
+
+	dir, _ := filepath.Split(container.config.hostsPath)
 	err := createBasePath(dir)
 	if err != nil {
 		return err
 	}
 
-	if ep.joinInfo != nil && ep.joinInfo.HostsPath != "" {
-		content, err := ioutil.ReadFile(ep.joinInfo.HostsPath)
+	if joinInfo != nil && joinInfo.HostsPath != "" {
+		content, err := ioutil.ReadFile(joinInfo.HostsPath)
 		if err != nil && !os.IsNotExist(err) {
 			return err
 		}
 
 		if err == nil {
-			return ioutil.WriteFile(ep.container.config.hostsPath, content, 0644)
+			return ioutil.WriteFile(container.config.hostsPath, content, 0644)
 		}
 	}
 
-	name := ep.container.config.hostName
-	if ep.container.config.domainName != "" {
-		name = name + "." + ep.container.config.domainName
+	name := container.config.hostName
+	if container.config.domainName != "" {
+		name = name + "." + container.config.domainName
 	}
 
-	for _, extraHost := range ep.container.config.extraHosts {
+	for _, extraHost := range container.config.extraHosts {
 		extraContent = append(extraContent,
 			etchosts.Record{Hosts: extraHost.name, IP: extraHost.IP})
 	}
 
 	IP := ""
-	if ep.sandboxInfo != nil && ep.sandboxInfo.Interfaces[0] != nil &&
-		ep.sandboxInfo.Interfaces[0].Address != nil {
-		IP = ep.sandboxInfo.Interfaces[0].Address.IP.String()
+	sinfo := ep.SandboxInfo()
+	if sinfo != nil && sinfo.Interfaces[0] != nil &&
+		sinfo.Interfaces[0].Address != nil {
+		IP = sinfo.Interfaces[0].Address.IP.String()
 	}
 
-	return etchosts.Build(ep.container.config.hostsPath, IP, ep.container.config.hostName,
-		ep.container.config.domainName, extraContent)
+	return etchosts.Build(container.config.hostsPath, IP, container.config.hostName,
+		container.config.domainName, extraContent)
 }
 
 func (ep *endpoint) updateParentHosts() error {
-	for _, update := range ep.container.config.parentUpdates {
-		ep.network.Lock()
-		pep, ok := ep.network.endpoints[types.UUID(update.eid)]
+	ep.Lock()
+	container := ep.container
+	network := ep.network
+	ep.Unlock()
+
+	if container == nil {
+		return ErrNoContainer
+	}
+
+	for _, update := range container.config.parentUpdates {
+		network.Lock()
+		pep, ok := network.endpoints[types.UUID(update.eid)]
 		if !ok {
-			ep.network.Unlock()
+			network.Unlock()
 			continue
 		}
-		ep.network.Unlock()
+		network.Unlock()
 
-		if err := etchosts.Update(pep.container.config.hostsPath,
-			update.ip, update.name); err != nil {
-			return err
+		pep.Lock()
+		pContainer := pep.container
+		pep.Unlock()
+
+		if pContainer != nil {
+			if err := etchosts.Update(pContainer.config.hostsPath, update.ip, update.name); err != nil {
+				return err
+			}
 		}
 	}
 
@@ -382,7 +514,20 @@ func (ep *endpoint) updateParentHosts() error {
 }
 
 func (ep *endpoint) setupDNS() error {
-	dir, _ := filepath.Split(ep.container.config.resolvConfPath)
+	ep.Lock()
+	container := ep.container
+	network := ep.network
+	ep.Unlock()
+
+	if container == nil {
+		return ErrNoContainer
+	}
+
+	if container.config.resolvConfPath == "" {
+		container.config.resolvConfPath = defaultPrefix + "/" + container.id + "/resolv.conf"
+	}
+
+	dir, _ := filepath.Split(container.config.resolvConfPath)
 	err := createBasePath(dir)
 	if err != nil {
 		return err
@@ -393,26 +538,26 @@ func (ep *endpoint) setupDNS() error {
 		return err
 	}
 
-	if len(ep.container.config.dnsList) > 0 ||
-		len(ep.container.config.dnsSearchList) > 0 {
+	if len(container.config.dnsList) > 0 ||
+		len(container.config.dnsSearchList) > 0 {
 		var (
 			dnsList       = resolvconf.GetNameservers(resolvConf)
 			dnsSearchList = resolvconf.GetSearchDomains(resolvConf)
 		)
 
-		if len(ep.container.config.dnsList) > 0 {
-			dnsList = ep.container.config.dnsList
+		if len(container.config.dnsList) > 0 {
+			dnsList = container.config.dnsList
 		}
 
-		if len(ep.container.config.dnsSearchList) > 0 {
-			dnsSearchList = ep.container.config.dnsSearchList
+		if len(container.config.dnsSearchList) > 0 {
+			dnsSearchList = container.config.dnsSearchList
 		}
 
-		return resolvconf.Build(ep.container.config.resolvConfPath, dnsList, dnsSearchList)
+		return resolvconf.Build(container.config.resolvConfPath, dnsList, dnsSearchList)
 	}
 
 	// replace any localhost/127.* but always discard IPv6 entries for now.
-	resolvConf, _ = resolvconf.FilterResolvDNS(resolvConf, ep.network.enableIPv6)
+	resolvConf, _ = resolvconf.FilterResolvDNS(resolvConf, network.enableIPv6)
 	return ioutil.WriteFile(ep.container.config.resolvConfPath, resolvConf, 0644)
 }
 

+ 3 - 0
libnetwork/error.go

@@ -15,6 +15,9 @@ var (
 	// ErrInvalidJoin is returned if a join is attempted on an endpoint
 	// which already has a container joined.
 	ErrInvalidJoin = errors.New("A container has already joined the endpoint")
+	// ErrNoContainer is returned when the endpoint has no container
+	// attached to it.
+	ErrNoContainer = errors.New("no container attached to the endpoint")
 )
 
 // NetworkTypeError type is returned when the network type string is not

+ 181 - 1
libnetwork/libnetwork_test.go

@@ -2,9 +2,14 @@ package libnetwork_test
 
 import (
 	"bytes"
+	"flag"
+	"fmt"
 	"io/ioutil"
 	"net"
 	"os"
+	"runtime"
+	"strconv"
+	"sync"
 	"testing"
 
 	log "github.com/Sirupsen/logrus"
@@ -13,6 +18,7 @@ import (
 	"github.com/docker/libnetwork/netutils"
 	"github.com/docker/libnetwork/pkg/netlabel"
 	"github.com/docker/libnetwork/pkg/options"
+	"github.com/vishvananda/netns"
 )
 
 const (
@@ -736,7 +742,9 @@ func TestEndpointInvalidLeave(t *testing.T) {
 	}
 
 	if _, ok := err.(libnetwork.InvalidContainerIDError); !ok {
-		t.Fatalf("Failed for unexpected reason: %v", err)
+		if err != libnetwork.ErrNoContainer {
+			t.Fatalf("Failed for unexpected reason: %v", err)
+		}
 	}
 
 	_, err = ep.Join(containerID,
@@ -947,3 +955,175 @@ func TestNoEnableIPv6(t *testing.T) {
 		t.Fatal(err)
 	}
 }
+
+var (
+	once   sync.Once
+	ctrlr  libnetwork.NetworkController
+	start  = make(chan struct{})
+	done   = make(chan chan struct{}, numThreads-1)
+	origns = netns.None()
+	testns = netns.None()
+)
+
+const (
+	iterCnt    = 25
+	numThreads = 3
+	first      = 1
+	last       = numThreads
+	debug      = false
+)
+
+func createGlobalInstance(t *testing.T) {
+	var err error
+	defer close(start)
+
+	origns, err = netns.Get()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	//testns = origns
+	testns, err = netns.New()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	ctrlr = libnetwork.New()
+
+	err = ctrlr.ConfigureNetworkDriver(bridgeNetType, getEmptyGenericOption())
+	if err != nil {
+		t.Fatal("configure driver")
+	}
+
+	net, err := ctrlr.NewNetwork(bridgeNetType, "network1")
+	if err != nil {
+		t.Fatal("new network")
+	}
+
+	_, err = net.CreateEndpoint("ep1")
+	if err != nil {
+		t.Fatal("createendpoint")
+	}
+}
+
+func debugf(format string, a ...interface{}) (int, error) {
+	if debug {
+		return fmt.Printf(format, a...)
+	}
+
+	return 0, nil
+}
+
+func parallelJoin(t *testing.T, ep libnetwork.Endpoint, thrNumber int) {
+	debugf("J%d.", thrNumber)
+	_, err := ep.Join("racing_container")
+	runtime.LockOSThread()
+	if err != nil {
+		if err != libnetwork.ErrNoContainer && err != libnetwork.ErrInvalidJoin {
+			t.Fatal(err)
+		}
+		debugf("JE%d(%v).", thrNumber, err)
+	}
+	debugf("JD%d.", thrNumber)
+}
+
+func parallelLeave(t *testing.T, ep libnetwork.Endpoint, thrNumber int) {
+	debugf("L%d.", thrNumber)
+	err := ep.Leave("racing_container")
+	runtime.LockOSThread()
+	if err != nil {
+		if err != libnetwork.ErrNoContainer && err != libnetwork.ErrInvalidJoin {
+			t.Fatal(err)
+		}
+		debugf("LE%d(%v).", thrNumber, err)
+	}
+	debugf("LD%d.", thrNumber)
+}
+
+func runParallelTests(t *testing.T, thrNumber int) {
+	var err error
+
+	t.Parallel()
+
+	pTest := flag.Lookup("test.parallel")
+	if pTest == nil {
+		t.Skip("Skipped because test.parallel flag not set;")
+	}
+	numParallel, err := strconv.Atoi(pTest.Value.String())
+	if err != nil {
+		t.Fatal(err)
+	}
+	if numParallel < numThreads {
+		t.Skip("Skipped because t.parallel was less than ", numThreads)
+	}
+
+	runtime.LockOSThread()
+	defer runtime.UnlockOSThread()
+
+	if thrNumber == first {
+		createGlobalInstance(t)
+	}
+
+	if thrNumber != first {
+		select {
+		case <-start:
+		}
+
+		thrdone := make(chan struct{})
+		done <- thrdone
+		defer close(thrdone)
+
+		if thrNumber == last {
+			defer close(done)
+		}
+
+		err = netns.Set(testns)
+		if err != nil {
+			t.Fatal(err)
+		}
+	}
+	defer netns.Set(origns)
+
+	net := ctrlr.NetworkByName("network1")
+	if net == nil {
+		t.Fatal("Could not find network1")
+	}
+
+	ep := net.EndpointByName("ep1")
+	if ep == nil {
+		t.Fatal("Could not find ep1")
+	}
+
+	for i := 0; i < iterCnt; i++ {
+		parallelJoin(t, ep, thrNumber)
+		parallelLeave(t, ep, thrNumber)
+	}
+
+	debugf("\n")
+
+	if thrNumber == first {
+		for thrdone := range done {
+			select {
+			case <-thrdone:
+			}
+		}
+
+		testns.Close()
+		err = ep.Delete()
+		if err != nil {
+			t.Fatal(err)
+		}
+	}
+}
+
+func TestParallel1(t *testing.T) {
+	runParallelTests(t, 1)
+}
+
+func TestParallel2(t *testing.T) {
+	runParallelTests(t, 2)
+}
+
+func TestParallel3(t *testing.T) {
+	runParallelTests(t, 3)
+}

+ 3 - 3
libnetwork/sandbox/namespace_linux.go

@@ -12,7 +12,7 @@ import (
 	"github.com/vishvananda/netns"
 )
 
-const prefix = "/var/lib/docker/netns"
+const prefix = "/var/run/netns"
 
 var once sync.Once
 
@@ -148,8 +148,8 @@ func (n *networkNamespace) RemoveInterface(i *Interface) error {
 		return err
 	}
 
-	// Move the network interface to init namespace.
-	if err := netlink.LinkSetNsPid(iface, 1); err != nil {
+	// Move the network interface to caller namespace.
+	if err := netlink.LinkSetNsFd(iface, int(origns)); err != nil {
 		fmt.Println("LinkSetNsPid failed: ", err)
 		return err
 	}