Browse Source

Refactor port allocator to not have ANY global state

Signed-off-by: Michael Crosby <crosbymichael@gmail.com>
Michael Crosby 10 years ago
parent
commit
43a50b0618

+ 63 - 76
daemon/networkdriver/portallocator/portallocator.go

@@ -16,59 +16,14 @@ const (
 	DefaultPortRangeEnd   = 65535
 )
 
-var (
-	beginPortRange = DefaultPortRangeStart
-	endPortRange   = DefaultPortRangeEnd
-)
-
-type portMap struct {
-	p    map[int]struct{}
-	last int
-}
-
-func newPortMap() *portMap {
-	return &portMap{
-		p:    map[int]struct{}{},
-		last: endPortRange,
-	}
-}
-
-type protoMap map[string]*portMap
-
-func newProtoMap() protoMap {
-	return protoMap{
-		"tcp": newPortMap(),
-		"udp": newPortMap(),
-	}
-}
-
 type ipMapping map[string]protoMap
 
 var (
 	ErrAllPortsAllocated = errors.New("all ports are allocated")
 	ErrUnknownProtocol   = errors.New("unknown protocol")
+	defaultIP            = net.ParseIP("0.0.0.0")
 )
 
-var (
-	defaultIP = net.ParseIP("0.0.0.0")
-
-	DefaultPortAllocator = New()
-	RequestPort          = DefaultPortAllocator.RequestPort
-	ReleasePort          = DefaultPortAllocator.ReleasePort
-	ReleaseAll           = DefaultPortAllocator.ReleaseAll
-)
-
-type PortAllocator struct {
-	mutex sync.Mutex
-	ipMap ipMapping
-}
-
-func New() *PortAllocator {
-	return &PortAllocator{
-		ipMap: ipMapping{},
-	}
-}
-
 type ErrPortAlreadyAllocated struct {
 	ip   string
 	port int
@@ -81,32 +36,6 @@ func NewErrPortAlreadyAllocated(ip string, port int) ErrPortAlreadyAllocated {
 	}
 }
 
-func init() {
-	const portRangeKernelParam = "/proc/sys/net/ipv4/ip_local_port_range"
-	portRangeFallback := fmt.Sprintf("using fallback port range %d-%d", beginPortRange, endPortRange)
-
-	file, err := os.Open(portRangeKernelParam)
-	if err != nil {
-		logrus.Warnf("port allocator - %s due to error: %v", portRangeFallback, err)
-		return
-	}
-	var start, end int
-	n, err := fmt.Fscanf(bufio.NewReader(file), "%d\t%d", &start, &end)
-	if n != 2 || err != nil {
-		if err == nil {
-			err = fmt.Errorf("unexpected count of parsed numbers (%d)", n)
-		}
-		logrus.Errorf("port allocator - failed to parse system ephemeral port range from %s - %s: %v", portRangeKernelParam, portRangeFallback, err)
-		return
-	}
-	beginPortRange = start
-	endPortRange = end
-}
-
-func PortRange() (int, int) {
-	return beginPortRange, endPortRange
-}
-
 func (e ErrPortAlreadyAllocated) IP() string {
 	return e.ip
 }
@@ -123,6 +52,51 @@ func (e ErrPortAlreadyAllocated) Error() string {
 	return fmt.Sprintf("Bind for %s:%d failed: port is already allocated", e.ip, e.port)
 }
 
+type (
+	PortAllocator struct {
+		mutex sync.Mutex
+		ipMap ipMapping
+		Begin int
+		End   int
+	}
+	portMap struct {
+		p          map[int]struct{}
+		begin, end int
+		last       int
+	}
+	protoMap map[string]*portMap
+)
+
+func New() *PortAllocator {
+	start, end, err := getDynamicPortRange()
+	if err != nil {
+		logrus.Warn(err)
+		start, end = DefaultPortRangeStart, DefaultPortRangeEnd
+	}
+	return &PortAllocator{
+		ipMap: ipMapping{},
+		Begin: start,
+		End:   end,
+	}
+}
+
+func getDynamicPortRange() (start int, end int, err error) {
+	const portRangeKernelParam = "/proc/sys/net/ipv4/ip_local_port_range"
+	portRangeFallback := fmt.Sprintf("using fallback port range %d-%d", DefaultPortRangeStart, DefaultPortRangeEnd)
+	file, err := os.Open(portRangeKernelParam)
+	if err != nil {
+		return 0, 0, fmt.Errorf("port allocator - %s due to error: %v", portRangeFallback, err)
+	}
+	n, err := fmt.Fscanf(bufio.NewReader(file), "%d\t%d", &start, &end)
+	if n != 2 || err != nil {
+		if err == nil {
+			err = fmt.Errorf("unexpected count of parsed numbers (%d)", n)
+		}
+		return 0, 0, fmt.Errorf("port allocator - failed to parse system ephemeral port range from %s - %s: %v", portRangeKernelParam, portRangeFallback, err)
+	}
+	return start, end, nil
+}
+
 // RequestPort requests new port from global ports pool for specified ip and proto.
 // If port is 0 it returns first free port. Otherwise it cheks port availability
 // in pool and return that port or error if port is already busy.
@@ -140,7 +114,11 @@ func (p *PortAllocator) RequestPort(ip net.IP, proto string, port int) (int, err
 	ipstr := ip.String()
 	protomap, ok := p.ipMap[ipstr]
 	if !ok {
-		protomap = newProtoMap()
+		protomap = protoMap{
+			"tcp": p.newPortMap(),
+			"udp": p.newPortMap(),
+		}
+
 		p.ipMap[ipstr] = protomap
 	}
 	mapping := protomap[proto]
@@ -175,6 +153,15 @@ func (p *PortAllocator) ReleasePort(ip net.IP, proto string, port int) error {
 	return nil
 }
 
+func (p *PortAllocator) newPortMap() *portMap {
+	return &portMap{
+		p:     map[int]struct{}{},
+		begin: p.Begin,
+		end:   p.End,
+		last:  p.End,
+	}
+}
+
 // ReleaseAll releases all ports for all ips.
 func (p *PortAllocator) ReleaseAll() error {
 	p.mutex.Lock()
@@ -185,10 +172,10 @@ func (p *PortAllocator) ReleaseAll() error {
 
 func (pm *portMap) findPort() (int, error) {
 	port := pm.last
-	for i := 0; i <= endPortRange-beginPortRange; i++ {
+	for i := 0; i <= pm.end-pm.begin; i++ {
 		port++
-		if port > endPortRange {
-			port = beginPortRange
+		if port > pm.end {
+			port = pm.begin
 		}
 
 		if _, ok := pm.p[port]; !ok {

+ 10 - 15
daemon/networkdriver/portallocator/portallocator_test.go

@@ -5,11 +5,6 @@ import (
 	"testing"
 )
 
-func init() {
-	beginPortRange = DefaultPortRangeStart
-	endPortRange = DefaultPortRangeEnd
-}
-
 func TestRequestNewPort(t *testing.T) {
 	p := New()
 
@@ -18,7 +13,7 @@ func TestRequestNewPort(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	if expected := beginPortRange; port != expected {
+	if expected := p.Begin; port != expected {
 		t.Fatalf("Expected port %d got %d", expected, port)
 	}
 }
@@ -101,13 +96,13 @@ func TestUnknowProtocol(t *testing.T) {
 func TestAllocateAllPorts(t *testing.T) {
 	p := New()
 
-	for i := 0; i <= endPortRange-beginPortRange; i++ {
+	for i := 0; i <= p.End-p.Begin; i++ {
 		port, err := p.RequestPort(defaultIP, "tcp", 0)
 		if err != nil {
 			t.Fatal(err)
 		}
 
-		if expected := beginPortRange + i; port != expected {
+		if expected := p.Begin + i; port != expected {
 			t.Fatalf("Expected port %d got %d", expected, port)
 		}
 	}
@@ -122,7 +117,7 @@ func TestAllocateAllPorts(t *testing.T) {
 	}
 
 	// release a port in the middle and ensure we get another tcp port
-	port := beginPortRange + 5
+	port := p.Begin + 5
 	if err := p.ReleasePort(defaultIP, "tcp", port); err != nil {
 		t.Fatal(err)
 	}
@@ -152,13 +147,13 @@ func BenchmarkAllocatePorts(b *testing.B) {
 	p := New()
 
 	for i := 0; i < b.N; i++ {
-		for i := 0; i <= endPortRange-beginPortRange; i++ {
+		for i := 0; i <= p.End-p.Begin; i++ {
 			port, err := p.RequestPort(defaultIP, "tcp", 0)
 			if err != nil {
 				b.Fatal(err)
 			}
 
-			if expected := beginPortRange + i; port != expected {
+			if expected := p.Begin + i; port != expected {
 				b.Fatalf("Expected port %d got %d", expected, port)
 			}
 		}
@@ -230,15 +225,15 @@ func TestPortAllocation(t *testing.T) {
 func TestNoDuplicateBPR(t *testing.T) {
 	p := New()
 
-	if port, err := p.RequestPort(defaultIP, "tcp", beginPortRange); err != nil {
+	if port, err := p.RequestPort(defaultIP, "tcp", p.Begin); err != nil {
 		t.Fatal(err)
-	} else if port != beginPortRange {
-		t.Fatalf("Expected port %d got %d", beginPortRange, port)
+	} else if port != p.Begin {
+		t.Fatalf("Expected port %d got %d", p.Begin, port)
 	}
 
 	if port, err := p.RequestPort(defaultIP, "tcp", 0); err != nil {
 		t.Fatal(err)
-	} else if port == beginPortRange {
+	} else if port == p.Begin {
 		t.Fatalf("Acquire(0) allocated the same port twice: %d", port)
 	}
 }