netlink_linux.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578
  1. package netlink
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "net"
  6. "syscall"
  7. "unsafe"
  8. )
  9. var nextSeqNr int
  10. func nativeEndian() binary.ByteOrder {
  11. var x uint32 = 0x01020304
  12. if *(*byte)(unsafe.Pointer(&x)) == 0x01 {
  13. return binary.BigEndian
  14. }
  15. return binary.LittleEndian
  16. }
  17. func getSeq() int {
  18. nextSeqNr = nextSeqNr + 1
  19. return nextSeqNr
  20. }
  21. func getIpFamily(ip net.IP) int {
  22. if len(ip) <= net.IPv4len {
  23. return syscall.AF_INET
  24. }
  25. if ip.To4() != nil {
  26. return syscall.AF_INET
  27. }
  28. return syscall.AF_INET6
  29. }
  30. type NetlinkRequestData interface {
  31. ToWireFormat() []byte
  32. }
  33. type IfInfomsg struct {
  34. syscall.IfInfomsg
  35. }
  36. func newIfInfomsg(family int) *IfInfomsg {
  37. msg := &IfInfomsg{}
  38. msg.Family = uint8(family)
  39. msg.Type = uint16(0)
  40. msg.Index = int32(0)
  41. msg.Flags = uint32(0)
  42. msg.Change = uint32(0)
  43. return msg
  44. }
  45. func (msg *IfInfomsg) ToWireFormat() []byte {
  46. native := nativeEndian()
  47. len := syscall.SizeofIfInfomsg
  48. b := make([]byte, len)
  49. b[0] = msg.Family
  50. b[1] = 0
  51. native.PutUint16(b[2:4], msg.Type)
  52. native.PutUint32(b[4:8], uint32(msg.Index))
  53. native.PutUint32(b[8:12], msg.Flags)
  54. native.PutUint32(b[12:16], msg.Change)
  55. return b
  56. }
  57. type IfAddrmsg struct {
  58. syscall.IfAddrmsg
  59. }
  60. func newIfAddrmsg(family int) *IfAddrmsg {
  61. msg := &IfAddrmsg{}
  62. msg.Family = uint8(family)
  63. msg.Prefixlen = uint8(0)
  64. msg.Flags = uint8(0)
  65. msg.Scope = uint8(0)
  66. msg.Index = uint32(0)
  67. return msg
  68. }
  69. func (msg *IfAddrmsg) ToWireFormat() []byte {
  70. native := nativeEndian()
  71. len := syscall.SizeofIfAddrmsg
  72. b := make([]byte, len)
  73. b[0] = msg.Family
  74. b[1] = msg.Prefixlen
  75. b[2] = msg.Flags
  76. b[3] = msg.Scope
  77. native.PutUint32(b[4:8], msg.Index)
  78. return b
  79. }
  80. type RtMsg struct {
  81. syscall.RtMsg
  82. }
  83. func newRtMsg(family int) *RtMsg {
  84. msg := &RtMsg{}
  85. msg.Family = uint8(family)
  86. msg.Table = syscall.RT_TABLE_MAIN
  87. msg.Scope = syscall.RT_SCOPE_UNIVERSE
  88. msg.Protocol = syscall.RTPROT_BOOT
  89. msg.Type = syscall.RTN_UNICAST
  90. return msg
  91. }
  92. func (msg *RtMsg) ToWireFormat() []byte {
  93. native := nativeEndian()
  94. len := syscall.SizeofRtMsg
  95. b := make([]byte, len)
  96. b[0] = msg.Family
  97. b[1] = msg.Dst_len
  98. b[2] = msg.Src_len
  99. b[3] = msg.Tos
  100. b[4] = msg.Table
  101. b[5] = msg.Protocol
  102. b[6] = msg.Scope
  103. b[7] = msg.Type
  104. native.PutUint32(b[8:12], msg.Flags)
  105. return b
  106. }
  107. func rtaAlignOf(attrlen int) int {
  108. return (attrlen + syscall.RTA_ALIGNTO - 1) & ^(syscall.RTA_ALIGNTO - 1)
  109. }
  110. type RtAttr struct {
  111. syscall.RtAttr
  112. Data []byte
  113. }
  114. func newRtAttr(attrType int, data []byte) *RtAttr {
  115. attr := &RtAttr{}
  116. attr.Type = uint16(attrType)
  117. attr.Data = data
  118. return attr
  119. }
  120. func (attr *RtAttr) ToWireFormat() []byte {
  121. native := nativeEndian()
  122. len := syscall.SizeofRtAttr + len(attr.Data)
  123. b := make([]byte, rtaAlignOf(len))
  124. native.PutUint16(b[0:2], uint16(len))
  125. native.PutUint16(b[2:4], attr.Type)
  126. for i, d := range attr.Data {
  127. b[4+i] = d
  128. }
  129. return b
  130. }
  131. type NetlinkRequest struct {
  132. syscall.NlMsghdr
  133. Data []NetlinkRequestData
  134. }
  135. func (rr *NetlinkRequest) ToWireFormat() []byte {
  136. native := nativeEndian()
  137. length := rr.Len
  138. dataBytes := make([][]byte, len(rr.Data))
  139. for i, data := range rr.Data {
  140. dataBytes[i] = data.ToWireFormat()
  141. length = length + uint32(len(dataBytes[i]))
  142. }
  143. b := make([]byte, length)
  144. native.PutUint32(b[0:4], length)
  145. native.PutUint16(b[4:6], rr.Type)
  146. native.PutUint16(b[6:8], rr.Flags)
  147. native.PutUint32(b[8:12], rr.Seq)
  148. native.PutUint32(b[12:16], rr.Pid)
  149. i := 16
  150. for _, data := range dataBytes {
  151. for _, dataByte := range data {
  152. b[i] = dataByte
  153. i = i + 1
  154. }
  155. }
  156. return b
  157. }
  158. func (rr *NetlinkRequest) AddData(data NetlinkRequestData) {
  159. rr.Data = append(rr.Data, data)
  160. }
  161. func newNetlinkRequest(proto, flags int) *NetlinkRequest {
  162. rr := &NetlinkRequest{}
  163. rr.Len = uint32(syscall.NLMSG_HDRLEN)
  164. rr.Type = uint16(proto)
  165. rr.Flags = syscall.NLM_F_REQUEST | uint16(flags)
  166. rr.Seq = uint32(getSeq())
  167. return rr
  168. }
  169. type NetlinkSocket struct {
  170. fd int
  171. lsa syscall.SockaddrNetlink
  172. }
  173. func getNetlinkSocket() (*NetlinkSocket, error) {
  174. fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, syscall.NETLINK_ROUTE)
  175. if err != nil {
  176. return nil, err
  177. }
  178. s := &NetlinkSocket{
  179. fd: fd,
  180. }
  181. s.lsa.Family = syscall.AF_NETLINK
  182. if err := syscall.Bind(fd, &s.lsa); err != nil {
  183. syscall.Close(fd)
  184. return nil, err
  185. }
  186. return s, nil
  187. }
  188. func (s *NetlinkSocket) Close() {
  189. syscall.Close(s.fd)
  190. }
  191. func (s *NetlinkSocket) Send(request *NetlinkRequest) error {
  192. if err := syscall.Sendto(s.fd, request.ToWireFormat(), 0, &s.lsa); err != nil {
  193. return err
  194. }
  195. return nil
  196. }
  197. func (s *NetlinkSocket) Recieve() ([]syscall.NetlinkMessage, error) {
  198. rb := make([]byte, syscall.Getpagesize())
  199. nr, _, err := syscall.Recvfrom(s.fd, rb, 0)
  200. if err != nil {
  201. return nil, err
  202. }
  203. if nr < syscall.NLMSG_HDRLEN {
  204. return nil, fmt.Errorf("Got short response from netlink")
  205. }
  206. rb = rb[:nr]
  207. return syscall.ParseNetlinkMessage(rb)
  208. }
  209. func (s *NetlinkSocket) GetPid() (uint32, error) {
  210. lsa, err := syscall.Getsockname(s.fd)
  211. if err != nil {
  212. return 0, err
  213. }
  214. switch v := lsa.(type) {
  215. case *syscall.SockaddrNetlink:
  216. return v.Pid, nil
  217. }
  218. return 0, fmt.Errorf("Wrong socket type")
  219. }
  220. func (s *NetlinkSocket) HandleAck(seq uint32) error {
  221. native := nativeEndian()
  222. pid, err := s.GetPid()
  223. if err != nil {
  224. return err
  225. }
  226. done:
  227. for {
  228. msgs, err := s.Recieve()
  229. if err != nil {
  230. return err
  231. }
  232. for _, m := range msgs {
  233. if m.Header.Seq != seq {
  234. return fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, seq)
  235. }
  236. if m.Header.Pid != pid {
  237. return fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid)
  238. }
  239. if m.Header.Type == syscall.NLMSG_DONE {
  240. break done
  241. }
  242. if m.Header.Type == syscall.NLMSG_ERROR {
  243. error := int32(native.Uint32(m.Data[0:4]))
  244. if error == 0 {
  245. break done
  246. }
  247. return syscall.Errno(-error)
  248. }
  249. }
  250. }
  251. return nil
  252. }
  253. // Add a new default gateway. Identical to:
  254. // ip route add default via $ip
  255. func AddDefaultGw(ip net.IP) error {
  256. s, err := getNetlinkSocket()
  257. if err != nil {
  258. return err
  259. }
  260. defer s.Close()
  261. family := getIpFamily(ip)
  262. wb := newNetlinkRequest(syscall.RTM_NEWROUTE, syscall.NLM_F_CREATE|syscall.NLM_F_EXCL|syscall.NLM_F_ACK)
  263. msg := newRtMsg(family)
  264. wb.AddData(msg)
  265. var ipData []byte
  266. if family == syscall.AF_INET {
  267. ipData = ip.To4()
  268. } else {
  269. ipData = ip.To16()
  270. }
  271. gateway := newRtAttr(syscall.RTA_GATEWAY, ipData)
  272. wb.AddData(gateway)
  273. if err := s.Send(wb); err != nil {
  274. return err
  275. }
  276. return s.HandleAck(wb.Seq)
  277. }
  278. // Bring up a particular network interface
  279. func NetworkLinkUp(iface *net.Interface) error {
  280. s, err := getNetlinkSocket()
  281. if err != nil {
  282. return err
  283. }
  284. defer s.Close()
  285. wb := newNetlinkRequest(syscall.RTM_NEWLINK, syscall.NLM_F_ACK)
  286. msg := newIfInfomsg(syscall.AF_UNSPEC)
  287. msg.Change = syscall.IFF_UP
  288. msg.Flags = syscall.IFF_UP
  289. msg.Index = int32(iface.Index)
  290. wb.AddData(msg)
  291. if err := s.Send(wb); err != nil {
  292. return err
  293. }
  294. return s.HandleAck(wb.Seq)
  295. }
  296. func NetworkSetMTU(iface *net.Interface, mtu int) error {
  297. s, err := getNetlinkSocket()
  298. if err != nil {
  299. return err
  300. }
  301. defer s.Close()
  302. wb := newNetlinkRequest(syscall.RTM_SETLINK, syscall.NLM_F_ACK)
  303. msg := newIfInfomsg(syscall.AF_UNSPEC)
  304. msg.Type = syscall.RTM_SETLINK
  305. msg.Flags = syscall.NLM_F_REQUEST
  306. msg.Index = int32(iface.Index)
  307. msg.Change = 0xFFFFFFFF
  308. wb.AddData(msg)
  309. var (
  310. b = make([]byte, 4)
  311. native = nativeEndian()
  312. )
  313. native.PutUint32(b, uint32(mtu))
  314. data := newRtAttr(syscall.IFLA_MTU, b)
  315. wb.AddData(data)
  316. if err := s.Send(wb); err != nil {
  317. return err
  318. }
  319. return s.HandleAck(wb.Seq)
  320. }
  321. // Add an Ip address to an interface. This is identical to:
  322. // ip addr add $ip/$ipNet dev $iface
  323. func NetworkLinkAddIp(iface *net.Interface, ip net.IP, ipNet *net.IPNet) error {
  324. s, err := getNetlinkSocket()
  325. if err != nil {
  326. return err
  327. }
  328. defer s.Close()
  329. family := getIpFamily(ip)
  330. wb := newNetlinkRequest(syscall.RTM_NEWADDR, syscall.NLM_F_CREATE|syscall.NLM_F_EXCL|syscall.NLM_F_ACK)
  331. msg := newIfAddrmsg(family)
  332. msg.Index = uint32(iface.Index)
  333. prefixLen, _ := ipNet.Mask.Size()
  334. msg.Prefixlen = uint8(prefixLen)
  335. wb.AddData(msg)
  336. var ipData []byte
  337. if family == syscall.AF_INET {
  338. ipData = ip.To4()
  339. } else {
  340. ipData = ip.To16()
  341. }
  342. localData := newRtAttr(syscall.IFA_LOCAL, ipData)
  343. wb.AddData(localData)
  344. addrData := newRtAttr(syscall.IFA_ADDRESS, ipData)
  345. wb.AddData(addrData)
  346. if err := s.Send(wb); err != nil {
  347. return err
  348. }
  349. return s.HandleAck(wb.Seq)
  350. }
  351. func zeroTerminated(s string) []byte {
  352. bytes := make([]byte, len(s)+1)
  353. for i := 0; i < len(s); i++ {
  354. bytes[i] = s[i]
  355. }
  356. bytes[len(s)] = 0
  357. return bytes
  358. }
  359. func nonZeroTerminated(s string) []byte {
  360. bytes := make([]byte, len(s))
  361. for i := 0; i < len(s); i++ {
  362. bytes[i] = s[i]
  363. }
  364. return bytes
  365. }
  366. // Add a new network link of a specified type. This is identical to
  367. // running: ip add link $name type $linkType
  368. func NetworkLinkAdd(name string, linkType string) error {
  369. s, err := getNetlinkSocket()
  370. if err != nil {
  371. return err
  372. }
  373. defer s.Close()
  374. wb := newNetlinkRequest(syscall.RTM_NEWLINK, syscall.NLM_F_CREATE|syscall.NLM_F_EXCL|syscall.NLM_F_ACK)
  375. msg := newIfInfomsg(syscall.AF_UNSPEC)
  376. wb.AddData(msg)
  377. nameData := newRtAttr(syscall.IFLA_IFNAME, zeroTerminated(name))
  378. wb.AddData(nameData)
  379. IFLA_INFO_KIND := 1
  380. kindData := newRtAttr(IFLA_INFO_KIND, nonZeroTerminated(linkType))
  381. infoData := newRtAttr(syscall.IFLA_LINKINFO, kindData.ToWireFormat())
  382. wb.AddData(infoData)
  383. if err := s.Send(wb); err != nil {
  384. return err
  385. }
  386. return s.HandleAck(wb.Seq)
  387. }
  388. // Returns an array of IPNet for all the currently routed subnets on ipv4
  389. // This is similar to the first column of "ip route" output
  390. func NetworkGetRoutes() ([]*net.IPNet, error) {
  391. native := nativeEndian()
  392. s, err := getNetlinkSocket()
  393. if err != nil {
  394. return nil, err
  395. }
  396. defer s.Close()
  397. wb := newNetlinkRequest(syscall.RTM_GETROUTE, syscall.NLM_F_DUMP)
  398. msg := newIfInfomsg(syscall.AF_UNSPEC)
  399. wb.AddData(msg)
  400. if err := s.Send(wb); err != nil {
  401. return nil, err
  402. }
  403. pid, err := s.GetPid()
  404. if err != nil {
  405. return nil, err
  406. }
  407. res := make([]*net.IPNet, 0)
  408. done:
  409. for {
  410. msgs, err := s.Recieve()
  411. if err != nil {
  412. return nil, err
  413. }
  414. for _, m := range msgs {
  415. if m.Header.Seq != wb.Seq {
  416. return nil, fmt.Errorf("Wrong Seq nr %d, expected 1", m.Header.Seq)
  417. }
  418. if m.Header.Pid != pid {
  419. return nil, fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid)
  420. }
  421. if m.Header.Type == syscall.NLMSG_DONE {
  422. break done
  423. }
  424. if m.Header.Type == syscall.NLMSG_ERROR {
  425. error := int32(native.Uint32(m.Data[0:4]))
  426. if error == 0 {
  427. break done
  428. }
  429. return nil, syscall.Errno(-error)
  430. }
  431. if m.Header.Type != syscall.RTM_NEWROUTE {
  432. continue
  433. }
  434. var iface *net.Interface = nil
  435. var ipNet *net.IPNet = nil
  436. msg := (*RtMsg)(unsafe.Pointer(&m.Data[0:syscall.SizeofRtMsg][0]))
  437. if msg.Flags&syscall.RTM_F_CLONED != 0 {
  438. // Ignore cloned routes
  439. continue
  440. }
  441. if msg.Table != syscall.RT_TABLE_MAIN {
  442. // Ignore non-main tables
  443. continue
  444. }
  445. if msg.Family != syscall.AF_INET {
  446. // Ignore non-ipv4 routes
  447. continue
  448. }
  449. if msg.Dst_len == 0 {
  450. // Ignore default routes
  451. continue
  452. }
  453. attrs, err := syscall.ParseNetlinkRouteAttr(&m)
  454. if err != nil {
  455. return nil, err
  456. }
  457. for _, attr := range attrs {
  458. switch attr.Attr.Type {
  459. case syscall.RTA_DST:
  460. ip := attr.Value
  461. ipNet = &net.IPNet{
  462. IP: ip,
  463. Mask: net.CIDRMask(int(msg.Dst_len), 8*len(ip)),
  464. }
  465. case syscall.RTA_OIF:
  466. index := int(native.Uint32(attr.Value[0:4]))
  467. iface, _ = net.InterfaceByIndex(index)
  468. _ = iface
  469. }
  470. }
  471. if ipNet != nil {
  472. res = append(res, ipNet)
  473. }
  474. }
  475. }
  476. return res, nil
  477. }