Procházet zdrojové kódy

Merge pull request #324 from mrjana/cnm

Add LeaveAll support
Madhu Venugopal před 10 roky
rodič
revize
c489e329af

+ 3 - 0
libnetwork/controller.go

@@ -87,6 +87,9 @@ type NetworkController interface {
 	// NetworkByID returns the Network which has the passed id. If not found, the error ErrNoSuchNetwork is returned.
 	// NetworkByID returns the Network which has the passed id. If not found, the error ErrNoSuchNetwork is returned.
 	NetworkByID(id string) (Network, error)
 	NetworkByID(id string) (Network, error)
 
 
+	// LeaveAll accepts a container id and attempts to leave all endpoints that the container has joined
+	LeaveAll(id string) error
+
 	// GC triggers immediate garbage collection of resources which are garbage collected.
 	// GC triggers immediate garbage collection of resources which are garbage collected.
 	GC()
 	GC()
 }
 }

+ 8 - 6
libnetwork/endpoint.go

@@ -349,11 +349,6 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) error {
 	ep.joinLeaveStart()
 	ep.joinLeaveStart()
 	defer func() {
 	defer func() {
 		ep.joinLeaveEnd()
 		ep.joinLeaveEnd()
-		if err != nil {
-			if e := ep.Leave(containerID); e != nil {
-				log.Warnf("couldnt leave endpoint : %v", ep.name, err)
-			}
-		}
 	}()
 	}()
 
 
 	ep.Lock()
 	ep.Lock()
@@ -403,6 +398,13 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) error {
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
+	defer func() {
+		if err != nil {
+			if err = driver.Leave(nid, epid); err != nil {
+				log.Warnf("driver leave failed while rolling back join: %v", err)
+			}
+		}
+	}()
 
 
 	err = ep.buildHostsFiles()
 	err = ep.buildHostsFiles()
 	if err != nil {
 	if err != nil {
@@ -421,7 +423,7 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) error {
 
 
 	sb, err := ctrlr.sandboxAdd(sboxKey, !container.config.useDefaultSandBox, ep)
 	sb, err := ctrlr.sandboxAdd(sboxKey, !container.config.useDefaultSandBox, ep)
 	if err != nil {
 	if err != nil {
-		return err
+		return fmt.Errorf("failed sandbox add: %v", err)
 	}
 	}
 	defer func() {
 	defer func() {
 		if err != nil {
 		if err != nil {

+ 1 - 1
libnetwork/error.go

@@ -51,7 +51,7 @@ func (ij ErrInvalidJoin) BadRequest() {}
 type ErrNoContainer struct{}
 type ErrNoContainer struct{}
 
 
 func (nc ErrNoContainer) Error() string {
 func (nc ErrNoContainer) Error() string {
-	return "a container has already joined the endpoint"
+	return "no container is attached to the endpoint"
 }
 }
 
 
 // Maskable denotes the type of this error
 // Maskable denotes the type of this error

+ 95 - 19
libnetwork/libnetwork_test.go

@@ -1009,6 +1009,8 @@ func TestEndpointJoin(t *testing.T) {
 		t.Fatalf("Expected an empty sandbox key for an empty endpoint. Instead found a non-empty sandbox key: %s", info.SandboxKey())
 		t.Fatalf("Expected an empty sandbox key for an empty endpoint. Instead found a non-empty sandbox key: %s", info.SandboxKey())
 	}
 	}
 
 
+	defer controller.LeaveAll(containerID)
+
 	err = ep1.Join(containerID,
 	err = ep1.Join(containerID,
 		libnetwork.JoinOptionHostname("test"),
 		libnetwork.JoinOptionHostname("test"),
 		libnetwork.JoinOptionDomainname("docker.io"),
 		libnetwork.JoinOptionDomainname("docker.io"),
@@ -1017,7 +1019,6 @@ func TestEndpointJoin(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-
 	defer func() {
 	defer func() {
 		err = ep1.Leave(containerID)
 		err = ep1.Leave(containerID)
 		runtime.LockOSThread()
 		runtime.LockOSThread()
@@ -1072,19 +1073,21 @@ func TestEndpointJoin(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-
-	if ep1.ContainerInfo().ID() != ep2.ContainerInfo().ID() {
-		t.Fatalf("ep1 and ep2 returned different container info")
-	}
-
+	runtime.LockOSThread()
 	defer func() {
 	defer func() {
 		err = ep2.Leave(containerID)
 		err = ep2.Leave(containerID)
+		runtime.LockOSThread()
 		if err != nil {
 		if err != nil {
 			t.Fatal(err)
 			t.Fatal(err)
 		}
 		}
 	}()
 	}()
 
 
+	if ep1.ContainerInfo().ID() != ep2.ContainerInfo().ID() {
+		t.Fatalf("ep1 and ep2 returned different container info")
+	}
+
 	checkSandbox(t, info)
 	checkSandbox(t, info)
+
 }
 }
 
 
 func TestEndpointJoinInvalidContainerId(t *testing.T) {
 func TestEndpointJoinInvalidContainerId(t *testing.T) {
@@ -1151,6 +1154,14 @@ func TestEndpointDeleteWithActiveContainer(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
+	defer func() {
+		err = ep.Delete()
+		if err != nil {
+			t.Fatal(err)
+		}
+	}()
+
+	defer controller.LeaveAll(containerID)
 
 
 	err = ep.Join(containerID,
 	err = ep.Join(containerID,
 		libnetwork.JoinOptionHostname("test"),
 		libnetwork.JoinOptionHostname("test"),
@@ -1166,11 +1177,6 @@ func TestEndpointDeleteWithActiveContainer(t *testing.T) {
 		if err != nil {
 		if err != nil {
 			t.Fatal(err)
 			t.Fatal(err)
 		}
 		}
-
-		err = ep.Delete()
-		if err != nil {
-			t.Fatal(err)
-		}
 	}()
 	}()
 
 
 	err = ep.Delete()
 	err = ep.Delete()
@@ -1213,6 +1219,8 @@ func TestEndpointMultipleJoins(t *testing.T) {
 		}
 		}
 	}()
 	}()
 
 
+	defer controller.LeaveAll(containerID)
+
 	err = ep.Join(containerID,
 	err = ep.Join(containerID,
 		libnetwork.JoinOptionHostname("test"),
 		libnetwork.JoinOptionHostname("test"),
 		libnetwork.JoinOptionDomainname("docker.io"),
 		libnetwork.JoinOptionDomainname("docker.io"),
@@ -1239,6 +1247,71 @@ func TestEndpointMultipleJoins(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestLeaveAll(t *testing.T) {
+	if !netutils.IsRunningInContainer() {
+		defer netutils.SetupTestNetNS(t)()
+	}
+
+	n, err := createTestNetwork(bridgeNetType, "testnetwork", options.Generic{
+		netlabel.GenericData: options.Generic{
+			"BridgeName":            "testnetwork",
+			"AllowNonDefaultBridge": true,
+		},
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer func() {
+		if err := n.Delete(); err != nil {
+			t.Fatal(err)
+		}
+	}()
+
+	ep1, err := n.CreateEndpoint("ep1")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer func() {
+		if err := ep1.Delete(); err != nil {
+			t.Fatal(err)
+		}
+	}()
+
+	ep2, err := n.CreateEndpoint("ep2")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer func() {
+		if err := ep2.Delete(); err != nil {
+			t.Fatal(err)
+		}
+	}()
+
+	err = ep1.Join("leaveall")
+	if err != nil {
+		t.Fatalf("Failed to join ep1: %v", err)
+	}
+	runtime.LockOSThread()
+
+	err = ep2.Join("leaveall")
+	if err != nil {
+		t.Fatalf("Failed to join ep2: %v", err)
+	}
+	runtime.LockOSThread()
+
+	err = ep1.Leave("leaveall")
+	if err != nil {
+		t.Fatalf("Failed to leave ep1: %v", err)
+	}
+	runtime.LockOSThread()
+
+	err = controller.LeaveAll("leaveall")
+	if err != nil {
+		t.Fatal(err)
+	}
+	runtime.LockOSThread()
+}
+
 func TestEndpointInvalidLeave(t *testing.T) {
 func TestEndpointInvalidLeave(t *testing.T) {
 	if !netutils.IsRunningInContainer() {
 	if !netutils.IsRunningInContainer() {
 		defer netutils.SetupTestNetNS(t)()
 		defer netutils.SetupTestNetNS(t)()
@@ -1280,6 +1353,8 @@ func TestEndpointInvalidLeave(t *testing.T) {
 		}
 		}
 	}
 	}
 
 
+	defer controller.LeaveAll(containerID)
+
 	err = ep.Join(containerID,
 	err = ep.Join(containerID,
 		libnetwork.JoinOptionHostname("test"),
 		libnetwork.JoinOptionHostname("test"),
 		libnetwork.JoinOptionDomainname("docker.io"),
 		libnetwork.JoinOptionDomainname("docker.io"),
@@ -1313,7 +1388,6 @@ func TestEndpointInvalidLeave(t *testing.T) {
 	if _, ok := err.(libnetwork.InvalidContainerIDError); !ok {
 	if _, ok := err.(libnetwork.InvalidContainerIDError); !ok {
 		t.Fatalf("Failed for unexpected reason: %v", err)
 		t.Fatalf("Failed for unexpected reason: %v", err)
 	}
 	}
-
 }
 }
 
 
 func TestEndpointUpdateParent(t *testing.T) {
 func TestEndpointUpdateParent(t *testing.T) {
@@ -1346,6 +1420,7 @@ func TestEndpointUpdateParent(t *testing.T) {
 		}
 		}
 	}()
 	}()
 
 
+	defer controller.LeaveAll(containerID)
 	err = ep1.Join(containerID,
 	err = ep1.Join(containerID,
 		libnetwork.JoinOptionHostname("test1"),
 		libnetwork.JoinOptionHostname("test1"),
 		libnetwork.JoinOptionDomainname("docker.io"),
 		libnetwork.JoinOptionDomainname("docker.io"),
@@ -1372,6 +1447,7 @@ func TestEndpointUpdateParent(t *testing.T) {
 		}
 		}
 	}()
 	}()
 
 
+	defer controller.LeaveAll("container2")
 	err = ep2.Join("container2",
 	err = ep2.Join("container2",
 		libnetwork.JoinOptionHostname("test2"),
 		libnetwork.JoinOptionHostname("test2"),
 		libnetwork.JoinOptionDomainname("docker.io"),
 		libnetwork.JoinOptionDomainname("docker.io"),
@@ -1382,13 +1458,11 @@ func TestEndpointUpdateParent(t *testing.T) {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
-	defer func() {
-		err = ep2.Leave("container2")
-		runtime.LockOSThread()
-		if err != nil {
-			t.Fatal(err)
-		}
-	}()
+	err = ep2.Leave("container2")
+	runtime.LockOSThread()
+	if err != nil {
+		t.Fatal(err)
+	}
 
 
 }
 }
 
 
@@ -1452,6 +1526,7 @@ func TestEnableIPv6(t *testing.T) {
 	resolvConfPath := "/tmp/libnetwork_test/resolv.conf"
 	resolvConfPath := "/tmp/libnetwork_test/resolv.conf"
 	defer os.Remove(resolvConfPath)
 	defer os.Remove(resolvConfPath)
 
 
+	defer controller.LeaveAll(containerID)
 	err = ep1.Join(containerID,
 	err = ep1.Join(containerID,
 		libnetwork.JoinOptionResolvConfPath(resolvConfPath))
 		libnetwork.JoinOptionResolvConfPath(resolvConfPath))
 	runtime.LockOSThread()
 	runtime.LockOSThread()
@@ -1536,6 +1611,7 @@ func TestResolvConf(t *testing.T) {
 	resolvConfPath := "/tmp/libnetwork_test/resolv.conf"
 	resolvConfPath := "/tmp/libnetwork_test/resolv.conf"
 	defer os.Remove(resolvConfPath)
 	defer os.Remove(resolvConfPath)
 
 
+	defer controller.LeaveAll(containerID)
 	err = ep1.Join(containerID,
 	err = ep1.Join(containerID,
 		libnetwork.JoinOptionResolvConfPath(resolvConfPath))
 		libnetwork.JoinOptionResolvConfPath(resolvConfPath))
 	runtime.LockOSThread()
 	runtime.LockOSThread()

+ 2 - 2
libnetwork/sandbox/route_linux.go

@@ -81,7 +81,7 @@ func programGateway(path string, gw net.IP, isAdd bool) error {
 	return nsInvoke(path, func(nsFD int) error { return nil }, func(callerFD int) error {
 	return nsInvoke(path, func(nsFD int) error { return nil }, func(callerFD int) error {
 		gwRoutes, err := netlink.RouteGet(gw)
 		gwRoutes, err := netlink.RouteGet(gw)
 		if err != nil {
 		if err != nil {
-			return fmt.Errorf("route for the gateway could not be found: %v", err)
+			return fmt.Errorf("route for the gateway %s could not be found: %v", gw, err)
 		}
 		}
 
 
 		if isAdd {
 		if isAdd {
@@ -105,7 +105,7 @@ func programRoute(path string, dest *net.IPNet, nh net.IP) error {
 	return nsInvoke(path, func(nsFD int) error { return nil }, func(callerFD int) error {
 	return nsInvoke(path, func(nsFD int) error { return nil }, func(callerFD int) error {
 		gwRoutes, err := netlink.RouteGet(nh)
 		gwRoutes, err := netlink.RouteGet(nh)
 		if err != nil {
 		if err != nil {
-			return fmt.Errorf("route for the next hop could not be found: %v", err)
+			return fmt.Errorf("route for the next hop %s could not be found: %v", nh, err)
 		}
 		}
 
 
 		return netlink.RouteAdd(&netlink.Route{
 		return netlink.RouteAdd(&netlink.Route{

+ 39 - 32
libnetwork/sandboxdata.go

@@ -2,6 +2,7 @@ package libnetwork
 
 
 import (
 import (
 	"container/heap"
 	"container/heap"
+	"fmt"
 	"sync"
 	"sync"
 
 
 	"github.com/Sirupsen/logrus"
 	"github.com/Sirupsen/logrus"
@@ -48,13 +49,9 @@ func (eh *epHeap) Pop() interface{} {
 
 
 func (s *sandboxData) updateGateway(ep *endpoint) error {
 func (s *sandboxData) updateGateway(ep *endpoint) error {
 	sb := s.sandbox()
 	sb := s.sandbox()
-	if err := sb.UnsetGateway(); err != nil {
-		return err
-	}
 
 
-	if err := sb.UnsetGatewayIPv6(); err != nil {
-		return err
-	}
+	sb.UnsetGateway()
+	sb.UnsetGatewayIPv6()
 
 
 	if ep == nil {
 	if ep == nil {
 		return nil
 		return nil
@@ -65,11 +62,11 @@ func (s *sandboxData) updateGateway(ep *endpoint) error {
 	ep.Unlock()
 	ep.Unlock()
 
 
 	if err := sb.SetGateway(joinInfo.gw); err != nil {
 	if err := sb.SetGateway(joinInfo.gw); err != nil {
-		return err
+		return fmt.Errorf("failed to set gateway while updating gateway: %v", err)
 	}
 	}
 
 
 	if err := sb.SetGatewayIPv6(joinInfo.gw6); err != nil {
 	if err := sb.SetGatewayIPv6(joinInfo.gw6); err != nil {
-		return err
+		return fmt.Errorf("failed to set IPv6 gateway while updating gateway: %v", err)
 	}
 	}
 
 
 	return nil
 	return nil
@@ -93,7 +90,7 @@ func (s *sandboxData) addEndpoint(ep *endpoint) error {
 		}
 		}
 
 
 		if err := sb.AddInterface(i.srcName, i.dstPrefix, ifaceOptions...); err != nil {
 		if err := sb.AddInterface(i.srcName, i.dstPrefix, ifaceOptions...); err != nil {
-			return err
+			return fmt.Errorf("failed to add interface %s to sandbox: %v", i.srcName, err)
 		}
 		}
 	}
 	}
 
 
@@ -101,7 +98,7 @@ func (s *sandboxData) addEndpoint(ep *endpoint) error {
 		// Set up non-interface routes.
 		// Set up non-interface routes.
 		for _, r := range ep.joinInfo.StaticRoutes {
 		for _, r := range ep.joinInfo.StaticRoutes {
 			if err := sb.AddStaticRoute(r); err != nil {
 			if err := sb.AddStaticRoute(r); err != nil {
-				return err
+				return fmt.Errorf("failed to add static route %s: %v", r.Destination.String(), err)
 			}
 			}
 		}
 		}
 	}
 	}
@@ -117,14 +114,10 @@ func (s *sandboxData) addEndpoint(ep *endpoint) error {
 		}
 		}
 	}
 	}
 
 
-	s.Lock()
-	s.refCnt++
-	s.Unlock()
-
 	return nil
 	return nil
 }
 }
 
 
-func (s *sandboxData) rmEndpoint(ep *endpoint) int {
+func (s *sandboxData) rmEndpoint(ep *endpoint) {
 	ep.Lock()
 	ep.Lock()
 	joinInfo := ep.joinInfo
 	joinInfo := ep.joinInfo
 	ep.Unlock()
 	ep.Unlock()
@@ -171,17 +164,6 @@ func (s *sandboxData) rmEndpoint(ep *endpoint) int {
 	if highEpBefore != highEpAfter {
 	if highEpBefore != highEpAfter {
 		s.updateGateway(highEpAfter)
 		s.updateGateway(highEpAfter)
 	}
 	}
-
-	s.Lock()
-	s.refCnt--
-	refCnt := s.refCnt
-	s.Unlock()
-
-	if refCnt == 0 {
-		s.sandbox().Destroy()
-	}
-
-	return refCnt
 }
 }
 
 
 func (s *sandboxData) sandbox() sandbox.Sandbox {
 func (s *sandboxData) sandbox() sandbox.Sandbox {
@@ -199,7 +181,7 @@ func (c *controller) sandboxAdd(key string, create bool, ep *endpoint) (sandbox.
 	if !ok {
 	if !ok {
 		sb, err := sandbox.NewSandbox(key, create)
 		sb, err := sandbox.NewSandbox(key, create)
 		if err != nil {
 		if err != nil {
-			return nil, err
+			return nil, fmt.Errorf("failed to create new sandbox: %v", err)
 		}
 		}
 
 
 		sData = &sandboxData{
 		sData = &sandboxData{
@@ -225,11 +207,7 @@ func (c *controller) sandboxRm(key string, ep *endpoint) {
 	sData := c.sandboxes[key]
 	sData := c.sandboxes[key]
 	c.Unlock()
 	c.Unlock()
 
 
-	if sData.rmEndpoint(ep) == 0 {
-		c.Lock()
-		delete(c.sandboxes, key)
-		c.Unlock()
-	}
+	sData.rmEndpoint(ep)
 }
 }
 
 
 func (c *controller) sandboxGet(key string) sandbox.Sandbox {
 func (c *controller) sandboxGet(key string) sandbox.Sandbox {
@@ -243,3 +221,32 @@ func (c *controller) sandboxGet(key string) sandbox.Sandbox {
 
 
 	return sData.sandbox()
 	return sData.sandbox()
 }
 }
+
+func (c *controller) LeaveAll(id string) error {
+	c.Lock()
+	sData, ok := c.sandboxes[sandbox.GenerateKey(id)]
+	c.Unlock()
+
+	if !ok {
+		c.Unlock()
+		return fmt.Errorf("could not find sandbox for container id %s", id)
+	}
+
+	sData.Lock()
+	eps := make([]*endpoint, len(sData.endpoints))
+	for i, ep := range sData.endpoints {
+		eps[i] = ep
+	}
+	sData.Unlock()
+
+	for _, ep := range eps {
+		if err := ep.Leave(id); err != nil {
+			logrus.Warnf("Failed leaving endpoint id %s: %v\n", ep.ID(), err)
+		}
+	}
+
+	sData.sandbox().Destroy()
+	delete(c.sandboxes, sandbox.GenerateKey(id))
+
+	return nil
+}

+ 35 - 34
libnetwork/sandboxdata_test.go

@@ -22,16 +22,13 @@ func TestSandboxAddEmpty(t *testing.T) {
 	ctrlr := createEmptyCtrlr()
 	ctrlr := createEmptyCtrlr()
 	ep := createEmptyEndpoint()
 	ep := createEmptyEndpoint()
 
 
-	if _, err := ctrlr.sandboxAdd("sandbox1", true, ep); err != nil {
+	if _, err := ctrlr.sandboxAdd(sandbox.GenerateKey("sandbox1"), true, ep); err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
-	if ctrlr.sandboxes["sandbox1"].refCnt != 1 {
-		t.Fatalf("Unexpected sandbox ref count. Expected 1, got %d",
-			ctrlr.sandboxes["sandbox1"].refCnt)
-	}
+	ctrlr.sandboxRm(sandbox.GenerateKey("sandbox1"), ep)
 
 
-	ctrlr.sandboxRm("sandbox1", ep)
+	ctrlr.LeaveAll("sandbox1")
 	if len(ctrlr.sandboxes) != 0 {
 	if len(ctrlr.sandboxes) != 0 {
 		t.Fatalf("controller sandboxes is not empty. len = %d", len(ctrlr.sandboxes))
 		t.Fatalf("controller sandboxes is not empty. len = %d", len(ctrlr.sandboxes))
 	}
 	}
@@ -49,50 +46,52 @@ func TestSandboxAddMultiPrio(t *testing.T) {
 	ep2.container.config.prio = 2
 	ep2.container.config.prio = 2
 	ep3.container.config.prio = 3
 	ep3.container.config.prio = 3
 
 
-	if _, err := ctrlr.sandboxAdd("sandbox1", true, ep1); err != nil {
-		t.Fatal(err)
-	}
+	sKey := sandbox.GenerateKey("sandbox1")
 
 
-	if _, err := ctrlr.sandboxAdd("sandbox1", true, ep2); err != nil {
+	if _, err := ctrlr.sandboxAdd(sKey, true, ep1); err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
-	if _, err := ctrlr.sandboxAdd("sandbox1", true, ep3); err != nil {
+	if _, err := ctrlr.sandboxAdd(sKey, true, ep2); err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
-	if ctrlr.sandboxes["sandbox1"].refCnt != 3 {
-		t.Fatalf("Unexpected sandbox ref count. Expected 3, got %d",
-			ctrlr.sandboxes["sandbox1"].refCnt)
+	if _, err := ctrlr.sandboxAdd(sKey, true, ep3); err != nil {
+		t.Fatal(err)
 	}
 	}
 
 
-	if ctrlr.sandboxes["sandbox1"].endpoints[0] != ep3 {
+	if ctrlr.sandboxes[sKey].endpoints[0] != ep3 {
 		t.Fatal("Expected ep3 to be at the top of the heap. But did not find ep3 at the top of the heap")
 		t.Fatal("Expected ep3 to be at the top of the heap. But did not find ep3 at the top of the heap")
 	}
 	}
 
 
-	ctrlr.sandboxRm("sandbox1", ep3)
+	ctrlr.sandboxRm(sKey, ep3)
 
 
-	if ctrlr.sandboxes["sandbox1"].endpoints[0] != ep2 {
+	if ctrlr.sandboxes[sKey].endpoints[0] != ep2 {
 		t.Fatal("Expected ep2 to be at the top of the heap after removing ep3. But did not find ep2 at the top of the heap")
 		t.Fatal("Expected ep2 to be at the top of the heap after removing ep3. But did not find ep2 at the top of the heap")
 	}
 	}
 
 
-	ctrlr.sandboxRm("sandbox1", ep2)
+	ctrlr.sandboxRm(sKey, ep2)
 
 
-	if ctrlr.sandboxes["sandbox1"].endpoints[0] != ep1 {
+	if ctrlr.sandboxes[sKey].endpoints[0] != ep1 {
 		t.Fatal("Expected ep1 to be at the top of the heap after removing ep2. But did not find ep1 at the top of the heap")
 		t.Fatal("Expected ep1 to be at the top of the heap after removing ep2. But did not find ep1 at the top of the heap")
 	}
 	}
 
 
 	// Re-add ep3 back
 	// Re-add ep3 back
-	if _, err := ctrlr.sandboxAdd("sandbox1", true, ep3); err != nil {
+	if _, err := ctrlr.sandboxAdd(sKey, true, ep3); err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
-	if ctrlr.sandboxes["sandbox1"].endpoints[0] != ep3 {
+	if ctrlr.sandboxes[sKey].endpoints[0] != ep3 {
 		t.Fatal("Expected ep3 to be at the top of the heap after adding ep3 back. But did not find ep3 at the top of the heap")
 		t.Fatal("Expected ep3 to be at the top of the heap after adding ep3 back. But did not find ep3 at the top of the heap")
 	}
 	}
 
 
-	ctrlr.sandboxRm("sandbox1", ep3)
-	ctrlr.sandboxRm("sandbox1", ep1)
+	ctrlr.sandboxRm(sKey, ep3)
+	ctrlr.sandboxRm(sKey, ep1)
+
+	if err := ctrlr.LeaveAll("sandbox1"); err != nil {
+		t.Fatal(err)
+	}
+
 	if len(ctrlr.sandboxes) != 0 {
 	if len(ctrlr.sandboxes) != 0 {
 		t.Fatalf("controller sandboxes is not empty. len = %d", len(ctrlr.sandboxes))
 		t.Fatalf("controller sandboxes is not empty. len = %d", len(ctrlr.sandboxes))
 	}
 	}
@@ -108,30 +107,32 @@ func TestSandboxAddSamePrio(t *testing.T) {
 	ep1.network = &network{name: "aaa"}
 	ep1.network = &network{name: "aaa"}
 	ep2.network = &network{name: "bbb"}
 	ep2.network = &network{name: "bbb"}
 
 
-	if _, err := ctrlr.sandboxAdd("sandbox1", true, ep1); err != nil {
-		t.Fatal(err)
-	}
+	sKey := sandbox.GenerateKey("sandbox1")
 
 
-	if _, err := ctrlr.sandboxAdd("sandbox1", true, ep2); err != nil {
+	if _, err := ctrlr.sandboxAdd(sKey, true, ep1); err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
 
 
-	if ctrlr.sandboxes["sandbox1"].refCnt != 2 {
-		t.Fatalf("Unexpected sandbox ref count. Expected 2, got %d",
-			ctrlr.sandboxes["sandbox1"].refCnt)
+	if _, err := ctrlr.sandboxAdd(sKey, true, ep2); err != nil {
+		t.Fatal(err)
 	}
 	}
 
 
-	if ctrlr.sandboxes["sandbox1"].endpoints[0] != ep1 {
+	if ctrlr.sandboxes[sKey].endpoints[0] != ep1 {
 		t.Fatal("Expected ep1 to be at the top of the heap. But did not find ep1 at the top of the heap")
 		t.Fatal("Expected ep1 to be at the top of the heap. But did not find ep1 at the top of the heap")
 	}
 	}
 
 
-	ctrlr.sandboxRm("sandbox1", ep1)
+	ctrlr.sandboxRm(sKey, ep1)
 
 
-	if ctrlr.sandboxes["sandbox1"].endpoints[0] != ep2 {
+	if ctrlr.sandboxes[sKey].endpoints[0] != ep2 {
 		t.Fatal("Expected ep2 to be at the top of the heap after removing ep3. But did not find ep2 at the top of the heap")
 		t.Fatal("Expected ep2 to be at the top of the heap after removing ep3. But did not find ep2 at the top of the heap")
 	}
 	}
 
 
-	ctrlr.sandboxRm("sandbox1", ep2)
+	ctrlr.sandboxRm(sKey, ep2)
+
+	if err := ctrlr.LeaveAll("sandbox1"); err != nil {
+		t.Fatal(err)
+	}
+
 	if len(ctrlr.sandboxes) != 0 {
 	if len(ctrlr.sandboxes) != 0 {
 		t.Fatalf("controller sandboxes is not empty. len = %d", len(ctrlr.sandboxes))
 		t.Fatalf("controller sandboxes is not empty. len = %d", len(ctrlr.sandboxes))
 	}
 	}