portallocator.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. package portallocator
  2. import (
  3. "errors"
  4. "net"
  5. "sync"
  6. )
  7. type (
  8. portMap map[int]bool
  9. protocolMap map[string]portMap
  10. ipMapping map[string]protocolMap
  11. )
  12. const (
  13. BeginPortRange = 49153
  14. EndPortRange = 65535
  15. )
  16. var (
  17. ErrAllPortsAllocated = errors.New("all ports are allocated")
  18. ErrPortAlreadyAllocated = errors.New("port has already been allocated")
  19. ErrUnknownProtocol = errors.New("unknown protocol")
  20. )
  21. var (
  22. mutex sync.Mutex
  23. defaultIP = net.ParseIP("0.0.0.0")
  24. globalMap = ipMapping{}
  25. )
  26. func RequestPort(ip net.IP, proto string, port int) (int, error) {
  27. mutex.Lock()
  28. defer mutex.Unlock()
  29. if err := validateProto(proto); err != nil {
  30. return 0, err
  31. }
  32. ip = getDefault(ip)
  33. mapping := getOrCreate(ip)
  34. if port > 0 {
  35. if !mapping[proto][port] {
  36. mapping[proto][port] = true
  37. return port, nil
  38. } else {
  39. return 0, ErrPortAlreadyAllocated
  40. }
  41. } else {
  42. port, err := findPort(ip, proto)
  43. if err != nil {
  44. return 0, err
  45. }
  46. return port, nil
  47. }
  48. }
  49. func ReleasePort(ip net.IP, proto string, port int) error {
  50. mutex.Lock()
  51. defer mutex.Unlock()
  52. ip = getDefault(ip)
  53. mapping := getOrCreate(ip)
  54. delete(mapping[proto], port)
  55. return nil
  56. }
  57. func ReleaseAll() error {
  58. mutex.Lock()
  59. defer mutex.Unlock()
  60. globalMap = ipMapping{}
  61. return nil
  62. }
  63. func getOrCreate(ip net.IP) protocolMap {
  64. ipstr := ip.String()
  65. if _, ok := globalMap[ipstr]; !ok {
  66. globalMap[ipstr] = protocolMap{
  67. "tcp": portMap{},
  68. "udp": portMap{},
  69. }
  70. }
  71. return globalMap[ipstr]
  72. }
  73. func findPort(ip net.IP, proto string) (int, error) {
  74. port := BeginPortRange
  75. mapping := getOrCreate(ip)
  76. for mapping[proto][port] {
  77. port++
  78. if port > EndPortRange {
  79. return 0, ErrAllPortsAllocated
  80. }
  81. }
  82. mapping[proto][port] = true
  83. return port, nil
  84. }
  85. func getDefault(ip net.IP) net.IP {
  86. if ip == nil {
  87. return defaultIP
  88. }
  89. return ip
  90. }
  91. func validateProto(proto string) error {
  92. if proto != "tcp" && proto != "udp" {
  93. return ErrUnknownProtocol
  94. }
  95. return nil
  96. }