Browse Source

Merge pull request #2604 from arkodg/fix-port-forwarding

Fix IPv6 Port Forwarding for the Bridge Driver
Arko Dasgupta 4 years ago
parent
commit
44e9db7e85

+ 10 - 0
libnetwork/cmd/proxy/proxy.go

@@ -8,6 +8,16 @@ import (
 	"github.com/ishidawataru/sctp"
 )
 
+// ipVersion refers to IP version - v4 or v6
+type ipVersion string
+
+const (
+	// IPv4 is version 4
+	ipv4 ipVersion = "4"
+	// IPv4 is version 6
+	ipv6 ipVersion = "6"
+)
+
 // Proxy defines the behavior of a proxy. It forwards traffic back and forth
 // between two endpoints : the frontend and the backend.
 // It can be used to do software port-mapping between two addresses.

+ 6 - 1
libnetwork/cmd/proxy/sctp_proxy.go

@@ -19,7 +19,12 @@ type SCTPProxy struct {
 
 // NewSCTPProxy creates a new SCTPProxy.
 func NewSCTPProxy(frontendAddr, backendAddr *sctp.SCTPAddr) (*SCTPProxy, error) {
-	listener, err := sctp.ListenSCTP("sctp", frontendAddr)
+	// detect version of hostIP to bind only to correct version
+	ipVersion := ipv4
+	if frontendAddr.IPAddrs[0].IP.To4() == nil {
+		ipVersion = ipv6
+	}
+	listener, err := sctp.ListenSCTP("sctp"+string(ipVersion), frontendAddr)
 	if err != nil {
 		return nil, err
 	}

+ 6 - 1
libnetwork/cmd/proxy/tcp_proxy.go

@@ -17,7 +17,12 @@ type TCPProxy struct {
 
 // NewTCPProxy creates a new TCPProxy.
 func NewTCPProxy(frontendAddr, backendAddr *net.TCPAddr) (*TCPProxy, error) {
-	listener, err := net.ListenTCP("tcp", frontendAddr)
+	// detect version of hostIP to bind only to correct version
+	ipVersion := ipv4
+	if frontendAddr.IP.To4() == nil {
+		ipVersion = ipv6
+	}
+	listener, err := net.ListenTCP("tcp"+string(ipVersion), frontendAddr)
 	if err != nil {
 		return nil, err
 	}

+ 6 - 1
libnetwork/cmd/proxy/udp_proxy.go

@@ -55,7 +55,12 @@ type UDPProxy struct {
 
 // NewUDPProxy creates a new UDPProxy.
 func NewUDPProxy(frontendAddr, backendAddr *net.UDPAddr) (*UDPProxy, error) {
-	listener, err := net.ListenUDP("udp", frontendAddr)
+	// detect version of hostIP to bind only to correct version
+	ipVersion := ipv4
+	if frontendAddr.IP.To4() == nil {
+		ipVersion = ipv6
+	}
+	listener, err := net.ListenUDP("udp"+string(ipVersion), frontendAddr)
 	if err != nil {
 		return nil, err
 	}

+ 81 - 39
libnetwork/drivers/bridge/port_mapping.go

@@ -11,71 +11,113 @@ import (
 	"github.com/sirupsen/logrus"
 )
 
-var (
-	defaultBindingIP   = net.IPv4(0, 0, 0, 0)
-	defaultBindingIPV6 = net.ParseIP("::")
-)
-
 func (n *bridgeNetwork) allocatePorts(ep *bridgeEndpoint, reqDefBindIP net.IP, ulPxyEnabled bool) ([]types.PortBinding, error) {
 	if ep.extConnConfig == nil || ep.extConnConfig.PortBindings == nil {
 		return nil, nil
 	}
 
-	defHostIP := defaultBindingIP
+	defHostIP := net.IPv4zero // 0.0.0.0
 	if reqDefBindIP != nil {
 		defHostIP = reqDefBindIP
 	}
 
-	// IPv4 port binding including user land proxy
-	pb, err := n.allocatePortsInternal(ep.extConnConfig.PortBindings, ep.addr.IP, defHostIP, ulPxyEnabled)
-	if err != nil {
-		return nil, err
+	var containerIPv6 net.IP
+	if ep.addrv6 != nil {
+		containerIPv6 = ep.addrv6.IP
 	}
 
-	// IPv6 port binding excluding user land proxy
-	if n.driver.config.EnableIP6Tables && ep.addrv6 != nil {
-		// TODO IPv6 custom default binding IP
-		pbv6, err := n.allocatePortsInternal(ep.extConnConfig.PortBindings, ep.addrv6.IP, defaultBindingIPV6, false)
-		if err != nil {
-			// ensure we clear the previous allocated IPv4 ports
-			n.releasePortsInternal(pb)
-			return nil, err
-		}
-
-		pb = append(pb, pbv6...)
+	pb, err := n.allocatePortsInternal(ep.extConnConfig.PortBindings, ep.addr.IP, containerIPv6, defHostIP, ulPxyEnabled)
+	if err != nil {
+		return nil, err
 	}
 	return pb, nil
 }
 
-func (n *bridgeNetwork) allocatePortsInternal(bindings []types.PortBinding, containerIP, defHostIP net.IP, ulPxyEnabled bool) ([]types.PortBinding, error) {
+func (n *bridgeNetwork) allocatePortsInternal(bindings []types.PortBinding, containerIPv4, containerIPv6, defHostIP net.IP, ulPxyEnabled bool) ([]types.PortBinding, error) {
 	bs := make([]types.PortBinding, 0, len(bindings))
 	for _, c := range bindings {
-		b := c.GetCopy()
-		if err := n.allocatePort(&b, containerIP, defHostIP, ulPxyEnabled); err != nil {
-			// On allocation failure, release previously allocated ports. On cleanup error, just log a warning message
-			if cuErr := n.releasePortsInternal(bs); cuErr != nil {
-				logrus.Warnf("Upon allocation failure for %v, failed to clear previously allocated port bindings: %v", b, cuErr)
+		bIPv4 := c.GetCopy()
+		bIPv6 := c.GetCopy()
+		// Allocate IPv4 Port mappings
+		if ok := n.validatePortBindingIPv4(&bIPv4, containerIPv4, defHostIP); ok {
+			if err := n.allocatePort(&bIPv4, ulPxyEnabled); err != nil {
+				// On allocation failure, release previously allocated ports. On cleanup error, just log a warning message
+				if cuErr := n.releasePortsInternal(bs); cuErr != nil {
+					logrus.Warnf("allocation failure for %v, failed to clear previously allocated ipv4 port bindings: %v", bIPv4, cuErr)
+				}
+				return nil, err
 			}
-			return nil, err
+			bs = append(bs, bIPv4)
+		}
+		// Allocate IPv6 Port mappings
+		if ok := n.validatePortBindingIPv6(&bIPv6, containerIPv6, defHostIP); ok {
+			if err := n.allocatePort(&bIPv6, ulPxyEnabled); err != nil {
+				// On allocation failure, release previously allocated ports. On cleanup error, just log a warning message
+				if cuErr := n.releasePortsInternal(bs); cuErr != nil {
+					logrus.Warnf("allocation failure for %v, failed to clear previously allocated ipv6 port bindings: %v", bIPv6, cuErr)
+				}
+				return nil, err
+			}
+			bs = append(bs, bIPv6)
 		}
-		bs = append(bs, b)
 	}
 	return bs, nil
 }
 
-func (n *bridgeNetwork) allocatePort(bnd *types.PortBinding, containerIP, defHostIP net.IP, ulPxyEnabled bool) error {
-	var (
-		host net.Addr
-		err  error
-	)
-
-	// Store the container interface address in the operational binding
-	bnd.IP = containerIP
-
+// validatePortBindingIPv4 validates the port binding, populates the missing Host IP field and returns true
+// if this is a valid IPv4 binding, else returns false
+func (n *bridgeNetwork) validatePortBindingIPv4(bnd *types.PortBinding, containerIPv4, defHostIP net.IP) bool {
+	//Return early if there is a valid Host IP, but its not a IPv6 address
+	if len(bnd.HostIP) > 0 && bnd.HostIP.To4() == nil {
+		return false
+	}
 	// Adjust the host address in the operational binding
 	if len(bnd.HostIP) == 0 {
+		// Return early if the default binding address is an IPv6 address
+		if defHostIP.To4() == nil {
+			return false
+		}
 		bnd.HostIP = defHostIP
 	}
+	bnd.IP = containerIPv4
+	return true
+
+}
+
+// validatePortBindingIPv6 validates the port binding, populates the missing Host IP field and returns true
+// if this is a valid IP6v binding, else returns false
+func (n *bridgeNetwork) validatePortBindingIPv6(bnd *types.PortBinding, containerIPv6, defHostIP net.IP) bool {
+	// Return early if there is no IPv6 container endpoint
+	if containerIPv6 == nil {
+		return false
+	}
+	// Return early if there is a valid Host IP, which is a IPv4 address
+	if len(bnd.HostIP) > 0 && bnd.HostIP.To4() != nil {
+		return false
+	}
+
+	// Setup a binding to  "::" if Host IP is empty and the default binding IP is 0.0.0.0
+	if len(bnd.HostIP) == 0 {
+		if defHostIP.Equal(net.IPv4zero) {
+			bnd.HostIP = net.IPv6zero
+			// If the default binding IP is an IPv6 address, use it
+		} else if defHostIP.To4() == nil {
+			bnd.HostIP = defHostIP
+			// Return false if default binding ip is an IPv4 address
+		} else {
+			return false
+		}
+	}
+	bnd.IP = containerIPv6
+	return true
+
+}
+
+func (n *bridgeNetwork) allocatePort(bnd *types.PortBinding, ulPxyEnabled bool) error {
+	var (
+		host net.Addr
+		err  error
+	)
 
 	// Adjust HostPortEnd if this is not a range.
 	if bnd.HostPortEnd == 0 {
@@ -90,7 +132,7 @@ func (n *bridgeNetwork) allocatePort(bnd *types.PortBinding, containerIP, defHos
 
 	portmapper := n.portMapper
 
-	if containerIP.To4() == nil {
+	if bnd.IP.To4() == nil {
 		portmapper = n.portMapperV6
 	}
 

+ 71 - 0
libnetwork/drivers/bridge/port_mapping_test.go

@@ -95,3 +95,74 @@ func TestPortMappingConfig(t *testing.T) {
 		t.Fatal(err)
 	}
 }
+
+func TestPortMappingV6Config(t *testing.T) {
+	defer testutils.SetupTestOSContext(t)()
+	d := newDriver()
+
+	config := &configuration{
+		EnableIPTables:  true,
+		EnableIP6Tables: true,
+	}
+	genericOption := make(map[string]interface{})
+	genericOption[netlabel.GenericData] = config
+
+	if err := d.configure(genericOption); err != nil {
+		t.Fatalf("Failed to setup driver config: %v", err)
+	}
+
+	portBindings := []types.PortBinding{
+		{Proto: types.UDP, Port: uint16(400), HostPort: uint16(54000)},
+		{Proto: types.TCP, Port: uint16(500), HostPort: uint16(65000)},
+		{Proto: types.SCTP, Port: uint16(500), HostPort: uint16(65000)},
+	}
+
+	sbOptions := make(map[string]interface{})
+	sbOptions[netlabel.PortMap] = portBindings
+	netConfig := &networkConfiguration{
+		BridgeName: DefaultBridgeName,
+		EnableIPv6: true,
+	}
+	netOptions := make(map[string]interface{})
+	netOptions[netlabel.GenericData] = netConfig
+
+	ipdList := getIPv4Data(t, "")
+	err := d.CreateNetwork("dummy", netOptions, nil, ipdList, nil)
+	if err != nil {
+		t.Fatalf("Failed to create bridge: %v", err)
+	}
+
+	te := newTestEndpoint(ipdList[0].Pool, 11)
+	err = d.CreateEndpoint("dummy", "ep1", te.Interface(), nil)
+	if err != nil {
+		t.Fatalf("Failed to create the endpoint: %s", err.Error())
+	}
+
+	if err = d.Join("dummy", "ep1", "sbox", te, sbOptions); err != nil {
+		t.Fatalf("Failed to join the endpoint: %v", err)
+	}
+
+	if err = d.ProgramExternalConnectivity("dummy", "ep1", sbOptions); err != nil {
+		t.Fatalf("Failed to program external connectivity: %v", err)
+	}
+
+	network, ok := d.networks["dummy"]
+	if !ok {
+		t.Fatalf("Cannot find network %s inside driver", "dummy")
+	}
+	ep, _ := network.endpoints["ep1"]
+	if len(ep.portMapping) != 6 {
+		t.Fatalf("Failed to store the port bindings into the sandbox info. Found: %v", ep.portMapping)
+	}
+
+	// release host mapped ports
+	err = d.Leave("dummy", "ep1")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = d.RevokeExternalConnectivity("dummy", "ep1")
+	if err != nil {
+		t.Fatal(err)
+	}
+}

+ 1 - 1
libnetwork/libnetwork_test.go

@@ -199,7 +199,7 @@ func TestBridge(t *testing.T) {
 	if !ok {
 		t.Fatalf("Unexpected format for port mapping in endpoint operational data")
 	}
-	if len(pm) != 5 {
+	if len(pm) != 10 {
 		t.Fatalf("Incomplete data for port mapping in endpoint operational data: %d", len(pm))
 	}
 }

+ 5 - 9
libnetwork/portmapper/mapper.go

@@ -151,20 +151,16 @@ func (pm *PortMapper) MapRange(container net.Addr, hostIP net.IP, hostPortStart,
 	}
 
 	containerIP, containerPort := getIPAndPort(m.container)
-	if pm.checkIP(hostIP) {
-		if err := pm.AppendForwardingTableEntry(m.proto, hostIP, allocatedHostPort, containerIP.String(), containerPort); err != nil {
-			return nil, err
-		}
+	if err := pm.AppendForwardingTableEntry(m.proto, hostIP, allocatedHostPort, containerIP.String(), containerPort); err != nil {
+		return nil, err
 	}
 
 	cleanup := func() error {
 		// need to undo the iptables rules before we return
 		m.userlandProxy.Stop()
-		if pm.checkIP(hostIP) {
-			pm.DeleteForwardingTableEntry(m.proto, hostIP, allocatedHostPort, containerIP.String(), containerPort)
-			if err := pm.Allocator.ReleasePort(hostIP, m.proto, allocatedHostPort); err != nil {
-				return err
-			}
+		pm.DeleteForwardingTableEntry(m.proto, hostIP, allocatedHostPort, containerIP.String(), containerPort)
+		if err := pm.Allocator.ReleasePort(hostIP, m.proto, allocatedHostPort); err != nil {
+			return err
 		}
 
 		return nil

+ 0 - 8
libnetwork/portmapper/mapper_linux.go

@@ -44,11 +44,3 @@ func (pm *PortMapper) forward(action iptables.Action, proto string, sourceIP net
 	}
 	return pm.chain.Forward(action, sourceIP, sourcePort, proto, containerIP, containerPort, pm.bridgeName)
 }
-
-// checkIP checks if IP is valid and matching to chain version
-func (pm *PortMapper) checkIP(ip net.IP) bool {
-	if pm.chain == nil || pm.chain.IPTable.Version == iptables.IPv4 {
-		return ip.To4() != nil
-	}
-	return ip.To16() != nil
-}

+ 24 - 8
libnetwork/portmapper/proxy.go

@@ -19,6 +19,16 @@ type userlandProxy interface {
 	Stop() error
 }
 
+// ipVersion refers to IP version - v4 or v6
+type ipVersion string
+
+const (
+	// IPv4 is version 4
+	ipv4 ipVersion = "4"
+	// IPv4 is version 6
+	ipv6 ipVersion = "6"
+)
+
 // proxyCommand wraps an exec.Cmd to run the userland TCP and UDP
 // proxies as separate processes.
 type proxyCommand struct {
@@ -77,21 +87,27 @@ func (p *proxyCommand) Stop() error {
 // port allocations on bound port, because without userland proxy we using
 // iptables rules and not net.Listen
 type dummyProxy struct {
-	listener io.Closer
-	addr     net.Addr
+	listener  io.Closer
+	addr      net.Addr
+	ipVersion ipVersion
 }
 
 func newDummyProxy(proto string, hostIP net.IP, hostPort int) (userlandProxy, error) {
+	// detect version of hostIP to bind only to correct version
+	version := ipv4
+	if hostIP.To4() == nil {
+		version = ipv6
+	}
 	switch proto {
 	case "tcp":
 		addr := &net.TCPAddr{IP: hostIP, Port: hostPort}
-		return &dummyProxy{addr: addr}, nil
+		return &dummyProxy{addr: addr, ipVersion: version}, nil
 	case "udp":
 		addr := &net.UDPAddr{IP: hostIP, Port: hostPort}
-		return &dummyProxy{addr: addr}, nil
+		return &dummyProxy{addr: addr, ipVersion: version}, nil
 	case "sctp":
 		addr := &sctp.SCTPAddr{IPAddrs: []net.IPAddr{{IP: hostIP}}, Port: hostPort}
-		return &dummyProxy{addr: addr}, nil
+		return &dummyProxy{addr: addr, ipVersion: version}, nil
 	default:
 		return nil, fmt.Errorf("Unknown addr type: %s", proto)
 	}
@@ -100,19 +116,19 @@ func newDummyProxy(proto string, hostIP net.IP, hostPort int) (userlandProxy, er
 func (p *dummyProxy) Start() error {
 	switch addr := p.addr.(type) {
 	case *net.TCPAddr:
-		l, err := net.ListenTCP("tcp", addr)
+		l, err := net.ListenTCP("tcp"+string(p.ipVersion), addr)
 		if err != nil {
 			return err
 		}
 		p.listener = l
 	case *net.UDPAddr:
-		l, err := net.ListenUDP("udp", addr)
+		l, err := net.ListenUDP("udp"+string(p.ipVersion), addr)
 		if err != nil {
 			return err
 		}
 		p.listener = l
 	case *sctp.SCTPAddr:
-		l, err := sctp.ListenSCTP("sctp", addr)
+		l, err := sctp.ListenSCTP("sctp"+string(p.ipVersion), addr)
 		if err != nil {
 			return err
 		}