Merge pull request #140 from mrjana/cnm_integ

Make endpoint Join and Leave multi-thread safe
This commit is contained in:
Madhu Venugopal 2015-05-10 10:53:43 -07:00
commit f9ef08c30f
5 changed files with 406 additions and 78 deletions

View file

@ -35,7 +35,7 @@ run-tests:
@echo "mode: count" > coverage.coverprofile
@for dir in $$(find . -maxdepth 10 -not -path './.git*' -not -path '*/_*' -type d); do \
if ls $$dir/*.go &> /dev/null; then \
$(shell which godep) go test -test.v -covermode=count -coverprofile=$$dir/profile.tmp $$dir ; \
$(shell which godep) go test -test.parallel 3 -test.v -covermode=count -coverprofile=$$dir/profile.tmp $$dir ; \
if [ $$? -ne 0 ]; then exit $$?; fi ;\
if [ -f $$dir/profile.tmp ]; then \
cat $$dir/profile.tmp | tail -n +2 >> coverage.coverprofile ; \

View file

@ -4,6 +4,7 @@ import (
"io/ioutil"
"os"
"path/filepath"
"sync"
"github.com/Sirupsen/logrus"
"github.com/docker/libnetwork/driverapi"
@ -96,33 +97,47 @@ type containerInfo struct {
}
type endpoint struct {
name string
id types.UUID
network *network
sandboxInfo *sandbox.Info
sandBox sandbox.Sandbox
joinInfo *driverapi.JoinInfo
container *containerInfo
exposedPorts []netutils.TransportPort
generic map[string]interface{}
context map[string]interface{}
name string
id types.UUID
network *network
sandboxInfo *sandbox.Info
sandBox sandbox.Sandbox
joinInfo *driverapi.JoinInfo
container *containerInfo
exposedPorts []netutils.TransportPort
generic map[string]interface{}
context map[string]interface{}
joinLeaveDone chan struct{}
sync.Mutex
}
const defaultPrefix = "/var/lib/docker/network/files"
func (ep *endpoint) ID() string {
ep.Lock()
defer ep.Unlock()
return string(ep.id)
}
func (ep *endpoint) Name() string {
ep.Lock()
defer ep.Unlock()
return ep.name
}
func (ep *endpoint) Network() string {
ep.Lock()
defer ep.Unlock()
return ep.network.name
}
func (ep *endpoint) SandboxInfo() *sandbox.Info {
ep.Lock()
defer ep.Unlock()
if ep.sandboxInfo == nil {
return nil
}
@ -130,10 +145,23 @@ func (ep *endpoint) SandboxInfo() *sandbox.Info {
}
func (ep *endpoint) Info() (map[string]interface{}, error) {
return ep.network.driver.EndpointInfo(ep.network.id, ep.id)
ep.Lock()
network := ep.network
epid := ep.id
ep.Unlock()
network.Lock()
driver := network.driver
nid := network.id
network.Unlock()
return driver.EndpointInfo(nid, epid)
}
func (ep *endpoint) processOptions(options ...EndpointOption) {
ep.Lock()
defer ep.Unlock()
for _, opt := range options {
if opt != nil {
opt(ep)
@ -167,6 +195,38 @@ func createFile(path string) error {
return err
}
// joinLeaveStart waits to ensure there are no joins or leaves in progress and
// marks this join/leave in progress without race
func (ep *endpoint) joinLeaveStart() {
ep.Lock()
defer ep.Unlock()
for ep.joinLeaveDone != nil {
joinLeaveDone := ep.joinLeaveDone
ep.Unlock()
select {
case <-joinLeaveDone:
}
ep.Lock()
}
ep.joinLeaveDone = make(chan struct{})
}
// joinLeaveEnd marks the end of this join/leave operation and
// signals the same without race to other join and leave waiters
func (ep *endpoint) joinLeaveEnd() {
ep.Lock()
defer ep.Unlock()
if ep.joinLeaveDone != nil {
close(ep.joinLeaveDone)
ep.joinLeaveDone = nil
}
}
func (ep *endpoint) Join(containerID string, options ...EndpointOption) (*ContainerData, error) {
var err error
@ -174,44 +234,58 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) (*Contai
return nil, InvalidContainerIDError(containerID)
}
ep.joinLeaveStart()
defer ep.joinLeaveEnd()
ep.Lock()
if ep.container != nil {
ep.Unlock()
return nil, ErrInvalidJoin
}
ep.container = &containerInfo{
id: containerID,
config: containerConfig{
hostsPathConfig: hostsPathConfig{
extraHosts: []extraHost{},
parentUpdates: []parentUpdate{},
},
}}
container := ep.container
network := ep.network
epid := ep.id
ep.Unlock()
defer func() {
ep.Lock()
if err != nil {
ep.container = nil
}
ep.Unlock()
}()
network.Lock()
driver := network.driver
nid := network.id
ctrlr := network.ctrlr
network.Unlock()
ep.processOptions(options...)
if ep.container.config.hostsPath == "" {
ep.container.config.hostsPath = defaultPrefix + "/" + containerID + "/hosts"
}
if ep.container.config.resolvConfPath == "" {
ep.container.config.resolvConfPath = defaultPrefix + "/" + containerID + "/resolv.conf"
}
sboxKey := sandbox.GenerateKey(containerID)
if ep.container.config.useDefaultSandBox {
if container.config.useDefaultSandBox {
sboxKey = sandbox.GenerateKey("default")
}
joinInfo, err := ep.network.driver.Join(ep.network.id, ep.id,
sboxKey, ep.container.config.generic)
joinInfo, err := driver.Join(nid, epid, sboxKey, container.config.generic)
if err != nil {
return nil, err
}
ep.Lock()
ep.joinInfo = joinInfo
ep.Unlock()
err = ep.buildHostsFiles()
if err != nil {
@ -228,14 +302,13 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) (*Contai
return nil, err
}
sb, err := ep.network.ctrlr.sandboxAdd(sboxKey,
!ep.container.config.useDefaultSandBox)
sb, err := ctrlr.sandboxAdd(sboxKey, !container.config.useDefaultSandBox)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
ep.network.ctrlr.sandboxRm(sboxKey)
ctrlr.sandboxRm(sboxKey)
}
}()
@ -259,27 +332,50 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) (*Contai
}
}
ep.container.id = containerID
ep.container.data.SandboxKey = sb.Key()
container.data.SandboxKey = sb.Key()
cData := container.data
cData := ep.container.data
return &cData, nil
}
func (ep *endpoint) Leave(containerID string, options ...EndpointOption) error {
if ep.container == nil || ep.container.id == "" ||
containerID == "" || ep.container.id != containerID {
return InvalidContainerIDError(containerID)
}
var err error
ep.joinLeaveStart()
defer ep.joinLeaveEnd()
ep.processOptions(options...)
ep.Lock()
container := ep.container
n := ep.network
err := n.driver.Leave(n.id, ep.id, ep.context)
context := ep.context
if container == nil || container.id == "" ||
containerID == "" || container.id != containerID {
if container == nil {
err = ErrNoContainer
} else {
err = InvalidContainerIDError(containerID)
}
ep.Unlock()
return err
}
ep.container = nil
ep.context = nil
ep.Unlock()
n.Lock()
driver := n.driver
ctrlr := n.ctrlr
n.Unlock()
err = driver.Leave(n.id, ep.id, context)
sinfo := ep.SandboxInfo()
if sinfo != nil {
sb := ep.network.ctrlr.sandboxGet(ep.container.data.SandboxKey)
sb := ctrlr.sandboxGet(container.data.SandboxKey)
for _, i := range sinfo.Interfaces {
err = sb.RemoveInterface(i)
if err != nil {
@ -288,93 +384,129 @@ func (ep *endpoint) Leave(containerID string, options ...EndpointOption) error {
}
}
ep.network.ctrlr.sandboxRm(ep.container.data.SandboxKey)
ep.container = nil
ep.context = nil
ctrlr.sandboxRm(container.data.SandboxKey)
return err
}
func (ep *endpoint) Delete() error {
var err error
ep.Lock()
epid := ep.id
name := ep.name
if ep.container != nil {
return &ActiveContainerError{name: ep.name, id: string(ep.id)}
ep.Unlock()
return &ActiveContainerError{name: name, id: string(epid)}
}
n := ep.network
ep.Unlock()
n.Lock()
_, ok := n.endpoints[ep.id]
_, ok := n.endpoints[epid]
if !ok {
n.Unlock()
return &UnknownEndpointError{name: ep.name, id: string(ep.id)}
return &UnknownEndpointError{name: name, id: string(epid)}
}
delete(n.endpoints, ep.id)
nid := n.id
driver := n.driver
delete(n.endpoints, epid)
n.Unlock()
defer func() {
if err != nil {
n.Lock()
n.endpoints[ep.id] = ep
n.endpoints[epid] = ep
n.Unlock()
}
}()
err = n.driver.DeleteEndpoint(n.id, ep.id)
err = driver.DeleteEndpoint(nid, epid)
return err
}
func (ep *endpoint) buildHostsFiles() error {
var extraContent []etchosts.Record
dir, _ := filepath.Split(ep.container.config.hostsPath)
ep.Lock()
container := ep.container
joinInfo := ep.joinInfo
ep.Unlock()
if container == nil {
return ErrNoContainer
}
if container.config.hostsPath == "" {
container.config.hostsPath = defaultPrefix + "/" + container.id + "/hosts"
}
dir, _ := filepath.Split(container.config.hostsPath)
err := createBasePath(dir)
if err != nil {
return err
}
if ep.joinInfo != nil && ep.joinInfo.HostsPath != "" {
content, err := ioutil.ReadFile(ep.joinInfo.HostsPath)
if joinInfo != nil && joinInfo.HostsPath != "" {
content, err := ioutil.ReadFile(joinInfo.HostsPath)
if err != nil && !os.IsNotExist(err) {
return err
}
if err == nil {
return ioutil.WriteFile(ep.container.config.hostsPath, content, 0644)
return ioutil.WriteFile(container.config.hostsPath, content, 0644)
}
}
name := ep.container.config.hostName
if ep.container.config.domainName != "" {
name = name + "." + ep.container.config.domainName
name := container.config.hostName
if container.config.domainName != "" {
name = name + "." + container.config.domainName
}
for _, extraHost := range ep.container.config.extraHosts {
for _, extraHost := range container.config.extraHosts {
extraContent = append(extraContent,
etchosts.Record{Hosts: extraHost.name, IP: extraHost.IP})
}
IP := ""
if ep.sandboxInfo != nil && ep.sandboxInfo.Interfaces[0] != nil &&
ep.sandboxInfo.Interfaces[0].Address != nil {
IP = ep.sandboxInfo.Interfaces[0].Address.IP.String()
sinfo := ep.SandboxInfo()
if sinfo != nil && sinfo.Interfaces[0] != nil &&
sinfo.Interfaces[0].Address != nil {
IP = sinfo.Interfaces[0].Address.IP.String()
}
return etchosts.Build(ep.container.config.hostsPath, IP, ep.container.config.hostName,
ep.container.config.domainName, extraContent)
return etchosts.Build(container.config.hostsPath, IP, container.config.hostName,
container.config.domainName, extraContent)
}
func (ep *endpoint) updateParentHosts() error {
for _, update := range ep.container.config.parentUpdates {
ep.network.Lock()
pep, ok := ep.network.endpoints[types.UUID(update.eid)]
ep.Lock()
container := ep.container
network := ep.network
ep.Unlock()
if container == nil {
return ErrNoContainer
}
for _, update := range container.config.parentUpdates {
network.Lock()
pep, ok := network.endpoints[types.UUID(update.eid)]
if !ok {
ep.network.Unlock()
network.Unlock()
continue
}
ep.network.Unlock()
network.Unlock()
if err := etchosts.Update(pep.container.config.hostsPath,
update.ip, update.name); err != nil {
return err
pep.Lock()
pContainer := pep.container
pep.Unlock()
if pContainer != nil {
if err := etchosts.Update(pContainer.config.hostsPath, update.ip, update.name); err != nil {
return err
}
}
}
@ -382,7 +514,20 @@ func (ep *endpoint) updateParentHosts() error {
}
func (ep *endpoint) setupDNS() error {
dir, _ := filepath.Split(ep.container.config.resolvConfPath)
ep.Lock()
container := ep.container
network := ep.network
ep.Unlock()
if container == nil {
return ErrNoContainer
}
if container.config.resolvConfPath == "" {
container.config.resolvConfPath = defaultPrefix + "/" + container.id + "/resolv.conf"
}
dir, _ := filepath.Split(container.config.resolvConfPath)
err := createBasePath(dir)
if err != nil {
return err
@ -393,26 +538,26 @@ func (ep *endpoint) setupDNS() error {
return err
}
if len(ep.container.config.dnsList) > 0 ||
len(ep.container.config.dnsSearchList) > 0 {
if len(container.config.dnsList) > 0 ||
len(container.config.dnsSearchList) > 0 {
var (
dnsList = resolvconf.GetNameservers(resolvConf)
dnsSearchList = resolvconf.GetSearchDomains(resolvConf)
)
if len(ep.container.config.dnsList) > 0 {
dnsList = ep.container.config.dnsList
if len(container.config.dnsList) > 0 {
dnsList = container.config.dnsList
}
if len(ep.container.config.dnsSearchList) > 0 {
dnsSearchList = ep.container.config.dnsSearchList
if len(container.config.dnsSearchList) > 0 {
dnsSearchList = container.config.dnsSearchList
}
return resolvconf.Build(ep.container.config.resolvConfPath, dnsList, dnsSearchList)
return resolvconf.Build(container.config.resolvConfPath, dnsList, dnsSearchList)
}
// replace any localhost/127.* but always discard IPv6 entries for now.
resolvConf, _ = resolvconf.FilterResolvDNS(resolvConf, ep.network.enableIPv6)
resolvConf, _ = resolvconf.FilterResolvDNS(resolvConf, network.enableIPv6)
return ioutil.WriteFile(ep.container.config.resolvConfPath, resolvConf, 0644)
}

View file

@ -15,6 +15,9 @@ var (
// ErrInvalidJoin is returned if a join is attempted on an endpoint
// which already has a container joined.
ErrInvalidJoin = errors.New("A container has already joined the endpoint")
// ErrNoContainer is returned when the endpoint has no container
// attached to it.
ErrNoContainer = errors.New("no container attached to the endpoint")
)
// NetworkTypeError type is returned when the network type string is not

View file

@ -2,9 +2,14 @@ package libnetwork_test
import (
"bytes"
"flag"
"fmt"
"io/ioutil"
"net"
"os"
"runtime"
"strconv"
"sync"
"testing"
log "github.com/Sirupsen/logrus"
@ -13,6 +18,7 @@ import (
"github.com/docker/libnetwork/netutils"
"github.com/docker/libnetwork/pkg/netlabel"
"github.com/docker/libnetwork/pkg/options"
"github.com/vishvananda/netns"
)
const (
@ -736,7 +742,9 @@ func TestEndpointInvalidLeave(t *testing.T) {
}
if _, ok := err.(libnetwork.InvalidContainerIDError); !ok {
t.Fatalf("Failed for unexpected reason: %v", err)
if err != libnetwork.ErrNoContainer {
t.Fatalf("Failed for unexpected reason: %v", err)
}
}
_, err = ep.Join(containerID,
@ -947,3 +955,175 @@ func TestNoEnableIPv6(t *testing.T) {
t.Fatal(err)
}
}
var (
once sync.Once
ctrlr libnetwork.NetworkController
start = make(chan struct{})
done = make(chan chan struct{}, numThreads-1)
origns = netns.None()
testns = netns.None()
)
const (
iterCnt = 25
numThreads = 3
first = 1
last = numThreads
debug = false
)
func createGlobalInstance(t *testing.T) {
var err error
defer close(start)
origns, err = netns.Get()
if err != nil {
t.Fatal(err)
}
//testns = origns
testns, err = netns.New()
if err != nil {
t.Fatal(err)
}
ctrlr = libnetwork.New()
err = ctrlr.ConfigureNetworkDriver(bridgeNetType, getEmptyGenericOption())
if err != nil {
t.Fatal("configure driver")
}
net, err := ctrlr.NewNetwork(bridgeNetType, "network1")
if err != nil {
t.Fatal("new network")
}
_, err = net.CreateEndpoint("ep1")
if err != nil {
t.Fatal("createendpoint")
}
}
func debugf(format string, a ...interface{}) (int, error) {
if debug {
return fmt.Printf(format, a...)
}
return 0, nil
}
func parallelJoin(t *testing.T, ep libnetwork.Endpoint, thrNumber int) {
debugf("J%d.", thrNumber)
_, err := ep.Join("racing_container")
runtime.LockOSThread()
if err != nil {
if err != libnetwork.ErrNoContainer && err != libnetwork.ErrInvalidJoin {
t.Fatal(err)
}
debugf("JE%d(%v).", thrNumber, err)
}
debugf("JD%d.", thrNumber)
}
func parallelLeave(t *testing.T, ep libnetwork.Endpoint, thrNumber int) {
debugf("L%d.", thrNumber)
err := ep.Leave("racing_container")
runtime.LockOSThread()
if err != nil {
if err != libnetwork.ErrNoContainer && err != libnetwork.ErrInvalidJoin {
t.Fatal(err)
}
debugf("LE%d(%v).", thrNumber, err)
}
debugf("LD%d.", thrNumber)
}
func runParallelTests(t *testing.T, thrNumber int) {
var err error
t.Parallel()
pTest := flag.Lookup("test.parallel")
if pTest == nil {
t.Skip("Skipped because test.parallel flag not set;")
}
numParallel, err := strconv.Atoi(pTest.Value.String())
if err != nil {
t.Fatal(err)
}
if numParallel < numThreads {
t.Skip("Skipped because t.parallel was less than ", numThreads)
}
runtime.LockOSThread()
defer runtime.UnlockOSThread()
if thrNumber == first {
createGlobalInstance(t)
}
if thrNumber != first {
select {
case <-start:
}
thrdone := make(chan struct{})
done <- thrdone
defer close(thrdone)
if thrNumber == last {
defer close(done)
}
err = netns.Set(testns)
if err != nil {
t.Fatal(err)
}
}
defer netns.Set(origns)
net := ctrlr.NetworkByName("network1")
if net == nil {
t.Fatal("Could not find network1")
}
ep := net.EndpointByName("ep1")
if ep == nil {
t.Fatal("Could not find ep1")
}
for i := 0; i < iterCnt; i++ {
parallelJoin(t, ep, thrNumber)
parallelLeave(t, ep, thrNumber)
}
debugf("\n")
if thrNumber == first {
for thrdone := range done {
select {
case <-thrdone:
}
}
testns.Close()
err = ep.Delete()
if err != nil {
t.Fatal(err)
}
}
}
func TestParallel1(t *testing.T) {
runParallelTests(t, 1)
}
func TestParallel2(t *testing.T) {
runParallelTests(t, 2)
}
func TestParallel3(t *testing.T) {
runParallelTests(t, 3)
}

View file

@ -12,7 +12,7 @@ import (
"github.com/vishvananda/netns"
)
const prefix = "/var/lib/docker/netns"
const prefix = "/var/run/netns"
var once sync.Once
@ -148,8 +148,8 @@ func (n *networkNamespace) RemoveInterface(i *Interface) error {
return err
}
// Move the network interface to init namespace.
if err := netlink.LinkSetNsPid(iface, 1); err != nil {
// Move the network interface to caller namespace.
if err := netlink.LinkSetNsFd(iface, int(origns)); err != nil {
fmt.Println("LinkSetNsPid failed: ", err)
return err
}