netlink.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. // +build linux
  2. package ipvs
  3. import (
  4. "bytes"
  5. "encoding/binary"
  6. "fmt"
  7. "net"
  8. "os/exec"
  9. "strings"
  10. "sync"
  11. "sync/atomic"
  12. "syscall"
  13. "unsafe"
  14. "github.com/Sirupsen/logrus"
  15. "github.com/vishvananda/netlink/nl"
  16. "github.com/vishvananda/netns"
  17. )
  18. var (
  19. native = nl.NativeEndian()
  20. ipvsFamily int
  21. ipvsOnce sync.Once
  22. )
  23. type genlMsgHdr struct {
  24. cmd uint8
  25. version uint8
  26. reserved uint16
  27. }
  28. type ipvsFlags struct {
  29. flags uint32
  30. mask uint32
  31. }
  32. func deserializeGenlMsg(b []byte) (hdr *genlMsgHdr) {
  33. return (*genlMsgHdr)(unsafe.Pointer(&b[0:unsafe.Sizeof(*hdr)][0]))
  34. }
  35. func (hdr *genlMsgHdr) Serialize() []byte {
  36. return (*(*[unsafe.Sizeof(*hdr)]byte)(unsafe.Pointer(hdr)))[:]
  37. }
  38. func (hdr *genlMsgHdr) Len() int {
  39. return int(unsafe.Sizeof(*hdr))
  40. }
  41. func (f *ipvsFlags) Serialize() []byte {
  42. return (*(*[unsafe.Sizeof(*f)]byte)(unsafe.Pointer(f)))[:]
  43. }
  44. func (f *ipvsFlags) Len() int {
  45. return int(unsafe.Sizeof(*f))
  46. }
  47. func setup() {
  48. ipvsOnce.Do(func() {
  49. var err error
  50. if out, err := exec.Command("modprobe", "-va", "ip_vs").CombinedOutput(); err != nil {
  51. logrus.Warnf("Running modprobe ip_vs failed with message: `%s`, error: %v", strings.TrimSpace(string(out)), err)
  52. }
  53. ipvsFamily, err = getIPVSFamily()
  54. if err != nil {
  55. logrus.Error("Could not get ipvs family information from the kernel. It is possible that ipvs is not enabled in your kernel. Native loadbalancing will not work until this is fixed.")
  56. }
  57. })
  58. }
  59. func fillService(s *Service) nl.NetlinkRequestData {
  60. cmdAttr := nl.NewRtAttr(ipvsCmdAttrService, nil)
  61. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrAddressFamily, nl.Uint16Attr(s.AddressFamily))
  62. if s.FWMark != 0 {
  63. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrFWMark, nl.Uint32Attr(s.FWMark))
  64. } else {
  65. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrProtocol, nl.Uint16Attr(s.Protocol))
  66. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrAddress, rawIPData(s.Address))
  67. // Port needs to be in network byte order.
  68. portBuf := new(bytes.Buffer)
  69. binary.Write(portBuf, binary.BigEndian, s.Port)
  70. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrPort, portBuf.Bytes())
  71. }
  72. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrSchedName, nl.ZeroTerminated(s.SchedName))
  73. if s.PEName != "" {
  74. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrPEName, nl.ZeroTerminated(s.PEName))
  75. }
  76. f := &ipvsFlags{
  77. flags: s.Flags,
  78. mask: 0xFFFFFFFF,
  79. }
  80. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrFlags, f.Serialize())
  81. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrTimeout, nl.Uint32Attr(s.Timeout))
  82. nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrNetmask, nl.Uint32Attr(s.Netmask))
  83. return cmdAttr
  84. }
  85. func fillDestinaton(d *Destination) nl.NetlinkRequestData {
  86. cmdAttr := nl.NewRtAttr(ipvsCmdAttrDest, nil)
  87. nl.NewRtAttrChild(cmdAttr, ipvsDestAttrAddress, rawIPData(d.Address))
  88. // Port needs to be in network byte order.
  89. portBuf := new(bytes.Buffer)
  90. binary.Write(portBuf, binary.BigEndian, d.Port)
  91. nl.NewRtAttrChild(cmdAttr, ipvsDestAttrPort, portBuf.Bytes())
  92. nl.NewRtAttrChild(cmdAttr, ipvsDestAttrForwardingMethod, nl.Uint32Attr(d.ConnectionFlags&ConnectionFlagFwdMask))
  93. nl.NewRtAttrChild(cmdAttr, ipvsDestAttrWeight, nl.Uint32Attr(uint32(d.Weight)))
  94. nl.NewRtAttrChild(cmdAttr, ipvsDestAttrUpperThreshold, nl.Uint32Attr(d.UpperThreshold))
  95. nl.NewRtAttrChild(cmdAttr, ipvsDestAttrLowerThreshold, nl.Uint32Attr(d.LowerThreshold))
  96. return cmdAttr
  97. }
  98. func (i *Handle) doCmd(s *Service, d *Destination, cmd uint8) error {
  99. req := newIPVSRequest(cmd)
  100. req.Seq = atomic.AddUint32(&i.seq, 1)
  101. req.AddData(fillService(s))
  102. if d != nil {
  103. req.AddData(fillDestinaton(d))
  104. }
  105. if _, err := execute(i.sock, req, 0); err != nil {
  106. return err
  107. }
  108. return nil
  109. }
  110. func getIPVSFamily() (int, error) {
  111. sock, err := nl.GetNetlinkSocketAt(netns.None(), netns.None(), syscall.NETLINK_GENERIC)
  112. if err != nil {
  113. return 0, err
  114. }
  115. defer sock.Close()
  116. req := newGenlRequest(genlCtrlID, genlCtrlCmdGetFamily)
  117. req.AddData(nl.NewRtAttr(genlCtrlAttrFamilyName, nl.ZeroTerminated("IPVS")))
  118. msgs, err := execute(sock, req, 0)
  119. if err != nil {
  120. return 0, err
  121. }
  122. for _, m := range msgs {
  123. hdr := deserializeGenlMsg(m)
  124. attrs, err := nl.ParseRouteAttr(m[hdr.Len():])
  125. if err != nil {
  126. return 0, err
  127. }
  128. for _, attr := range attrs {
  129. switch int(attr.Attr.Type) {
  130. case genlCtrlAttrFamilyID:
  131. return int(native.Uint16(attr.Value[0:2])), nil
  132. }
  133. }
  134. }
  135. return 0, fmt.Errorf("no family id in the netlink response")
  136. }
  137. func rawIPData(ip net.IP) []byte {
  138. family := nl.GetIPFamily(ip)
  139. if family == nl.FAMILY_V4 {
  140. return ip.To4()
  141. }
  142. return ip
  143. }
  144. func newIPVSRequest(cmd uint8) *nl.NetlinkRequest {
  145. return newGenlRequest(ipvsFamily, cmd)
  146. }
  147. func newGenlRequest(familyID int, cmd uint8) *nl.NetlinkRequest {
  148. req := nl.NewNetlinkRequest(familyID, syscall.NLM_F_ACK)
  149. req.AddData(&genlMsgHdr{cmd: cmd, version: 1})
  150. return req
  151. }
  152. func execute(s *nl.NetlinkSocket, req *nl.NetlinkRequest, resType uint16) ([][]byte, error) {
  153. var (
  154. err error
  155. )
  156. if err := s.Send(req); err != nil {
  157. return nil, err
  158. }
  159. pid, err := s.GetPid()
  160. if err != nil {
  161. return nil, err
  162. }
  163. var res [][]byte
  164. done:
  165. for {
  166. msgs, err := s.Receive()
  167. if err != nil {
  168. return nil, err
  169. }
  170. for _, m := range msgs {
  171. if m.Header.Seq != req.Seq {
  172. continue
  173. }
  174. if m.Header.Pid != pid {
  175. return nil, fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid)
  176. }
  177. if m.Header.Type == syscall.NLMSG_DONE {
  178. break done
  179. }
  180. if m.Header.Type == syscall.NLMSG_ERROR {
  181. error := int32(native.Uint32(m.Data[0:4]))
  182. if error == 0 {
  183. break done
  184. }
  185. return nil, syscall.Errno(-error)
  186. }
  187. if resType != 0 && m.Header.Type != resType {
  188. continue
  189. }
  190. res = append(res, m.Data)
  191. if m.Header.Flags&syscall.NLM_F_MULTI == 0 {
  192. break done
  193. }
  194. }
  195. }
  196. return res, nil
  197. }