encryption.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. package overlay
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "encoding/hex"
  6. "fmt"
  7. "hash/fnv"
  8. "net"
  9. "sync"
  10. "syscall"
  11. "strconv"
  12. "github.com/Sirupsen/logrus"
  13. "github.com/docker/libnetwork/iptables"
  14. "github.com/docker/libnetwork/ns"
  15. "github.com/docker/libnetwork/types"
  16. "github.com/vishvananda/netlink"
  17. )
  18. const (
  19. mark = uint32(0xD0C4E3)
  20. timeout = 30
  21. pktExpansion = 26 // SPI(4) + SeqN(4) + IV(8) + PadLength(1) + NextHeader(1) + ICV(8)
  22. )
  23. const (
  24. forward = iota + 1
  25. reverse
  26. bidir
  27. )
  28. type key struct {
  29. value []byte
  30. tag uint32
  31. }
  32. func (k *key) String() string {
  33. if k != nil {
  34. return fmt.Sprintf("(key: %s, tag: 0x%x)", hex.EncodeToString(k.value)[0:5], k.tag)
  35. }
  36. return ""
  37. }
  38. type spi struct {
  39. forward int
  40. reverse int
  41. }
  42. func (s *spi) String() string {
  43. return fmt.Sprintf("SPI(FWD: 0x%x, REV: 0x%x)", uint32(s.forward), uint32(s.reverse))
  44. }
  45. type encrMap struct {
  46. nodes map[string][]*spi
  47. sync.Mutex
  48. }
  49. func (e *encrMap) String() string {
  50. e.Lock()
  51. defer e.Unlock()
  52. b := new(bytes.Buffer)
  53. for k, v := range e.nodes {
  54. b.WriteString("\n")
  55. b.WriteString(k)
  56. b.WriteString(":")
  57. b.WriteString("[")
  58. for _, s := range v {
  59. b.WriteString(s.String())
  60. b.WriteString(",")
  61. }
  62. b.WriteString("]")
  63. }
  64. return b.String()
  65. }
  66. func (d *driver) checkEncryption(nid string, rIP net.IP, vxlanID uint32, isLocal, add bool) error {
  67. logrus.Debugf("checkEncryption(%s, %v, %d, %t)", nid[0:7], rIP, vxlanID, isLocal)
  68. n := d.network(nid)
  69. if n == nil || !n.secure {
  70. return nil
  71. }
  72. if len(d.keys) == 0 {
  73. return types.ForbiddenErrorf("encryption key is not present")
  74. }
  75. lIP := net.ParseIP(d.bindAddress)
  76. aIP := net.ParseIP(d.advertiseAddress)
  77. nodes := map[string]net.IP{}
  78. switch {
  79. case isLocal:
  80. if err := d.peerDbNetworkWalk(nid, func(pKey *peerKey, pEntry *peerEntry) bool {
  81. if !aIP.Equal(pEntry.vtep) {
  82. nodes[pEntry.vtep.String()] = pEntry.vtep
  83. }
  84. return false
  85. }); err != nil {
  86. logrus.Warnf("Failed to retrieve list of participating nodes in overlay network %s: %v", nid[0:5], err)
  87. }
  88. default:
  89. if len(d.network(nid).endpoints) > 0 {
  90. nodes[rIP.String()] = rIP
  91. }
  92. }
  93. logrus.Debugf("List of nodes: %s", nodes)
  94. if add {
  95. for _, rIP := range nodes {
  96. if err := setupEncryption(lIP, aIP, rIP, vxlanID, d.secMap, d.keys); err != nil {
  97. logrus.Warnf("Failed to program network encryption between %s and %s: %v", lIP, rIP, err)
  98. }
  99. }
  100. } else {
  101. if len(nodes) == 0 {
  102. if err := removeEncryption(lIP, rIP, d.secMap); err != nil {
  103. logrus.Warnf("Failed to remove network encryption between %s and %s: %v", lIP, rIP, err)
  104. }
  105. }
  106. }
  107. return nil
  108. }
  109. func setupEncryption(localIP, advIP, remoteIP net.IP, vni uint32, em *encrMap, keys []*key) error {
  110. logrus.Debugf("Programming encryption for vxlan %d between %s and %s", vni, localIP, remoteIP)
  111. rIPs := remoteIP.String()
  112. indices := make([]*spi, 0, len(keys))
  113. err := programMangle(vni, true)
  114. if err != nil {
  115. logrus.Warn(err)
  116. }
  117. for i, k := range keys {
  118. spis := &spi{buildSPI(advIP, remoteIP, k.tag), buildSPI(remoteIP, advIP, k.tag)}
  119. dir := reverse
  120. if i == 0 {
  121. dir = bidir
  122. }
  123. fSA, rSA, err := programSA(localIP, remoteIP, spis, k, dir, true)
  124. if err != nil {
  125. logrus.Warn(err)
  126. }
  127. indices = append(indices, spis)
  128. if i != 0 {
  129. continue
  130. }
  131. err = programSP(fSA, rSA, true)
  132. if err != nil {
  133. logrus.Warn(err)
  134. }
  135. }
  136. em.Lock()
  137. em.nodes[rIPs] = indices
  138. em.Unlock()
  139. return nil
  140. }
  141. func removeEncryption(localIP, remoteIP net.IP, em *encrMap) error {
  142. em.Lock()
  143. indices, ok := em.nodes[remoteIP.String()]
  144. em.Unlock()
  145. if !ok {
  146. return nil
  147. }
  148. for i, idxs := range indices {
  149. dir := reverse
  150. if i == 0 {
  151. dir = bidir
  152. }
  153. fSA, rSA, err := programSA(localIP, remoteIP, idxs, nil, dir, false)
  154. if err != nil {
  155. logrus.Warn(err)
  156. }
  157. if i != 0 {
  158. continue
  159. }
  160. err = programSP(fSA, rSA, false)
  161. if err != nil {
  162. logrus.Warn(err)
  163. }
  164. }
  165. return nil
  166. }
  167. func programMangle(vni uint32, add bool) (err error) {
  168. var (
  169. p = strconv.FormatUint(uint64(vxlanPort), 10)
  170. c = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8)
  171. m = strconv.FormatUint(uint64(mark), 10)
  172. chain = "OUTPUT"
  173. rule = []string{"-p", "udp", "--dport", p, "-m", "u32", "--u32", c, "-j", "MARK", "--set-mark", m}
  174. a = "-A"
  175. action = "install"
  176. )
  177. if add == iptables.Exists(iptables.Mangle, chain, rule...) {
  178. return
  179. }
  180. if !add {
  181. a = "-D"
  182. action = "remove"
  183. }
  184. if err = iptables.RawCombinedOutput(append([]string{"-t", string(iptables.Mangle), a, chain}, rule...)...); err != nil {
  185. logrus.Warnf("could not %s mangle rule: %v", action, err)
  186. }
  187. return
  188. }
  189. func programSA(localIP, remoteIP net.IP, spi *spi, k *key, dir int, add bool) (fSA *netlink.XfrmState, rSA *netlink.XfrmState, err error) {
  190. var (
  191. action = "Removing"
  192. xfrmProgram = ns.NlHandle().XfrmStateDel
  193. )
  194. if add {
  195. action = "Adding"
  196. xfrmProgram = ns.NlHandle().XfrmStateAdd
  197. }
  198. if dir&reverse > 0 {
  199. rSA = &netlink.XfrmState{
  200. Src: remoteIP,
  201. Dst: localIP,
  202. Proto: netlink.XFRM_PROTO_ESP,
  203. Spi: spi.reverse,
  204. Mode: netlink.XFRM_MODE_TRANSPORT,
  205. }
  206. if add {
  207. rSA.Aead = buildAeadAlgo(k, spi.reverse)
  208. }
  209. exists, err := saExists(rSA)
  210. if err != nil {
  211. exists = !add
  212. }
  213. if add != exists {
  214. logrus.Debugf("%s: rSA{%s}", action, rSA)
  215. if err := xfrmProgram(rSA); err != nil {
  216. logrus.Warnf("Failed %s rSA{%s}: %v", action, rSA, err)
  217. }
  218. }
  219. }
  220. if dir&forward > 0 {
  221. fSA = &netlink.XfrmState{
  222. Src: localIP,
  223. Dst: remoteIP,
  224. Proto: netlink.XFRM_PROTO_ESP,
  225. Spi: spi.forward,
  226. Mode: netlink.XFRM_MODE_TRANSPORT,
  227. }
  228. if add {
  229. fSA.Aead = buildAeadAlgo(k, spi.forward)
  230. }
  231. exists, err := saExists(fSA)
  232. if err != nil {
  233. exists = !add
  234. }
  235. if add != exists {
  236. logrus.Debugf("%s fSA{%s}", action, fSA)
  237. if err := xfrmProgram(fSA); err != nil {
  238. logrus.Warnf("Failed %s fSA{%s}: %v.", action, fSA, err)
  239. }
  240. }
  241. }
  242. return
  243. }
  244. func programSP(fSA *netlink.XfrmState, rSA *netlink.XfrmState, add bool) error {
  245. action := "Removing"
  246. xfrmProgram := ns.NlHandle().XfrmPolicyDel
  247. if add {
  248. action = "Adding"
  249. xfrmProgram = ns.NlHandle().XfrmPolicyAdd
  250. }
  251. // Create a congruent cidr
  252. s := types.GetMinimalIP(fSA.Src)
  253. d := types.GetMinimalIP(fSA.Dst)
  254. fullMask := net.CIDRMask(8*len(s), 8*len(s))
  255. fPol := &netlink.XfrmPolicy{
  256. Src: &net.IPNet{IP: s, Mask: fullMask},
  257. Dst: &net.IPNet{IP: d, Mask: fullMask},
  258. Dir: netlink.XFRM_DIR_OUT,
  259. Proto: 17,
  260. DstPort: 4789,
  261. Mark: &netlink.XfrmMark{
  262. Value: mark,
  263. },
  264. Tmpls: []netlink.XfrmPolicyTmpl{
  265. {
  266. Src: fSA.Src,
  267. Dst: fSA.Dst,
  268. Proto: netlink.XFRM_PROTO_ESP,
  269. Mode: netlink.XFRM_MODE_TRANSPORT,
  270. Spi: fSA.Spi,
  271. },
  272. },
  273. }
  274. exists, err := spExists(fPol)
  275. if err != nil {
  276. exists = !add
  277. }
  278. if add != exists {
  279. logrus.Debugf("%s fSP{%s}", action, fPol)
  280. if err := xfrmProgram(fPol); err != nil {
  281. logrus.Warnf("%s fSP{%s}: %v", action, fPol, err)
  282. }
  283. }
  284. return nil
  285. }
  286. func saExists(sa *netlink.XfrmState) (bool, error) {
  287. _, err := ns.NlHandle().XfrmStateGet(sa)
  288. switch err {
  289. case nil:
  290. return true, nil
  291. case syscall.ESRCH:
  292. return false, nil
  293. default:
  294. err = fmt.Errorf("Error while checking for SA existence: %v", err)
  295. logrus.Warn(err)
  296. return false, err
  297. }
  298. }
  299. func spExists(sp *netlink.XfrmPolicy) (bool, error) {
  300. _, err := ns.NlHandle().XfrmPolicyGet(sp)
  301. switch err {
  302. case nil:
  303. return true, nil
  304. case syscall.ENOENT:
  305. return false, nil
  306. default:
  307. err = fmt.Errorf("Error while checking for SP existence: %v", err)
  308. logrus.Warn(err)
  309. return false, err
  310. }
  311. }
  312. func buildSPI(src, dst net.IP, st uint32) int {
  313. b := make([]byte, 4)
  314. binary.BigEndian.PutUint32(b, st)
  315. h := fnv.New32a()
  316. h.Write(src)
  317. h.Write(b)
  318. h.Write(dst)
  319. return int(binary.BigEndian.Uint32(h.Sum(nil)))
  320. }
  321. func buildAeadAlgo(k *key, s int) *netlink.XfrmStateAlgo {
  322. salt := make([]byte, 4)
  323. binary.BigEndian.PutUint32(salt, uint32(s))
  324. return &netlink.XfrmStateAlgo{
  325. Name: "rfc4106(gcm(aes))",
  326. Key: append(k.value, salt...),
  327. ICVLen: 64,
  328. }
  329. }
  330. func (d *driver) secMapWalk(f func(string, []*spi) ([]*spi, bool)) error {
  331. d.secMap.Lock()
  332. for node, indices := range d.secMap.nodes {
  333. idxs, stop := f(node, indices)
  334. if idxs != nil {
  335. d.secMap.nodes[node] = idxs
  336. }
  337. if stop {
  338. break
  339. }
  340. }
  341. d.secMap.Unlock()
  342. return nil
  343. }
  344. func (d *driver) setKeys(keys []*key) error {
  345. // Accept the encryption keys and clear any stale encryption map
  346. d.Lock()
  347. d.keys = keys
  348. d.secMap = &encrMap{nodes: map[string][]*spi{}}
  349. d.Unlock()
  350. logrus.Debugf("Initial encryption keys: %v", d.keys)
  351. return nil
  352. }
  353. // updateKeys allows to add a new key and/or change the primary key and/or prune an existing key
  354. // The primary key is the key used in transmission and will go in first position in the list.
  355. func (d *driver) updateKeys(newKey, primary, pruneKey *key) error {
  356. logrus.Debugf("Updating Keys. New: %v, Primary: %v, Pruned: %v", newKey, primary, pruneKey)
  357. logrus.Debugf("Current: %v", d.keys)
  358. var (
  359. newIdx = -1
  360. priIdx = -1
  361. delIdx = -1
  362. lIP = net.ParseIP(d.bindAddress)
  363. )
  364. d.Lock()
  365. // add new
  366. if newKey != nil {
  367. d.keys = append(d.keys, newKey)
  368. newIdx += len(d.keys)
  369. }
  370. for i, k := range d.keys {
  371. if primary != nil && k.tag == primary.tag {
  372. priIdx = i
  373. }
  374. if pruneKey != nil && k.tag == pruneKey.tag {
  375. delIdx = i
  376. }
  377. }
  378. d.Unlock()
  379. if (newKey != nil && newIdx == -1) ||
  380. (primary != nil && priIdx == -1) ||
  381. (pruneKey != nil && delIdx == -1) {
  382. return types.BadRequestErrorf("cannot find proper key indices while processing key update:"+
  383. "(newIdx,priIdx,delIdx):(%d, %d, %d)", newIdx, priIdx, delIdx)
  384. }
  385. d.secMapWalk(func(rIPs string, spis []*spi) ([]*spi, bool) {
  386. rIP := net.ParseIP(rIPs)
  387. return updateNodeKey(lIP, rIP, spis, d.keys, newIdx, priIdx, delIdx), false
  388. })
  389. d.Lock()
  390. // swap primary
  391. if priIdx != -1 {
  392. swp := d.keys[0]
  393. d.keys[0] = d.keys[priIdx]
  394. d.keys[priIdx] = swp
  395. }
  396. // prune
  397. if delIdx != -1 {
  398. if delIdx == 0 {
  399. delIdx = priIdx
  400. }
  401. d.keys = append(d.keys[:delIdx], d.keys[delIdx+1:]...)
  402. }
  403. d.Unlock()
  404. logrus.Debugf("Updated: %v", d.keys)
  405. return nil
  406. }
  407. /********************************************************
  408. * Steady state: rSA0, rSA1, rSA2, fSA1, fSP1
  409. * Rotation --> -rSA0, +rSA3, +fSA2, +fSP2/-fSP1, -fSA1
  410. * Steady state: rSA1, rSA2, rSA3, fSA2, fSP2
  411. *********************************************************/
  412. // Spis and keys are sorted in such away the one in position 0 is the primary
  413. func updateNodeKey(lIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx, delIdx int) []*spi {
  414. logrus.Debugf("Updating keys for node: %s (%d,%d,%d)", rIP, newIdx, priIdx, delIdx)
  415. spis := idxs
  416. logrus.Debugf("Current: %v", spis)
  417. // add new
  418. if newIdx != -1 {
  419. spis = append(spis, &spi{
  420. forward: buildSPI(lIP, rIP, curKeys[newIdx].tag),
  421. reverse: buildSPI(rIP, lIP, curKeys[newIdx].tag),
  422. })
  423. }
  424. if delIdx != -1 {
  425. // -rSA0
  426. programSA(lIP, rIP, spis[delIdx], nil, reverse, false)
  427. }
  428. if newIdx > -1 {
  429. // +rSA2
  430. programSA(lIP, rIP, spis[newIdx], curKeys[newIdx], reverse, true)
  431. }
  432. if priIdx > 0 {
  433. // +fSA2
  434. fSA2, _, _ := programSA(lIP, rIP, spis[priIdx], curKeys[priIdx], forward, true)
  435. // +fSP2, -fSP1
  436. s := types.GetMinimalIP(fSA2.Src)
  437. d := types.GetMinimalIP(fSA2.Dst)
  438. fullMask := net.CIDRMask(8*len(s), 8*len(s))
  439. fSP1 := &netlink.XfrmPolicy{
  440. Src: &net.IPNet{IP: s, Mask: fullMask},
  441. Dst: &net.IPNet{IP: d, Mask: fullMask},
  442. Dir: netlink.XFRM_DIR_OUT,
  443. Proto: 17,
  444. DstPort: 4789,
  445. Mark: &netlink.XfrmMark{
  446. Value: mark,
  447. },
  448. Tmpls: []netlink.XfrmPolicyTmpl{
  449. {
  450. Src: fSA2.Src,
  451. Dst: fSA2.Dst,
  452. Proto: netlink.XFRM_PROTO_ESP,
  453. Mode: netlink.XFRM_MODE_TRANSPORT,
  454. Spi: fSA2.Spi,
  455. },
  456. },
  457. }
  458. logrus.Debugf("Updating fSP{%s}", fSP1)
  459. if err := ns.NlHandle().XfrmPolicyUpdate(fSP1); err != nil {
  460. logrus.Warnf("Failed to update fSP{%s}: %v", fSP1, err)
  461. }
  462. // -fSA1
  463. programSA(lIP, rIP, spis[0], nil, forward, false)
  464. }
  465. // swap
  466. if priIdx > 0 {
  467. swp := spis[0]
  468. spis[0] = spis[priIdx]
  469. spis[priIdx] = swp
  470. }
  471. // prune
  472. if delIdx != -1 {
  473. if delIdx == 0 {
  474. delIdx = priIdx
  475. }
  476. spis = append(spis[:delIdx], spis[delIdx+1:]...)
  477. }
  478. logrus.Debugf("Updated: %v", spis)
  479. return spis
  480. }
  481. func (n *network) maxMTU() int {
  482. mtu := 1500
  483. if n.mtu != 0 {
  484. mtu = n.mtu
  485. }
  486. mtu -= vxlanEncap
  487. if n.secure {
  488. // In case of encryption account for the
  489. // esp packet espansion and padding
  490. mtu -= pktExpansion
  491. mtu -= (mtu % 4)
  492. }
  493. return mtu
  494. }