diff --git a/runtime/networkdriver/portallocator/portallocator.go b/runtime/networkdriver/portallocator/portallocator.go index 71cac82703..4d698f2de2 100644 --- a/runtime/networkdriver/portallocator/portallocator.go +++ b/runtime/networkdriver/portallocator/portallocator.go @@ -100,22 +100,30 @@ func ReleaseAll() error { } func registerDynamicPort(ip net.IP, proto string) (int, error) { - allocated := defaultAllocatedPorts[proto] - - port := nextPort(proto) - if port > EndPortRange { - return 0, ErrPortExceedsRange - } if !equalsDefault(ip) { registerIP(ip) ipAllocated := otherAllocatedPorts[ip.String()][proto] + + port, err := findNextPort(proto, ipAllocated) + if err != nil { + return 0, err + } ipAllocated.Push(port) + return port, nil + } else { + + allocated := defaultAllocatedPorts[proto] + + port, err := findNextPort(proto, allocated) + if err != nil { + return 0, err + } allocated.Push(port) + return port, nil } - return port, nil } func registerSetPort(ip net.IP, proto string, port int) error { @@ -142,6 +150,17 @@ func equalsDefault(ip net.IP) bool { return ip == nil || ip.Equal(defaultIP) } +func findNextPort(proto string, allocated *collections.OrderedIntSet) (int, error) { + port := nextPort(proto) + for allocated.Exists(port) { + port = nextPort(proto) + } + if port > EndPortRange { + return 0, ErrPortExceedsRange + } + return port, nil +} + func nextPort(proto string) int { c := currentDynamicPort[proto] + 1 currentDynamicPort[proto] = c diff --git a/runtime/networkdriver/portallocator/portallocator_test.go b/runtime/networkdriver/portallocator/portallocator_test.go index 603bd03bd7..f01bcfc99e 100644 --- a/runtime/networkdriver/portallocator/portallocator_test.go +++ b/runtime/networkdriver/portallocator/portallocator_test.go @@ -181,4 +181,20 @@ func TestPortAllocation(t *testing.T) { if _, err := RequestPort(ip, "tcp", 80); err != nil { t.Fatal(err) } + + port, err = RequestPort(ip, "tcp", 0) + if err != nil { + t.Fatal(err) + } + port2, err := RequestPort(ip, "tcp", port+1) + if err != nil { + t.Fatal(err) + } + port3, err := RequestPort(ip, "tcp", 0) + if err != nil { + t.Fatal(err) + } + if port3 == port2 { + t.Fatal("Requesting a dynamic port should never allocate a used port") + } }