sockaddrs.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. package sockaddr
  2. import (
  3. "bytes"
  4. "sort"
  5. )
  6. // SockAddrs is a slice of SockAddrs
  7. type SockAddrs []SockAddr
  8. func (s SockAddrs) Len() int { return len(s) }
  9. func (s SockAddrs) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
  10. // CmpAddrFunc is the function signature that must be met to be used in the
  11. // OrderedAddrBy multiAddrSorter
  12. type CmpAddrFunc func(p1, p2 *SockAddr) int
  13. // multiAddrSorter implements the Sort interface, sorting the SockAddrs within.
  14. type multiAddrSorter struct {
  15. addrs SockAddrs
  16. cmp []CmpAddrFunc
  17. }
  18. // Sort sorts the argument slice according to the Cmp functions passed to
  19. // OrderedAddrBy.
  20. func (ms *multiAddrSorter) Sort(sockAddrs SockAddrs) {
  21. ms.addrs = sockAddrs
  22. sort.Sort(ms)
  23. }
  24. // OrderedAddrBy sorts SockAddr by the list of sort function pointers.
  25. func OrderedAddrBy(cmpFuncs ...CmpAddrFunc) *multiAddrSorter {
  26. return &multiAddrSorter{
  27. cmp: cmpFuncs,
  28. }
  29. }
  30. // Len is part of sort.Interface.
  31. func (ms *multiAddrSorter) Len() int {
  32. return len(ms.addrs)
  33. }
  34. // Less is part of sort.Interface. It is implemented by looping along the
  35. // Cmp() functions until it finds a comparison that is either less than,
  36. // equal to, or greater than.
  37. func (ms *multiAddrSorter) Less(i, j int) bool {
  38. p, q := &ms.addrs[i], &ms.addrs[j]
  39. // Try all but the last comparison.
  40. var k int
  41. for k = 0; k < len(ms.cmp)-1; k++ {
  42. cmp := ms.cmp[k]
  43. x := cmp(p, q)
  44. switch x {
  45. case -1:
  46. // p < q, so we have a decision.
  47. return true
  48. case 1:
  49. // p > q, so we have a decision.
  50. return false
  51. }
  52. // p == q; try the next comparison.
  53. }
  54. // All comparisons to here said "equal", so just return whatever the
  55. // final comparison reports.
  56. switch ms.cmp[k](p, q) {
  57. case -1:
  58. return true
  59. case 1:
  60. return false
  61. default:
  62. // Still a tie! Now what?
  63. return false
  64. }
  65. }
  66. // Swap is part of sort.Interface.
  67. func (ms *multiAddrSorter) Swap(i, j int) {
  68. ms.addrs[i], ms.addrs[j] = ms.addrs[j], ms.addrs[i]
  69. }
  70. const (
  71. // NOTE (sean@): These constants are here for code readability only and
  72. // are sprucing up the code for readability purposes. Some of the
  73. // Cmp*() variants have confusing logic (especially when dealing with
  74. // mixed-type comparisons) and this, I think, has made it easier to grok
  75. // the code faster.
  76. sortReceiverBeforeArg = -1
  77. sortDeferDecision = 0
  78. sortArgBeforeReceiver = 1
  79. )
  80. // AscAddress is a sorting function to sort SockAddrs by their respective
  81. // address type. Non-equal types are deferred in the sort.
  82. func AscAddress(p1Ptr, p2Ptr *SockAddr) int {
  83. p1 := *p1Ptr
  84. p2 := *p2Ptr
  85. switch v := p1.(type) {
  86. case IPv4Addr:
  87. return v.CmpAddress(p2)
  88. case IPv6Addr:
  89. return v.CmpAddress(p2)
  90. case UnixSock:
  91. return v.CmpAddress(p2)
  92. default:
  93. return sortDeferDecision
  94. }
  95. }
  96. // AscPort is a sorting function to sort SockAddrs by their respective address
  97. // type. Non-equal types are deferred in the sort.
  98. func AscPort(p1Ptr, p2Ptr *SockAddr) int {
  99. p1 := *p1Ptr
  100. p2 := *p2Ptr
  101. switch v := p1.(type) {
  102. case IPv4Addr:
  103. return v.CmpPort(p2)
  104. case IPv6Addr:
  105. return v.CmpPort(p2)
  106. default:
  107. return sortDeferDecision
  108. }
  109. }
  110. // AscPrivate is a sorting function to sort "more secure" private values before
  111. // "more public" values. Both IPv4 and IPv6 are compared against RFC6890
  112. // (RFC6890 includes, and is not limited to, RFC1918 and RFC6598 for IPv4, and
  113. // IPv6 includes RFC4193).
  114. func AscPrivate(p1Ptr, p2Ptr *SockAddr) int {
  115. p1 := *p1Ptr
  116. p2 := *p2Ptr
  117. switch v := p1.(type) {
  118. case IPv4Addr, IPv6Addr:
  119. return v.CmpRFC(6890, p2)
  120. default:
  121. return sortDeferDecision
  122. }
  123. }
  124. // AscNetworkSize is a sorting function to sort SockAddrs based on their network
  125. // size. Non-equal types are deferred in the sort.
  126. func AscNetworkSize(p1Ptr, p2Ptr *SockAddr) int {
  127. p1 := *p1Ptr
  128. p2 := *p2Ptr
  129. p1Type := p1.Type()
  130. p2Type := p2.Type()
  131. // Network size operations on non-IP types make no sense
  132. if p1Type != p2Type && p1Type != TypeIP {
  133. return sortDeferDecision
  134. }
  135. ipA := p1.(IPAddr)
  136. ipB := p2.(IPAddr)
  137. return bytes.Compare([]byte(*ipA.NetIPMask()), []byte(*ipB.NetIPMask()))
  138. }
  139. // AscType is a sorting function to sort "more secure" types before
  140. // "less-secure" types.
  141. func AscType(p1Ptr, p2Ptr *SockAddr) int {
  142. p1 := *p1Ptr
  143. p2 := *p2Ptr
  144. p1Type := p1.Type()
  145. p2Type := p2.Type()
  146. switch {
  147. case p1Type < p2Type:
  148. return sortReceiverBeforeArg
  149. case p1Type == p2Type:
  150. return sortDeferDecision
  151. case p1Type > p2Type:
  152. return sortArgBeforeReceiver
  153. default:
  154. return sortDeferDecision
  155. }
  156. }
  157. // FilterByType returns two lists: a list of matched and unmatched SockAddrs
  158. func (sas SockAddrs) FilterByType(type_ SockAddrType) (matched, excluded SockAddrs) {
  159. matched = make(SockAddrs, 0, len(sas))
  160. excluded = make(SockAddrs, 0, len(sas))
  161. for _, sa := range sas {
  162. if sa.Type()&type_ != 0 {
  163. matched = append(matched, sa)
  164. } else {
  165. excluded = append(excluded, sa)
  166. }
  167. }
  168. return matched, excluded
  169. }