Jelajahi Sumber

Refactor global portallocator and portmapper state

Continuation of: #11660, working on issue #11626.

Wrapped portmapper global state into a struct. Now portallocator and
portmapper have no global state (except configuration, and a default
instance).

Unfortunately, removing the global default instances will break
```api/server/server.go:1539```, and ```daemon/daemon.go:832```, which
both call the global portallocator directly. Fixing that would be a much
bigger change, so for now, have postponed that.

Signed-off-by: Paul Bellamy <paul.a.bellamy@gmail.com>
Paul Bellamy 10 tahun lalu
induk
melakukan
87df5ab41b

+ 9 - 17
daemon/networkdriver/portallocator/portallocator.go

@@ -50,8 +50,12 @@ var (
 )
 )
 
 
 var (
 var (
-	defaultIP            = net.ParseIP("0.0.0.0")
-	defaultPortAllocator = New()
+	defaultIP = net.ParseIP("0.0.0.0")
+
+	DefaultPortAllocator = New()
+	RequestPort          = DefaultPortAllocator.RequestPort
+	ReleasePort          = DefaultPortAllocator.ReleasePort
+	ReleaseAll           = DefaultPortAllocator.ReleaseAll
 )
 )
 
 
 type PortAllocator struct {
 type PortAllocator struct {
@@ -119,6 +123,9 @@ func (e ErrPortAlreadyAllocated) Error() string {
 	return fmt.Sprintf("Bind for %s:%d failed: port is already allocated", e.ip, e.port)
 	return fmt.Sprintf("Bind for %s:%d failed: port is already allocated", e.ip, e.port)
 }
 }
 
 
+// RequestPort requests new port from global ports pool for specified ip and proto.
+// If port is 0 it returns first free port. Otherwise it cheks port availability
+// in pool and return that port or error if port is already busy.
 func (p *PortAllocator) RequestPort(ip net.IP, proto string, port int) (int, error) {
 func (p *PortAllocator) RequestPort(ip net.IP, proto string, port int) (int, error) {
 	p.mutex.Lock()
 	p.mutex.Lock()
 	defer p.mutex.Unlock()
 	defer p.mutex.Unlock()
@@ -152,13 +159,6 @@ func (p *PortAllocator) RequestPort(ip net.IP, proto string, port int) (int, err
 	return port, nil
 	return port, nil
 }
 }
 
 
-// RequestPort requests new port from global ports pool for specified ip and proto.
-// If port is 0 it returns first free port. Otherwise it cheks port availability
-// in pool and return that port or error if port is already busy.
-func RequestPort(ip net.IP, proto string, port int) (int, error) {
-	return defaultPortAllocator.RequestPort(ip, proto, port)
-}
-
 // ReleasePort releases port from global ports pool for specified ip and proto.
 // ReleasePort releases port from global ports pool for specified ip and proto.
 func (p *PortAllocator) ReleasePort(ip net.IP, proto string, port int) error {
 func (p *PortAllocator) ReleasePort(ip net.IP, proto string, port int) error {
 	p.mutex.Lock()
 	p.mutex.Lock()
@@ -175,10 +175,6 @@ func (p *PortAllocator) ReleasePort(ip net.IP, proto string, port int) error {
 	return nil
 	return nil
 }
 }
 
 
-func ReleasePort(ip net.IP, proto string, port int) error {
-	return defaultPortAllocator.ReleasePort(ip, proto, port)
-}
-
 // ReleaseAll releases all ports for all ips.
 // ReleaseAll releases all ports for all ips.
 func (p *PortAllocator) ReleaseAll() error {
 func (p *PortAllocator) ReleaseAll() error {
 	p.mutex.Lock()
 	p.mutex.Lock()
@@ -187,10 +183,6 @@ func (p *PortAllocator) ReleaseAll() error {
 	return nil
 	return nil
 }
 }
 
 
-func ReleaseAll() error {
-	return defaultPortAllocator.ReleaseAll()
-}
-
 func (pm *portMap) findPort() (int, error) {
 func (pm *portMap) findPort() (int, error) {
 	port := pm.last
 	port := pm.last
 	for i := 0; i <= endPortRange-beginPortRange; i++ {
 	for i := 0; i <= endPortRange-beginPortRange; i++ {

+ 50 - 30
daemon/networkdriver/portmapper/mapper.go

@@ -19,13 +19,12 @@ type mapping struct {
 }
 }
 
 
 var (
 var (
-	chain *iptables.Chain
-	lock  sync.Mutex
-
-	// udp:ip:port
-	currentMappings = make(map[string]*mapping)
-
 	NewProxy = NewProxyCommand
 	NewProxy = NewProxyCommand
+
+	DefaultPortMapper = NewWithPortAllocator(portallocator.DefaultPortAllocator)
+	SetIptablesChain  = DefaultPortMapper.SetIptablesChain
+	Map               = DefaultPortMapper.Map
+	Unmap             = DefaultPortMapper.Unmap
 )
 )
 
 
 var (
 var (
@@ -34,13 +33,34 @@ var (
 	ErrPortNotMapped             = errors.New("port is not mapped")
 	ErrPortNotMapped             = errors.New("port is not mapped")
 )
 )
 
 
-func SetIptablesChain(c *iptables.Chain) {
-	chain = c
+type PortMapper struct {
+	chain *iptables.Chain
+
+	// udp:ip:port
+	currentMappings map[string]*mapping
+	lock            sync.Mutex
+
+	allocator *portallocator.PortAllocator
+}
+
+func New() *PortMapper {
+	return NewWithPortAllocator(portallocator.New())
+}
+
+func NewWithPortAllocator(allocator *portallocator.PortAllocator) *PortMapper {
+	return &PortMapper{
+		currentMappings: make(map[string]*mapping),
+		allocator:       allocator,
+	}
+}
+
+func (pm *PortMapper) SetIptablesChain(c *iptables.Chain) {
+	pm.chain = c
 }
 }
 
 
-func Map(container net.Addr, hostIP net.IP, hostPort int) (host net.Addr, err error) {
-	lock.Lock()
-	defer lock.Unlock()
+func (pm *PortMapper) Map(container net.Addr, hostIP net.IP, hostPort int) (host net.Addr, err error) {
+	pm.lock.Lock()
+	defer pm.lock.Unlock()
 
 
 	var (
 	var (
 		m                 *mapping
 		m                 *mapping
@@ -52,7 +72,7 @@ func Map(container net.Addr, hostIP net.IP, hostPort int) (host net.Addr, err er
 	switch container.(type) {
 	switch container.(type) {
 	case *net.TCPAddr:
 	case *net.TCPAddr:
 		proto = "tcp"
 		proto = "tcp"
-		if allocatedHostPort, err = portallocator.RequestPort(hostIP, proto, hostPort); err != nil {
+		if allocatedHostPort, err = pm.allocator.RequestPort(hostIP, proto, hostPort); err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
 
 
@@ -65,7 +85,7 @@ func Map(container net.Addr, hostIP net.IP, hostPort int) (host net.Addr, err er
 		proxy = NewProxy(proto, hostIP, allocatedHostPort, container.(*net.TCPAddr).IP, container.(*net.TCPAddr).Port)
 		proxy = NewProxy(proto, hostIP, allocatedHostPort, container.(*net.TCPAddr).IP, container.(*net.TCPAddr).Port)
 	case *net.UDPAddr:
 	case *net.UDPAddr:
 		proto = "udp"
 		proto = "udp"
-		if allocatedHostPort, err = portallocator.RequestPort(hostIP, proto, hostPort); err != nil {
+		if allocatedHostPort, err = pm.allocator.RequestPort(hostIP, proto, hostPort); err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
 
 
@@ -83,25 +103,25 @@ func Map(container net.Addr, hostIP net.IP, hostPort int) (host net.Addr, err er
 	// release the allocated port on any further error during return.
 	// release the allocated port on any further error during return.
 	defer func() {
 	defer func() {
 		if err != nil {
 		if err != nil {
-			portallocator.ReleasePort(hostIP, proto, allocatedHostPort)
+			pm.allocator.ReleasePort(hostIP, proto, allocatedHostPort)
 		}
 		}
 	}()
 	}()
 
 
 	key := getKey(m.host)
 	key := getKey(m.host)
-	if _, exists := currentMappings[key]; exists {
+	if _, exists := pm.currentMappings[key]; exists {
 		return nil, ErrPortMappedForIP
 		return nil, ErrPortMappedForIP
 	}
 	}
 
 
 	containerIP, containerPort := getIPAndPort(m.container)
 	containerIP, containerPort := getIPAndPort(m.container)
-	if err := forward(iptables.Append, m.proto, hostIP, allocatedHostPort, containerIP.String(), containerPort); err != nil {
+	if err := pm.forward(iptables.Append, m.proto, hostIP, allocatedHostPort, containerIP.String(), containerPort); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
 	cleanup := func() error {
 	cleanup := func() error {
 		// need to undo the iptables rules before we return
 		// need to undo the iptables rules before we return
 		proxy.Stop()
 		proxy.Stop()
-		forward(iptables.Delete, m.proto, hostIP, allocatedHostPort, containerIP.String(), containerPort)
-		if err := portallocator.ReleasePort(hostIP, m.proto, allocatedHostPort); err != nil {
+		pm.forward(iptables.Delete, m.proto, hostIP, allocatedHostPort, containerIP.String(), containerPort)
+		if err := pm.allocator.ReleasePort(hostIP, m.proto, allocatedHostPort); err != nil {
 			return err
 			return err
 		}
 		}
 
 
@@ -115,35 +135,35 @@ func Map(container net.Addr, hostIP net.IP, hostPort int) (host net.Addr, err er
 		return nil, err
 		return nil, err
 	}
 	}
 	m.userlandProxy = proxy
 	m.userlandProxy = proxy
-	currentMappings[key] = m
+	pm.currentMappings[key] = m
 	return m.host, nil
 	return m.host, nil
 }
 }
 
 
-func Unmap(host net.Addr) error {
-	lock.Lock()
-	defer lock.Unlock()
+func (pm *PortMapper) Unmap(host net.Addr) error {
+	pm.lock.Lock()
+	defer pm.lock.Unlock()
 
 
 	key := getKey(host)
 	key := getKey(host)
-	data, exists := currentMappings[key]
+	data, exists := pm.currentMappings[key]
 	if !exists {
 	if !exists {
 		return ErrPortNotMapped
 		return ErrPortNotMapped
 	}
 	}
 
 
 	data.userlandProxy.Stop()
 	data.userlandProxy.Stop()
 
 
-	delete(currentMappings, key)
+	delete(pm.currentMappings, key)
 
 
 	containerIP, containerPort := getIPAndPort(data.container)
 	containerIP, containerPort := getIPAndPort(data.container)
 	hostIP, hostPort := getIPAndPort(data.host)
 	hostIP, hostPort := getIPAndPort(data.host)
-	if err := forward(iptables.Delete, data.proto, hostIP, hostPort, containerIP.String(), containerPort); err != nil {
+	if err := pm.forward(iptables.Delete, data.proto, hostIP, hostPort, containerIP.String(), containerPort); err != nil {
 		log.Errorf("Error on iptables delete: %s", err)
 		log.Errorf("Error on iptables delete: %s", err)
 	}
 	}
 
 
 	switch a := host.(type) {
 	switch a := host.(type) {
 	case *net.TCPAddr:
 	case *net.TCPAddr:
-		return portallocator.ReleasePort(a.IP, "tcp", a.Port)
+		return pm.allocator.ReleasePort(a.IP, "tcp", a.Port)
 	case *net.UDPAddr:
 	case *net.UDPAddr:
-		return portallocator.ReleasePort(a.IP, "udp", a.Port)
+		return pm.allocator.ReleasePort(a.IP, "udp", a.Port)
 	}
 	}
 	return nil
 	return nil
 }
 }
@@ -168,9 +188,9 @@ func getIPAndPort(a net.Addr) (net.IP, int) {
 	return nil, 0
 	return nil, 0
 }
 }
 
 
-func forward(action iptables.Action, proto string, sourceIP net.IP, sourcePort int, containerIP string, containerPort int) error {
-	if chain == nil {
+func (pm *PortMapper) forward(action iptables.Action, proto string, sourceIP net.IP, sourcePort int, containerIP string, containerPort int) error {
+	if pm.chain == nil {
 		return nil
 		return nil
 	}
 	}
-	return chain.Forward(action, sourceIP, sourcePort, proto, containerIP, containerPort)
+	return pm.chain.Forward(action, sourceIP, sourcePort, proto, containerIP, containerPort)
 }
 }

+ 17 - 20
daemon/networkdriver/portmapper/mapper_test.go

@@ -13,30 +13,26 @@ func init() {
 	NewProxy = NewMockProxyCommand
 	NewProxy = NewMockProxyCommand
 }
 }
 
 
-func reset() {
-	chain = nil
-	currentMappings = make(map[string]*mapping)
-}
-
 func TestSetIptablesChain(t *testing.T) {
 func TestSetIptablesChain(t *testing.T) {
-	defer reset()
+	pm := New()
 
 
 	c := &iptables.Chain{
 	c := &iptables.Chain{
 		Name:   "TEST",
 		Name:   "TEST",
 		Bridge: "192.168.1.1",
 		Bridge: "192.168.1.1",
 	}
 	}
 
 
-	if chain != nil {
+	if pm.chain != nil {
 		t.Fatal("chain should be nil at init")
 		t.Fatal("chain should be nil at init")
 	}
 	}
 
 
-	SetIptablesChain(c)
-	if chain == nil {
+	pm.SetIptablesChain(c)
+	if pm.chain == nil {
 		t.Fatal("chain should not be nil after set")
 		t.Fatal("chain should not be nil after set")
 	}
 	}
 }
 }
 
 
 func TestMapPorts(t *testing.T) {
 func TestMapPorts(t *testing.T) {
+	pm := New()
 	dstIp1 := net.ParseIP("192.168.0.1")
 	dstIp1 := net.ParseIP("192.168.0.1")
 	dstIp2 := net.ParseIP("192.168.0.2")
 	dstIp2 := net.ParseIP("192.168.0.2")
 	dstAddr1 := &net.TCPAddr{IP: dstIp1, Port: 80}
 	dstAddr1 := &net.TCPAddr{IP: dstIp1, Port: 80}
@@ -49,34 +45,34 @@ func TestMapPorts(t *testing.T) {
 		return (addr1.Network() == addr2.Network()) && (addr1.String() == addr2.String())
 		return (addr1.Network() == addr2.Network()) && (addr1.String() == addr2.String())
 	}
 	}
 
 
-	if host, err := Map(srcAddr1, dstIp1, 80); err != nil {
+	if host, err := pm.Map(srcAddr1, dstIp1, 80); err != nil {
 		t.Fatalf("Failed to allocate port: %s", err)
 		t.Fatalf("Failed to allocate port: %s", err)
 	} else if !addrEqual(dstAddr1, host) {
 	} else if !addrEqual(dstAddr1, host) {
 		t.Fatalf("Incorrect mapping result: expected %s:%s, got %s:%s",
 		t.Fatalf("Incorrect mapping result: expected %s:%s, got %s:%s",
 			dstAddr1.String(), dstAddr1.Network(), host.String(), host.Network())
 			dstAddr1.String(), dstAddr1.Network(), host.String(), host.Network())
 	}
 	}
 
 
-	if _, err := Map(srcAddr1, dstIp1, 80); err == nil {
+	if _, err := pm.Map(srcAddr1, dstIp1, 80); err == nil {
 		t.Fatalf("Port is in use - mapping should have failed")
 		t.Fatalf("Port is in use - mapping should have failed")
 	}
 	}
 
 
-	if _, err := Map(srcAddr2, dstIp1, 80); err == nil {
+	if _, err := pm.Map(srcAddr2, dstIp1, 80); err == nil {
 		t.Fatalf("Port is in use - mapping should have failed")
 		t.Fatalf("Port is in use - mapping should have failed")
 	}
 	}
 
 
-	if _, err := Map(srcAddr2, dstIp2, 80); err != nil {
+	if _, err := pm.Map(srcAddr2, dstIp2, 80); err != nil {
 		t.Fatalf("Failed to allocate port: %s", err)
 		t.Fatalf("Failed to allocate port: %s", err)
 	}
 	}
 
 
-	if Unmap(dstAddr1) != nil {
+	if pm.Unmap(dstAddr1) != nil {
 		t.Fatalf("Failed to release port")
 		t.Fatalf("Failed to release port")
 	}
 	}
 
 
-	if Unmap(dstAddr2) != nil {
+	if pm.Unmap(dstAddr2) != nil {
 		t.Fatalf("Failed to release port")
 		t.Fatalf("Failed to release port")
 	}
 	}
 
 
-	if Unmap(dstAddr2) == nil {
+	if pm.Unmap(dstAddr2) == nil {
 		t.Fatalf("Port already released, but no error reported")
 		t.Fatalf("Port already released, but no error reported")
 	}
 	}
 }
 }
@@ -115,6 +111,7 @@ func TestGetUDPIPAndPort(t *testing.T) {
 }
 }
 
 
 func TestMapAllPortsSingleInterface(t *testing.T) {
 func TestMapAllPortsSingleInterface(t *testing.T) {
+	pm := New()
 	dstIp1 := net.ParseIP("0.0.0.0")
 	dstIp1 := net.ParseIP("0.0.0.0")
 	srcAddr1 := &net.TCPAddr{Port: 1080, IP: net.ParseIP("172.16.0.1")}
 	srcAddr1 := &net.TCPAddr{Port: 1080, IP: net.ParseIP("172.16.0.1")}
 
 
@@ -124,26 +121,26 @@ func TestMapAllPortsSingleInterface(t *testing.T) {
 
 
 	defer func() {
 	defer func() {
 		for _, val := range hosts {
 		for _, val := range hosts {
-			Unmap(val)
+			pm.Unmap(val)
 		}
 		}
 	}()
 	}()
 
 
 	for i := 0; i < 10; i++ {
 	for i := 0; i < 10; i++ {
 		start, end := portallocator.PortRange()
 		start, end := portallocator.PortRange()
 		for i := start; i < end; i++ {
 		for i := start; i < end; i++ {
-			if host, err = Map(srcAddr1, dstIp1, 0); err != nil {
+			if host, err = pm.Map(srcAddr1, dstIp1, 0); err != nil {
 				t.Fatal(err)
 				t.Fatal(err)
 			}
 			}
 
 
 			hosts = append(hosts, host)
 			hosts = append(hosts, host)
 		}
 		}
 
 
-		if _, err := Map(srcAddr1, dstIp1, start); err == nil {
+		if _, err := pm.Map(srcAddr1, dstIp1, start); err == nil {
 			t.Fatalf("Port %d should be bound but is not", start)
 			t.Fatalf("Port %d should be bound but is not", start)
 		}
 		}
 
 
 		for _, val := range hosts {
 		for _, val := range hosts {
-			if err := Unmap(val); err != nil {
+			if err := pm.Unmap(val); err != nil {
 				t.Fatal(err)
 				t.Fatal(err)
 			}
 			}
 		}
 		}