瀏覽代碼

Merge pull request #324 from mrjana/cnm

Add LeaveAll support
Madhu Venugopal 10 年之前
父節點
當前提交
c489e329af

+ 3 - 0
libnetwork/controller.go

@@ -87,6 +87,9 @@ type NetworkController interface {
 	// NetworkByID returns the Network which has the passed id. If not found, the error ErrNoSuchNetwork is returned.
 	NetworkByID(id string) (Network, error)
 
+	// LeaveAll accepts a container id and attempts to leave all endpoints that the container has joined
+	LeaveAll(id string) error
+
 	// GC triggers immediate garbage collection of resources which are garbage collected.
 	GC()
 }

+ 8 - 6
libnetwork/endpoint.go

@@ -349,11 +349,6 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) error {
 	ep.joinLeaveStart()
 	defer func() {
 		ep.joinLeaveEnd()
-		if err != nil {
-			if e := ep.Leave(containerID); e != nil {
-				log.Warnf("couldnt leave endpoint : %v", ep.name, err)
-			}
-		}
 	}()
 
 	ep.Lock()
@@ -403,6 +398,13 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) error {
 	if err != nil {
 		return err
 	}
+	defer func() {
+		if err != nil {
+			if err = driver.Leave(nid, epid); err != nil {
+				log.Warnf("driver leave failed while rolling back join: %v", err)
+			}
+		}
+	}()
 
 	err = ep.buildHostsFiles()
 	if err != nil {
@@ -421,7 +423,7 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) error {
 
 	sb, err := ctrlr.sandboxAdd(sboxKey, !container.config.useDefaultSandBox, ep)
 	if err != nil {
-		return err
+		return fmt.Errorf("failed sandbox add: %v", err)
 	}
 	defer func() {
 		if err != nil {

+ 1 - 1
libnetwork/error.go

@@ -51,7 +51,7 @@ func (ij ErrInvalidJoin) BadRequest() {}
 type ErrNoContainer struct{}
 
 func (nc ErrNoContainer) Error() string {
-	return "a container has already joined the endpoint"
+	return "no container is attached to the endpoint"
 }
 
 // Maskable denotes the type of this error

+ 95 - 19
libnetwork/libnetwork_test.go

@@ -1009,6 +1009,8 @@ func TestEndpointJoin(t *testing.T) {
 		t.Fatalf("Expected an empty sandbox key for an empty endpoint. Instead found a non-empty sandbox key: %s", info.SandboxKey())
 	}
 
+	defer controller.LeaveAll(containerID)
+
 	err = ep1.Join(containerID,
 		libnetwork.JoinOptionHostname("test"),
 		libnetwork.JoinOptionDomainname("docker.io"),
@@ -1017,7 +1019,6 @@ func TestEndpointJoin(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
-
 	defer func() {
 		err = ep1.Leave(containerID)
 		runtime.LockOSThread()
@@ -1072,19 +1073,21 @@ func TestEndpointJoin(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
-
-	if ep1.ContainerInfo().ID() != ep2.ContainerInfo().ID() {
-		t.Fatalf("ep1 and ep2 returned different container info")
-	}
-
+	runtime.LockOSThread()
 	defer func() {
 		err = ep2.Leave(containerID)
+		runtime.LockOSThread()
 		if err != nil {
 			t.Fatal(err)
 		}
 	}()
 
+	if ep1.ContainerInfo().ID() != ep2.ContainerInfo().ID() {
+		t.Fatalf("ep1 and ep2 returned different container info")
+	}
+
 	checkSandbox(t, info)
+
 }
 
 func TestEndpointJoinInvalidContainerId(t *testing.T) {
@@ -1151,6 +1154,14 @@ func TestEndpointDeleteWithActiveContainer(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
+	defer func() {
+		err = ep.Delete()
+		if err != nil {
+			t.Fatal(err)
+		}
+	}()
+
+	defer controller.LeaveAll(containerID)
 
 	err = ep.Join(containerID,
 		libnetwork.JoinOptionHostname("test"),
@@ -1166,11 +1177,6 @@ func TestEndpointDeleteWithActiveContainer(t *testing.T) {
 		if err != nil {
 			t.Fatal(err)
 		}
-
-		err = ep.Delete()
-		if err != nil {
-			t.Fatal(err)
-		}
 	}()
 
 	err = ep.Delete()
@@ -1213,6 +1219,8 @@ func TestEndpointMultipleJoins(t *testing.T) {
 		}
 	}()
 
+	defer controller.LeaveAll(containerID)
+
 	err = ep.Join(containerID,
 		libnetwork.JoinOptionHostname("test"),
 		libnetwork.JoinOptionDomainname("docker.io"),
@@ -1239,6 +1247,71 @@ func TestEndpointMultipleJoins(t *testing.T) {
 	}
 }
 
+func TestLeaveAll(t *testing.T) {
+	if !netutils.IsRunningInContainer() {
+		defer netutils.SetupTestNetNS(t)()
+	}
+
+	n, err := createTestNetwork(bridgeNetType, "testnetwork", options.Generic{
+		netlabel.GenericData: options.Generic{
+			"BridgeName":            "testnetwork",
+			"AllowNonDefaultBridge": true,
+		},
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer func() {
+		if err := n.Delete(); err != nil {
+			t.Fatal(err)
+		}
+	}()
+
+	ep1, err := n.CreateEndpoint("ep1")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer func() {
+		if err := ep1.Delete(); err != nil {
+			t.Fatal(err)
+		}
+	}()
+
+	ep2, err := n.CreateEndpoint("ep2")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer func() {
+		if err := ep2.Delete(); err != nil {
+			t.Fatal(err)
+		}
+	}()
+
+	err = ep1.Join("leaveall")
+	if err != nil {
+		t.Fatalf("Failed to join ep1: %v", err)
+	}
+	runtime.LockOSThread()
+
+	err = ep2.Join("leaveall")
+	if err != nil {
+		t.Fatalf("Failed to join ep2: %v", err)
+	}
+	runtime.LockOSThread()
+
+	err = ep1.Leave("leaveall")
+	if err != nil {
+		t.Fatalf("Failed to leave ep1: %v", err)
+	}
+	runtime.LockOSThread()
+
+	err = controller.LeaveAll("leaveall")
+	if err != nil {
+		t.Fatal(err)
+	}
+	runtime.LockOSThread()
+}
+
 func TestEndpointInvalidLeave(t *testing.T) {
 	if !netutils.IsRunningInContainer() {
 		defer netutils.SetupTestNetNS(t)()
@@ -1280,6 +1353,8 @@ func TestEndpointInvalidLeave(t *testing.T) {
 		}
 	}
 
+	defer controller.LeaveAll(containerID)
+
 	err = ep.Join(containerID,
 		libnetwork.JoinOptionHostname("test"),
 		libnetwork.JoinOptionDomainname("docker.io"),
@@ -1313,7 +1388,6 @@ func TestEndpointInvalidLeave(t *testing.T) {
 	if _, ok := err.(libnetwork.InvalidContainerIDError); !ok {
 		t.Fatalf("Failed for unexpected reason: %v", err)
 	}
-
 }
 
 func TestEndpointUpdateParent(t *testing.T) {
@@ -1346,6 +1420,7 @@ func TestEndpointUpdateParent(t *testing.T) {
 		}
 	}()
 
+	defer controller.LeaveAll(containerID)
 	err = ep1.Join(containerID,
 		libnetwork.JoinOptionHostname("test1"),
 		libnetwork.JoinOptionDomainname("docker.io"),
@@ -1372,6 +1447,7 @@ func TestEndpointUpdateParent(t *testing.T) {
 		}
 	}()
 
+	defer controller.LeaveAll("container2")
 	err = ep2.Join("container2",
 		libnetwork.JoinOptionHostname("test2"),
 		libnetwork.JoinOptionDomainname("docker.io"),
@@ -1382,13 +1458,11 @@ func TestEndpointUpdateParent(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	defer func() {
-		err = ep2.Leave("container2")
-		runtime.LockOSThread()
-		if err != nil {
-			t.Fatal(err)
-		}
-	}()
+	err = ep2.Leave("container2")
+	runtime.LockOSThread()
+	if err != nil {
+		t.Fatal(err)
+	}
 
 }
 
@@ -1452,6 +1526,7 @@ func TestEnableIPv6(t *testing.T) {
 	resolvConfPath := "/tmp/libnetwork_test/resolv.conf"
 	defer os.Remove(resolvConfPath)
 
+	defer controller.LeaveAll(containerID)
 	err = ep1.Join(containerID,
 		libnetwork.JoinOptionResolvConfPath(resolvConfPath))
 	runtime.LockOSThread()
@@ -1536,6 +1611,7 @@ func TestResolvConf(t *testing.T) {
 	resolvConfPath := "/tmp/libnetwork_test/resolv.conf"
 	defer os.Remove(resolvConfPath)
 
+	defer controller.LeaveAll(containerID)
 	err = ep1.Join(containerID,
 		libnetwork.JoinOptionResolvConfPath(resolvConfPath))
 	runtime.LockOSThread()

+ 2 - 2
libnetwork/sandbox/route_linux.go

@@ -81,7 +81,7 @@ func programGateway(path string, gw net.IP, isAdd bool) error {
 	return nsInvoke(path, func(nsFD int) error { return nil }, func(callerFD int) error {
 		gwRoutes, err := netlink.RouteGet(gw)
 		if err != nil {
-			return fmt.Errorf("route for the gateway could not be found: %v", err)
+			return fmt.Errorf("route for the gateway %s could not be found: %v", gw, err)
 		}
 
 		if isAdd {
@@ -105,7 +105,7 @@ func programRoute(path string, dest *net.IPNet, nh net.IP) error {
 	return nsInvoke(path, func(nsFD int) error { return nil }, func(callerFD int) error {
 		gwRoutes, err := netlink.RouteGet(nh)
 		if err != nil {
-			return fmt.Errorf("route for the next hop could not be found: %v", err)
+			return fmt.Errorf("route for the next hop %s could not be found: %v", nh, err)
 		}
 
 		return netlink.RouteAdd(&netlink.Route{

+ 39 - 32
libnetwork/sandboxdata.go

@@ -2,6 +2,7 @@ package libnetwork
 
 import (
 	"container/heap"
+	"fmt"
 	"sync"
 
 	"github.com/Sirupsen/logrus"
@@ -48,13 +49,9 @@ func (eh *epHeap) Pop() interface{} {
 
 func (s *sandboxData) updateGateway(ep *endpoint) error {
 	sb := s.sandbox()
-	if err := sb.UnsetGateway(); err != nil {
-		return err
-	}
 
-	if err := sb.UnsetGatewayIPv6(); err != nil {
-		return err
-	}
+	sb.UnsetGateway()
+	sb.UnsetGatewayIPv6()
 
 	if ep == nil {
 		return nil
@@ -65,11 +62,11 @@ func (s *sandboxData) updateGateway(ep *endpoint) error {
 	ep.Unlock()
 
 	if err := sb.SetGateway(joinInfo.gw); err != nil {
-		return err
+		return fmt.Errorf("failed to set gateway while updating gateway: %v", err)
 	}
 
 	if err := sb.SetGatewayIPv6(joinInfo.gw6); err != nil {
-		return err
+		return fmt.Errorf("failed to set IPv6 gateway while updating gateway: %v", err)
 	}
 
 	return nil
@@ -93,7 +90,7 @@ func (s *sandboxData) addEndpoint(ep *endpoint) error {
 		}
 
 		if err := sb.AddInterface(i.srcName, i.dstPrefix, ifaceOptions...); err != nil {
-			return err
+			return fmt.Errorf("failed to add interface %s to sandbox: %v", i.srcName, err)
 		}
 	}
 
@@ -101,7 +98,7 @@ func (s *sandboxData) addEndpoint(ep *endpoint) error {
 		// Set up non-interface routes.
 		for _, r := range ep.joinInfo.StaticRoutes {
 			if err := sb.AddStaticRoute(r); err != nil {
-				return err
+				return fmt.Errorf("failed to add static route %s: %v", r.Destination.String(), err)
 			}
 		}
 	}
@@ -117,14 +114,10 @@ func (s *sandboxData) addEndpoint(ep *endpoint) error {
 		}
 	}
 
-	s.Lock()
-	s.refCnt++
-	s.Unlock()
-
 	return nil
 }
 
-func (s *sandboxData) rmEndpoint(ep *endpoint) int {
+func (s *sandboxData) rmEndpoint(ep *endpoint) {
 	ep.Lock()
 	joinInfo := ep.joinInfo
 	ep.Unlock()
@@ -171,17 +164,6 @@ func (s *sandboxData) rmEndpoint(ep *endpoint) int {
 	if highEpBefore != highEpAfter {
 		s.updateGateway(highEpAfter)
 	}
-
-	s.Lock()
-	s.refCnt--
-	refCnt := s.refCnt
-	s.Unlock()
-
-	if refCnt == 0 {
-		s.sandbox().Destroy()
-	}
-
-	return refCnt
 }
 
 func (s *sandboxData) sandbox() sandbox.Sandbox {
@@ -199,7 +181,7 @@ func (c *controller) sandboxAdd(key string, create bool, ep *endpoint) (sandbox.
 	if !ok {
 		sb, err := sandbox.NewSandbox(key, create)
 		if err != nil {
-			return nil, err
+			return nil, fmt.Errorf("failed to create new sandbox: %v", err)
 		}
 
 		sData = &sandboxData{
@@ -225,11 +207,7 @@ func (c *controller) sandboxRm(key string, ep *endpoint) {
 	sData := c.sandboxes[key]
 	c.Unlock()
 
-	if sData.rmEndpoint(ep) == 0 {
-		c.Lock()
-		delete(c.sandboxes, key)
-		c.Unlock()
-	}
+	sData.rmEndpoint(ep)
 }
 
 func (c *controller) sandboxGet(key string) sandbox.Sandbox {
@@ -243,3 +221,32 @@ func (c *controller) sandboxGet(key string) sandbox.Sandbox {
 
 	return sData.sandbox()
 }
+
+func (c *controller) LeaveAll(id string) error {
+	c.Lock()
+	sData, ok := c.sandboxes[sandbox.GenerateKey(id)]
+	c.Unlock()
+
+	if !ok {
+		c.Unlock()
+		return fmt.Errorf("could not find sandbox for container id %s", id)
+	}
+
+	sData.Lock()
+	eps := make([]*endpoint, len(sData.endpoints))
+	for i, ep := range sData.endpoints {
+		eps[i] = ep
+	}
+	sData.Unlock()
+
+	for _, ep := range eps {
+		if err := ep.Leave(id); err != nil {
+			logrus.Warnf("Failed leaving endpoint id %s: %v\n", ep.ID(), err)
+		}
+	}
+
+	sData.sandbox().Destroy()
+	delete(c.sandboxes, sandbox.GenerateKey(id))
+
+	return nil
+}

+ 35 - 34
libnetwork/sandboxdata_test.go

@@ -22,16 +22,13 @@ func TestSandboxAddEmpty(t *testing.T) {
 	ctrlr := createEmptyCtrlr()
 	ep := createEmptyEndpoint()
 
-	if _, err := ctrlr.sandboxAdd("sandbox1", true, ep); err != nil {
+	if _, err := ctrlr.sandboxAdd(sandbox.GenerateKey("sandbox1"), true, ep); err != nil {
 		t.Fatal(err)
 	}
 
-	if ctrlr.sandboxes["sandbox1"].refCnt != 1 {
-		t.Fatalf("Unexpected sandbox ref count. Expected 1, got %d",
-			ctrlr.sandboxes["sandbox1"].refCnt)
-	}
+	ctrlr.sandboxRm(sandbox.GenerateKey("sandbox1"), ep)
 
-	ctrlr.sandboxRm("sandbox1", ep)
+	ctrlr.LeaveAll("sandbox1")
 	if len(ctrlr.sandboxes) != 0 {
 		t.Fatalf("controller sandboxes is not empty. len = %d", len(ctrlr.sandboxes))
 	}
@@ -49,50 +46,52 @@ func TestSandboxAddMultiPrio(t *testing.T) {
 	ep2.container.config.prio = 2
 	ep3.container.config.prio = 3
 
-	if _, err := ctrlr.sandboxAdd("sandbox1", true, ep1); err != nil {
-		t.Fatal(err)
-	}
+	sKey := sandbox.GenerateKey("sandbox1")
 
-	if _, err := ctrlr.sandboxAdd("sandbox1", true, ep2); err != nil {
+	if _, err := ctrlr.sandboxAdd(sKey, true, ep1); err != nil {
 		t.Fatal(err)
 	}
 
-	if _, err := ctrlr.sandboxAdd("sandbox1", true, ep3); err != nil {
+	if _, err := ctrlr.sandboxAdd(sKey, true, ep2); err != nil {
 		t.Fatal(err)
 	}
 
-	if ctrlr.sandboxes["sandbox1"].refCnt != 3 {
-		t.Fatalf("Unexpected sandbox ref count. Expected 3, got %d",
-			ctrlr.sandboxes["sandbox1"].refCnt)
+	if _, err := ctrlr.sandboxAdd(sKey, true, ep3); err != nil {
+		t.Fatal(err)
 	}
 
-	if ctrlr.sandboxes["sandbox1"].endpoints[0] != ep3 {
+	if ctrlr.sandboxes[sKey].endpoints[0] != ep3 {
 		t.Fatal("Expected ep3 to be at the top of the heap. But did not find ep3 at the top of the heap")
 	}
 
-	ctrlr.sandboxRm("sandbox1", ep3)
+	ctrlr.sandboxRm(sKey, ep3)
 
-	if ctrlr.sandboxes["sandbox1"].endpoints[0] != ep2 {
+	if ctrlr.sandboxes[sKey].endpoints[0] != ep2 {
 		t.Fatal("Expected ep2 to be at the top of the heap after removing ep3. But did not find ep2 at the top of the heap")
 	}
 
-	ctrlr.sandboxRm("sandbox1", ep2)
+	ctrlr.sandboxRm(sKey, ep2)
 
-	if ctrlr.sandboxes["sandbox1"].endpoints[0] != ep1 {
+	if ctrlr.sandboxes[sKey].endpoints[0] != ep1 {
 		t.Fatal("Expected ep1 to be at the top of the heap after removing ep2. But did not find ep1 at the top of the heap")
 	}
 
 	// Re-add ep3 back
-	if _, err := ctrlr.sandboxAdd("sandbox1", true, ep3); err != nil {
+	if _, err := ctrlr.sandboxAdd(sKey, true, ep3); err != nil {
 		t.Fatal(err)
 	}
 
-	if ctrlr.sandboxes["sandbox1"].endpoints[0] != ep3 {
+	if ctrlr.sandboxes[sKey].endpoints[0] != ep3 {
 		t.Fatal("Expected ep3 to be at the top of the heap after adding ep3 back. But did not find ep3 at the top of the heap")
 	}
 
-	ctrlr.sandboxRm("sandbox1", ep3)
-	ctrlr.sandboxRm("sandbox1", ep1)
+	ctrlr.sandboxRm(sKey, ep3)
+	ctrlr.sandboxRm(sKey, ep1)
+
+	if err := ctrlr.LeaveAll("sandbox1"); err != nil {
+		t.Fatal(err)
+	}
+
 	if len(ctrlr.sandboxes) != 0 {
 		t.Fatalf("controller sandboxes is not empty. len = %d", len(ctrlr.sandboxes))
 	}
@@ -108,30 +107,32 @@ func TestSandboxAddSamePrio(t *testing.T) {
 	ep1.network = &network{name: "aaa"}
 	ep2.network = &network{name: "bbb"}
 
-	if _, err := ctrlr.sandboxAdd("sandbox1", true, ep1); err != nil {
-		t.Fatal(err)
-	}
+	sKey := sandbox.GenerateKey("sandbox1")
 
-	if _, err := ctrlr.sandboxAdd("sandbox1", true, ep2); err != nil {
+	if _, err := ctrlr.sandboxAdd(sKey, true, ep1); err != nil {
 		t.Fatal(err)
 	}
 
-	if ctrlr.sandboxes["sandbox1"].refCnt != 2 {
-		t.Fatalf("Unexpected sandbox ref count. Expected 2, got %d",
-			ctrlr.sandboxes["sandbox1"].refCnt)
+	if _, err := ctrlr.sandboxAdd(sKey, true, ep2); err != nil {
+		t.Fatal(err)
 	}
 
-	if ctrlr.sandboxes["sandbox1"].endpoints[0] != ep1 {
+	if ctrlr.sandboxes[sKey].endpoints[0] != ep1 {
 		t.Fatal("Expected ep1 to be at the top of the heap. But did not find ep1 at the top of the heap")
 	}
 
-	ctrlr.sandboxRm("sandbox1", ep1)
+	ctrlr.sandboxRm(sKey, ep1)
 
-	if ctrlr.sandboxes["sandbox1"].endpoints[0] != ep2 {
+	if ctrlr.sandboxes[sKey].endpoints[0] != ep2 {
 		t.Fatal("Expected ep2 to be at the top of the heap after removing ep3. But did not find ep2 at the top of the heap")
 	}
 
-	ctrlr.sandboxRm("sandbox1", ep2)
+	ctrlr.sandboxRm(sKey, ep2)
+
+	if err := ctrlr.LeaveAll("sandbox1"); err != nil {
+		t.Fatal(err)
+	}
+
 	if len(ctrlr.sandboxes) != 0 {
 		t.Fatalf("controller sandboxes is not empty. len = %d", len(ctrlr.sandboxes))
 	}