netlink.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. // +build linux
  2. package ipvs
  3. import (
  4. "bytes"
  5. "encoding/binary"
  6. "fmt"
  7. "net"
  8. "sync"
  9. "syscall"
  10. "unsafe"
  11. "github.com/vishvananda/netlink/nl"
  12. "github.com/vishvananda/netns"
  13. )
  14. var (
  15. native = nl.NativeEndian()
  16. ipvsFamily int
  17. ipvsOnce sync.Once
  18. )
  19. type genlMsgHdr struct {
  20. cmd uint8
  21. version uint8
  22. reserved uint16
  23. }
  24. type ipvsFlags struct {
  25. flags uint32
  26. mask uint32
  27. }
  28. func deserializeGenlMsg(b []byte) (hdr *genlMsgHdr) {
  29. return (*genlMsgHdr)(unsafe.Pointer(&b[0:unsafe.Sizeof(*hdr)][0]))
  30. }
  31. func (hdr *genlMsgHdr) Serialize() []byte {
  32. return (*(*[unsafe.Sizeof(*hdr)]byte)(unsafe.Pointer(hdr)))[:]
  33. }
  34. func (hdr *genlMsgHdr) Len() int {
  35. return int(unsafe.Sizeof(*hdr))
  36. }
  37. func (f *ipvsFlags) Serialize() []byte {
  38. return (*(*[unsafe.Sizeof(*f)]byte)(unsafe.Pointer(f)))[:]
  39. }
  40. func (f *ipvsFlags) Len() int {
  41. return int(unsafe.Sizeof(*f))
  42. }
  43. func setup() {
  44. ipvsOnce.Do(func() {
  45. var err error
  46. ipvsFamily, err = getIPVSFamily()
  47. if err != nil {
  48. panic("could not get ipvs family")
  49. }
  50. })
  51. }
  52. func fillService(s *Service) nl.NetlinkRequestData {
  53. cmdAttr := nl.NewRtAttr(ipvsCmdAttrService, nil)
  54. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrAddressFamily, nl.Uint16Attr(s.AddressFamily))
  55. if s.FWMark != 0 {
  56. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrFWMark, nl.Uint32Attr(s.FWMark))
  57. } else {
  58. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrProtocol, nl.Uint16Attr(s.Protocol))
  59. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrAddress, rawIPData(s.Address))
  60. // Port needs to be in network byte order.
  61. portBuf := new(bytes.Buffer)
  62. binary.Write(portBuf, binary.BigEndian, s.Port)
  63. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrPort, portBuf.Bytes())
  64. }
  65. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrSchedName, nl.ZeroTerminated(s.SchedName))
  66. if s.PEName != "" {
  67. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrPEName, nl.ZeroTerminated(s.PEName))
  68. }
  69. f := &ipvsFlags{
  70. flags: s.Flags,
  71. mask: 0xFFFFFFFF,
  72. }
  73. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrFlags, f.Serialize())
  74. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrTimeout, nl.Uint32Attr(s.Timeout))
  75. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrNetmask, nl.Uint32Attr(s.Netmask))
  76. return cmdAttr
  77. }
  78. func fillDestinaton(d *Destination) nl.NetlinkRequestData {
  79. cmdAttr := nl.NewRtAttr(ipvsCmdAttrDest, nil)
  80. nl.NewRtAttrChild(cmdAttr, ipvsDestAttrAddress, rawIPData(d.Address))
  81. // Port needs to be in network byte order.
  82. portBuf := new(bytes.Buffer)
  83. binary.Write(portBuf, binary.BigEndian, d.Port)
  84. nl.NewRtAttrChild(cmdAttr, ipvsDestAttrPort, portBuf.Bytes())
  85. nl.NewRtAttrChild(cmdAttr, ipvsDestAttrForwardingMethod, nl.Uint32Attr(d.ConnectionFlags&ConnectionFlagFwdMask))
  86. nl.NewRtAttrChild(cmdAttr, ipvsDestAttrWeight, nl.Uint32Attr(uint32(d.Weight)))
  87. nl.NewRtAttrChild(cmdAttr, ipvsDestAttrUpperThreshold, nl.Uint32Attr(d.UpperThreshold))
  88. nl.NewRtAttrChild(cmdAttr, ipvsDestAttrLowerThreshold, nl.Uint32Attr(d.LowerThreshold))
  89. return cmdAttr
  90. }
  91. func (i *Handle) doCmd(s *Service, d *Destination, cmd uint8) error {
  92. req := newIPVSRequest(cmd)
  93. req.AddData(fillService(s))
  94. if d != nil {
  95. req.AddData(fillDestinaton(d))
  96. }
  97. if _, err := execute(i.sock, req, 0); err != nil {
  98. return err
  99. }
  100. return nil
  101. }
  102. func getIPVSFamily() (int, error) {
  103. sock, err := nl.GetNetlinkSocketAt(netns.None(), netns.None(), syscall.NETLINK_GENERIC)
  104. if err != nil {
  105. return 0, err
  106. }
  107. req := newGenlRequest(genlCtrlID, genlCtrlCmdGetFamily)
  108. req.AddData(nl.NewRtAttr(genlCtrlAttrFamilyName, nl.ZeroTerminated("IPVS")))
  109. msgs, err := execute(sock, req, 0)
  110. if err != nil {
  111. return 0, err
  112. }
  113. for _, m := range msgs {
  114. hdr := deserializeGenlMsg(m)
  115. attrs, err := nl.ParseRouteAttr(m[hdr.Len():])
  116. if err != nil {
  117. return 0, err
  118. }
  119. for _, attr := range attrs {
  120. switch int(attr.Attr.Type) {
  121. case genlCtrlAttrFamilyID:
  122. return int(native.Uint16(attr.Value[0:2])), nil
  123. }
  124. }
  125. }
  126. return 0, fmt.Errorf("no family id in the netlink response")
  127. }
  128. func rawIPData(ip net.IP) []byte {
  129. family := nl.GetIPFamily(ip)
  130. if family == nl.FAMILY_V4 {
  131. return ip.To4()
  132. }
  133. return ip
  134. }
  135. func newIPVSRequest(cmd uint8) *nl.NetlinkRequest {
  136. return newGenlRequest(ipvsFamily, cmd)
  137. }
  138. func newGenlRequest(familyID int, cmd uint8) *nl.NetlinkRequest {
  139. req := nl.NewNetlinkRequest(familyID, syscall.NLM_F_ACK)
  140. req.AddData(&genlMsgHdr{cmd: cmd, version: 1})
  141. return req
  142. }
  143. func execute(s *nl.NetlinkSocket, req *nl.NetlinkRequest, resType uint16) ([][]byte, error) {
  144. var (
  145. err error
  146. )
  147. if err := s.Send(req); err != nil {
  148. return nil, err
  149. }
  150. pid, err := s.GetPid()
  151. if err != nil {
  152. return nil, err
  153. }
  154. var res [][]byte
  155. done:
  156. for {
  157. msgs, err := s.Receive()
  158. if err != nil {
  159. return nil, err
  160. }
  161. for _, m := range msgs {
  162. if m.Header.Seq != req.Seq {
  163. return nil, fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq)
  164. }
  165. if m.Header.Pid != pid {
  166. return nil, fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid)
  167. }
  168. if m.Header.Type == syscall.NLMSG_DONE {
  169. break done
  170. }
  171. if m.Header.Type == syscall.NLMSG_ERROR {
  172. error := int32(native.Uint32(m.Data[0:4]))
  173. if error == 0 {
  174. break done
  175. }
  176. return nil, syscall.Errno(-error)
  177. }
  178. if resType != 0 && m.Header.Type != resType {
  179. continue
  180. }
  181. res = append(res, m.Data)
  182. if m.Header.Flags&syscall.NLM_F_MULTI == 0 {
  183. break done
  184. }
  185. }
  186. }
  187. return res, nil
  188. }