dispatcher.go 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438
  1. package dispatcher
  2. import (
  3. "fmt"
  4. "net"
  5. "strconv"
  6. "sync"
  7. "time"
  8. "google.golang.org/grpc"
  9. "google.golang.org/grpc/codes"
  10. "google.golang.org/grpc/transport"
  11. "github.com/Sirupsen/logrus"
  12. "github.com/docker/go-events"
  13. "github.com/docker/swarmkit/api"
  14. "github.com/docker/swarmkit/api/equality"
  15. "github.com/docker/swarmkit/ca"
  16. "github.com/docker/swarmkit/log"
  17. "github.com/docker/swarmkit/manager/state"
  18. "github.com/docker/swarmkit/manager/state/store"
  19. "github.com/docker/swarmkit/remotes"
  20. "github.com/docker/swarmkit/watch"
  21. gogotypes "github.com/gogo/protobuf/types"
  22. "github.com/pkg/errors"
  23. "golang.org/x/net/context"
  24. )
  25. const (
  26. // DefaultHeartBeatPeriod is used for setting default value in cluster config
  27. // and in case if cluster config is missing.
  28. DefaultHeartBeatPeriod = 5 * time.Second
  29. defaultHeartBeatEpsilon = 500 * time.Millisecond
  30. defaultGracePeriodMultiplier = 3
  31. defaultRateLimitPeriod = 8 * time.Second
  32. // maxBatchItems is the threshold of queued writes that should
  33. // trigger an actual transaction to commit them to the shared store.
  34. maxBatchItems = 10000
  35. // maxBatchInterval needs to strike a balance between keeping
  36. // latency low, and realizing opportunities to combine many writes
  37. // into a single transaction. A fraction of a second feels about
  38. // right.
  39. maxBatchInterval = 100 * time.Millisecond
  40. modificationBatchLimit = 100
  41. batchingWaitTime = 100 * time.Millisecond
  42. // defaultNodeDownPeriod specifies the default time period we
  43. // wait before moving tasks assigned to down nodes to ORPHANED
  44. // state.
  45. defaultNodeDownPeriod = 24 * time.Hour
  46. )
  47. var (
  48. // ErrNodeAlreadyRegistered returned if node with same ID was already
  49. // registered with this dispatcher.
  50. ErrNodeAlreadyRegistered = errors.New("node already registered")
  51. // ErrNodeNotRegistered returned if node with such ID wasn't registered
  52. // with this dispatcher.
  53. ErrNodeNotRegistered = errors.New("node not registered")
  54. // ErrSessionInvalid returned when the session in use is no longer valid.
  55. // The node should re-register and start a new session.
  56. ErrSessionInvalid = errors.New("session invalid")
  57. // ErrNodeNotFound returned when the Node doesn't exist in raft.
  58. ErrNodeNotFound = errors.New("node not found")
  59. )
  60. // Config is configuration for Dispatcher. For default you should use
  61. // DefaultConfig.
  62. type Config struct {
  63. HeartbeatPeriod time.Duration
  64. HeartbeatEpsilon time.Duration
  65. // RateLimitPeriod specifies how often node with same ID can try to register
  66. // new session.
  67. RateLimitPeriod time.Duration
  68. GracePeriodMultiplier int
  69. }
  70. // DefaultConfig returns default config for Dispatcher.
  71. func DefaultConfig() *Config {
  72. return &Config{
  73. HeartbeatPeriod: DefaultHeartBeatPeriod,
  74. HeartbeatEpsilon: defaultHeartBeatEpsilon,
  75. RateLimitPeriod: defaultRateLimitPeriod,
  76. GracePeriodMultiplier: defaultGracePeriodMultiplier,
  77. }
  78. }
  79. // Cluster is interface which represent raft cluster. manager/state/raft.Node
  80. // is implements it. This interface needed only for easier unit-testing.
  81. type Cluster interface {
  82. GetMemberlist() map[uint64]*api.RaftMember
  83. SubscribePeers() (chan events.Event, func())
  84. MemoryStore() *store.MemoryStore
  85. }
  86. // nodeUpdate provides a new status and/or description to apply to a node
  87. // object.
  88. type nodeUpdate struct {
  89. status *api.NodeStatus
  90. description *api.NodeDescription
  91. }
  92. // Dispatcher is responsible for dispatching tasks and tracking agent health.
  93. type Dispatcher struct {
  94. mu sync.Mutex
  95. wg sync.WaitGroup
  96. nodes *nodeStore
  97. store *store.MemoryStore
  98. mgrQueue *watch.Queue
  99. lastSeenManagers []*api.WeightedPeer
  100. networkBootstrapKeys []*api.EncryptionKey
  101. keyMgrQueue *watch.Queue
  102. config *Config
  103. cluster Cluster
  104. ctx context.Context
  105. cancel context.CancelFunc
  106. taskUpdates map[string]*api.TaskStatus // indexed by task ID
  107. taskUpdatesLock sync.Mutex
  108. nodeUpdates map[string]nodeUpdate // indexed by node ID
  109. nodeUpdatesLock sync.Mutex
  110. downNodes *nodeStore
  111. processUpdatesTrigger chan struct{}
  112. // for waiting for the next task/node batch update
  113. processUpdatesLock sync.Mutex
  114. processUpdatesCond *sync.Cond
  115. }
  116. // New returns Dispatcher with cluster interface(usually raft.Node).
  117. // NOTE: each handler which does something with raft must add to Dispatcher.wg
  118. func New(cluster Cluster, c *Config) *Dispatcher {
  119. d := &Dispatcher{
  120. nodes: newNodeStore(c.HeartbeatPeriod, c.HeartbeatEpsilon, c.GracePeriodMultiplier, c.RateLimitPeriod),
  121. downNodes: newNodeStore(defaultNodeDownPeriod, 0, 1, 0),
  122. store: cluster.MemoryStore(),
  123. cluster: cluster,
  124. taskUpdates: make(map[string]*api.TaskStatus),
  125. nodeUpdates: make(map[string]nodeUpdate),
  126. processUpdatesTrigger: make(chan struct{}, 1),
  127. config: c,
  128. }
  129. d.processUpdatesCond = sync.NewCond(&d.processUpdatesLock)
  130. return d
  131. }
  132. func getWeightedPeers(cluster Cluster) []*api.WeightedPeer {
  133. members := cluster.GetMemberlist()
  134. var mgrs []*api.WeightedPeer
  135. for _, m := range members {
  136. mgrs = append(mgrs, &api.WeightedPeer{
  137. Peer: &api.Peer{
  138. NodeID: m.NodeID,
  139. Addr: m.Addr,
  140. },
  141. // TODO(stevvooe): Calculate weight of manager selection based on
  142. // cluster-level observations, such as number of connections and
  143. // load.
  144. Weight: remotes.DefaultObservationWeight,
  145. })
  146. }
  147. return mgrs
  148. }
  149. // Run runs dispatcher tasks which should be run on leader dispatcher.
  150. // Dispatcher can be stopped with cancelling ctx or calling Stop().
  151. func (d *Dispatcher) Run(ctx context.Context) error {
  152. d.mu.Lock()
  153. if d.isRunning() {
  154. d.mu.Unlock()
  155. return errors.New("dispatcher is already running")
  156. }
  157. ctx = log.WithModule(ctx, "dispatcher")
  158. if err := d.markNodesUnknown(ctx); err != nil {
  159. log.G(ctx).Errorf(`failed to move all nodes to "unknown" state: %v`, err)
  160. }
  161. configWatcher, cancel, err := store.ViewAndWatch(
  162. d.store,
  163. func(readTx store.ReadTx) error {
  164. clusters, err := store.FindClusters(readTx, store.ByName(store.DefaultClusterName))
  165. if err != nil {
  166. return err
  167. }
  168. if err == nil && len(clusters) == 1 {
  169. heartbeatPeriod, err := gogotypes.DurationFromProto(clusters[0].Spec.Dispatcher.HeartbeatPeriod)
  170. if err == nil && heartbeatPeriod > 0 {
  171. d.config.HeartbeatPeriod = heartbeatPeriod
  172. }
  173. if clusters[0].NetworkBootstrapKeys != nil {
  174. d.networkBootstrapKeys = clusters[0].NetworkBootstrapKeys
  175. }
  176. }
  177. return nil
  178. },
  179. state.EventUpdateCluster{},
  180. )
  181. if err != nil {
  182. d.mu.Unlock()
  183. return err
  184. }
  185. // set queues here to guarantee that Close will close them
  186. d.mgrQueue = watch.NewQueue()
  187. d.keyMgrQueue = watch.NewQueue()
  188. peerWatcher, peerCancel := d.cluster.SubscribePeers()
  189. defer peerCancel()
  190. d.lastSeenManagers = getWeightedPeers(d.cluster)
  191. defer cancel()
  192. d.ctx, d.cancel = context.WithCancel(ctx)
  193. ctx = d.ctx
  194. d.wg.Add(1)
  195. defer d.wg.Done()
  196. d.mu.Unlock()
  197. publishManagers := func(peers []*api.Peer) {
  198. var mgrs []*api.WeightedPeer
  199. for _, p := range peers {
  200. mgrs = append(mgrs, &api.WeightedPeer{
  201. Peer: p,
  202. Weight: remotes.DefaultObservationWeight,
  203. })
  204. }
  205. d.mu.Lock()
  206. d.lastSeenManagers = mgrs
  207. d.mu.Unlock()
  208. d.mgrQueue.Publish(mgrs)
  209. }
  210. batchTimer := time.NewTimer(maxBatchInterval)
  211. defer batchTimer.Stop()
  212. for {
  213. select {
  214. case ev := <-peerWatcher:
  215. publishManagers(ev.([]*api.Peer))
  216. case <-d.processUpdatesTrigger:
  217. d.processUpdates(ctx)
  218. batchTimer.Reset(maxBatchInterval)
  219. case <-batchTimer.C:
  220. d.processUpdates(ctx)
  221. batchTimer.Reset(maxBatchInterval)
  222. case v := <-configWatcher:
  223. cluster := v.(state.EventUpdateCluster)
  224. d.mu.Lock()
  225. if cluster.Cluster.Spec.Dispatcher.HeartbeatPeriod != nil {
  226. // ignore error, since Spec has passed validation before
  227. heartbeatPeriod, _ := gogotypes.DurationFromProto(cluster.Cluster.Spec.Dispatcher.HeartbeatPeriod)
  228. if heartbeatPeriod != d.config.HeartbeatPeriod {
  229. // only call d.nodes.updatePeriod when heartbeatPeriod changes
  230. d.config.HeartbeatPeriod = heartbeatPeriod
  231. d.nodes.updatePeriod(d.config.HeartbeatPeriod, d.config.HeartbeatEpsilon, d.config.GracePeriodMultiplier)
  232. }
  233. }
  234. d.networkBootstrapKeys = cluster.Cluster.NetworkBootstrapKeys
  235. d.mu.Unlock()
  236. d.keyMgrQueue.Publish(cluster.Cluster.NetworkBootstrapKeys)
  237. case <-ctx.Done():
  238. return nil
  239. }
  240. }
  241. }
  242. // Stop stops dispatcher and closes all grpc streams.
  243. func (d *Dispatcher) Stop() error {
  244. d.mu.Lock()
  245. if !d.isRunning() {
  246. d.mu.Unlock()
  247. return errors.New("dispatcher is already stopped")
  248. }
  249. d.cancel()
  250. d.mu.Unlock()
  251. d.nodes.Clean()
  252. d.processUpdatesLock.Lock()
  253. // In case there are any waiters. There is no chance of any starting
  254. // after this point, because they check if the context is canceled
  255. // before waiting.
  256. d.processUpdatesCond.Broadcast()
  257. d.processUpdatesLock.Unlock()
  258. d.mgrQueue.Close()
  259. d.keyMgrQueue.Close()
  260. d.wg.Wait()
  261. return nil
  262. }
  263. func (d *Dispatcher) isRunningLocked() (context.Context, error) {
  264. d.mu.Lock()
  265. if !d.isRunning() {
  266. d.mu.Unlock()
  267. return nil, grpc.Errorf(codes.Aborted, "dispatcher is stopped")
  268. }
  269. ctx := d.ctx
  270. d.mu.Unlock()
  271. return ctx, nil
  272. }
  273. func (d *Dispatcher) markNodesUnknown(ctx context.Context) error {
  274. log := log.G(ctx).WithField("method", "(*Dispatcher).markNodesUnknown")
  275. var nodes []*api.Node
  276. var err error
  277. d.store.View(func(tx store.ReadTx) {
  278. nodes, err = store.FindNodes(tx, store.All)
  279. })
  280. if err != nil {
  281. return errors.Wrap(err, "failed to get list of nodes")
  282. }
  283. _, err = d.store.Batch(func(batch *store.Batch) error {
  284. for _, n := range nodes {
  285. err := batch.Update(func(tx store.Tx) error {
  286. // check if node is still here
  287. node := store.GetNode(tx, n.ID)
  288. if node == nil {
  289. return nil
  290. }
  291. // do not try to resurrect down nodes
  292. if node.Status.State == api.NodeStatus_DOWN {
  293. nodeCopy := node
  294. expireFunc := func() {
  295. if err := d.moveTasksToOrphaned(nodeCopy.ID); err != nil {
  296. log.WithError(err).Error(`failed to move all tasks to "ORPHANED" state`)
  297. }
  298. d.downNodes.Delete(nodeCopy.ID)
  299. }
  300. d.downNodes.Add(nodeCopy, expireFunc)
  301. return nil
  302. }
  303. node.Status.State = api.NodeStatus_UNKNOWN
  304. node.Status.Message = `Node moved to "unknown" state due to leadership change in cluster`
  305. nodeID := node.ID
  306. expireFunc := func() {
  307. log := log.WithField("node", nodeID)
  308. log.Debugf("heartbeat expiration for unknown node")
  309. if err := d.markNodeNotReady(nodeID, api.NodeStatus_DOWN, `heartbeat failure for node in "unknown" state`); err != nil {
  310. log.WithError(err).Errorf(`failed deregistering node after heartbeat expiration for node in "unknown" state`)
  311. }
  312. }
  313. if err := d.nodes.AddUnknown(node, expireFunc); err != nil {
  314. return errors.Wrap(err, `adding node in "unknown" state to node store failed`)
  315. }
  316. if err := store.UpdateNode(tx, node); err != nil {
  317. return errors.Wrap(err, "update failed")
  318. }
  319. return nil
  320. })
  321. if err != nil {
  322. log.WithField("node", n.ID).WithError(err).Errorf(`failed to move node to "unknown" state`)
  323. }
  324. }
  325. return nil
  326. })
  327. return err
  328. }
  329. func (d *Dispatcher) isRunning() bool {
  330. if d.ctx == nil {
  331. return false
  332. }
  333. select {
  334. case <-d.ctx.Done():
  335. return false
  336. default:
  337. }
  338. return true
  339. }
  340. // markNodeReady updates the description of a node, updates its address, and sets status to READY
  341. // this is used during registration when a new node description is provided
  342. // and during node updates when the node description changes
  343. func (d *Dispatcher) markNodeReady(ctx context.Context, nodeID string, description *api.NodeDescription, addr string) error {
  344. d.nodeUpdatesLock.Lock()
  345. d.nodeUpdates[nodeID] = nodeUpdate{
  346. status: &api.NodeStatus{
  347. State: api.NodeStatus_READY,
  348. Addr: addr,
  349. },
  350. description: description,
  351. }
  352. numUpdates := len(d.nodeUpdates)
  353. d.nodeUpdatesLock.Unlock()
  354. // Node is marked ready. Remove the node from down nodes if it
  355. // is there.
  356. d.downNodes.Delete(nodeID)
  357. if numUpdates >= maxBatchItems {
  358. select {
  359. case d.processUpdatesTrigger <- struct{}{}:
  360. case <-ctx.Done():
  361. return ctx.Err()
  362. }
  363. }
  364. // Wait until the node update batch happens before unblocking register.
  365. d.processUpdatesLock.Lock()
  366. select {
  367. case <-ctx.Done():
  368. return ctx.Err()
  369. default:
  370. }
  371. d.processUpdatesCond.Wait()
  372. d.processUpdatesLock.Unlock()
  373. return nil
  374. }
  375. // gets the node IP from the context of a grpc call
  376. func nodeIPFromContext(ctx context.Context) (string, error) {
  377. nodeInfo, err := ca.RemoteNode(ctx)
  378. if err != nil {
  379. return "", err
  380. }
  381. addr, _, err := net.SplitHostPort(nodeInfo.RemoteAddr)
  382. if err != nil {
  383. return "", errors.Wrap(err, "unable to get ip from addr:port")
  384. }
  385. return addr, nil
  386. }
  387. // register is used for registration of node with particular dispatcher.
  388. func (d *Dispatcher) register(ctx context.Context, nodeID string, description *api.NodeDescription) (string, error) {
  389. // prevent register until we're ready to accept it
  390. dctx, err := d.isRunningLocked()
  391. if err != nil {
  392. return "", err
  393. }
  394. if err := d.nodes.CheckRateLimit(nodeID); err != nil {
  395. return "", err
  396. }
  397. // TODO(stevvooe): Validate node specification.
  398. var node *api.Node
  399. d.store.View(func(tx store.ReadTx) {
  400. node = store.GetNode(tx, nodeID)
  401. })
  402. if node == nil {
  403. return "", ErrNodeNotFound
  404. }
  405. addr, err := nodeIPFromContext(ctx)
  406. if err != nil {
  407. log.G(ctx).Debugf(err.Error())
  408. }
  409. if err := d.markNodeReady(dctx, nodeID, description, addr); err != nil {
  410. return "", err
  411. }
  412. expireFunc := func() {
  413. log.G(ctx).Debugf("heartbeat expiration")
  414. if err := d.markNodeNotReady(nodeID, api.NodeStatus_DOWN, "heartbeat failure"); err != nil {
  415. log.G(ctx).WithError(err).Errorf("failed deregistering node after heartbeat expiration")
  416. }
  417. }
  418. rn := d.nodes.Add(node, expireFunc)
  419. // NOTE(stevvooe): We need be a little careful with re-registration. The
  420. // current implementation just matches the node id and then gives away the
  421. // sessionID. If we ever want to use sessionID as a secret, which we may
  422. // want to, this is giving away the keys to the kitchen.
  423. //
  424. // The right behavior is going to be informed by identity. Basically, each
  425. // time a node registers, we invalidate the session and issue a new
  426. // session, once identity is proven. This will cause misbehaved agents to
  427. // be kicked when multiple connections are made.
  428. return rn.SessionID, nil
  429. }
  430. // UpdateTaskStatus updates status of task. Node should send such updates
  431. // on every status change of its tasks.
  432. func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStatusRequest) (*api.UpdateTaskStatusResponse, error) {
  433. nodeInfo, err := ca.RemoteNode(ctx)
  434. if err != nil {
  435. return nil, err
  436. }
  437. nodeID := nodeInfo.NodeID
  438. fields := logrus.Fields{
  439. "node.id": nodeID,
  440. "node.session": r.SessionID,
  441. "method": "(*Dispatcher).UpdateTaskStatus",
  442. }
  443. if nodeInfo.ForwardedBy != nil {
  444. fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID
  445. }
  446. log := log.G(ctx).WithFields(fields)
  447. dctx, err := d.isRunningLocked()
  448. if err != nil {
  449. return nil, err
  450. }
  451. if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
  452. return nil, err
  453. }
  454. // Validate task updates
  455. for _, u := range r.Updates {
  456. if u.Status == nil {
  457. log.WithField("task.id", u.TaskID).Warn("task report has nil status")
  458. continue
  459. }
  460. var t *api.Task
  461. d.store.View(func(tx store.ReadTx) {
  462. t = store.GetTask(tx, u.TaskID)
  463. })
  464. if t == nil {
  465. log.WithField("task.id", u.TaskID).Warn("cannot find target task in store")
  466. continue
  467. }
  468. if t.NodeID != nodeID {
  469. err := grpc.Errorf(codes.PermissionDenied, "cannot update a task not assigned this node")
  470. log.WithField("task.id", u.TaskID).Error(err)
  471. return nil, err
  472. }
  473. }
  474. d.taskUpdatesLock.Lock()
  475. // Enqueue task updates
  476. for _, u := range r.Updates {
  477. if u.Status == nil {
  478. continue
  479. }
  480. d.taskUpdates[u.TaskID] = u.Status
  481. }
  482. numUpdates := len(d.taskUpdates)
  483. d.taskUpdatesLock.Unlock()
  484. if numUpdates >= maxBatchItems {
  485. select {
  486. case d.processUpdatesTrigger <- struct{}{}:
  487. case <-dctx.Done():
  488. }
  489. }
  490. return nil, nil
  491. }
  492. func (d *Dispatcher) processUpdates(ctx context.Context) {
  493. var (
  494. taskUpdates map[string]*api.TaskStatus
  495. nodeUpdates map[string]nodeUpdate
  496. )
  497. d.taskUpdatesLock.Lock()
  498. if len(d.taskUpdates) != 0 {
  499. taskUpdates = d.taskUpdates
  500. d.taskUpdates = make(map[string]*api.TaskStatus)
  501. }
  502. d.taskUpdatesLock.Unlock()
  503. d.nodeUpdatesLock.Lock()
  504. if len(d.nodeUpdates) != 0 {
  505. nodeUpdates = d.nodeUpdates
  506. d.nodeUpdates = make(map[string]nodeUpdate)
  507. }
  508. d.nodeUpdatesLock.Unlock()
  509. if len(taskUpdates) == 0 && len(nodeUpdates) == 0 {
  510. return
  511. }
  512. log := log.G(ctx).WithFields(logrus.Fields{
  513. "method": "(*Dispatcher).processUpdates",
  514. })
  515. _, err := d.store.Batch(func(batch *store.Batch) error {
  516. for taskID, status := range taskUpdates {
  517. err := batch.Update(func(tx store.Tx) error {
  518. logger := log.WithField("task.id", taskID)
  519. task := store.GetTask(tx, taskID)
  520. if task == nil {
  521. logger.Errorf("task unavailable")
  522. return nil
  523. }
  524. logger = logger.WithField("state.transition", fmt.Sprintf("%v->%v", task.Status.State, status.State))
  525. if task.Status == *status {
  526. logger.Debug("task status identical, ignoring")
  527. return nil
  528. }
  529. if task.Status.State > status.State {
  530. logger.Debug("task status invalid transition")
  531. return nil
  532. }
  533. task.Status = *status
  534. if err := store.UpdateTask(tx, task); err != nil {
  535. logger.WithError(err).Error("failed to update task status")
  536. return nil
  537. }
  538. logger.Debug("task status updated")
  539. return nil
  540. })
  541. if err != nil {
  542. log.WithError(err).Error("dispatcher task update transaction failed")
  543. }
  544. }
  545. for nodeID, nodeUpdate := range nodeUpdates {
  546. err := batch.Update(func(tx store.Tx) error {
  547. logger := log.WithField("node.id", nodeID)
  548. node := store.GetNode(tx, nodeID)
  549. if node == nil {
  550. logger.Errorf("node unavailable")
  551. return nil
  552. }
  553. if nodeUpdate.status != nil {
  554. node.Status.State = nodeUpdate.status.State
  555. node.Status.Message = nodeUpdate.status.Message
  556. if nodeUpdate.status.Addr != "" {
  557. node.Status.Addr = nodeUpdate.status.Addr
  558. }
  559. }
  560. if nodeUpdate.description != nil {
  561. node.Description = nodeUpdate.description
  562. }
  563. if err := store.UpdateNode(tx, node); err != nil {
  564. logger.WithError(err).Error("failed to update node status")
  565. return nil
  566. }
  567. logger.Debug("node status updated")
  568. return nil
  569. })
  570. if err != nil {
  571. log.WithError(err).Error("dispatcher node update transaction failed")
  572. }
  573. }
  574. return nil
  575. })
  576. if err != nil {
  577. log.WithError(err).Error("dispatcher batch failed")
  578. }
  579. d.processUpdatesCond.Broadcast()
  580. }
  581. // Tasks is a stream of tasks state for node. Each message contains full list
  582. // of tasks which should be run on node, if task is not present in that list,
  583. // it should be terminated.
  584. func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServer) error {
  585. nodeInfo, err := ca.RemoteNode(stream.Context())
  586. if err != nil {
  587. return err
  588. }
  589. nodeID := nodeInfo.NodeID
  590. dctx, err := d.isRunningLocked()
  591. if err != nil {
  592. return err
  593. }
  594. fields := logrus.Fields{
  595. "node.id": nodeID,
  596. "node.session": r.SessionID,
  597. "method": "(*Dispatcher).Tasks",
  598. }
  599. if nodeInfo.ForwardedBy != nil {
  600. fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID
  601. }
  602. log.G(stream.Context()).WithFields(fields).Debugf("")
  603. if _, err = d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
  604. return err
  605. }
  606. tasksMap := make(map[string]*api.Task)
  607. nodeTasks, cancel, err := store.ViewAndWatch(
  608. d.store,
  609. func(readTx store.ReadTx) error {
  610. tasks, err := store.FindTasks(readTx, store.ByNodeID(nodeID))
  611. if err != nil {
  612. return err
  613. }
  614. for _, t := range tasks {
  615. tasksMap[t.ID] = t
  616. }
  617. return nil
  618. },
  619. state.EventCreateTask{Task: &api.Task{NodeID: nodeID},
  620. Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}},
  621. state.EventUpdateTask{Task: &api.Task{NodeID: nodeID},
  622. Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}},
  623. state.EventDeleteTask{Task: &api.Task{NodeID: nodeID},
  624. Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}},
  625. )
  626. if err != nil {
  627. return err
  628. }
  629. defer cancel()
  630. for {
  631. if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
  632. return err
  633. }
  634. var tasks []*api.Task
  635. for _, t := range tasksMap {
  636. // dispatcher only sends tasks that have been assigned to a node
  637. if t != nil && t.Status.State >= api.TaskStateAssigned {
  638. tasks = append(tasks, t)
  639. }
  640. }
  641. if err := stream.Send(&api.TasksMessage{Tasks: tasks}); err != nil {
  642. return err
  643. }
  644. // bursty events should be processed in batches and sent out snapshot
  645. var (
  646. modificationCnt int
  647. batchingTimer *time.Timer
  648. batchingTimeout <-chan time.Time
  649. )
  650. batchingLoop:
  651. for modificationCnt < modificationBatchLimit {
  652. select {
  653. case event := <-nodeTasks:
  654. switch v := event.(type) {
  655. case state.EventCreateTask:
  656. tasksMap[v.Task.ID] = v.Task
  657. modificationCnt++
  658. case state.EventUpdateTask:
  659. if oldTask, exists := tasksMap[v.Task.ID]; exists {
  660. // States ASSIGNED and below are set by the orchestrator/scheduler,
  661. // not the agent, so tasks in these states need to be sent to the
  662. // agent even if nothing else has changed.
  663. if equality.TasksEqualStable(oldTask, v.Task) && v.Task.Status.State > api.TaskStateAssigned {
  664. // this update should not trigger action at agent
  665. tasksMap[v.Task.ID] = v.Task
  666. continue
  667. }
  668. }
  669. tasksMap[v.Task.ID] = v.Task
  670. modificationCnt++
  671. case state.EventDeleteTask:
  672. delete(tasksMap, v.Task.ID)
  673. modificationCnt++
  674. }
  675. if batchingTimer != nil {
  676. batchingTimer.Reset(batchingWaitTime)
  677. } else {
  678. batchingTimer = time.NewTimer(batchingWaitTime)
  679. batchingTimeout = batchingTimer.C
  680. }
  681. case <-batchingTimeout:
  682. break batchingLoop
  683. case <-stream.Context().Done():
  684. return stream.Context().Err()
  685. case <-dctx.Done():
  686. return dctx.Err()
  687. }
  688. }
  689. if batchingTimer != nil {
  690. batchingTimer.Stop()
  691. }
  692. }
  693. }
  694. // Assignments is a stream of assignments for a node. Each message contains
  695. // either full list of tasks and secrets for the node, or an incremental update.
  696. func (d *Dispatcher) Assignments(r *api.AssignmentsRequest, stream api.Dispatcher_AssignmentsServer) error {
  697. nodeInfo, err := ca.RemoteNode(stream.Context())
  698. if err != nil {
  699. return err
  700. }
  701. nodeID := nodeInfo.NodeID
  702. dctx, err := d.isRunningLocked()
  703. if err != nil {
  704. return err
  705. }
  706. fields := logrus.Fields{
  707. "node.id": nodeID,
  708. "node.session": r.SessionID,
  709. "method": "(*Dispatcher).Assignments",
  710. }
  711. if nodeInfo.ForwardedBy != nil {
  712. fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID
  713. }
  714. log := log.G(stream.Context()).WithFields(fields)
  715. log.Debugf("")
  716. if _, err = d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
  717. return err
  718. }
  719. var (
  720. sequence int64
  721. appliesTo string
  722. initial api.AssignmentsMessage
  723. )
  724. tasksMap := make(map[string]*api.Task)
  725. tasksUsingSecret := make(map[string]map[string]struct{})
  726. sendMessage := func(msg api.AssignmentsMessage, assignmentType api.AssignmentsMessage_Type) error {
  727. sequence++
  728. msg.AppliesTo = appliesTo
  729. msg.ResultsIn = strconv.FormatInt(sequence, 10)
  730. appliesTo = msg.ResultsIn
  731. msg.Type = assignmentType
  732. if err := stream.Send(&msg); err != nil {
  733. return err
  734. }
  735. return nil
  736. }
  737. // returns a slice of new secrets to send down
  738. addSecretsForTask := func(readTx store.ReadTx, t *api.Task) []*api.Secret {
  739. container := t.Spec.GetContainer()
  740. if container == nil {
  741. return nil
  742. }
  743. var newSecrets []*api.Secret
  744. for _, secretRef := range container.Secrets {
  745. // Empty ID prefix will return all secrets. Bail if there is no SecretID
  746. if secretRef.SecretID == "" {
  747. log.Debugf("invalid secret reference")
  748. continue
  749. }
  750. secretID := secretRef.SecretID
  751. log := log.WithFields(logrus.Fields{
  752. "secret.id": secretID,
  753. "secret.name": secretRef.SecretName,
  754. })
  755. if len(tasksUsingSecret[secretID]) == 0 {
  756. tasksUsingSecret[secretID] = make(map[string]struct{})
  757. secrets, err := store.FindSecrets(readTx, store.ByIDPrefix(secretID))
  758. if err != nil {
  759. log.WithError(err).Errorf("error retrieving secret")
  760. continue
  761. }
  762. if len(secrets) != 1 {
  763. log.Debugf("secret not found")
  764. continue
  765. }
  766. // If the secret was found and there was one result
  767. // (there should never be more than one because of the
  768. // uniqueness constraint), add this secret to our
  769. // initial set that we send down.
  770. newSecrets = append(newSecrets, secrets[0])
  771. }
  772. tasksUsingSecret[secretID][t.ID] = struct{}{}
  773. }
  774. return newSecrets
  775. }
  776. // TODO(aaronl): Also send node secrets that should be exposed to
  777. // this node.
  778. nodeTasks, cancel, err := store.ViewAndWatch(
  779. d.store,
  780. func(readTx store.ReadTx) error {
  781. tasks, err := store.FindTasks(readTx, store.ByNodeID(nodeID))
  782. if err != nil {
  783. return err
  784. }
  785. for _, t := range tasks {
  786. // We only care about tasks that are ASSIGNED or
  787. // higher. If the state is below ASSIGNED, the
  788. // task may not meet the constraints for this
  789. // node, so we have to be careful about sending
  790. // secrets associated with it.
  791. if t.Status.State < api.TaskStateAssigned {
  792. continue
  793. }
  794. tasksMap[t.ID] = t
  795. taskChange := &api.AssignmentChange{
  796. Assignment: &api.Assignment{
  797. Item: &api.Assignment_Task{
  798. Task: t,
  799. },
  800. },
  801. Action: api.AssignmentChange_AssignmentActionUpdate,
  802. }
  803. initial.Changes = append(initial.Changes, taskChange)
  804. // Only send secrets down if these tasks are in < RUNNING
  805. if t.Status.State <= api.TaskStateRunning {
  806. newSecrets := addSecretsForTask(readTx, t)
  807. for _, secret := range newSecrets {
  808. secretChange := &api.AssignmentChange{
  809. Assignment: &api.Assignment{
  810. Item: &api.Assignment_Secret{
  811. Secret: secret,
  812. },
  813. },
  814. Action: api.AssignmentChange_AssignmentActionUpdate,
  815. }
  816. initial.Changes = append(initial.Changes, secretChange)
  817. }
  818. }
  819. }
  820. return nil
  821. },
  822. state.EventUpdateTask{Task: &api.Task{NodeID: nodeID},
  823. Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}},
  824. state.EventDeleteTask{Task: &api.Task{NodeID: nodeID},
  825. Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}},
  826. state.EventUpdateSecret{},
  827. state.EventDeleteSecret{},
  828. )
  829. if err != nil {
  830. return err
  831. }
  832. defer cancel()
  833. if err := sendMessage(initial, api.AssignmentsMessage_COMPLETE); err != nil {
  834. return err
  835. }
  836. for {
  837. // Check for session expiration
  838. if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
  839. return err
  840. }
  841. // bursty events should be processed in batches and sent out together
  842. var (
  843. update api.AssignmentsMessage
  844. modificationCnt int
  845. batchingTimer *time.Timer
  846. batchingTimeout <-chan time.Time
  847. updateTasks = make(map[string]*api.Task)
  848. updateSecrets = make(map[string]*api.Secret)
  849. removeTasks = make(map[string]struct{})
  850. removeSecrets = make(map[string]struct{})
  851. )
  852. oneModification := func() {
  853. modificationCnt++
  854. if batchingTimer != nil {
  855. batchingTimer.Reset(batchingWaitTime)
  856. } else {
  857. batchingTimer = time.NewTimer(batchingWaitTime)
  858. batchingTimeout = batchingTimer.C
  859. }
  860. }
  861. // Release the secrets references from this task
  862. releaseSecretsForTask := func(t *api.Task) bool {
  863. var modified bool
  864. container := t.Spec.GetContainer()
  865. if container == nil {
  866. return modified
  867. }
  868. for _, secretRef := range container.Secrets {
  869. secretID := secretRef.SecretID
  870. delete(tasksUsingSecret[secretID], t.ID)
  871. if len(tasksUsingSecret[secretID]) == 0 {
  872. // No tasks are using the secret anymore
  873. delete(tasksUsingSecret, secretID)
  874. removeSecrets[secretID] = struct{}{}
  875. modified = true
  876. }
  877. }
  878. return modified
  879. }
  880. // The batching loop waits for 50 ms after the most recent
  881. // change, or until modificationBatchLimit is reached. The
  882. // worst case latency is modificationBatchLimit * batchingWaitTime,
  883. // which is 10 seconds.
  884. batchingLoop:
  885. for modificationCnt < modificationBatchLimit {
  886. select {
  887. case event := <-nodeTasks:
  888. switch v := event.(type) {
  889. // We don't monitor EventCreateTask because tasks are
  890. // never created in the ASSIGNED state. First tasks are
  891. // created by the orchestrator, then the scheduler moves
  892. // them to ASSIGNED. If this ever changes, we will need
  893. // to monitor task creations as well.
  894. case state.EventUpdateTask:
  895. // We only care about tasks that are ASSIGNED or
  896. // higher.
  897. if v.Task.Status.State < api.TaskStateAssigned {
  898. continue
  899. }
  900. if oldTask, exists := tasksMap[v.Task.ID]; exists {
  901. // States ASSIGNED and below are set by the orchestrator/scheduler,
  902. // not the agent, so tasks in these states need to be sent to the
  903. // agent even if nothing else has changed.
  904. if equality.TasksEqualStable(oldTask, v.Task) && v.Task.Status.State > api.TaskStateAssigned {
  905. // this update should not trigger a task change for the agent
  906. tasksMap[v.Task.ID] = v.Task
  907. // If this task got updated to a final state, let's release
  908. // the secrets that are being used by the task
  909. if v.Task.Status.State > api.TaskStateRunning {
  910. // If releasing the secrets caused a secret to be
  911. // removed from an agent, mark one modification
  912. if releaseSecretsForTask(v.Task) {
  913. oneModification()
  914. }
  915. }
  916. continue
  917. }
  918. } else if v.Task.Status.State <= api.TaskStateRunning {
  919. // If this task wasn't part of the assignment set before, and it's <= RUNNING
  920. // add the secrets it references to the secrets assignment.
  921. // Task states > RUNNING are worker reported only, are never created in
  922. // a > RUNNING state.
  923. var newSecrets []*api.Secret
  924. d.store.View(func(readTx store.ReadTx) {
  925. newSecrets = addSecretsForTask(readTx, v.Task)
  926. })
  927. for _, secret := range newSecrets {
  928. updateSecrets[secret.ID] = secret
  929. }
  930. }
  931. tasksMap[v.Task.ID] = v.Task
  932. updateTasks[v.Task.ID] = v.Task
  933. oneModification()
  934. case state.EventDeleteTask:
  935. if _, exists := tasksMap[v.Task.ID]; !exists {
  936. continue
  937. }
  938. removeTasks[v.Task.ID] = struct{}{}
  939. delete(tasksMap, v.Task.ID)
  940. // Release the secrets being used by this task
  941. // Ignoring the return here. We will always mark
  942. // this as a modification, since a task is being
  943. // removed.
  944. releaseSecretsForTask(v.Task)
  945. oneModification()
  946. // TODO(aaronl): For node secrets, we'll need to handle
  947. // EventCreateSecret.
  948. case state.EventUpdateSecret:
  949. if _, exists := tasksUsingSecret[v.Secret.ID]; !exists {
  950. continue
  951. }
  952. log.Debugf("Secret %s (ID: %d) was updated though it was still referenced by one or more tasks",
  953. v.Secret.Spec.Annotations.Name, v.Secret.ID)
  954. case state.EventDeleteSecret:
  955. if _, exists := tasksUsingSecret[v.Secret.ID]; !exists {
  956. continue
  957. }
  958. log.Debugf("Secret %s (ID: %d) was deleted though it was still referenced by one or more tasks",
  959. v.Secret.Spec.Annotations.Name, v.Secret.ID)
  960. }
  961. case <-batchingTimeout:
  962. break batchingLoop
  963. case <-stream.Context().Done():
  964. return stream.Context().Err()
  965. case <-dctx.Done():
  966. return dctx.Err()
  967. }
  968. }
  969. if batchingTimer != nil {
  970. batchingTimer.Stop()
  971. }
  972. if modificationCnt > 0 {
  973. for id, task := range updateTasks {
  974. if _, ok := removeTasks[id]; !ok {
  975. taskChange := &api.AssignmentChange{
  976. Assignment: &api.Assignment{
  977. Item: &api.Assignment_Task{
  978. Task: task,
  979. },
  980. },
  981. Action: api.AssignmentChange_AssignmentActionUpdate,
  982. }
  983. update.Changes = append(update.Changes, taskChange)
  984. }
  985. }
  986. for id, secret := range updateSecrets {
  987. // If, due to multiple updates, this secret is no longer in use,
  988. // don't send it down.
  989. if len(tasksUsingSecret[id]) == 0 {
  990. // delete this secret for the secrets to be updated
  991. // so that deleteSecrets knows the current list
  992. delete(updateSecrets, id)
  993. continue
  994. }
  995. secretChange := &api.AssignmentChange{
  996. Assignment: &api.Assignment{
  997. Item: &api.Assignment_Secret{
  998. Secret: secret,
  999. },
  1000. },
  1001. Action: api.AssignmentChange_AssignmentActionUpdate,
  1002. }
  1003. update.Changes = append(update.Changes, secretChange)
  1004. }
  1005. for id := range removeTasks {
  1006. taskChange := &api.AssignmentChange{
  1007. Assignment: &api.Assignment{
  1008. Item: &api.Assignment_Task{
  1009. Task: &api.Task{ID: id},
  1010. },
  1011. },
  1012. Action: api.AssignmentChange_AssignmentActionRemove,
  1013. }
  1014. update.Changes = append(update.Changes, taskChange)
  1015. }
  1016. for id := range removeSecrets {
  1017. // If this secret is also being sent on the updated set
  1018. // don't also add it to the removed set
  1019. if _, ok := updateSecrets[id]; ok {
  1020. continue
  1021. }
  1022. secretChange := &api.AssignmentChange{
  1023. Assignment: &api.Assignment{
  1024. Item: &api.Assignment_Secret{
  1025. Secret: &api.Secret{ID: id},
  1026. },
  1027. },
  1028. Action: api.AssignmentChange_AssignmentActionRemove,
  1029. }
  1030. update.Changes = append(update.Changes, secretChange)
  1031. }
  1032. if err := sendMessage(update, api.AssignmentsMessage_INCREMENTAL); err != nil {
  1033. return err
  1034. }
  1035. }
  1036. }
  1037. }
  1038. func (d *Dispatcher) moveTasksToOrphaned(nodeID string) error {
  1039. _, err := d.store.Batch(func(batch *store.Batch) error {
  1040. var (
  1041. tasks []*api.Task
  1042. err error
  1043. )
  1044. d.store.View(func(tx store.ReadTx) {
  1045. tasks, err = store.FindTasks(tx, store.ByNodeID(nodeID))
  1046. })
  1047. if err != nil {
  1048. return err
  1049. }
  1050. for _, task := range tasks {
  1051. if task.Status.State < api.TaskStateOrphaned {
  1052. task.Status.State = api.TaskStateOrphaned
  1053. }
  1054. if err := batch.Update(func(tx store.Tx) error {
  1055. err := store.UpdateTask(tx, task)
  1056. if err != nil {
  1057. return err
  1058. }
  1059. return nil
  1060. }); err != nil {
  1061. return err
  1062. }
  1063. }
  1064. return nil
  1065. })
  1066. return err
  1067. }
  1068. // markNodeNotReady sets the node state to some state other than READY
  1069. func (d *Dispatcher) markNodeNotReady(id string, state api.NodeStatus_State, message string) error {
  1070. dctx, err := d.isRunningLocked()
  1071. if err != nil {
  1072. return err
  1073. }
  1074. // Node is down. Add it to down nodes so that we can keep
  1075. // track of tasks assigned to the node.
  1076. var node *api.Node
  1077. d.store.View(func(readTx store.ReadTx) {
  1078. node = store.GetNode(readTx, id)
  1079. if node == nil {
  1080. err = fmt.Errorf("could not find node %s while trying to add to down nodes store", id)
  1081. }
  1082. })
  1083. if err != nil {
  1084. return err
  1085. }
  1086. expireFunc := func() {
  1087. if err := d.moveTasksToOrphaned(id); err != nil {
  1088. log.G(dctx).WithError(err).Error(`failed to move all tasks to "ORPHANED" state`)
  1089. }
  1090. d.downNodes.Delete(id)
  1091. }
  1092. d.downNodes.Add(node, expireFunc)
  1093. status := &api.NodeStatus{
  1094. State: state,
  1095. Message: message,
  1096. }
  1097. d.nodeUpdatesLock.Lock()
  1098. // pluck the description out of nodeUpdates. this protects against a case
  1099. // where a node is marked ready and a description is added, but then the
  1100. // node is immediately marked not ready. this preserves that description
  1101. d.nodeUpdates[id] = nodeUpdate{status: status, description: d.nodeUpdates[id].description}
  1102. numUpdates := len(d.nodeUpdates)
  1103. d.nodeUpdatesLock.Unlock()
  1104. if numUpdates >= maxBatchItems {
  1105. select {
  1106. case d.processUpdatesTrigger <- struct{}{}:
  1107. case <-dctx.Done():
  1108. }
  1109. }
  1110. if rn := d.nodes.Delete(id); rn == nil {
  1111. return errors.Errorf("node %s is not found in local storage", id)
  1112. }
  1113. return nil
  1114. }
  1115. // Heartbeat is heartbeat method for nodes. It returns new TTL in response.
  1116. // Node should send new heartbeat earlier than now + TTL, otherwise it will
  1117. // be deregistered from dispatcher and its status will be updated to NodeStatus_DOWN
  1118. func (d *Dispatcher) Heartbeat(ctx context.Context, r *api.HeartbeatRequest) (*api.HeartbeatResponse, error) {
  1119. nodeInfo, err := ca.RemoteNode(ctx)
  1120. if err != nil {
  1121. return nil, err
  1122. }
  1123. period, err := d.nodes.Heartbeat(nodeInfo.NodeID, r.SessionID)
  1124. return &api.HeartbeatResponse{Period: period}, err
  1125. }
  1126. func (d *Dispatcher) getManagers() []*api.WeightedPeer {
  1127. d.mu.Lock()
  1128. defer d.mu.Unlock()
  1129. return d.lastSeenManagers
  1130. }
  1131. func (d *Dispatcher) getNetworkBootstrapKeys() []*api.EncryptionKey {
  1132. d.mu.Lock()
  1133. defer d.mu.Unlock()
  1134. return d.networkBootstrapKeys
  1135. }
  1136. // Session is a stream which controls agent connection.
  1137. // Each message contains list of backup Managers with weights. Also there is
  1138. // a special boolean field Disconnect which if true indicates that node should
  1139. // reconnect to another Manager immediately.
  1140. func (d *Dispatcher) Session(r *api.SessionRequest, stream api.Dispatcher_SessionServer) error {
  1141. ctx := stream.Context()
  1142. nodeInfo, err := ca.RemoteNode(ctx)
  1143. if err != nil {
  1144. return err
  1145. }
  1146. nodeID := nodeInfo.NodeID
  1147. dctx, err := d.isRunningLocked()
  1148. if err != nil {
  1149. return err
  1150. }
  1151. var sessionID string
  1152. if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
  1153. // register the node.
  1154. sessionID, err = d.register(ctx, nodeID, r.Description)
  1155. if err != nil {
  1156. return err
  1157. }
  1158. } else {
  1159. sessionID = r.SessionID
  1160. // get the node IP addr
  1161. addr, err := nodeIPFromContext(stream.Context())
  1162. if err != nil {
  1163. log.G(ctx).Debugf(err.Error())
  1164. }
  1165. // update the node description
  1166. if err := d.markNodeReady(dctx, nodeID, r.Description, addr); err != nil {
  1167. return err
  1168. }
  1169. }
  1170. fields := logrus.Fields{
  1171. "node.id": nodeID,
  1172. "node.session": sessionID,
  1173. "method": "(*Dispatcher).Session",
  1174. }
  1175. if nodeInfo.ForwardedBy != nil {
  1176. fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID
  1177. }
  1178. log := log.G(ctx).WithFields(fields)
  1179. var nodeObj *api.Node
  1180. nodeUpdates, cancel, err := store.ViewAndWatch(d.store, func(readTx store.ReadTx) error {
  1181. nodeObj = store.GetNode(readTx, nodeID)
  1182. return nil
  1183. }, state.EventUpdateNode{Node: &api.Node{ID: nodeID},
  1184. Checks: []state.NodeCheckFunc{state.NodeCheckID}},
  1185. )
  1186. if cancel != nil {
  1187. defer cancel()
  1188. }
  1189. if err != nil {
  1190. log.WithError(err).Error("ViewAndWatch Node failed")
  1191. }
  1192. if _, err = d.nodes.GetWithSession(nodeID, sessionID); err != nil {
  1193. return err
  1194. }
  1195. if err := stream.Send(&api.SessionMessage{
  1196. SessionID: sessionID,
  1197. Node: nodeObj,
  1198. Managers: d.getManagers(),
  1199. NetworkBootstrapKeys: d.getNetworkBootstrapKeys(),
  1200. }); err != nil {
  1201. return err
  1202. }
  1203. managerUpdates, mgrCancel := d.mgrQueue.Watch()
  1204. defer mgrCancel()
  1205. keyMgrUpdates, keyMgrCancel := d.keyMgrQueue.Watch()
  1206. defer keyMgrCancel()
  1207. // disconnectNode is a helper forcibly shutdown connection
  1208. disconnectNode := func() error {
  1209. // force disconnect by shutting down the stream.
  1210. transportStream, ok := transport.StreamFromContext(stream.Context())
  1211. if ok {
  1212. // if we have the transport stream, we can signal a disconnect
  1213. // in the client.
  1214. if err := transportStream.ServerTransport().Close(); err != nil {
  1215. log.WithError(err).Error("session end")
  1216. }
  1217. }
  1218. if err := d.markNodeNotReady(nodeID, api.NodeStatus_DISCONNECTED, "node is currently trying to find new manager"); err != nil {
  1219. log.WithError(err).Error("failed to remove node")
  1220. }
  1221. // still return an abort if the transport closure was ineffective.
  1222. return grpc.Errorf(codes.Aborted, "node must disconnect")
  1223. }
  1224. for {
  1225. // After each message send, we need to check the nodes sessionID hasn't
  1226. // changed. If it has, we will shut down the stream and make the node
  1227. // re-register.
  1228. node, err := d.nodes.GetWithSession(nodeID, sessionID)
  1229. if err != nil {
  1230. return err
  1231. }
  1232. var (
  1233. disconnect bool
  1234. mgrs []*api.WeightedPeer
  1235. netKeys []*api.EncryptionKey
  1236. )
  1237. select {
  1238. case ev := <-managerUpdates:
  1239. mgrs = ev.([]*api.WeightedPeer)
  1240. case ev := <-nodeUpdates:
  1241. nodeObj = ev.(state.EventUpdateNode).Node
  1242. case <-stream.Context().Done():
  1243. return stream.Context().Err()
  1244. case <-node.Disconnect:
  1245. disconnect = true
  1246. case <-dctx.Done():
  1247. disconnect = true
  1248. case ev := <-keyMgrUpdates:
  1249. netKeys = ev.([]*api.EncryptionKey)
  1250. }
  1251. if mgrs == nil {
  1252. mgrs = d.getManagers()
  1253. }
  1254. if netKeys == nil {
  1255. netKeys = d.getNetworkBootstrapKeys()
  1256. }
  1257. if err := stream.Send(&api.SessionMessage{
  1258. SessionID: sessionID,
  1259. Node: nodeObj,
  1260. Managers: mgrs,
  1261. NetworkBootstrapKeys: netKeys,
  1262. }); err != nil {
  1263. return err
  1264. }
  1265. if disconnect {
  1266. return disconnectNode()
  1267. }
  1268. }
  1269. }