Переглянути джерело

Merge pull request #1177 from lopter/udp-support-final

* Network: Add UDP support
Victor Vieux 12 роки тому
батько
коміт
b7a62f1f1b
6 змінених файлів з 784 додано та 149 видалено
  1. 13 4
      container.go
  2. 130 91
      network.go
  3. 257 0
      network_proxy.go
  4. 221 0
      network_proxy_test.go
  5. 75 6
      network_test.go
  6. 88 48
      runtime_test.go

+ 13 - 4
container.go

@@ -202,20 +202,25 @@ func ParseRun(args []string, capabilities *Capabilities) (*Config, *HostConfig,
 	return config, hostConfig, cmd, nil
 }
 
+type portMapping map[string]string
+
 type NetworkSettings struct {
 	IPAddress   string
 	IPPrefixLen int
 	Gateway     string
 	Bridge      string
-	PortMapping map[string]string
+	PortMapping map[string]portMapping
 }
 
 // String returns a human-readable description of the port mapping defined in the settings
 func (settings *NetworkSettings) PortMappingHuman() string {
 	var mapping []string
-	for private, public := range settings.PortMapping {
+	for private, public := range settings.PortMapping["Tcp"] {
 		mapping = append(mapping, fmt.Sprintf("%s->%s", public, private))
 	}
+	for private, public := range settings.PortMapping["Udp"] {
+		mapping = append(mapping, fmt.Sprintf("%s->%s/udp", public, private))
+	}
 	sort.Strings(mapping)
 	return strings.Join(mapping, ", ")
 }
@@ -688,14 +693,18 @@ func (container *Container) allocateNetwork() error {
 	if err != nil {
 		return err
 	}
-	container.NetworkSettings.PortMapping = make(map[string]string)
+	container.NetworkSettings.PortMapping = make(map[string]portMapping)
+	container.NetworkSettings.PortMapping["Tcp"] = make(portMapping)
+	container.NetworkSettings.PortMapping["Udp"] = make(portMapping)
 	for _, spec := range container.Config.PortSpecs {
 		nat, err := iface.AllocatePort(spec)
 		if err != nil {
 			iface.Release()
 			return err
 		}
-		container.NetworkSettings.PortMapping[strconv.Itoa(nat.Backend)] = strconv.Itoa(nat.Frontend)
+		proto := strings.Title(nat.Proto)
+		backend, frontend := strconv.Itoa(nat.Backend), strconv.Itoa(nat.Frontend)
+		container.NetworkSettings.PortMapping[proto][backend] = frontend
 	}
 	container.network = iface
 	container.NetworkSettings.Bridge = container.runtime.networkManager.bridgeIface

+ 130 - 91
network.go

@@ -5,7 +5,6 @@ import (
 	"errors"
 	"fmt"
 	"github.com/dotcloud/docker/utils"
-	"io"
 	"log"
 	"net"
 	"os/exec"
@@ -183,8 +182,10 @@ func getIfaceAddr(name string) (net.Addr, error) {
 // up iptables rules.
 // It keeps track of all mappings and is able to unmap at will
 type PortMapper struct {
-	mapping map[int]net.TCPAddr
-	proxies map[int]net.Listener
+	tcpMapping map[int]*net.TCPAddr
+	tcpProxies map[int]Proxy
+	udpMapping map[int]*net.UDPAddr
+	udpProxies map[int]Proxy
 }
 
 func (mapper *PortMapper) cleanup() error {
@@ -197,8 +198,10 @@ func (mapper *PortMapper) cleanup() error {
 	iptables("-t", "nat", "-D", "OUTPUT", "-j", "DOCKER")
 	iptables("-t", "nat", "-F", "DOCKER")
 	iptables("-t", "nat", "-X", "DOCKER")
-	mapper.mapping = make(map[int]net.TCPAddr)
-	mapper.proxies = make(map[int]net.Listener)
+	mapper.tcpMapping = make(map[int]*net.TCPAddr)
+	mapper.tcpProxies = make(map[int]Proxy)
+	mapper.udpMapping = make(map[int]*net.UDPAddr)
+	mapper.udpProxies = make(map[int]Proxy)
 	return nil
 }
 
@@ -215,76 +218,72 @@ func (mapper *PortMapper) setup() error {
 	return nil
 }
 
-func (mapper *PortMapper) iptablesForward(rule string, port int, dest net.TCPAddr) error {
-	return iptables("-t", "nat", rule, "DOCKER", "-p", "tcp", "--dport", strconv.Itoa(port),
-		"-j", "DNAT", "--to-destination", net.JoinHostPort(dest.IP.String(), strconv.Itoa(dest.Port)))
+func (mapper *PortMapper) iptablesForward(rule string, port int, proto string, dest_addr string, dest_port int) error {
+	return iptables("-t", "nat", rule, "DOCKER", "-p", proto, "--dport", strconv.Itoa(port),
+		"-j", "DNAT", "--to-destination", net.JoinHostPort(dest_addr, strconv.Itoa(dest_port)))
 }
 
-func (mapper *PortMapper) Map(port int, dest net.TCPAddr) error {
-	if err := mapper.iptablesForward("-A", port, dest); err != nil {
-		return err
-	}
-
-	mapper.mapping[port] = dest
-	listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
-	if err != nil {
-		mapper.Unmap(port)
-		return err
-	}
-	mapper.proxies[port] = listener
-	go proxy(listener, "tcp", dest.String())
-	return nil
-}
-
-// proxy listens for socket connections on `listener`, and forwards them unmodified
-// to `proto:address`
-func proxy(listener net.Listener, proto, address string) error {
-	utils.Debugf("proxying to %s:%s", proto, address)
-	defer utils.Debugf("Done proxying to %s:%s", proto, address)
-	for {
-		utils.Debugf("Listening on %s", listener)
-		src, err := listener.Accept()
+func (mapper *PortMapper) Map(port int, backendAddr net.Addr) error {
+	if _, isTCP := backendAddr.(*net.TCPAddr); isTCP {
+		backendPort := backendAddr.(*net.TCPAddr).Port
+		backendIP := backendAddr.(*net.TCPAddr).IP
+		if err := mapper.iptablesForward("-A", port, "tcp", backendIP.String(), backendPort); err != nil {
+			return err
+		}
+		mapper.tcpMapping[port] = backendAddr.(*net.TCPAddr)
+		proxy, err := NewProxy(&net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port}, backendAddr)
 		if err != nil {
+			mapper.Unmap(port, "tcp")
 			return err
 		}
-		utils.Debugf("Connecting to %s:%s", proto, address)
-		dst, err := net.Dial(proto, address)
+		mapper.tcpProxies[port] = proxy
+		go proxy.Run()
+	} else {
+		backendPort := backendAddr.(*net.UDPAddr).Port
+		backendIP := backendAddr.(*net.UDPAddr).IP
+		if err := mapper.iptablesForward("-A", port, "udp", backendIP.String(), backendPort); err != nil {
+			return err
+		}
+		mapper.udpMapping[port] = backendAddr.(*net.UDPAddr)
+		proxy, err := NewProxy(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port}, backendAddr)
 		if err != nil {
-			log.Printf("Error connecting to %s:%s: %s", proto, address, err)
-			src.Close()
-			continue
+			mapper.Unmap(port, "udp")
+			return err
 		}
-		utils.Debugf("Connected to backend, splicing")
-		splice(src, dst)
+		mapper.udpProxies[port] = proxy
+		go proxy.Run()
 	}
+	return nil
 }
 
-func halfSplice(dst, src net.Conn) error {
-	_, err := io.Copy(dst, src)
-	// FIXME: on EOF from a tcp connection, pass WriteClose()
-	dst.Close()
-	src.Close()
-	return err
-}
-
-func splice(a, b net.Conn) {
-	go halfSplice(a, b)
-	go halfSplice(b, a)
-}
-
-func (mapper *PortMapper) Unmap(port int) error {
-	dest, ok := mapper.mapping[port]
-	if !ok {
-		return errors.New("Port is not mapped")
-	}
-	if proxy, exists := mapper.proxies[port]; exists {
-		proxy.Close()
-		delete(mapper.proxies, port)
-	}
-	if err := mapper.iptablesForward("-D", port, dest); err != nil {
-		return err
+func (mapper *PortMapper) Unmap(port int, proto string) error {
+	if proto == "tcp" {
+		backendAddr, ok := mapper.tcpMapping[port]
+		if !ok {
+			return fmt.Errorf("Port tcp/%v is not mapped", port)
+		}
+		if proxy, exists := mapper.tcpProxies[port]; exists {
+			proxy.Close()
+			delete(mapper.tcpProxies, port)
+		}
+		if err := mapper.iptablesForward("-D", port, proto, backendAddr.IP.String(), backendAddr.Port); err != nil {
+			return err
+		}
+		delete(mapper.tcpMapping, port)
+	} else {
+		backendAddr, ok := mapper.udpMapping[port]
+		if !ok {
+			return fmt.Errorf("Port udp/%v is not mapped", port)
+		}
+		if proxy, exists := mapper.udpProxies[port]; exists {
+			proxy.Close()
+			delete(mapper.udpProxies, port)
+		}
+		if err := mapper.iptablesForward("-D", port, proto, backendAddr.IP.String(), backendAddr.Port); err != nil {
+			return err
+		}
+		delete(mapper.udpMapping, port)
 	}
-	delete(mapper.mapping, port)
 	return nil
 }
 
@@ -453,7 +452,7 @@ type NetworkInterface struct {
 	Gateway net.IP
 
 	manager  *NetworkManager
-	extPorts []int
+	extPorts []*Nat
 }
 
 // Allocate an external TCP port and map it to the interface
@@ -462,17 +461,32 @@ func (iface *NetworkInterface) AllocatePort(spec string) (*Nat, error) {
 	if err != nil {
 		return nil, err
 	}
-	// Allocate a random port if Frontend==0
-	extPort, err := iface.manager.portAllocator.Acquire(nat.Frontend)
-	if err != nil {
-		return nil, err
-	}
-	nat.Frontend = extPort
-	if err := iface.manager.portMapper.Map(nat.Frontend, net.TCPAddr{IP: iface.IPNet.IP, Port: nat.Backend}); err != nil {
-		iface.manager.portAllocator.Release(nat.Frontend)
-		return nil, err
+
+	if nat.Proto == "tcp" {
+		extPort, err := iface.manager.tcpPortAllocator.Acquire(nat.Frontend)
+		if err != nil {
+			return nil, err
+		}
+		backend := &net.TCPAddr{IP: iface.IPNet.IP, Port: nat.Backend}
+		if err := iface.manager.portMapper.Map(extPort, backend); err != nil {
+			iface.manager.tcpPortAllocator.Release(extPort)
+			return nil, err
+		}
+		nat.Frontend = extPort
+	} else {
+		extPort, err := iface.manager.udpPortAllocator.Acquire(nat.Frontend)
+		if err != nil {
+			return nil, err
+		}
+		backend := &net.UDPAddr{IP: iface.IPNet.IP, Port: nat.Backend}
+		if err := iface.manager.portMapper.Map(extPort, backend); err != nil {
+			iface.manager.udpPortAllocator.Release(extPort)
+			return nil, err
+		}
+		nat.Frontend = extPort
 	}
-	iface.extPorts = append(iface.extPorts, nat.Frontend)
+	iface.extPorts = append(iface.extPorts, nat)
+
 	return nat, nil
 }
 
@@ -485,6 +499,21 @@ type Nat struct {
 func parseNat(spec string) (*Nat, error) {
 	var nat Nat
 
+	if strings.Contains(spec, "/") {
+		specParts := strings.Split(spec, "/")
+		if len(specParts) != 2 {
+			return nil, fmt.Errorf("Invalid port format.")
+		}
+		proto := specParts[1]
+		spec = specParts[0]
+		if proto != "tcp" && proto != "udp" {
+			return nil, fmt.Errorf("Invalid port format: unknown protocol %v.", proto)
+		}
+		nat.Proto = proto
+	} else {
+		nat.Proto = "tcp"
+	}
+
 	if strings.Contains(spec, ":") {
 		specParts := strings.Split(spec, ":")
 		if len(specParts) != 2 {
@@ -517,20 +546,24 @@ func parseNat(spec string) (*Nat, error) {
 		}
 		nat.Backend = int(port)
 	}
-	nat.Proto = "tcp"
+
 	return &nat, nil
 }
 
 // Release: Network cleanup - release all resources
 func (iface *NetworkInterface) Release() {
-	for _, port := range iface.extPorts {
-		if err := iface.manager.portMapper.Unmap(port); err != nil {
-			log.Printf("Unable to unmap port %v: %v", port, err)
+	for _, nat := range iface.extPorts {
+		utils.Debugf("Unmaping %v/%v", nat.Proto, nat.Frontend)
+		if err := iface.manager.portMapper.Unmap(nat.Frontend, nat.Proto); err != nil {
+			log.Printf("Unable to unmap port %v/%v: %v", nat.Proto, nat.Frontend, err)
 		}
-		if err := iface.manager.portAllocator.Release(port); err != nil {
-			log.Printf("Unable to release port %v: %v", port, err)
+		if nat.Proto == "tcp" {
+			if err := iface.manager.tcpPortAllocator.Release(nat.Frontend); err != nil {
+				log.Printf("Unable to release port tcp/%v: %v", nat.Frontend, err)
+			}
+		} else if err := iface.manager.udpPortAllocator.Release(nat.Frontend); err != nil {
+			log.Printf("Unable to release port udp/%v: %v", nat.Frontend, err)
 		}
-
 	}
 
 	iface.manager.ipAllocator.Release(iface.IPNet.IP)
@@ -542,9 +575,10 @@ type NetworkManager struct {
 	bridgeIface   string
 	bridgeNetwork *net.IPNet
 
-	ipAllocator   *IPAllocator
-	portAllocator *PortAllocator
-	portMapper    *PortMapper
+	ipAllocator      *IPAllocator
+	tcpPortAllocator *PortAllocator
+	udpPortAllocator *PortAllocator
+	portMapper       *PortMapper
 }
 
 // Allocate a network interface
@@ -577,7 +611,11 @@ func newNetworkManager(bridgeIface string) (*NetworkManager, error) {
 
 	ipAllocator := newIPAllocator(network)
 
-	portAllocator, err := newPortAllocator()
+	tcpPortAllocator, err := newPortAllocator()
+	if err != nil {
+		return nil, err
+	}
+	udpPortAllocator, err := newPortAllocator()
 	if err != nil {
 		return nil, err
 	}
@@ -588,11 +626,12 @@ func newNetworkManager(bridgeIface string) (*NetworkManager, error) {
 	}
 
 	manager := &NetworkManager{
-		bridgeIface:   bridgeIface,
-		bridgeNetwork: network,
-		ipAllocator:   ipAllocator,
-		portAllocator: portAllocator,
-		portMapper:    portMapper,
+		bridgeIface:      bridgeIface,
+		bridgeNetwork:    network,
+		ipAllocator:      ipAllocator,
+		tcpPortAllocator: tcpPortAllocator,
+		udpPortAllocator: udpPortAllocator,
+		portMapper:       portMapper,
 	}
 	return manager, nil
 }

+ 257 - 0
network_proxy.go

@@ -0,0 +1,257 @@
+package docker
+
+import (
+	"encoding/binary"
+	"fmt"
+	"github.com/dotcloud/docker/utils"
+	"io"
+	"log"
+	"net"
+	"sync"
+	"syscall"
+	"time"
+)
+
+const (
+	UDPConnTrackTimeout = 90 * time.Second
+	UDPBufSize          = 2048
+)
+
+type Proxy interface {
+	// Start forwarding traffic back and forth the front and back-end
+	// addresses.
+	Run()
+	// Stop forwarding traffic and close both ends of the Proxy.
+	Close()
+	// Return the address on which the proxy is listening.
+	FrontendAddr() net.Addr
+	// Return the proxied address.
+	BackendAddr() net.Addr
+}
+
+type TCPProxy struct {
+	listener     *net.TCPListener
+	frontendAddr *net.TCPAddr
+	backendAddr  *net.TCPAddr
+}
+
+func NewTCPProxy(frontendAddr, backendAddr *net.TCPAddr) (*TCPProxy, error) {
+	listener, err := net.ListenTCP("tcp", frontendAddr)
+	if err != nil {
+		return nil, err
+	}
+	// If the port in frontendAddr was 0 then ListenTCP will have a picked
+	// a port to listen on, hence the call to Addr to get that actual port:
+	return &TCPProxy{
+		listener:     listener,
+		frontendAddr: listener.Addr().(*net.TCPAddr),
+		backendAddr:  backendAddr,
+	}, nil
+}
+
+func (proxy *TCPProxy) clientLoop(client *net.TCPConn, quit chan bool) {
+	backend, err := net.DialTCP("tcp", nil, proxy.backendAddr)
+	if err != nil {
+		log.Printf("Can't forward traffic to backend tcp/%v: %v\n", proxy.backendAddr, err.Error())
+		client.Close()
+		return
+	}
+
+	event := make(chan int64)
+	var broker = func(to, from *net.TCPConn) {
+		written, err := io.Copy(to, from)
+		if err != nil {
+			err, ok := err.(*net.OpError)
+			// If the socket we are writing to is shutdown with
+			// SHUT_WR, forward it to the other end of the pipe:
+			if ok && err.Err == syscall.EPIPE {
+				from.CloseWrite()
+			}
+		}
+		event <- written
+	}
+	utils.Debugf("Forwarding traffic between tcp/%v and tcp/%v", client.RemoteAddr(), backend.RemoteAddr())
+	go broker(client, backend)
+	go broker(backend, client)
+
+	var transferred int64 = 0
+	for i := 0; i < 2; i++ {
+		select {
+		case written := <-event:
+			transferred += written
+		case <-quit:
+			// Interrupt the two brokers and "join" them.
+			client.Close()
+			backend.Close()
+			for ; i < 2; i++ {
+				transferred += <-event
+			}
+			goto done
+		}
+	}
+	client.Close()
+	backend.Close()
+done:
+	utils.Debugf("%v bytes transferred between tcp/%v and tcp/%v", transferred, client.RemoteAddr(), backend.RemoteAddr())
+}
+
+func (proxy *TCPProxy) Run() {
+	quit := make(chan bool)
+	defer close(quit)
+	utils.Debugf("Starting proxy on tcp/%v for tcp/%v", proxy.frontendAddr, proxy.backendAddr)
+	for {
+		client, err := proxy.listener.Accept()
+		if err != nil {
+			utils.Debugf("Stopping proxy on tcp/%v for tcp/%v (%v)", proxy.frontendAddr, proxy.backendAddr, err.Error())
+			return
+		}
+		go proxy.clientLoop(client.(*net.TCPConn), quit)
+	}
+}
+
+func (proxy *TCPProxy) Close()                 { proxy.listener.Close() }
+func (proxy *TCPProxy) FrontendAddr() net.Addr { return proxy.frontendAddr }
+func (proxy *TCPProxy) BackendAddr() net.Addr  { return proxy.backendAddr }
+
+// A net.Addr where the IP is split into two fields so you can use it as a key
+// in a map:
+type connTrackKey struct {
+	IPHigh uint64
+	IPLow  uint64
+	Port   int
+}
+
+func newConnTrackKey(addr *net.UDPAddr) *connTrackKey {
+	if len(addr.IP) == net.IPv4len {
+		return &connTrackKey{
+			IPHigh: 0,
+			IPLow:  uint64(binary.BigEndian.Uint32(addr.IP)),
+			Port:   addr.Port,
+		}
+	}
+	return &connTrackKey{
+		IPHigh: binary.BigEndian.Uint64(addr.IP[:8]),
+		IPLow:  binary.BigEndian.Uint64(addr.IP[8:]),
+		Port:   addr.Port,
+	}
+}
+
+type connTrackMap map[connTrackKey]*net.UDPConn
+
+type UDPProxy struct {
+	listener       *net.UDPConn
+	frontendAddr   *net.UDPAddr
+	backendAddr    *net.UDPAddr
+	connTrackTable connTrackMap
+	connTrackLock  sync.Mutex
+}
+
+func NewUDPProxy(frontendAddr, backendAddr *net.UDPAddr) (*UDPProxy, error) {
+	listener, err := net.ListenUDP("udp", frontendAddr)
+	if err != nil {
+		return nil, err
+	}
+	return &UDPProxy{
+		listener:       listener,
+		frontendAddr:   listener.LocalAddr().(*net.UDPAddr),
+		backendAddr:    backendAddr,
+		connTrackTable: make(connTrackMap),
+	}, nil
+}
+
+func (proxy *UDPProxy) replyLoop(proxyConn *net.UDPConn, clientAddr *net.UDPAddr, clientKey *connTrackKey) {
+	defer func() {
+		proxy.connTrackLock.Lock()
+		delete(proxy.connTrackTable, *clientKey)
+		proxy.connTrackLock.Unlock()
+		utils.Debugf("Done proxying between udp/%v and udp/%v", clientAddr.String(), proxy.backendAddr.String())
+		proxyConn.Close()
+	}()
+
+	readBuf := make([]byte, UDPBufSize)
+	for {
+		proxyConn.SetReadDeadline(time.Now().Add(UDPConnTrackTimeout))
+	again:
+		read, err := proxyConn.Read(readBuf)
+		if err != nil {
+			if err, ok := err.(*net.OpError); ok && err.Err == syscall.ECONNREFUSED {
+				// This will happen if the last write failed
+				// (e.g: nothing is actually listening on the
+				// proxied port on the container), ignore it
+				// and continue until UDPConnTrackTimeout
+				// expires:
+				goto again
+			}
+			return
+		}
+		for i := 0; i != read; {
+			written, err := proxy.listener.WriteToUDP(readBuf[i:read], clientAddr)
+			if err != nil {
+				return
+			}
+			i += written
+			utils.Debugf("Forwarded %v/%v bytes to udp/%v", i, read, clientAddr.String())
+		}
+	}
+}
+
+func (proxy *UDPProxy) Run() {
+	readBuf := make([]byte, UDPBufSize)
+	utils.Debugf("Starting proxy on udp/%v for udp/%v", proxy.frontendAddr, proxy.backendAddr)
+	for {
+		read, from, err := proxy.listener.ReadFromUDP(readBuf)
+		if err != nil {
+			// NOTE: Apparently ReadFrom doesn't return
+			// ECONNREFUSED like Read do (see comment in
+			// UDPProxy.replyLoop)
+			utils.Debugf("Stopping proxy on udp/%v for udp/%v (%v)", proxy.frontendAddr, proxy.backendAddr, err.Error())
+			break
+		}
+
+		fromKey := newConnTrackKey(from)
+		proxy.connTrackLock.Lock()
+		proxyConn, hit := proxy.connTrackTable[*fromKey]
+		if !hit {
+			proxyConn, err = net.DialUDP("udp", nil, proxy.backendAddr)
+			if err != nil {
+				log.Printf("Can't proxy a datagram to udp/%s: %v\n", proxy.backendAddr.String(), err)
+				continue
+			}
+			proxy.connTrackTable[*fromKey] = proxyConn
+			go proxy.replyLoop(proxyConn, from, fromKey)
+		}
+		proxy.connTrackLock.Unlock()
+		for i := 0; i != read; {
+			written, err := proxyConn.Write(readBuf[i:read])
+			if err != nil {
+				log.Printf("Can't proxy a datagram to udp/%s: %v\n", proxy.backendAddr.String(), err)
+				break
+			}
+			i += written
+			utils.Debugf("Forwarded %v/%v bytes to udp/%v", i, read, proxy.backendAddr.String())
+		}
+	}
+}
+
+func (proxy *UDPProxy) Close() {
+	proxy.listener.Close()
+	proxy.connTrackLock.Lock()
+	defer proxy.connTrackLock.Unlock()
+	for _, conn := range proxy.connTrackTable {
+		conn.Close()
+	}
+}
+
+func (proxy *UDPProxy) FrontendAddr() net.Addr { return proxy.frontendAddr }
+func (proxy *UDPProxy) BackendAddr() net.Addr  { return proxy.backendAddr }
+
+func NewProxy(frontendAddr, backendAddr net.Addr) (Proxy, error) {
+	switch frontendAddr.(type) {
+	case *net.UDPAddr:
+		return NewUDPProxy(frontendAddr.(*net.UDPAddr), backendAddr.(*net.UDPAddr))
+	case *net.TCPAddr:
+		return NewTCPProxy(frontendAddr.(*net.TCPAddr), backendAddr.(*net.TCPAddr))
+	default:
+		panic(fmt.Errorf("Unsupported protocol"))
+	}
+}

+ 221 - 0
network_proxy_test.go

@@ -0,0 +1,221 @@
+package docker
+
+import (
+	"bytes"
+	"fmt"
+	"io"
+	"net"
+	"strings"
+	"testing"
+	"time"
+)
+
+var testBuf = []byte("Buffalo buffalo Buffalo buffalo buffalo buffalo Buffalo buffalo")
+var testBufSize = len(testBuf)
+
+type EchoServer interface {
+	Run()
+	Close()
+	LocalAddr() net.Addr
+}
+
+type TCPEchoServer struct {
+	listener net.Listener
+	testCtx  *testing.T
+}
+
+type UDPEchoServer struct {
+	conn    net.PacketConn
+	testCtx *testing.T
+}
+
+func NewEchoServer(t *testing.T, proto, address string) EchoServer {
+	var server EchoServer
+	if strings.HasPrefix(proto, "tcp") {
+		listener, err := net.Listen(proto, address)
+		if err != nil {
+			t.Fatal(err)
+		}
+		server = &TCPEchoServer{listener: listener, testCtx: t}
+	} else {
+		socket, err := net.ListenPacket(proto, address)
+		if err != nil {
+			t.Fatal(err)
+		}
+		server = &UDPEchoServer{conn: socket, testCtx: t}
+	}
+	t.Logf("EchoServer listening on %v/%v\n", proto, server.LocalAddr().String())
+	return server
+}
+
+func (server *TCPEchoServer) Run() {
+	go func() {
+		for {
+			client, err := server.listener.Accept()
+			if err != nil {
+				return
+			}
+			go func(client net.Conn) {
+				server.testCtx.Logf("TCP client accepted on the EchoServer\n")
+				written, err := io.Copy(client, client)
+				server.testCtx.Logf("%v bytes echoed back to the client\n", written)
+				if err != nil {
+					server.testCtx.Logf("can't echo to the client: %v\n", err.Error())
+				}
+				client.Close()
+			}(client)
+		}
+	}()
+}
+
+func (server *TCPEchoServer) LocalAddr() net.Addr { return server.listener.Addr() }
+func (server *TCPEchoServer) Close()              { server.listener.Addr() }
+
+func (server *UDPEchoServer) Run() {
+	go func() {
+		readBuf := make([]byte, 1024)
+		for {
+			read, from, err := server.conn.ReadFrom(readBuf)
+			if err != nil {
+				return
+			}
+			server.testCtx.Logf("Writing UDP datagram back")
+			for i := 0; i != read; {
+				written, err := server.conn.WriteTo(readBuf[i:read], from)
+				if err != nil {
+					break
+				}
+				i += written
+			}
+		}
+	}()
+}
+
+func (server *UDPEchoServer) LocalAddr() net.Addr { return server.conn.LocalAddr() }
+func (server *UDPEchoServer) Close()              { server.conn.Close() }
+
+func testProxyAt(t *testing.T, proto string, proxy Proxy, addr string) {
+	defer proxy.Close()
+	go proxy.Run()
+	client, err := net.Dial(proto, addr)
+	if err != nil {
+		t.Fatalf("Can't connect to the proxy: %v", err)
+	}
+	defer client.Close()
+	client.SetDeadline(time.Now().Add(10 * time.Second))
+	if _, err = client.Write(testBuf); err != nil {
+		t.Fatal(err)
+	}
+	recvBuf := make([]byte, testBufSize)
+	if _, err = client.Read(recvBuf); err != nil {
+		t.Fatal(err)
+	}
+	if !bytes.Equal(testBuf, recvBuf) {
+		t.Fatal(fmt.Errorf("Expected [%v] but got [%v]", testBuf, recvBuf))
+	}
+}
+
+func testProxy(t *testing.T, proto string, proxy Proxy) {
+	testProxyAt(t, proto, proxy, proxy.FrontendAddr().String())
+}
+
+func TestTCP4Proxy(t *testing.T) {
+	backend := NewEchoServer(t, "tcp", "127.0.0.1:0")
+	defer backend.Close()
+	backend.Run()
+	frontendAddr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
+	proxy, err := NewProxy(frontendAddr, backend.LocalAddr())
+	if err != nil {
+		t.Fatal(err)
+	}
+	testProxy(t, "tcp", proxy)
+}
+
+func TestTCP6Proxy(t *testing.T) {
+	backend := NewEchoServer(t, "tcp", "[::1]:0")
+	defer backend.Close()
+	backend.Run()
+	frontendAddr := &net.TCPAddr{IP: net.IPv6loopback, Port: 0}
+	proxy, err := NewProxy(frontendAddr, backend.LocalAddr())
+	if err != nil {
+		t.Fatal(err)
+	}
+	testProxy(t, "tcp", proxy)
+}
+
+func TestTCPDualStackProxy(t *testing.T) {
+	// If I understand `godoc -src net favoriteAddrFamily` (used by the
+	// net.Listen* functions) correctly this should work, but it doesn't.
+	t.Skip("No support for dual stack yet")
+	backend := NewEchoServer(t, "tcp", "[::1]:0")
+	defer backend.Close()
+	backend.Run()
+	frontendAddr := &net.TCPAddr{IP: net.IPv6loopback, Port: 0}
+	proxy, err := NewProxy(frontendAddr, backend.LocalAddr())
+	if err != nil {
+		t.Fatal(err)
+	}
+	ipv4ProxyAddr := &net.TCPAddr{
+		IP:   net.IPv4(127, 0, 0, 1),
+		Port: proxy.FrontendAddr().(*net.TCPAddr).Port,
+	}
+	testProxyAt(t, "tcp", proxy, ipv4ProxyAddr.String())
+}
+
+func TestUDP4Proxy(t *testing.T) {
+	backend := NewEchoServer(t, "udp", "127.0.0.1:0")
+	defer backend.Close()
+	backend.Run()
+	frontendAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
+	proxy, err := NewProxy(frontendAddr, backend.LocalAddr())
+	if err != nil {
+		t.Fatal(err)
+	}
+	testProxy(t, "udp", proxy)
+}
+
+func TestUDP6Proxy(t *testing.T) {
+	backend := NewEchoServer(t, "udp", "[::1]:0")
+	defer backend.Close()
+	backend.Run()
+	frontendAddr := &net.UDPAddr{IP: net.IPv6loopback, Port: 0}
+	proxy, err := NewProxy(frontendAddr, backend.LocalAddr())
+	if err != nil {
+		t.Fatal(err)
+	}
+	testProxy(t, "udp", proxy)
+}
+
+func TestUDPWriteError(t *testing.T) {
+	frontendAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
+	// Hopefully, this port will be free: */
+	backendAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 25587}
+	proxy, err := NewProxy(frontendAddr, backendAddr)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer proxy.Close()
+	go proxy.Run()
+	client, err := net.Dial("udp", "127.0.0.1:25587")
+	if err != nil {
+		t.Fatalf("Can't connect to the proxy: %v", err)
+	}
+	defer client.Close()
+	// Make sure the proxy doesn't stop when there is no actual backend:
+	client.Write(testBuf)
+	client.Write(testBuf)
+	backend := NewEchoServer(t, "udp", "127.0.0.1:25587")
+	defer backend.Close()
+	backend.Run()
+	client.SetDeadline(time.Now().Add(10 * time.Second))
+	if _, err = client.Write(testBuf); err != nil {
+		t.Fatal(err)
+	}
+	recvBuf := make([]byte, testBufSize)
+	if _, err = client.Read(recvBuf); err != nil {
+		t.Fatal(err)
+	}
+	if !bytes.Equal(testBuf, recvBuf) {
+		t.Fatal(fmt.Errorf("Expected [%v] but got [%v]", testBuf, recvBuf))
+	}
+}

+ 75 - 6
network_test.go

@@ -20,28 +20,97 @@ func TestIptables(t *testing.T) {
 
 func TestParseNat(t *testing.T) {
 	if nat, err := parseNat("4500"); err == nil {
-		if nat.Frontend != 0 || nat.Backend != 4500 {
-			t.Errorf("-p 4500 should produce 0->4500, got %d->%d", nat.Frontend, nat.Backend)
+		if nat.Frontend != 0 || nat.Backend != 4500 || nat.Proto != "tcp" {
+			t.Errorf("-p 4500 should produce 0->4500/tcp, got %d->%d/%s",
+				nat.Frontend, nat.Backend, nat.Proto)
 		}
 	} else {
 		t.Fatal(err)
 	}
 
 	if nat, err := parseNat(":4501"); err == nil {
-		if nat.Frontend != 4501 || nat.Backend != 4501 {
-			t.Errorf("-p :4501 should produce 4501->4501, got %d->%d", nat.Frontend, nat.Backend)
+		if nat.Frontend != 4501 || nat.Backend != 4501 || nat.Proto != "tcp" {
+			t.Errorf("-p :4501 should produce 4501->4501/tcp, got %d->%d/%s",
+				nat.Frontend, nat.Backend, nat.Proto)
 		}
 	} else {
 		t.Fatal(err)
 	}
 
 	if nat, err := parseNat("4502:4503"); err == nil {
-		if nat.Frontend != 4502 || nat.Backend != 4503 {
-			t.Errorf("-p 4502:4503 should produce 4502->4503, got %d->%d", nat.Frontend, nat.Backend)
+		if nat.Frontend != 4502 || nat.Backend != 4503 || nat.Proto != "tcp" {
+			t.Errorf("-p 4502:4503 should produce 4502->4503/tcp, got %d->%d/%s",
+				nat.Frontend, nat.Backend, nat.Proto)
 		}
 	} else {
 		t.Fatal(err)
 	}
+
+	if nat, err := parseNat("4502:4503/tcp"); err == nil {
+		if nat.Frontend != 4502 || nat.Backend != 4503 || nat.Proto != "tcp" {
+			t.Errorf("-p 4502:4503/tcp should produce 4502->4503/tcp, got %d->%d/%s",
+				nat.Frontend, nat.Backend, nat.Proto)
+		}
+	} else {
+		t.Fatal(err)
+	}
+
+	if nat, err := parseNat("4502:4503/udp"); err == nil {
+		if nat.Frontend != 4502 || nat.Backend != 4503 || nat.Proto != "udp" {
+			t.Errorf("-p 4502:4503/udp should produce 4502->4503/udp, got %d->%d/%s",
+				nat.Frontend, nat.Backend, nat.Proto)
+		}
+	} else {
+		t.Fatal(err)
+	}
+
+	if nat, err := parseNat(":4503/udp"); err == nil {
+		if nat.Frontend != 4503 || nat.Backend != 4503 || nat.Proto != "udp" {
+			t.Errorf("-p :4503/udp should produce 4503->4503/udp, got %d->%d/%s",
+				nat.Frontend, nat.Backend, nat.Proto)
+		}
+	} else {
+		t.Fatal(err)
+	}
+
+	if nat, err := parseNat(":4503/tcp"); err == nil {
+		if nat.Frontend != 4503 || nat.Backend != 4503 || nat.Proto != "tcp" {
+			t.Errorf("-p :4503/tcp should produce 4503->4503/tcp, got %d->%d/%s",
+				nat.Frontend, nat.Backend, nat.Proto)
+		}
+	} else {
+		t.Fatal(err)
+	}
+
+	if nat, err := parseNat("4503/tcp"); err == nil {
+		if nat.Frontend != 0 || nat.Backend != 4503 || nat.Proto != "tcp" {
+			t.Errorf("-p 4503/tcp should produce 0->4503/tcp, got %d->%d/%s",
+				nat.Frontend, nat.Backend, nat.Proto)
+		}
+	} else {
+		t.Fatal(err)
+	}
+
+	if nat, err := parseNat("4503/udp"); err == nil {
+		if nat.Frontend != 0 || nat.Backend != 4503 || nat.Proto != "udp" {
+			t.Errorf("-p 4503/udp should produce 0->4503/udp, got %d->%d/%s",
+				nat.Frontend, nat.Backend, nat.Proto)
+		}
+	} else {
+		t.Fatal(err)
+	}
+
+	if _, err := parseNat("4503/tcpgarbage"); err == nil {
+		t.Fatal(err)
+	}
+
+	if _, err := parseNat("4503/tcp/udp"); err == nil {
+		t.Fatal(err)
+	}
+
+	if _, err := parseNat("4503/"); err == nil {
+		t.Fatal(err)
+	}
 }
 
 func TestPortAllocation(t *testing.T) {

+ 88 - 48
runtime_test.go

@@ -1,6 +1,7 @@
 package docker
 
 import (
+	"bytes"
 	"fmt"
 	"github.com/dotcloud/docker/utils"
 	"io"
@@ -17,12 +18,12 @@ import (
 )
 
 const (
-	unitTestImageName     = "docker-unit-tests"
-	unitTestImageID       = "e9aa60c60128cad1"
-	unitTestNetworkBridge = "testdockbr0"
-	unitTestStoreBase     = "/var/lib/docker/unit-tests"
-	testDaemonAddr        = "127.0.0.1:4270"
-	testDaemonProto       = "tcp"
+	unitTestImageName	= "docker-test-image"
+	unitTestImageID		= "83599e29c455eb719f77d799bc7c51521b9551972f5a850d7ad265bc1b5292f6" // 1.0
+	unitTestNetworkBridge	= "testdockbr0"
+	unitTestStoreBase	= "/var/lib/docker/unit-tests"
+	testDaemonAddr		= "127.0.0.1:4270"
+	testDaemonProto		= "tcp"
 )
 
 var globalRuntime *Runtime
@@ -321,52 +322,47 @@ func TestGet(t *testing.T) {
 
 }
 
-func findAvailalblePort(runtime *Runtime, port int) (*Container, error) {
-	strPort := strconv.Itoa(port)
-	container, err := NewBuilder(runtime).Create(&Config{
-		Image:     GetTestImage(runtime).ID,
-		Cmd:       []string{"sh", "-c", "echo well hello there | nc -l -p " + strPort},
-		PortSpecs: []string{strPort},
-	},
-	)
-	if err != nil {
-		return nil, err
-	}
-	hostConfig := &HostConfig{}
-	if err := container.Start(hostConfig); err != nil {
-		if strings.Contains(err.Error(), "address already in use") {
-			return nil, nil
-		}
-		return nil, err
-	}
-	return container, nil
-}
-
-// Run a container with a TCP port allocated, and test that it can receive connections on localhost
-func TestAllocatePortLocalhost(t *testing.T) {
+func startEchoServerContainer(t *testing.T, proto string) (*Runtime, *Container, string) {
 	runtime, err := newTestRuntime()
 	if err != nil {
 		t.Fatal(err)
 	}
-	port := 5554
 
+	port := 5554
 	var container *Container
+	var strPort string
 	for {
 		port += 1
-		log.Println("Trying port", port)
-		t.Log("Trying port", port)
-		container, err = findAvailalblePort(runtime, port)
+		strPort = strconv.Itoa(port)
+		var cmd string
+		if proto == "tcp" {
+			cmd = "socat TCP-LISTEN:" + strPort + ",reuseaddr,fork EXEC:/bin/cat"
+		} else if proto == "udp" {
+			cmd = "socat UDP-RECVFROM:" + strPort + ",fork EXEC:/bin/cat"
+		} else {
+			t.Fatal(fmt.Errorf("Unknown protocol %v", proto))
+		}
+		t.Log("Trying port", strPort)
+		container, err = NewBuilder(runtime).Create(&Config{
+			Image:     GetTestImage(runtime).ID,
+			Cmd:       []string{"sh", "-c", cmd},
+			PortSpecs: []string{fmt.Sprintf("%s/%s", strPort, proto)},
+		})
 		if container != nil {
 			break
 		}
 		if err != nil {
+			nuke(runtime)
 			t.Fatal(err)
 		}
-		log.Println("Port", port, "already in use")
-		t.Log("Port", port, "already in use")
+		t.Logf("Port %v already in use", strPort)
 	}
 
-	defer container.Kill()
+	hostConfig := &HostConfig{}
+	if err := container.Start(hostConfig); err != nil {
+		nuke(runtime)
+		t.Fatal(err)
+	}
 
 	setTimeout(t, "Waiting for the container to be started timed out", 2*time.Second, func() {
 		for !container.State.Running {
@@ -377,26 +373,70 @@ func TestAllocatePortLocalhost(t *testing.T) {
 	// Even if the state is running, lets give some time to lxc to spawn the process
 	container.WaitTimeout(500 * time.Millisecond)
 
-	conn, err := net.Dial("tcp",
-		fmt.Sprintf(
-			"localhost:%s", container.NetworkSettings.PortMapping[strconv.Itoa(port)],
-		),
-	)
+	strPort = container.NetworkSettings.PortMapping[strings.Title(proto)][strPort]
+	return runtime, container, strPort
+}
+
+// Run a container with a TCP port allocated, and test that it can receive connections on localhost
+func TestAllocateTCPPortLocalhost(t *testing.T) {
+	runtime, container, port := startEchoServerContainer(t, "tcp")
+	defer nuke(runtime)
+	defer container.Kill()
+
+	conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%v", port))
 	if err != nil {
 		t.Fatal(err)
 	}
 	defer conn.Close()
-	output, err := ioutil.ReadAll(conn)
+
+	input := bytes.NewBufferString("well hello there\n")
+	_, err = conn.Write(input.Bytes())
 	if err != nil {
 		t.Fatal(err)
 	}
-	if string(output) != "well hello there\n" {
-		t.Fatalf("Received wrong output from network connection: should be '%s', not '%s'",
-			"well hello there\n",
-			string(output),
-		)
+	buf := make([]byte, 16)
+	read := 0
+	conn.SetReadDeadline(time.Now().Add(2 * time.Second))
+	read, err = conn.Read(buf)
+	if err != nil {
+		t.Fatal(err)
 	}
-	container.Wait()
+	output := string(buf[:read])
+	if !strings.Contains(output, "well hello there") {
+		t.Fatal(fmt.Errorf("[%v] doesn't contain [well hello there]", output))
+	}
+}
+
+// Run a container with a TCP port allocated, and test that it can receive connections on localhost
+func TestAllocateUDPPortLocalhost(t *testing.T) {
+	runtime, container, port := startEchoServerContainer(t, "udp")
+	defer nuke(runtime)
+	defer container.Kill()
+
+	conn, err := net.Dial("udp", fmt.Sprintf("localhost:%v", port))
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer conn.Close()
+
+	input := bytes.NewBufferString("well hello there\n")
+	buf := make([]byte, 16)
+	for i := 0; i != 10; i++ {
+		_, err := conn.Write(input.Bytes())
+		if err != nil {
+			t.Fatal(err)
+		}
+		conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
+		read, err := conn.Read(buf)
+		if err == nil {
+			output := string(buf[:read])
+			if strings.Contains(output, "well hello there") {
+				return
+			}
+		}
+	}
+
+	t.Fatal("No reply from the container")
 }
 
 func TestRestore(t *testing.T) {