diff --git a/daemon/networkdriver/portallocator/portallocator.go b/daemon/networkdriver/portallocator/portallocator.go index a7c183d9da..01533419bd 100644 --- a/daemon/networkdriver/portallocator/portallocator.go +++ b/daemon/networkdriver/portallocator/portallocator.go @@ -50,8 +50,12 @@ var ( ) var ( - defaultIP = net.ParseIP("0.0.0.0") - defaultPortAllocator = New() + defaultIP = net.ParseIP("0.0.0.0") + + DefaultPortAllocator = New() + RequestPort = DefaultPortAllocator.RequestPort + ReleasePort = DefaultPortAllocator.ReleasePort + ReleaseAll = DefaultPortAllocator.ReleaseAll ) type PortAllocator struct { @@ -119,6 +123,9 @@ func (e ErrPortAlreadyAllocated) Error() string { return fmt.Sprintf("Bind for %s:%d failed: port is already allocated", e.ip, e.port) } +// RequestPort requests new port from global ports pool for specified ip and proto. +// If port is 0 it returns first free port. Otherwise it cheks port availability +// in pool and return that port or error if port is already busy. func (p *PortAllocator) RequestPort(ip net.IP, proto string, port int) (int, error) { p.mutex.Lock() defer p.mutex.Unlock() @@ -152,13 +159,6 @@ func (p *PortAllocator) RequestPort(ip net.IP, proto string, port int) (int, err return port, nil } -// RequestPort requests new port from global ports pool for specified ip and proto. -// If port is 0 it returns first free port. Otherwise it cheks port availability -// in pool and return that port or error if port is already busy. -func RequestPort(ip net.IP, proto string, port int) (int, error) { - return defaultPortAllocator.RequestPort(ip, proto, port) -} - // ReleasePort releases port from global ports pool for specified ip and proto. func (p *PortAllocator) ReleasePort(ip net.IP, proto string, port int) error { p.mutex.Lock() @@ -175,10 +175,6 @@ func (p *PortAllocator) ReleasePort(ip net.IP, proto string, port int) error { return nil } -func ReleasePort(ip net.IP, proto string, port int) error { - return defaultPortAllocator.ReleasePort(ip, proto, port) -} - // ReleaseAll releases all ports for all ips. func (p *PortAllocator) ReleaseAll() error { p.mutex.Lock() @@ -187,10 +183,6 @@ func (p *PortAllocator) ReleaseAll() error { return nil } -func ReleaseAll() error { - return defaultPortAllocator.ReleaseAll() -} - func (pm *portMap) findPort() (int, error) { port := pm.last for i := 0; i <= endPortRange-beginPortRange; i++ { diff --git a/daemon/networkdriver/portmapper/mapper.go b/daemon/networkdriver/portmapper/mapper.go index 9f2ca5a754..74b329e2f7 100644 --- a/daemon/networkdriver/portmapper/mapper.go +++ b/daemon/networkdriver/portmapper/mapper.go @@ -19,13 +19,12 @@ type mapping struct { } var ( - chain *iptables.Chain - lock sync.Mutex - - // udp:ip:port - currentMappings = make(map[string]*mapping) - NewProxy = NewProxyCommand + + DefaultPortMapper = NewWithPortAllocator(portallocator.DefaultPortAllocator) + SetIptablesChain = DefaultPortMapper.SetIptablesChain + Map = DefaultPortMapper.Map + Unmap = DefaultPortMapper.Unmap ) var ( @@ -34,13 +33,34 @@ var ( ErrPortNotMapped = errors.New("port is not mapped") ) -func SetIptablesChain(c *iptables.Chain) { - chain = c +type PortMapper struct { + chain *iptables.Chain + + // udp:ip:port + currentMappings map[string]*mapping + lock sync.Mutex + + allocator *portallocator.PortAllocator } -func Map(container net.Addr, hostIP net.IP, hostPort int) (host net.Addr, err error) { - lock.Lock() - defer lock.Unlock() +func New() *PortMapper { + return NewWithPortAllocator(portallocator.New()) +} + +func NewWithPortAllocator(allocator *portallocator.PortAllocator) *PortMapper { + return &PortMapper{ + currentMappings: make(map[string]*mapping), + allocator: allocator, + } +} + +func (pm *PortMapper) SetIptablesChain(c *iptables.Chain) { + pm.chain = c +} + +func (pm *PortMapper) Map(container net.Addr, hostIP net.IP, hostPort int) (host net.Addr, err error) { + pm.lock.Lock() + defer pm.lock.Unlock() var ( m *mapping @@ -52,7 +72,7 @@ func Map(container net.Addr, hostIP net.IP, hostPort int) (host net.Addr, err er switch container.(type) { case *net.TCPAddr: proto = "tcp" - if allocatedHostPort, err = portallocator.RequestPort(hostIP, proto, hostPort); err != nil { + if allocatedHostPort, err = pm.allocator.RequestPort(hostIP, proto, hostPort); err != nil { return nil, err } @@ -65,7 +85,7 @@ func Map(container net.Addr, hostIP net.IP, hostPort int) (host net.Addr, err er proxy = NewProxy(proto, hostIP, allocatedHostPort, container.(*net.TCPAddr).IP, container.(*net.TCPAddr).Port) case *net.UDPAddr: proto = "udp" - if allocatedHostPort, err = portallocator.RequestPort(hostIP, proto, hostPort); err != nil { + if allocatedHostPort, err = pm.allocator.RequestPort(hostIP, proto, hostPort); err != nil { return nil, err } @@ -83,25 +103,25 @@ func Map(container net.Addr, hostIP net.IP, hostPort int) (host net.Addr, err er // release the allocated port on any further error during return. defer func() { if err != nil { - portallocator.ReleasePort(hostIP, proto, allocatedHostPort) + pm.allocator.ReleasePort(hostIP, proto, allocatedHostPort) } }() key := getKey(m.host) - if _, exists := currentMappings[key]; exists { + if _, exists := pm.currentMappings[key]; exists { return nil, ErrPortMappedForIP } containerIP, containerPort := getIPAndPort(m.container) - if err := forward(iptables.Append, m.proto, hostIP, allocatedHostPort, containerIP.String(), containerPort); err != nil { + if err := pm.forward(iptables.Append, m.proto, hostIP, allocatedHostPort, containerIP.String(), containerPort); err != nil { return nil, err } cleanup := func() error { // need to undo the iptables rules before we return proxy.Stop() - forward(iptables.Delete, m.proto, hostIP, allocatedHostPort, containerIP.String(), containerPort) - if err := portallocator.ReleasePort(hostIP, m.proto, allocatedHostPort); err != nil { + pm.forward(iptables.Delete, m.proto, hostIP, allocatedHostPort, containerIP.String(), containerPort) + if err := pm.allocator.ReleasePort(hostIP, m.proto, allocatedHostPort); err != nil { return err } @@ -115,35 +135,35 @@ func Map(container net.Addr, hostIP net.IP, hostPort int) (host net.Addr, err er return nil, err } m.userlandProxy = proxy - currentMappings[key] = m + pm.currentMappings[key] = m return m.host, nil } -func Unmap(host net.Addr) error { - lock.Lock() - defer lock.Unlock() +func (pm *PortMapper) Unmap(host net.Addr) error { + pm.lock.Lock() + defer pm.lock.Unlock() key := getKey(host) - data, exists := currentMappings[key] + data, exists := pm.currentMappings[key] if !exists { return ErrPortNotMapped } data.userlandProxy.Stop() - delete(currentMappings, key) + delete(pm.currentMappings, key) containerIP, containerPort := getIPAndPort(data.container) hostIP, hostPort := getIPAndPort(data.host) - if err := forward(iptables.Delete, data.proto, hostIP, hostPort, containerIP.String(), containerPort); err != nil { + if err := pm.forward(iptables.Delete, data.proto, hostIP, hostPort, containerIP.String(), containerPort); err != nil { log.Errorf("Error on iptables delete: %s", err) } switch a := host.(type) { case *net.TCPAddr: - return portallocator.ReleasePort(a.IP, "tcp", a.Port) + return pm.allocator.ReleasePort(a.IP, "tcp", a.Port) case *net.UDPAddr: - return portallocator.ReleasePort(a.IP, "udp", a.Port) + return pm.allocator.ReleasePort(a.IP, "udp", a.Port) } return nil } @@ -168,9 +188,9 @@ func getIPAndPort(a net.Addr) (net.IP, int) { return nil, 0 } -func forward(action iptables.Action, proto string, sourceIP net.IP, sourcePort int, containerIP string, containerPort int) error { - if chain == nil { +func (pm *PortMapper) forward(action iptables.Action, proto string, sourceIP net.IP, sourcePort int, containerIP string, containerPort int) error { + if pm.chain == nil { return nil } - return chain.Forward(action, sourceIP, sourcePort, proto, containerIP, containerPort) + return pm.chain.Forward(action, sourceIP, sourcePort, proto, containerIP, containerPort) } diff --git a/daemon/networkdriver/portmapper/mapper_test.go b/daemon/networkdriver/portmapper/mapper_test.go index fa7bdecdbf..4082a6002b 100644 --- a/daemon/networkdriver/portmapper/mapper_test.go +++ b/daemon/networkdriver/portmapper/mapper_test.go @@ -13,30 +13,26 @@ func init() { NewProxy = NewMockProxyCommand } -func reset() { - chain = nil - currentMappings = make(map[string]*mapping) -} - func TestSetIptablesChain(t *testing.T) { - defer reset() + pm := New() c := &iptables.Chain{ Name: "TEST", Bridge: "192.168.1.1", } - if chain != nil { + if pm.chain != nil { t.Fatal("chain should be nil at init") } - SetIptablesChain(c) - if chain == nil { + pm.SetIptablesChain(c) + if pm.chain == nil { t.Fatal("chain should not be nil after set") } } func TestMapPorts(t *testing.T) { + pm := New() dstIp1 := net.ParseIP("192.168.0.1") dstIp2 := net.ParseIP("192.168.0.2") dstAddr1 := &net.TCPAddr{IP: dstIp1, Port: 80} @@ -49,34 +45,34 @@ func TestMapPorts(t *testing.T) { return (addr1.Network() == addr2.Network()) && (addr1.String() == addr2.String()) } - if host, err := Map(srcAddr1, dstIp1, 80); err != nil { + if host, err := pm.Map(srcAddr1, dstIp1, 80); err != nil { t.Fatalf("Failed to allocate port: %s", err) } else if !addrEqual(dstAddr1, host) { t.Fatalf("Incorrect mapping result: expected %s:%s, got %s:%s", dstAddr1.String(), dstAddr1.Network(), host.String(), host.Network()) } - if _, err := Map(srcAddr1, dstIp1, 80); err == nil { + if _, err := pm.Map(srcAddr1, dstIp1, 80); err == nil { t.Fatalf("Port is in use - mapping should have failed") } - if _, err := Map(srcAddr2, dstIp1, 80); err == nil { + if _, err := pm.Map(srcAddr2, dstIp1, 80); err == nil { t.Fatalf("Port is in use - mapping should have failed") } - if _, err := Map(srcAddr2, dstIp2, 80); err != nil { + if _, err := pm.Map(srcAddr2, dstIp2, 80); err != nil { t.Fatalf("Failed to allocate port: %s", err) } - if Unmap(dstAddr1) != nil { + if pm.Unmap(dstAddr1) != nil { t.Fatalf("Failed to release port") } - if Unmap(dstAddr2) != nil { + if pm.Unmap(dstAddr2) != nil { t.Fatalf("Failed to release port") } - if Unmap(dstAddr2) == nil { + if pm.Unmap(dstAddr2) == nil { t.Fatalf("Port already released, but no error reported") } } @@ -115,6 +111,7 @@ func TestGetUDPIPAndPort(t *testing.T) { } func TestMapAllPortsSingleInterface(t *testing.T) { + pm := New() dstIp1 := net.ParseIP("0.0.0.0") srcAddr1 := &net.TCPAddr{Port: 1080, IP: net.ParseIP("172.16.0.1")} @@ -124,26 +121,26 @@ func TestMapAllPortsSingleInterface(t *testing.T) { defer func() { for _, val := range hosts { - Unmap(val) + pm.Unmap(val) } }() for i := 0; i < 10; i++ { start, end := portallocator.PortRange() for i := start; i < end; i++ { - if host, err = Map(srcAddr1, dstIp1, 0); err != nil { + if host, err = pm.Map(srcAddr1, dstIp1, 0); err != nil { t.Fatal(err) } hosts = append(hosts, host) } - if _, err := Map(srcAddr1, dstIp1, start); err == nil { + if _, err := pm.Map(srcAddr1, dstIp1, start); err == nil { t.Fatalf("Port %d should be bound but is not", start) } for _, val := range hosts { - if err := Unmap(val); err != nil { + if err := pm.Unmap(val); err != nil { t.Fatal(err) } }