浏览代码

Overlay driver to support network layer encryption

Signed-off-by: Alessandro Boch <aboch@docker.com>
Alessandro Boch 9 年之前
父节点
当前提交
93b5073a7d

+ 59 - 0
libnetwork/agent.go

@@ -3,10 +3,12 @@ package libnetwork
 //go:generate protoc -I.:Godeps/_workspace/src/github.com/gogo/protobuf  --gogo_out=import_path=github.com/docker/libnetwork,Mgogoproto/gogo.proto=github.com/gogo/protobuf/gogoproto:. agent.proto
 
 import (
+	"encoding/hex"
 	"fmt"
 	"net"
 	"os"
 	"sort"
+	"strconv"
 
 	"github.com/Sirupsen/logrus"
 	"github.com/docker/go-events"
@@ -72,6 +74,8 @@ func resolveAddr(addrOrInterface string) (string, error) {
 }
 
 func (c *controller) handleKeyChange(keys []*types.EncryptionKey) error {
+	drvEnc := discoverapi.DriverEncryptionUpdate{}
+
 	// Find the new key and add it to the key ring
 	a := c.agent
 	for _, key := range keys {
@@ -86,6 +90,10 @@ func (c *controller) handleKeyChange(keys []*types.EncryptionKey) error {
 			if key.Subsystem == "networking:gossip" {
 				a.networkDB.SetKey(key.Key)
 			}
+			if key.Subsystem == "networking:gossip" /*"networking:ipsec"*/ {
+				drvEnc.Key = hex.EncodeToString(key.Key)
+				drvEnc.Tag = strconv.FormatUint(key.LamportTime, 10)
+			}
 			break
 		}
 	}
@@ -103,6 +111,10 @@ func (c *controller) handleKeyChange(keys []*types.EncryptionKey) error {
 			if cKey.Subsystem == "networking:gossip" {
 				deleted = cKey.Key
 			}
+			if cKey.Subsystem == "networking:gossip" /*"networking:ipsec"*/ {
+				drvEnc.Prune = hex.EncodeToString(cKey.Key)
+				drvEnc.PruneTag = strconv.FormatUint(cKey.LamportTime, 10)
+			}
 			c.keys = append(c.keys[:i], c.keys[i+1:]...)
 			break
 		}
@@ -115,9 +127,25 @@ func (c *controller) handleKeyChange(keys []*types.EncryptionKey) error {
 			break
 		}
 	}
+	for _, key := range c.keys {
+		if key.Subsystem == "networking:gossip" /*"networking:ipsec"*/ {
+			drvEnc.Primary = hex.EncodeToString(key.Key)
+			drvEnc.PrimaryTag = strconv.FormatUint(key.LamportTime, 10)
+			break
+		}
+	}
 	if len(deleted) > 0 {
 		a.networkDB.RemoveKey(deleted)
 	}
+
+	c.drvRegistry.WalkDrivers(func(name string, driver driverapi.Driver, capability driverapi.Capability) bool {
+		err := driver.DiscoverNew(discoverapi.EncryptionKeysUpdate, drvEnc)
+		if err != nil {
+			logrus.Warnf("Failed to update datapath keys in driver %s: %v", name, err)
+		}
+		return false
+	})
+
 	return nil
 }
 
@@ -170,6 +198,8 @@ func (c *controller) agentInit(bindAddrOrInterface string) error {
 		return nil
 	}
 
+	drvEnc := discoverapi.DriverEncryptionConfig{}
+
 	// sort the keys by lamport time
 	sort.Sort(ByTime(c.keys))
 
@@ -178,6 +208,10 @@ func (c *controller) agentInit(bindAddrOrInterface string) error {
 		if key.Subsystem == "networking:gossip" {
 			gossipkey = append(gossipkey, key.Key)
 		}
+		if key.Subsystem == "networking:gossip" /*"networking:ipsec"*/ {
+			drvEnc.Keys = append(drvEnc.Keys, hex.EncodeToString(key.Key))
+			drvEnc.Tags = append(drvEnc.Tags, strconv.FormatUint(key.LamportTime, 10))
+		}
 	}
 
 	bindAddr, err := resolveAddr(bindAddrOrInterface)
@@ -206,6 +240,15 @@ func (c *controller) agentInit(bindAddrOrInterface string) error {
 	}
 
 	go c.handleTableEvents(ch, c.handleEpTableEvent)
+
+	c.drvRegistry.WalkDrivers(func(name string, driver driverapi.Driver, capability driverapi.Capability) bool {
+		err := driver.DiscoverNew(discoverapi.EncryptionKeysConfig, drvEnc)
+		if err != nil {
+			logrus.Warnf("Failed to set datapath keys in driver %s: %v", name, err)
+		}
+		return false
+	})
+
 	return nil
 }
 
@@ -226,6 +269,22 @@ func (c *controller) agentDriverNotify(d driverapi.Driver) {
 		Address: c.agent.bindAddr,
 		Self:    true,
 	})
+
+	drvEnc := discoverapi.DriverEncryptionConfig{}
+	for _, key := range c.keys {
+		if key.Subsystem == "networking:gossip" /*"networking:ipsec"*/ {
+			drvEnc.Keys = append(drvEnc.Keys, hex.EncodeToString(key.Key))
+			drvEnc.Tags = append(drvEnc.Tags, strconv.FormatUint(key.LamportTime, 10))
+		}
+	}
+	c.drvRegistry.WalkDrivers(func(name string, driver driverapi.Driver, capability driverapi.Capability) bool {
+		err := driver.DiscoverNew(discoverapi.EncryptionKeysConfig, drvEnc)
+		if err != nil {
+			logrus.Warnf("Failed to set datapath keys in driver %s: %v", name, err)
+		}
+		return false
+	})
+
 }
 
 func (c *controller) agentClose() {

+ 24 - 0
libnetwork/discoverapi/discoverapi.go

@@ -18,6 +18,10 @@ const (
 	NodeDiscovery = iota + 1
 	// DatastoreConfig represents an add/remove datastore event
 	DatastoreConfig
+	// EncryptionKeysConfig represents the initial key(s) for performing datapath encryption
+	EncryptionKeysConfig
+	// EncryptionKeysUpdate represents an update to the datapath encryption key(s)
+	EncryptionKeysUpdate
 )
 
 // NodeDiscoveryData represents the structure backing the node discovery data json string
@@ -33,3 +37,23 @@ type DatastoreConfigData struct {
 	Address  string
 	Config   interface{}
 }
+
+// DriverEncryptionConfig contains the initial datapath encryption key(s)
+// Key in first position is the primary key, the one to be used in tx.
+// Original key and tag types are []byte and uint64
+type DriverEncryptionConfig struct {
+	Keys []string
+	Tags []string
+}
+
+// DriverEncryptionUpdate carries an update to the encryption key(s) as:
+// a new key and/or set a primary key and/or a removal of an existing key.
+// Original key and tag types are []byte and uint64
+type DriverEncryptionUpdate struct {
+	Key        string
+	Tag        string
+	Primary    string
+	PrimaryTag string
+	Prune      string
+	PruneTag   string
+}

+ 578 - 0
libnetwork/drivers/overlay/encryption.go

@@ -0,0 +1,578 @@
+package overlay
+
+import (
+	"bytes"
+	"encoding/hex"
+	"fmt"
+	"net"
+	"sync"
+	"syscall"
+
+	log "github.com/Sirupsen/logrus"
+	"github.com/docker/libnetwork/iptables"
+	"github.com/docker/libnetwork/types"
+	"github.com/vishvananda/netlink"
+	"strconv"
+)
+
+const (
+	mark    = uint32(0xD0C4E3)
+	timeout = 30
+)
+
+const (
+	forward = iota + 1
+	reverse
+	bidir
+)
+
+type key struct {
+	value []byte
+	tag   uint32
+}
+
+func (k *key) String() string {
+	return fmt.Sprintf("(key: %s, tag: 0x%x)", hex.EncodeToString(k.value)[0:5], k.tag)
+}
+
+type spi struct {
+	forward int
+	reverse int
+}
+
+func (s *spi) String() string {
+	return fmt.Sprintf("SPI(FWD: 0x%x, REV: 0x%x)", uint32(s.forward), uint32(s.reverse))
+}
+
+type encrMap struct {
+	nodes map[string][]*spi
+	sync.Mutex
+}
+
+func (e *encrMap) String() string {
+	e.Lock()
+	defer e.Unlock()
+	b := new(bytes.Buffer)
+	for k, v := range e.nodes {
+		b.WriteString("\n")
+		b.WriteString(k)
+		b.WriteString(":")
+		b.WriteString("[")
+		for _, s := range v {
+			b.WriteString(s.String())
+			b.WriteString(",")
+		}
+		b.WriteString("]")
+
+	}
+	return b.String()
+}
+
+func (d *driver) checkEncryption(nid string, rIP net.IP, vxlanID uint32, isLocal, add bool) error {
+	log.Infof("checkEncryption(%s, %v, %d, %t)", nid[0:7], rIP, vxlanID, isLocal)
+
+	n := d.network(nid)
+	if n == nil || !n.secure {
+		return nil
+	}
+
+	if len(d.keys) == 0 {
+		return types.ForbiddenErrorf("encryption key is not present")
+	}
+
+	lIP := types.GetMinimalIP(net.ParseIP(d.bindAddress))
+	nodes := map[string]net.IP{}
+
+	switch {
+	case isLocal:
+		if err := d.peerDbNetworkWalk(nid, func(pKey *peerKey, pEntry *peerEntry) bool {
+			if !lIP.Equal(pEntry.vtep) {
+				nodes[pEntry.vtep.String()] = types.GetMinimalIP(pEntry.vtep)
+			}
+			return false
+		}); err != nil {
+			log.Warnf("Failed to retrieve list of participating nodes in overlay network %s: %v", nid[0:5], err)
+		}
+	default:
+		if len(d.network(nid).endpoints) > 0 {
+			nodes[rIP.String()] = types.GetMinimalIP(rIP)
+		}
+	}
+
+	log.Debugf("List of nodes: %s", nodes)
+
+	if add {
+		for _, rIP := range nodes {
+			if err := setupEncryption(lIP, rIP, vxlanID, d.secMap, d.keys); err != nil {
+				log.Warnf("Failed to program network encryption between %s and %s: %v", lIP, rIP, err)
+			}
+		}
+	} else {
+		if len(nodes) == 0 {
+			if err := removeEncryption(lIP, rIP, d.secMap); err != nil {
+				log.Warnf("Failed to remove network encryption between %s and %s: %v", lIP, rIP, err)
+			}
+		}
+	}
+
+	return nil
+}
+
+func setupEncryption(localIP, remoteIP net.IP, vni uint32, em *encrMap, keys []*key) error {
+	log.Infof("Programming encryption for vxlan %d between %s and %s", vni, localIP, remoteIP)
+	rIPs := remoteIP.String()
+
+	indices := make([]*spi, 0, len(keys))
+
+	err := programMangle(vni, true)
+	if err != nil {
+		log.Warn(err)
+	}
+
+	for i, k := range keys {
+		spis := &spi{buildSPI(localIP, remoteIP, k.tag), buildSPI(remoteIP, localIP, k.tag)}
+		dir := reverse
+		if i == 0 {
+			dir = bidir
+		}
+		fSA, rSA, err := programSA(localIP, remoteIP, spis, k, dir, true)
+		if err != nil {
+			log.Warn(err)
+		}
+		indices = append(indices, spis)
+		if i != 0 {
+			continue
+		}
+		err = programSP(fSA, rSA, true)
+		if err != nil {
+			log.Warn(err)
+		}
+	}
+
+	em.Lock()
+	em.nodes[rIPs] = indices
+	em.Unlock()
+
+	return nil
+}
+
+func removeEncryption(localIP, remoteIP net.IP, em *encrMap) error {
+	em.Lock()
+	indices, ok := em.nodes[remoteIP.String()]
+	em.Unlock()
+	if !ok {
+		return nil
+	}
+	for i, idxs := range indices {
+		dir := reverse
+		if i == 0 {
+			dir = bidir
+		}
+		fSA, rSA, err := programSA(localIP, remoteIP, idxs, nil, dir, false)
+		if err != nil {
+			log.Warn(err)
+		}
+		if i != 0 {
+			continue
+		}
+		err = programSP(fSA, rSA, false)
+		if err != nil {
+			log.Warn(err)
+		}
+	}
+	return nil
+}
+
+func programMangle(vni uint32, add bool) (err error) {
+	var (
+		p      = strconv.FormatUint(uint64(vxlanPort), 10)
+		c      = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8)
+		m      = strconv.FormatUint(uint64(mark), 10)
+		chain  = "OUTPUT"
+		rule   = []string{"-p", "udp", "--dport", p, "-m", "u32", "--u32", c, "-j", "MARK", "--set-mark", m}
+		a      = "-A"
+		action = "install"
+	)
+
+	if add == iptables.Exists(iptables.Mangle, chain, rule...) {
+		return
+	}
+
+	if !add {
+		a = "-D"
+		action = "remove"
+	}
+
+	if err = iptables.RawCombinedOutput(append([]string{"-t", string(iptables.Mangle), a, chain}, rule...)...); err != nil {
+		log.Warnf("could not %s mangle rule: %v", action, err)
+	}
+
+	return
+}
+
+func programSA(localIP, remoteIP net.IP, spi *spi, k *key, dir int, add bool) (fSA *netlink.XfrmState, rSA *netlink.XfrmState, err error) {
+	var (
+		crypt       *netlink.XfrmStateAlgo
+		action      = "Removing"
+		xfrmProgram = netlink.XfrmStateDel
+	)
+
+	if add {
+		action = "Adding"
+		xfrmProgram = netlink.XfrmStateAdd
+		crypt = &netlink.XfrmStateAlgo{Name: "cbc(aes)", Key: k.value}
+	}
+
+	if dir&reverse > 0 {
+		rSA = &netlink.XfrmState{
+			Src:   remoteIP,
+			Dst:   localIP,
+			Proto: netlink.XFRM_PROTO_ESP,
+			Spi:   spi.reverse,
+			Mode:  netlink.XFRM_MODE_TRANSPORT,
+		}
+		if add {
+			rSA.Crypt = crypt
+		}
+
+		exists, err := saExists(rSA)
+		if err != nil {
+			exists = !add
+		}
+
+		if add != exists {
+			log.Infof("%s: rSA{%s}", action, rSA)
+			if err := xfrmProgram(rSA); err != nil {
+				log.Warnf("Failed %s rSA{%s}: %v", action, rSA, err)
+			}
+		}
+	}
+
+	if dir&forward > 0 {
+		fSA = &netlink.XfrmState{
+			Src:   localIP,
+			Dst:   remoteIP,
+			Proto: netlink.XFRM_PROTO_ESP,
+			Spi:   spi.forward,
+			Mode:  netlink.XFRM_MODE_TRANSPORT,
+		}
+		if add {
+			fSA.Crypt = crypt
+		}
+
+		exists, err := saExists(fSA)
+		if err != nil {
+			exists = !add
+		}
+
+		if add != exists {
+			log.Infof("%s fSA{%s}", action, fSA)
+			if err := xfrmProgram(fSA); err != nil {
+				log.Warnf("Failed %s fSA{%s}: %v.", action, fSA, err)
+			}
+		}
+	}
+
+	return
+}
+
+func programSP(fSA *netlink.XfrmState, rSA *netlink.XfrmState, add bool) error {
+	action := "Removing"
+	xfrmProgram := netlink.XfrmPolicyDel
+	if add {
+		action = "Adding"
+		xfrmProgram = netlink.XfrmPolicyAdd
+	}
+
+	fullMask := net.CIDRMask(8*len(fSA.Src), 8*len(fSA.Src))
+
+	fPol := &netlink.XfrmPolicy{
+		Src:     &net.IPNet{IP: fSA.Src, Mask: fullMask},
+		Dst:     &net.IPNet{IP: fSA.Dst, Mask: fullMask},
+		Dir:     netlink.XFRM_DIR_OUT,
+		Proto:   17,
+		DstPort: 4789,
+		Mark: &netlink.XfrmMark{
+			Value: mark,
+		},
+		Tmpls: []netlink.XfrmPolicyTmpl{
+			{
+				Src:   fSA.Src,
+				Dst:   fSA.Dst,
+				Proto: netlink.XFRM_PROTO_ESP,
+				Mode:  netlink.XFRM_MODE_TRANSPORT,
+				Spi:   fSA.Spi,
+			},
+		},
+	}
+
+	exists, err := spExists(fPol)
+	if err != nil {
+		exists = !add
+	}
+
+	if add != exists {
+		log.Infof("%s fSP{%s}", action, fPol)
+		if err := xfrmProgram(fPol); err != nil {
+			log.Warnf("%s fSP{%s}: %v", action, fPol, err)
+		}
+	}
+
+	return nil
+}
+
+func saExists(sa *netlink.XfrmState) (bool, error) {
+	_, err := netlink.XfrmStateGet(sa)
+	switch err {
+	case nil:
+		return true, nil
+	case syscall.ESRCH:
+		return false, nil
+	default:
+		err = fmt.Errorf("Error while checking for SA existence: %v", err)
+		log.Debug(err)
+		return false, err
+	}
+}
+
+func spExists(sp *netlink.XfrmPolicy) (bool, error) {
+	_, err := netlink.XfrmPolicyGet(sp)
+	switch err {
+	case nil:
+		return true, nil
+	case syscall.ENOENT:
+		return false, nil
+	default:
+		err = fmt.Errorf("Error while checking for SP existence: %v", err)
+		log.Debug(err)
+		return false, err
+	}
+}
+
+func buildSPI(src, dst net.IP, st uint32) int {
+	spi := int(st)
+	f := src[len(src)-4:]
+	t := dst[len(dst)-4:]
+	for i := 0; i < 4; i++ {
+		spi = spi ^ (int(f[i])^int(t[3-i]))<<uint32(8*i)
+	}
+	return spi
+}
+
+func (d *driver) secMapWalk(f func(string, []*spi) ([]*spi, bool)) error {
+	d.secMap.Lock()
+	for node, indices := range d.secMap.nodes {
+		idxs, stop := f(node, indices)
+		if idxs != nil {
+			d.secMap.nodes[node] = idxs
+		}
+		if stop {
+			break
+		}
+	}
+	d.secMap.Unlock()
+	return nil
+}
+
+func (d *driver) setKeys(keys []*key) error {
+	if d.keys != nil {
+		return types.ForbiddenErrorf("initial keys are already present")
+	}
+	d.keys = keys
+	log.Infof("Initial encryption keys: %v", d.keys)
+	return nil
+}
+
+// updateKeys allows to add a new key and/or change the primary key and/or prune an existing key
+// The primary key is the key used in transmission and will go in first position in the list.
+func (d *driver) updateKeys(newKey, primary, pruneKey *key) error {
+	log.Infof("Updating Keys. New: %v, Primary: %v, Pruned: %v", newKey, primary, pruneKey)
+
+	log.Infof("Current: %v", d.keys)
+
+	var (
+		newIdx = -1
+		priIdx = -1
+		delIdx = -1
+		lIP    = types.GetMinimalIP(net.ParseIP(d.bindAddress))
+	)
+
+	d.Lock()
+	// add new
+	if newKey != nil {
+		d.keys = append(d.keys, newKey)
+		newIdx += len(d.keys)
+	}
+	for i, k := range d.keys {
+		if primary != nil && k.tag == primary.tag {
+			priIdx = i
+		}
+		if pruneKey != nil && k.tag == pruneKey.tag {
+			delIdx = i
+		}
+	}
+	d.Unlock()
+
+	if (newKey != nil && newIdx == -1) ||
+		(primary != nil && priIdx == -1) ||
+		(pruneKey != nil && delIdx == -1) {
+		err := types.BadRequestErrorf("cannot find proper key indices while processing key update:"+
+			"(newIdx,priIdx,delIdx):(%d, %d, %d)", newIdx, priIdx, delIdx)
+		log.Warn(err)
+		return err
+	}
+
+	d.secMapWalk(func(rIPs string, spis []*spi) ([]*spi, bool) {
+		rIP := types.GetMinimalIP(net.ParseIP(rIPs))
+		return updateNodeKey(lIP, rIP, spis, d.keys, newIdx, priIdx, delIdx), false
+	})
+
+	d.Lock()
+	// swap primary
+	if priIdx != -1 {
+		swp := d.keys[0]
+		d.keys[0] = d.keys[priIdx]
+		d.keys[priIdx] = swp
+	}
+	// prune
+	if delIdx != -1 {
+		if delIdx == 0 {
+			delIdx = priIdx
+		}
+		d.keys = append(d.keys[:delIdx], d.keys[delIdx+1:]...)
+	}
+	d.Unlock()
+
+	log.Infof("Updated: %v", d.keys)
+
+	return nil
+}
+
+/********************************************************
+ * Steady state: rSA0, rSA1, fSA0, fSP0
+ * Rotation --> %rSA0, +rSA2, +fSA1, +fSP1/-fSP0, -fSA0,
+ * Half state:   rSA0, rSA1, rSA2, fSA1, fSP1
+ * Steady state: rSA1, rSA2, fSA1, fSP1
+ *********************************************************/
+
+// Spis and keys are sorted in such away the one in position 0 is the primary
+func updateNodeKey(lIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx, delIdx int) []*spi {
+	log.Infof("Updating keys for node: %s (%d,%d,%d)", rIP, newIdx, priIdx, delIdx)
+
+	spis := idxs
+	log.Infof("Current: %v", spis)
+
+	// add new
+	if newIdx != -1 {
+		spis = append(spis, &spi{
+			forward: buildSPI(lIP, rIP, curKeys[newIdx].tag),
+			reverse: buildSPI(rIP, lIP, curKeys[newIdx].tag),
+		})
+	}
+
+	if delIdx != -1 {
+		// %rSA0
+		rSA0 := &netlink.XfrmState{
+			Src:    rIP,
+			Dst:    lIP,
+			Proto:  netlink.XFRM_PROTO_ESP,
+			Spi:    spis[delIdx].reverse,
+			Mode:   netlink.XFRM_MODE_TRANSPORT,
+			Crypt:  &netlink.XfrmStateAlgo{Name: "cbc(aes)", Key: curKeys[delIdx].value},
+			Limits: netlink.XfrmStateLimits{TimeSoft: timeout},
+		}
+		log.Infof("Updating rSA0{%s}", rSA0)
+		if err := netlink.XfrmStateUpdate(rSA0); err != nil {
+			log.Warnf("Failed to update rSA0{%s}: %v", rSA0, err)
+		}
+	}
+
+	if newIdx > -1 {
+		// +RSA2
+		programSA(lIP, rIP, spis[newIdx], curKeys[newIdx], reverse, true)
+	}
+
+	if priIdx > 0 {
+		// +fSA1
+		fSA1, _, _ := programSA(lIP, rIP, spis[priIdx], curKeys[priIdx], forward, true)
+
+		// +fSP1, -fSP0
+		fullMask := net.CIDRMask(8*len(fSA1.Src), 8*len(fSA1.Src))
+		fSP1 := &netlink.XfrmPolicy{
+			Src:     &net.IPNet{IP: fSA1.Src, Mask: fullMask},
+			Dst:     &net.IPNet{IP: fSA1.Dst, Mask: fullMask},
+			Dir:     netlink.XFRM_DIR_OUT,
+			Proto:   17,
+			DstPort: 4789,
+			Mark: &netlink.XfrmMark{
+				Value: mark,
+			},
+			Tmpls: []netlink.XfrmPolicyTmpl{
+				{
+					Src:   fSA1.Src,
+					Dst:   fSA1.Dst,
+					Proto: netlink.XFRM_PROTO_ESP,
+					Mode:  netlink.XFRM_MODE_TRANSPORT,
+					Spi:   fSA1.Spi,
+				},
+			},
+		}
+		log.Infof("Updating fSP{%s}", fSP1)
+		if err := netlink.XfrmPolicyUpdate(fSP1); err != nil {
+			log.Warnf("Failed to update fSP{%s}: %v", fSP1, err)
+		}
+
+		// -fSA0
+		fSA0 := &netlink.XfrmState{
+			Src:    lIP,
+			Dst:    rIP,
+			Proto:  netlink.XFRM_PROTO_ESP,
+			Spi:    spis[0].forward,
+			Mode:   netlink.XFRM_MODE_TRANSPORT,
+			Crypt:  &netlink.XfrmStateAlgo{Name: "cbc(aes)", Key: curKeys[0].value},
+			Limits: netlink.XfrmStateLimits{TimeHard: timeout},
+		}
+		log.Infof("Removing fSA0{%s}", fSA0)
+		if err := netlink.XfrmStateUpdate(fSA0); err != nil {
+			log.Warnf("Failed to remove fSA0{%s}: %v", fSA0, err)
+		}
+	}
+
+	// swap
+	if priIdx > 0 {
+		swp := spis[0]
+		spis[0] = spis[priIdx]
+		spis[priIdx] = swp
+	}
+	// prune
+	if delIdx != -1 {
+		if delIdx == 0 {
+			delIdx = priIdx
+		}
+		spis = append(spis[:delIdx], spis[delIdx+1:]...)
+	}
+
+	log.Infof("Updated: %v", spis)
+
+	return spis
+}
+
+func parseEncryptionKey(value, tag string) (*key, error) {
+	var (
+		k   *key
+		err error
+	)
+	if value == "" {
+		return nil, nil
+	}
+	k = &key{}
+	if k.value, err = hex.DecodeString(value); err != nil {
+		return nil, types.BadRequestErrorf("failed to decode key (%s): %v", value, err)
+	}
+	t, err := strconv.ParseUint(tag, 10, 64)
+	if err != nil {
+		return nil, types.BadRequestErrorf("failed to decode tag (%s): %v", tag, err)
+	}
+	k.tag = uint32(t)
+	return k, nil
+}

+ 12 - 0
libnetwork/drivers/overlay/joinleave.go

@@ -27,6 +27,10 @@ func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo,
 		return fmt.Errorf("could not find endpoint with id %s", eid)
 	}
 
+	if n.secure && len(d.keys) == 0 {
+		return fmt.Errorf("cannot join secure network: encryption keys not present")
+	}
+
 	s := n.getSubnetforIP(ep.addr)
 	if s == nil {
 		return fmt.Errorf("could not find subnet for endpoint %s", eid)
@@ -106,6 +110,10 @@ func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo,
 	d.peerDbAdd(nid, eid, ep.addr.IP, ep.addr.Mask, ep.mac,
 		net.ParseIP(d.bindAddress), true)
 
+	if err := d.checkEncryption(nid, nil, n.vxlanID(s), true, true); err != nil {
+		log.Warn(err)
+	}
+
 	buf, err := proto.Marshal(&PeerRecord{
 		EndpointIP:       ep.addr.String(),
 		EndpointMAC:      ep.mac.String(),
@@ -197,5 +205,9 @@ func (d *driver) Leave(nid, eid string) error {
 
 	n.leaveSandbox()
 
+	if err := d.checkEncryption(nid, nil, 0, true, false); err != nil {
+		log.Warn(err)
+	}
+
 	return nil
 }

+ 60 - 14
libnetwork/drivers/overlay/ov_network.go

@@ -61,6 +61,7 @@ type network struct {
 	initEpoch int
 	initErr   error
 	subnets   []*subnet
+	secure    bool
 	sync.Mutex
 }
 
@@ -109,6 +110,9 @@ func (d *driver) CreateNetwork(id string, option map[string]interface{}, nInfo d
 				vnis = append(vnis, uint32(vni))
 			}
 		}
+		if _, ok := optMap["secure"]; ok {
+			n.secure = true
+		}
 	}
 
 	// If we are getting vnis from libnetwork, either we get for
@@ -162,7 +166,18 @@ func (d *driver) DeleteNetwork(nid string) error {
 
 	d.deleteNetwork(nid)
 
-	return n.releaseVxlanID()
+	vnis, err := n.releaseVxlanID()
+	if err != nil {
+		return err
+	}
+
+	if n.secure {
+		for _, vni := range vnis {
+			programMangle(vni, false)
+		}
+	}
+
+	return nil
 }
 
 func (d *driver) ProgramExternalConnectivity(nid, eid string, options map[string]interface{}) error {
@@ -618,6 +633,8 @@ func (n *network) KeyPrefix() []string {
 }
 
 func (n *network) Value() []byte {
+	m := map[string]interface{}{}
+
 	netJSON := []*subnetJSON{}
 
 	for _, s := range n.subnets {
@@ -630,10 +647,17 @@ func (n *network) Value() []byte {
 	}
 
 	b, err := json.Marshal(netJSON)
+	if err != nil {
+		return []byte{}
+	}
 
+	m["secure"] = n.secure
+	m["subnets"] = netJSON
+	b, err = json.Marshal(m)
 	if err != nil {
 		return []byte{}
 	}
+
 	return b
 }
 
@@ -655,18 +679,38 @@ func (n *network) Skip() bool {
 }
 
 func (n *network) SetValue(value []byte) error {
-	var newNet bool
-	netJSON := []*subnetJSON{}
-
-	err := json.Unmarshal(value, &netJSON)
-	if err != nil {
-		return err
+	var (
+		m       map[string]interface{}
+		newNet  bool
+		isMap   = true
+		netJSON = []*subnetJSON{}
+	)
+
+	if err := json.Unmarshal(value, &m); err != nil {
+		err := json.Unmarshal(value, &netJSON)
+		if err != nil {
+			return err
+		}
+		isMap = false
 	}
 
 	if len(n.subnets) == 0 {
 		newNet = true
 	}
 
+	if isMap {
+		if val, ok := m["secure"]; ok {
+			n.secure = val.(bool)
+		}
+		bytes, err := json.Marshal(m["subnets"])
+		if err != nil {
+			return err
+		}
+		if err := json.Unmarshal(bytes, &netJSON); err != nil {
+			return err
+		}
+	}
+
 	for _, sj := range netJSON {
 		subnetIPstr := sj.SubnetIP
 		gwIPstr := sj.GwIP
@@ -705,9 +749,9 @@ func (n *network) writeToStore() error {
 	return n.driver.store.PutObjectAtomic(n)
 }
 
-func (n *network) releaseVxlanID() error {
+func (n *network) releaseVxlanID() ([]uint32, error) {
 	if len(n.subnets) == 0 {
-		return nil
+		return nil, nil
 	}
 
 	if n.driver.store != nil {
@@ -715,22 +759,24 @@ func (n *network) releaseVxlanID() error {
 			if err == datastore.ErrKeyModified || err == datastore.ErrKeyNotFound {
 				// In both the above cases we can safely assume that the key has been removed by some other
 				// instance and so simply get out of here
-				return nil
+				return nil, nil
 			}
 
-			return fmt.Errorf("failed to delete network to vxlan id map: %v", err)
+			return nil, fmt.Errorf("failed to delete network to vxlan id map: %v", err)
 		}
 	}
-
+	var vnis []uint32
 	for _, s := range n.subnets {
 		if n.driver.vxlanIdm != nil {
-			n.driver.vxlanIdm.Release(uint64(n.vxlanID(s)))
+			vni := n.vxlanID(s)
+			vnis = append(vnis, vni)
+			n.driver.vxlanIdm.Release(uint64(vni))
 		}
 
 		n.setVxlanID(s, 0)
 	}
 
-	return nil
+	return vnis, nil
 }
 
 func (n *network) obtainVxlanID(s *subnet) error {

+ 37 - 2
libnetwork/drivers/overlay/overlay.go

@@ -37,12 +37,14 @@ type driver struct {
 	neighIP      string
 	config       map[string]interface{}
 	peerDb       peerNetworkMap
+	secMap       *encrMap
 	serfInstance *serf.Serf
 	networks     networkTable
 	store        datastore.DataStore
 	vxlanIdm     *idm.Idm
 	once         sync.Once
 	joinOnce     sync.Once
+	keys         []*key
 	sync.Mutex
 }
 
@@ -51,12 +53,12 @@ func Init(dc driverapi.DriverCallback, config map[string]interface{}) error {
 	c := driverapi.Capability{
 		DataScope: datastore.GlobalScope,
 	}
-
 	d := &driver{
 		networks: networkTable{},
 		peerDb: peerNetworkMap{
 			mp: map[string]*peerMap{},
 		},
+		secMap: &encrMap{nodes: map[string][]*spi{}},
 		config: config,
 	}
 
@@ -209,6 +211,7 @@ func (d *driver) pushLocalEndpointEvent(action, nid, eid string) {
 
 // DiscoverNew is a notification for a new discovery event, such as a new node joining a cluster
 func (d *driver) DiscoverNew(dType discoverapi.DiscoveryType, data interface{}) error {
+	var err error
 	switch dType {
 	case discoverapi.NodeDiscovery:
 		nodeData, ok := data.(discoverapi.NodeDiscoveryData)
@@ -217,7 +220,6 @@ func (d *driver) DiscoverNew(dType discoverapi.DiscoveryType, data interface{})
 		}
 		d.nodeJoin(nodeData.Address, nodeData.Self)
 	case discoverapi.DatastoreConfig:
-		var err error
 		if d.store != nil {
 			return types.ForbiddenErrorf("cannot accept datastore configuration: Overlay driver has a datastore configured already")
 		}
@@ -229,6 +231,39 @@ func (d *driver) DiscoverNew(dType discoverapi.DiscoveryType, data interface{})
 		if err != nil {
 			return types.InternalErrorf("failed to initialize data store: %v", err)
 		}
+	case discoverapi.EncryptionKeysConfig:
+		encrData, ok := data.(discoverapi.DriverEncryptionConfig)
+		if !ok {
+			return fmt.Errorf("invalid encryption key notification data")
+		}
+		keys := make([]*key, 0, len(encrData.Keys))
+		for i := 0; i < len(encrData.Keys); i++ {
+			k, err := parseEncryptionKey(encrData.Keys[i], encrData.Tags[i])
+			if err != nil {
+				return err
+			}
+			keys = append(keys, k)
+		}
+		d.setKeys(keys)
+	case discoverapi.EncryptionKeysUpdate:
+		var newKey, delKey, priKey *key
+		encrData, ok := data.(discoverapi.DriverEncryptionUpdate)
+		if !ok {
+			return fmt.Errorf("invalid encryption key notification data")
+		}
+		newKey, err = parseEncryptionKey(encrData.Key, encrData.Tag)
+		if err != nil {
+			return err
+		}
+		priKey, err = parseEncryptionKey(encrData.Primary, encrData.PrimaryTag)
+		if err != nil {
+			return err
+		}
+		delKey, err = parseEncryptionKey(encrData.Prune, encrData.PruneTag)
+		if err != nil {
+			return err
+		}
+		d.updateKeys(newKey, priKey, delKey)
 	default:
 	}
 	return nil

+ 11 - 1
libnetwork/drivers/overlay/peerdb.go

@@ -5,6 +5,8 @@ import (
 	"net"
 	"sync"
 	"syscall"
+
+	log "github.com/Sirupsen/logrus"
 )
 
 const ovPeerTable = "overlay_peer_table"
@@ -88,7 +90,7 @@ func (d *driver) peerDbNetworkWalk(nid string, f func(*peerKey, *peerEntry) bool
 	for pKeyStr, pEntry := range pMap.mp {
 		var pKey peerKey
 		if _, err := fmt.Sscan(pKeyStr, &pKey); err != nil {
-			fmt.Printf("peer key scan failed: %v", err)
+			log.Warnf("Peer key scan on network %s failed: %v", nid, err)
 		}
 
 		if f(&pKey, &pEntry) {
@@ -273,6 +275,10 @@ func (d *driver) peerAdd(nid, eid string, peerIP net.IP, peerIPMask net.IPMask,
 		return fmt.Errorf("subnet sandbox join failed for %q: %v", s.subnetIP.String(), err)
 	}
 
+	if err := d.checkEncryption(nid, vtep, n.vxlanID(s), false, true); err != nil {
+		log.Warn(err)
+	}
+
 	// Add neighbor entry for the peer IP
 	if err := sbox.AddNeighbor(peerIP, peerMac, sbox.NeighborOptions().LinkName(s.vxlanName)); err != nil {
 		return fmt.Errorf("could not add neigbor entry into the sandbox: %v", err)
@@ -318,6 +324,10 @@ func (d *driver) peerDelete(nid, eid string, peerIP net.IP, peerIPMask net.IPMas
 		return fmt.Errorf("could not delete neigbor entry into the sandbox: %v", err)
 	}
 
+	if err := d.checkEncryption(nid, vtep, 0, false, false); err != nil {
+		log.Warn(err)
+	}
+
 	return nil
 }