Просмотр исходного кода

Merge pull request #11660 from paulbellamy/11626-portallocator

Refactoring portallocator away from a global var
Alexander Morozov 10 лет назад
Родитель
Сommit
d852e5fde8

+ 40 - 19
daemon/networkdriver/portallocator/portallocator.go

@@ -50,11 +50,20 @@ var (
 )
 
 var (
+	defaultIP            = net.ParseIP("0.0.0.0")
+	defaultPortAllocator = New()
+)
+
+type PortAllocator struct {
 	mutex sync.Mutex
+	ipMap ipMapping
+}
 
-	defaultIP = net.ParseIP("0.0.0.0")
-	globalMap = ipMapping{}
-)
+func New() *PortAllocator {
+	return &PortAllocator{
+		ipMap: ipMapping{},
+	}
+}
 
 type ErrPortAlreadyAllocated struct {
 	ip   string
@@ -109,12 +118,9 @@ func (e ErrPortAlreadyAllocated) Error() string {
 	return fmt.Sprintf("Bind for %s:%d failed: port is already allocated", e.ip, e.port)
 }
 
-// RequestPort requests new port from global ports pool for specified ip and proto.
-// If port is 0 it returns first free port. Otherwise it cheks port availability
-// in pool and return that port or error if port is already busy.
-func RequestPort(ip net.IP, proto string, port int) (int, error) {
-	mutex.Lock()
-	defer mutex.Unlock()
+func (p *PortAllocator) RequestPort(ip net.IP, proto string, port int) (int, error) {
+	p.mutex.Lock()
+	defer p.mutex.Unlock()
 
 	if proto != "tcp" && proto != "udp" {
 		return 0, ErrUnknownProtocol
@@ -124,10 +130,10 @@ func RequestPort(ip net.IP, proto string, port int) (int, error) {
 		ip = defaultIP
 	}
 	ipstr := ip.String()
-	protomap, ok := globalMap[ipstr]
+	protomap, ok := p.ipMap[ipstr]
 	if !ok {
 		protomap = newProtoMap()
-		globalMap[ipstr] = protomap
+		p.ipMap[ipstr] = protomap
 	}
 	mapping := protomap[proto]
 	if port > 0 {
@@ -145,15 +151,22 @@ func RequestPort(ip net.IP, proto string, port int) (int, error) {
 	return port, nil
 }
 
+// RequestPort requests new port from global ports pool for specified ip and proto.
+// If port is 0 it returns first free port. Otherwise it cheks port availability
+// in pool and return that port or error if port is already busy.
+func RequestPort(ip net.IP, proto string, port int) (int, error) {
+	return defaultPortAllocator.RequestPort(ip, proto, port)
+}
+
 // ReleasePort releases port from global ports pool for specified ip and proto.
-func ReleasePort(ip net.IP, proto string, port int) error {
-	mutex.Lock()
-	defer mutex.Unlock()
+func (p *PortAllocator) ReleasePort(ip net.IP, proto string, port int) error {
+	p.mutex.Lock()
+	defer p.mutex.Unlock()
 
 	if ip == nil {
 		ip = defaultIP
 	}
-	protomap, ok := globalMap[ip.String()]
+	protomap, ok := p.ipMap[ip.String()]
 	if !ok {
 		return nil
 	}
@@ -161,14 +174,22 @@ func ReleasePort(ip net.IP, proto string, port int) error {
 	return nil
 }
 
+func ReleasePort(ip net.IP, proto string, port int) error {
+	return defaultPortAllocator.ReleasePort(ip, proto, port)
+}
+
 // ReleaseAll releases all ports for all ips.
-func ReleaseAll() error {
-	mutex.Lock()
-	globalMap = ipMapping{}
-	mutex.Unlock()
+func (p *PortAllocator) ReleaseAll() error {
+	p.mutex.Lock()
+	p.ipMap = ipMapping{}
+	p.mutex.Unlock()
 	return nil
 }
 
+func ReleaseAll() error {
+	return defaultPortAllocator.ReleaseAll()
+}
+
 func (pm *portMap) findPort() (int, error) {
 	port := pm.last
 	for i := 0; i <= endPortRange-beginPortRange; i++ {

+ 42 - 48
daemon/networkdriver/portallocator/portallocator_test.go

@@ -10,14 +10,10 @@ func init() {
 	endPortRange = DefaultPortRangeEnd
 }
 
-func reset() {
-	ReleaseAll()
-}
-
 func TestRequestNewPort(t *testing.T) {
-	defer reset()
+	p := New()
 
-	port, err := RequestPort(defaultIP, "tcp", 0)
+	port, err := p.RequestPort(defaultIP, "tcp", 0)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -28,9 +24,9 @@ func TestRequestNewPort(t *testing.T) {
 }
 
 func TestRequestSpecificPort(t *testing.T) {
-	defer reset()
+	p := New()
 
-	port, err := RequestPort(defaultIP, "tcp", 5000)
+	port, err := p.RequestPort(defaultIP, "tcp", 5000)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -40,9 +36,9 @@ func TestRequestSpecificPort(t *testing.T) {
 }
 
 func TestReleasePort(t *testing.T) {
-	defer reset()
+	p := New()
 
-	port, err := RequestPort(defaultIP, "tcp", 5000)
+	port, err := p.RequestPort(defaultIP, "tcp", 5000)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -50,15 +46,15 @@ func TestReleasePort(t *testing.T) {
 		t.Fatalf("Expected port 5000 got %d", port)
 	}
 
-	if err := ReleasePort(defaultIP, "tcp", 5000); err != nil {
+	if err := p.ReleasePort(defaultIP, "tcp", 5000); err != nil {
 		t.Fatal(err)
 	}
 }
 
 func TestReuseReleasedPort(t *testing.T) {
-	defer reset()
+	p := New()
 
-	port, err := RequestPort(defaultIP, "tcp", 5000)
+	port, err := p.RequestPort(defaultIP, "tcp", 5000)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -66,20 +62,20 @@ func TestReuseReleasedPort(t *testing.T) {
 		t.Fatalf("Expected port 5000 got %d", port)
 	}
 
-	if err := ReleasePort(defaultIP, "tcp", 5000); err != nil {
+	if err := p.ReleasePort(defaultIP, "tcp", 5000); err != nil {
 		t.Fatal(err)
 	}
 
-	port, err = RequestPort(defaultIP, "tcp", 5000)
+	port, err = p.RequestPort(defaultIP, "tcp", 5000)
 	if err != nil {
 		t.Fatal(err)
 	}
 }
 
 func TestReleaseUnreadledPort(t *testing.T) {
-	defer reset()
+	p := New()
 
-	port, err := RequestPort(defaultIP, "tcp", 5000)
+	port, err := p.RequestPort(defaultIP, "tcp", 5000)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -87,7 +83,7 @@ func TestReleaseUnreadledPort(t *testing.T) {
 		t.Fatalf("Expected port 5000 got %d", port)
 	}
 
-	port, err = RequestPort(defaultIP, "tcp", 5000)
+	port, err = p.RequestPort(defaultIP, "tcp", 5000)
 
 	switch err.(type) {
 	case ErrPortAlreadyAllocated:
@@ -97,18 +93,16 @@ func TestReleaseUnreadledPort(t *testing.T) {
 }
 
 func TestUnknowProtocol(t *testing.T) {
-	defer reset()
-
-	if _, err := RequestPort(defaultIP, "tcpp", 0); err != ErrUnknownProtocol {
+	if _, err := New().RequestPort(defaultIP, "tcpp", 0); err != ErrUnknownProtocol {
 		t.Fatalf("Expected error %s got %s", ErrUnknownProtocol, err)
 	}
 }
 
 func TestAllocateAllPorts(t *testing.T) {
-	defer reset()
+	p := New()
 
 	for i := 0; i <= endPortRange-beginPortRange; i++ {
-		port, err := RequestPort(defaultIP, "tcp", 0)
+		port, err := p.RequestPort(defaultIP, "tcp", 0)
 		if err != nil {
 			t.Fatal(err)
 		}
@@ -118,21 +112,21 @@ func TestAllocateAllPorts(t *testing.T) {
 		}
 	}
 
-	if _, err := RequestPort(defaultIP, "tcp", 0); err != ErrAllPortsAllocated {
+	if _, err := p.RequestPort(defaultIP, "tcp", 0); err != ErrAllPortsAllocated {
 		t.Fatalf("Expected error %s got %s", ErrAllPortsAllocated, err)
 	}
 
-	_, err := RequestPort(defaultIP, "udp", 0)
+	_, err := p.RequestPort(defaultIP, "udp", 0)
 	if err != nil {
 		t.Fatal(err)
 	}
 
 	// release a port in the middle and ensure we get another tcp port
 	port := beginPortRange + 5
-	if err := ReleasePort(defaultIP, "tcp", port); err != nil {
+	if err := p.ReleasePort(defaultIP, "tcp", port); err != nil {
 		t.Fatal(err)
 	}
-	newPort, err := RequestPort(defaultIP, "tcp", 0)
+	newPort, err := p.RequestPort(defaultIP, "tcp", 0)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -142,10 +136,10 @@ func TestAllocateAllPorts(t *testing.T) {
 
 	// now pm.last == newPort, release it so that it's the only free port of
 	// the range, and ensure we get it back
-	if err := ReleasePort(defaultIP, "tcp", newPort); err != nil {
+	if err := p.ReleasePort(defaultIP, "tcp", newPort); err != nil {
 		t.Fatal(err)
 	}
-	port, err = RequestPort(defaultIP, "tcp", 0)
+	port, err = p.RequestPort(defaultIP, "tcp", 0)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -155,11 +149,11 @@ func TestAllocateAllPorts(t *testing.T) {
 }
 
 func BenchmarkAllocatePorts(b *testing.B) {
-	defer reset()
+	p := New()
 
 	for i := 0; i < b.N; i++ {
 		for i := 0; i <= endPortRange-beginPortRange; i++ {
-			port, err := RequestPort(defaultIP, "tcp", 0)
+			port, err := p.RequestPort(defaultIP, "tcp", 0)
 			if err != nil {
 				b.Fatal(err)
 			}
@@ -168,21 +162,21 @@ func BenchmarkAllocatePorts(b *testing.B) {
 				b.Fatalf("Expected port %d got %d", expected, port)
 			}
 		}
-		reset()
+		p.ReleaseAll()
 	}
 }
 
 func TestPortAllocation(t *testing.T) {
-	defer reset()
+	p := New()
 
 	ip := net.ParseIP("192.168.0.1")
 	ip2 := net.ParseIP("192.168.0.2")
-	if port, err := RequestPort(ip, "tcp", 80); err != nil {
+	if port, err := p.RequestPort(ip, "tcp", 80); err != nil {
 		t.Fatal(err)
 	} else if port != 80 {
 		t.Fatalf("Acquire(80) should return 80, not %d", port)
 	}
-	port, err := RequestPort(ip, "tcp", 0)
+	port, err := p.RequestPort(ip, "tcp", 0)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -190,41 +184,41 @@ func TestPortAllocation(t *testing.T) {
 		t.Fatalf("Acquire(0) should return a non-zero port")
 	}
 
-	if _, err := RequestPort(ip, "tcp", port); err == nil {
+	if _, err := p.RequestPort(ip, "tcp", port); err == nil {
 		t.Fatalf("Acquiring a port already in use should return an error")
 	}
 
-	if newPort, err := RequestPort(ip, "tcp", 0); err != nil {
+	if newPort, err := p.RequestPort(ip, "tcp", 0); err != nil {
 		t.Fatal(err)
 	} else if newPort == port {
 		t.Fatalf("Acquire(0) allocated the same port twice: %d", port)
 	}
 
-	if _, err := RequestPort(ip, "tcp", 80); err == nil {
+	if _, err := p.RequestPort(ip, "tcp", 80); err == nil {
 		t.Fatalf("Acquiring a port already in use should return an error")
 	}
-	if _, err := RequestPort(ip2, "tcp", 80); err != nil {
+	if _, err := p.RequestPort(ip2, "tcp", 80); err != nil {
 		t.Fatalf("It should be possible to allocate the same port on a different interface")
 	}
-	if _, err := RequestPort(ip2, "tcp", 80); err == nil {
+	if _, err := p.RequestPort(ip2, "tcp", 80); err == nil {
 		t.Fatalf("Acquiring a port already in use should return an error")
 	}
-	if err := ReleasePort(ip, "tcp", 80); err != nil {
+	if err := p.ReleasePort(ip, "tcp", 80); err != nil {
 		t.Fatal(err)
 	}
-	if _, err := RequestPort(ip, "tcp", 80); err != nil {
+	if _, err := p.RequestPort(ip, "tcp", 80); err != nil {
 		t.Fatal(err)
 	}
 
-	port, err = RequestPort(ip, "tcp", 0)
+	port, err = p.RequestPort(ip, "tcp", 0)
 	if err != nil {
 		t.Fatal(err)
 	}
-	port2, err := RequestPort(ip, "tcp", port+1)
+	port2, err := p.RequestPort(ip, "tcp", port+1)
 	if err != nil {
 		t.Fatal(err)
 	}
-	port3, err := RequestPort(ip, "tcp", 0)
+	port3, err := p.RequestPort(ip, "tcp", 0)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -234,15 +228,15 @@ func TestPortAllocation(t *testing.T) {
 }
 
 func TestNoDuplicateBPR(t *testing.T) {
-	defer reset()
+	p := New()
 
-	if port, err := RequestPort(defaultIP, "tcp", beginPortRange); err != nil {
+	if port, err := p.RequestPort(defaultIP, "tcp", beginPortRange); err != nil {
 		t.Fatal(err)
 	} else if port != beginPortRange {
 		t.Fatalf("Expected port %d got %d", beginPortRange, port)
 	}
 
-	if port, err := RequestPort(defaultIP, "tcp", 0); err != nil {
+	if port, err := p.RequestPort(defaultIP, "tcp", 0); err != nil {
 		t.Fatal(err)
 	} else if port == beginPortRange {
 		t.Fatalf("Acquire(0) allocated the same port twice: %d", port)