diff --git a/libnetwork/Makefile b/libnetwork/Makefile index 62bc036a3a..ba565fabad 100644 --- a/libnetwork/Makefile +++ b/libnetwork/Makefile @@ -35,7 +35,7 @@ run-tests: @echo "mode: count" > coverage.coverprofile @for dir in $$(find . -maxdepth 10 -not -path './.git*' -not -path '*/_*' -type d); do \ if ls $$dir/*.go &> /dev/null; then \ - $(shell which godep) go test -test.v -covermode=count -coverprofile=$$dir/profile.tmp $$dir ; \ + $(shell which godep) go test -test.parallel 3 -test.v -covermode=count -coverprofile=$$dir/profile.tmp $$dir ; \ if [ $$? -ne 0 ]; then exit $$?; fi ;\ if [ -f $$dir/profile.tmp ]; then \ cat $$dir/profile.tmp | tail -n +2 >> coverage.coverprofile ; \ diff --git a/libnetwork/endpoint.go b/libnetwork/endpoint.go index 85ab1e0f05..5630bed231 100644 --- a/libnetwork/endpoint.go +++ b/libnetwork/endpoint.go @@ -4,6 +4,7 @@ import ( "io/ioutil" "os" "path/filepath" + "sync" "github.com/Sirupsen/logrus" "github.com/docker/libnetwork/driverapi" @@ -96,33 +97,47 @@ type containerInfo struct { } type endpoint struct { - name string - id types.UUID - network *network - sandboxInfo *sandbox.Info - sandBox sandbox.Sandbox - joinInfo *driverapi.JoinInfo - container *containerInfo - exposedPorts []netutils.TransportPort - generic map[string]interface{} - context map[string]interface{} + name string + id types.UUID + network *network + sandboxInfo *sandbox.Info + sandBox sandbox.Sandbox + joinInfo *driverapi.JoinInfo + container *containerInfo + exposedPorts []netutils.TransportPort + generic map[string]interface{} + context map[string]interface{} + joinLeaveDone chan struct{} + sync.Mutex } const defaultPrefix = "/var/lib/docker/network/files" func (ep *endpoint) ID() string { + ep.Lock() + defer ep.Unlock() + return string(ep.id) } func (ep *endpoint) Name() string { + ep.Lock() + defer ep.Unlock() + return ep.name } func (ep *endpoint) Network() string { + ep.Lock() + defer ep.Unlock() + return ep.network.name } func (ep *endpoint) SandboxInfo() *sandbox.Info { + ep.Lock() + defer ep.Unlock() + if ep.sandboxInfo == nil { return nil } @@ -130,10 +145,23 @@ func (ep *endpoint) SandboxInfo() *sandbox.Info { } func (ep *endpoint) Info() (map[string]interface{}, error) { - return ep.network.driver.EndpointInfo(ep.network.id, ep.id) + ep.Lock() + network := ep.network + epid := ep.id + ep.Unlock() + + network.Lock() + driver := network.driver + nid := network.id + network.Unlock() + + return driver.EndpointInfo(nid, epid) } func (ep *endpoint) processOptions(options ...EndpointOption) { + ep.Lock() + defer ep.Unlock() + for _, opt := range options { if opt != nil { opt(ep) @@ -167,6 +195,38 @@ func createFile(path string) error { return err } +// joinLeaveStart waits to ensure there are no joins or leaves in progress and +// marks this join/leave in progress without race +func (ep *endpoint) joinLeaveStart() { + ep.Lock() + defer ep.Unlock() + + for ep.joinLeaveDone != nil { + joinLeaveDone := ep.joinLeaveDone + ep.Unlock() + + select { + case <-joinLeaveDone: + } + + ep.Lock() + } + + ep.joinLeaveDone = make(chan struct{}) +} + +// joinLeaveEnd marks the end of this join/leave operation and +// signals the same without race to other join and leave waiters +func (ep *endpoint) joinLeaveEnd() { + ep.Lock() + defer ep.Unlock() + + if ep.joinLeaveDone != nil { + close(ep.joinLeaveDone) + ep.joinLeaveDone = nil + } +} + func (ep *endpoint) Join(containerID string, options ...EndpointOption) (*ContainerData, error) { var err error @@ -174,44 +234,58 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) (*Contai return nil, InvalidContainerIDError(containerID) } + ep.joinLeaveStart() + defer ep.joinLeaveEnd() + + ep.Lock() if ep.container != nil { + ep.Unlock() return nil, ErrInvalidJoin } ep.container = &containerInfo{ + id: containerID, config: containerConfig{ hostsPathConfig: hostsPathConfig{ extraHosts: []extraHost{}, parentUpdates: []parentUpdate{}, }, }} + + container := ep.container + network := ep.network + epid := ep.id + + ep.Unlock() defer func() { + ep.Lock() if err != nil { ep.container = nil } + ep.Unlock() }() + network.Lock() + driver := network.driver + nid := network.id + ctrlr := network.ctrlr + network.Unlock() + ep.processOptions(options...) - if ep.container.config.hostsPath == "" { - ep.container.config.hostsPath = defaultPrefix + "/" + containerID + "/hosts" - } - - if ep.container.config.resolvConfPath == "" { - ep.container.config.resolvConfPath = defaultPrefix + "/" + containerID + "/resolv.conf" - } - sboxKey := sandbox.GenerateKey(containerID) - if ep.container.config.useDefaultSandBox { + if container.config.useDefaultSandBox { sboxKey = sandbox.GenerateKey("default") } - joinInfo, err := ep.network.driver.Join(ep.network.id, ep.id, - sboxKey, ep.container.config.generic) + joinInfo, err := driver.Join(nid, epid, sboxKey, container.config.generic) if err != nil { return nil, err } + + ep.Lock() ep.joinInfo = joinInfo + ep.Unlock() err = ep.buildHostsFiles() if err != nil { @@ -228,14 +302,13 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) (*Contai return nil, err } - sb, err := ep.network.ctrlr.sandboxAdd(sboxKey, - !ep.container.config.useDefaultSandBox) + sb, err := ctrlr.sandboxAdd(sboxKey, !container.config.useDefaultSandBox) if err != nil { return nil, err } defer func() { if err != nil { - ep.network.ctrlr.sandboxRm(sboxKey) + ctrlr.sandboxRm(sboxKey) } }() @@ -259,27 +332,50 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) (*Contai } } - ep.container.id = containerID - ep.container.data.SandboxKey = sb.Key() + container.data.SandboxKey = sb.Key() + cData := container.data - cData := ep.container.data return &cData, nil } func (ep *endpoint) Leave(containerID string, options ...EndpointOption) error { - if ep.container == nil || ep.container.id == "" || - containerID == "" || ep.container.id != containerID { - return InvalidContainerIDError(containerID) - } + var err error + + ep.joinLeaveStart() + defer ep.joinLeaveEnd() ep.processOptions(options...) + ep.Lock() + container := ep.container n := ep.network - err := n.driver.Leave(n.id, ep.id, ep.context) + context := ep.context + + if container == nil || container.id == "" || + containerID == "" || container.id != containerID { + if container == nil { + err = ErrNoContainer + } else { + err = InvalidContainerIDError(containerID) + } + + ep.Unlock() + return err + } + ep.container = nil + ep.context = nil + ep.Unlock() + + n.Lock() + driver := n.driver + ctrlr := n.ctrlr + n.Unlock() + + err = driver.Leave(n.id, ep.id, context) sinfo := ep.SandboxInfo() if sinfo != nil { - sb := ep.network.ctrlr.sandboxGet(ep.container.data.SandboxKey) + sb := ctrlr.sandboxGet(container.data.SandboxKey) for _, i := range sinfo.Interfaces { err = sb.RemoveInterface(i) if err != nil { @@ -288,93 +384,129 @@ func (ep *endpoint) Leave(containerID string, options ...EndpointOption) error { } } - ep.network.ctrlr.sandboxRm(ep.container.data.SandboxKey) - ep.container = nil - ep.context = nil + ctrlr.sandboxRm(container.data.SandboxKey) + return err } func (ep *endpoint) Delete() error { var err error + + ep.Lock() + epid := ep.id + name := ep.name if ep.container != nil { - return &ActiveContainerError{name: ep.name, id: string(ep.id)} + ep.Unlock() + return &ActiveContainerError{name: name, id: string(epid)} } n := ep.network + ep.Unlock() + n.Lock() - _, ok := n.endpoints[ep.id] + _, ok := n.endpoints[epid] if !ok { n.Unlock() - return &UnknownEndpointError{name: ep.name, id: string(ep.id)} + return &UnknownEndpointError{name: name, id: string(epid)} } - delete(n.endpoints, ep.id) + nid := n.id + driver := n.driver + delete(n.endpoints, epid) n.Unlock() defer func() { if err != nil { n.Lock() - n.endpoints[ep.id] = ep + n.endpoints[epid] = ep n.Unlock() } }() - err = n.driver.DeleteEndpoint(n.id, ep.id) + err = driver.DeleteEndpoint(nid, epid) return err } func (ep *endpoint) buildHostsFiles() error { var extraContent []etchosts.Record - dir, _ := filepath.Split(ep.container.config.hostsPath) + ep.Lock() + container := ep.container + joinInfo := ep.joinInfo + ep.Unlock() + + if container == nil { + return ErrNoContainer + } + + if container.config.hostsPath == "" { + container.config.hostsPath = defaultPrefix + "/" + container.id + "/hosts" + } + + dir, _ := filepath.Split(container.config.hostsPath) err := createBasePath(dir) if err != nil { return err } - if ep.joinInfo != nil && ep.joinInfo.HostsPath != "" { - content, err := ioutil.ReadFile(ep.joinInfo.HostsPath) + if joinInfo != nil && joinInfo.HostsPath != "" { + content, err := ioutil.ReadFile(joinInfo.HostsPath) if err != nil && !os.IsNotExist(err) { return err } if err == nil { - return ioutil.WriteFile(ep.container.config.hostsPath, content, 0644) + return ioutil.WriteFile(container.config.hostsPath, content, 0644) } } - name := ep.container.config.hostName - if ep.container.config.domainName != "" { - name = name + "." + ep.container.config.domainName + name := container.config.hostName + if container.config.domainName != "" { + name = name + "." + container.config.domainName } - for _, extraHost := range ep.container.config.extraHosts { + for _, extraHost := range container.config.extraHosts { extraContent = append(extraContent, etchosts.Record{Hosts: extraHost.name, IP: extraHost.IP}) } IP := "" - if ep.sandboxInfo != nil && ep.sandboxInfo.Interfaces[0] != nil && - ep.sandboxInfo.Interfaces[0].Address != nil { - IP = ep.sandboxInfo.Interfaces[0].Address.IP.String() + sinfo := ep.SandboxInfo() + if sinfo != nil && sinfo.Interfaces[0] != nil && + sinfo.Interfaces[0].Address != nil { + IP = sinfo.Interfaces[0].Address.IP.String() } - return etchosts.Build(ep.container.config.hostsPath, IP, ep.container.config.hostName, - ep.container.config.domainName, extraContent) + return etchosts.Build(container.config.hostsPath, IP, container.config.hostName, + container.config.domainName, extraContent) } func (ep *endpoint) updateParentHosts() error { - for _, update := range ep.container.config.parentUpdates { - ep.network.Lock() - pep, ok := ep.network.endpoints[types.UUID(update.eid)] + ep.Lock() + container := ep.container + network := ep.network + ep.Unlock() + + if container == nil { + return ErrNoContainer + } + + for _, update := range container.config.parentUpdates { + network.Lock() + pep, ok := network.endpoints[types.UUID(update.eid)] if !ok { - ep.network.Unlock() + network.Unlock() continue } - ep.network.Unlock() + network.Unlock() - if err := etchosts.Update(pep.container.config.hostsPath, - update.ip, update.name); err != nil { - return err + pep.Lock() + pContainer := pep.container + pep.Unlock() + + if pContainer != nil { + if err := etchosts.Update(pContainer.config.hostsPath, update.ip, update.name); err != nil { + return err + } } } @@ -382,7 +514,20 @@ func (ep *endpoint) updateParentHosts() error { } func (ep *endpoint) setupDNS() error { - dir, _ := filepath.Split(ep.container.config.resolvConfPath) + ep.Lock() + container := ep.container + network := ep.network + ep.Unlock() + + if container == nil { + return ErrNoContainer + } + + if container.config.resolvConfPath == "" { + container.config.resolvConfPath = defaultPrefix + "/" + container.id + "/resolv.conf" + } + + dir, _ := filepath.Split(container.config.resolvConfPath) err := createBasePath(dir) if err != nil { return err @@ -393,26 +538,26 @@ func (ep *endpoint) setupDNS() error { return err } - if len(ep.container.config.dnsList) > 0 || - len(ep.container.config.dnsSearchList) > 0 { + if len(container.config.dnsList) > 0 || + len(container.config.dnsSearchList) > 0 { var ( dnsList = resolvconf.GetNameservers(resolvConf) dnsSearchList = resolvconf.GetSearchDomains(resolvConf) ) - if len(ep.container.config.dnsList) > 0 { - dnsList = ep.container.config.dnsList + if len(container.config.dnsList) > 0 { + dnsList = container.config.dnsList } - if len(ep.container.config.dnsSearchList) > 0 { - dnsSearchList = ep.container.config.dnsSearchList + if len(container.config.dnsSearchList) > 0 { + dnsSearchList = container.config.dnsSearchList } - return resolvconf.Build(ep.container.config.resolvConfPath, dnsList, dnsSearchList) + return resolvconf.Build(container.config.resolvConfPath, dnsList, dnsSearchList) } // replace any localhost/127.* but always discard IPv6 entries for now. - resolvConf, _ = resolvconf.FilterResolvDNS(resolvConf, ep.network.enableIPv6) + resolvConf, _ = resolvconf.FilterResolvDNS(resolvConf, network.enableIPv6) return ioutil.WriteFile(ep.container.config.resolvConfPath, resolvConf, 0644) } diff --git a/libnetwork/error.go b/libnetwork/error.go index eaefe9e95a..601f0afefe 100644 --- a/libnetwork/error.go +++ b/libnetwork/error.go @@ -15,6 +15,9 @@ var ( // ErrInvalidJoin is returned if a join is attempted on an endpoint // which already has a container joined. ErrInvalidJoin = errors.New("A container has already joined the endpoint") + // ErrNoContainer is returned when the endpoint has no container + // attached to it. + ErrNoContainer = errors.New("no container attached to the endpoint") ) // NetworkTypeError type is returned when the network type string is not diff --git a/libnetwork/libnetwork_test.go b/libnetwork/libnetwork_test.go index 5a3204bf41..2df1444c1d 100644 --- a/libnetwork/libnetwork_test.go +++ b/libnetwork/libnetwork_test.go @@ -2,9 +2,14 @@ package libnetwork_test import ( "bytes" + "flag" + "fmt" "io/ioutil" "net" "os" + "runtime" + "strconv" + "sync" "testing" log "github.com/Sirupsen/logrus" @@ -13,6 +18,7 @@ import ( "github.com/docker/libnetwork/netutils" "github.com/docker/libnetwork/pkg/netlabel" "github.com/docker/libnetwork/pkg/options" + "github.com/vishvananda/netns" ) const ( @@ -736,7 +742,9 @@ func TestEndpointInvalidLeave(t *testing.T) { } if _, ok := err.(libnetwork.InvalidContainerIDError); !ok { - t.Fatalf("Failed for unexpected reason: %v", err) + if err != libnetwork.ErrNoContainer { + t.Fatalf("Failed for unexpected reason: %v", err) + } } _, err = ep.Join(containerID, @@ -947,3 +955,175 @@ func TestNoEnableIPv6(t *testing.T) { t.Fatal(err) } } + +var ( + once sync.Once + ctrlr libnetwork.NetworkController + start = make(chan struct{}) + done = make(chan chan struct{}, numThreads-1) + origns = netns.None() + testns = netns.None() +) + +const ( + iterCnt = 25 + numThreads = 3 + first = 1 + last = numThreads + debug = false +) + +func createGlobalInstance(t *testing.T) { + var err error + defer close(start) + + origns, err = netns.Get() + if err != nil { + t.Fatal(err) + } + + //testns = origns + testns, err = netns.New() + if err != nil { + t.Fatal(err) + } + + ctrlr = libnetwork.New() + + err = ctrlr.ConfigureNetworkDriver(bridgeNetType, getEmptyGenericOption()) + if err != nil { + t.Fatal("configure driver") + } + + net, err := ctrlr.NewNetwork(bridgeNetType, "network1") + if err != nil { + t.Fatal("new network") + } + + _, err = net.CreateEndpoint("ep1") + if err != nil { + t.Fatal("createendpoint") + } +} + +func debugf(format string, a ...interface{}) (int, error) { + if debug { + return fmt.Printf(format, a...) + } + + return 0, nil +} + +func parallelJoin(t *testing.T, ep libnetwork.Endpoint, thrNumber int) { + debugf("J%d.", thrNumber) + _, err := ep.Join("racing_container") + runtime.LockOSThread() + if err != nil { + if err != libnetwork.ErrNoContainer && err != libnetwork.ErrInvalidJoin { + t.Fatal(err) + } + debugf("JE%d(%v).", thrNumber, err) + } + debugf("JD%d.", thrNumber) +} + +func parallelLeave(t *testing.T, ep libnetwork.Endpoint, thrNumber int) { + debugf("L%d.", thrNumber) + err := ep.Leave("racing_container") + runtime.LockOSThread() + if err != nil { + if err != libnetwork.ErrNoContainer && err != libnetwork.ErrInvalidJoin { + t.Fatal(err) + } + debugf("LE%d(%v).", thrNumber, err) + } + debugf("LD%d.", thrNumber) +} + +func runParallelTests(t *testing.T, thrNumber int) { + var err error + + t.Parallel() + + pTest := flag.Lookup("test.parallel") + if pTest == nil { + t.Skip("Skipped because test.parallel flag not set;") + } + numParallel, err := strconv.Atoi(pTest.Value.String()) + if err != nil { + t.Fatal(err) + } + if numParallel < numThreads { + t.Skip("Skipped because t.parallel was less than ", numThreads) + } + + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + if thrNumber == first { + createGlobalInstance(t) + } + + if thrNumber != first { + select { + case <-start: + } + + thrdone := make(chan struct{}) + done <- thrdone + defer close(thrdone) + + if thrNumber == last { + defer close(done) + } + + err = netns.Set(testns) + if err != nil { + t.Fatal(err) + } + } + defer netns.Set(origns) + + net := ctrlr.NetworkByName("network1") + if net == nil { + t.Fatal("Could not find network1") + } + + ep := net.EndpointByName("ep1") + if ep == nil { + t.Fatal("Could not find ep1") + } + + for i := 0; i < iterCnt; i++ { + parallelJoin(t, ep, thrNumber) + parallelLeave(t, ep, thrNumber) + } + + debugf("\n") + + if thrNumber == first { + for thrdone := range done { + select { + case <-thrdone: + } + } + + testns.Close() + err = ep.Delete() + if err != nil { + t.Fatal(err) + } + } +} + +func TestParallel1(t *testing.T) { + runParallelTests(t, 1) +} + +func TestParallel2(t *testing.T) { + runParallelTests(t, 2) +} + +func TestParallel3(t *testing.T) { + runParallelTests(t, 3) +} diff --git a/libnetwork/sandbox/namespace_linux.go b/libnetwork/sandbox/namespace_linux.go index 4a7489625c..62fa2ddb6c 100644 --- a/libnetwork/sandbox/namespace_linux.go +++ b/libnetwork/sandbox/namespace_linux.go @@ -12,7 +12,7 @@ import ( "github.com/vishvananda/netns" ) -const prefix = "/var/lib/docker/netns" +const prefix = "/var/run/netns" var once sync.Once @@ -148,8 +148,8 @@ func (n *networkNamespace) RemoveInterface(i *Interface) error { return err } - // Move the network interface to init namespace. - if err := netlink.LinkSetNsPid(iface, 1); err != nil { + // Move the network interface to caller namespace. + if err := netlink.LinkSetNsFd(iface, int(origns)); err != nil { fmt.Println("LinkSetNsPid failed: ", err) return err }