dispatcher.go 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359
  1. package dispatcher
  2. import (
  3. "fmt"
  4. "net"
  5. "strconv"
  6. "sync"
  7. "time"
  8. "github.com/docker/go-events"
  9. "github.com/docker/go-metrics"
  10. "github.com/docker/swarmkit/api"
  11. "github.com/docker/swarmkit/api/equality"
  12. "github.com/docker/swarmkit/ca"
  13. "github.com/docker/swarmkit/log"
  14. "github.com/docker/swarmkit/manager/drivers"
  15. "github.com/docker/swarmkit/manager/state/store"
  16. "github.com/docker/swarmkit/protobuf/ptypes"
  17. "github.com/docker/swarmkit/remotes"
  18. "github.com/docker/swarmkit/watch"
  19. gogotypes "github.com/gogo/protobuf/types"
  20. "github.com/pkg/errors"
  21. "github.com/sirupsen/logrus"
  22. "golang.org/x/net/context"
  23. "google.golang.org/grpc/codes"
  24. "google.golang.org/grpc/status"
  25. "google.golang.org/grpc/transport"
  26. )
  27. const (
  28. // DefaultHeartBeatPeriod is used for setting default value in cluster config
  29. // and in case if cluster config is missing.
  30. DefaultHeartBeatPeriod = 5 * time.Second
  31. defaultHeartBeatEpsilon = 500 * time.Millisecond
  32. defaultGracePeriodMultiplier = 3
  33. defaultRateLimitPeriod = 8 * time.Second
  34. // maxBatchItems is the threshold of queued writes that should
  35. // trigger an actual transaction to commit them to the shared store.
  36. maxBatchItems = 10000
  37. // maxBatchInterval needs to strike a balance between keeping
  38. // latency low, and realizing opportunities to combine many writes
  39. // into a single transaction. A fraction of a second feels about
  40. // right.
  41. maxBatchInterval = 100 * time.Millisecond
  42. modificationBatchLimit = 100
  43. batchingWaitTime = 100 * time.Millisecond
  44. // defaultNodeDownPeriod specifies the default time period we
  45. // wait before moving tasks assigned to down nodes to ORPHANED
  46. // state.
  47. defaultNodeDownPeriod = 24 * time.Hour
  48. )
  49. var (
  50. // ErrNodeAlreadyRegistered returned if node with same ID was already
  51. // registered with this dispatcher.
  52. ErrNodeAlreadyRegistered = errors.New("node already registered")
  53. // ErrNodeNotRegistered returned if node with such ID wasn't registered
  54. // with this dispatcher.
  55. ErrNodeNotRegistered = errors.New("node not registered")
  56. // ErrSessionInvalid returned when the session in use is no longer valid.
  57. // The node should re-register and start a new session.
  58. ErrSessionInvalid = errors.New("session invalid")
  59. // ErrNodeNotFound returned when the Node doesn't exist in raft.
  60. ErrNodeNotFound = errors.New("node not found")
  61. // Scheduling delay timer.
  62. schedulingDelayTimer metrics.Timer
  63. )
  64. func init() {
  65. ns := metrics.NewNamespace("swarm", "dispatcher", nil)
  66. schedulingDelayTimer = ns.NewTimer("scheduling_delay",
  67. "Scheduling delay is the time a task takes to go from NEW to RUNNING state.")
  68. metrics.Register(ns)
  69. }
  70. // Config is configuration for Dispatcher. For default you should use
  71. // DefaultConfig.
  72. type Config struct {
  73. HeartbeatPeriod time.Duration
  74. HeartbeatEpsilon time.Duration
  75. // RateLimitPeriod specifies how often node with same ID can try to register
  76. // new session.
  77. RateLimitPeriod time.Duration
  78. GracePeriodMultiplier int
  79. }
  80. // DefaultConfig returns default config for Dispatcher.
  81. func DefaultConfig() *Config {
  82. return &Config{
  83. HeartbeatPeriod: DefaultHeartBeatPeriod,
  84. HeartbeatEpsilon: defaultHeartBeatEpsilon,
  85. RateLimitPeriod: defaultRateLimitPeriod,
  86. GracePeriodMultiplier: defaultGracePeriodMultiplier,
  87. }
  88. }
  89. // Cluster is interface which represent raft cluster. manager/state/raft.Node
  90. // is implements it. This interface needed only for easier unit-testing.
  91. type Cluster interface {
  92. GetMemberlist() map[uint64]*api.RaftMember
  93. SubscribePeers() (chan events.Event, func())
  94. MemoryStore() *store.MemoryStore
  95. }
  96. // nodeUpdate provides a new status and/or description to apply to a node
  97. // object.
  98. type nodeUpdate struct {
  99. status *api.NodeStatus
  100. description *api.NodeDescription
  101. }
  102. // clusterUpdate is an object that stores an update to the cluster that should trigger
  103. // a new session message. These are pointers to indicate the difference between
  104. // "there is no update" and "update this to nil"
  105. type clusterUpdate struct {
  106. managerUpdate *[]*api.WeightedPeer
  107. bootstrapKeyUpdate *[]*api.EncryptionKey
  108. rootCAUpdate *[]byte
  109. }
  110. // Dispatcher is responsible for dispatching tasks and tracking agent health.
  111. type Dispatcher struct {
  112. // mu is a lock to provide mutually exclusive access to dispatcher fields
  113. // e.g. lastSeenManagers, networkBootstrapKeys, lastSeenRootCert etc.
  114. mu sync.Mutex
  115. // shutdownWait is used by stop() to wait for existing operations to finish.
  116. shutdownWait sync.WaitGroup
  117. nodes *nodeStore
  118. store *store.MemoryStore
  119. lastSeenManagers []*api.WeightedPeer
  120. networkBootstrapKeys []*api.EncryptionKey
  121. lastSeenRootCert []byte
  122. config *Config
  123. cluster Cluster
  124. ctx context.Context
  125. cancel context.CancelFunc
  126. clusterUpdateQueue *watch.Queue
  127. dp *drivers.DriverProvider
  128. securityConfig *ca.SecurityConfig
  129. taskUpdates map[string]*api.TaskStatus // indexed by task ID
  130. taskUpdatesLock sync.Mutex
  131. nodeUpdates map[string]nodeUpdate // indexed by node ID
  132. nodeUpdatesLock sync.Mutex
  133. downNodes *nodeStore
  134. processUpdatesTrigger chan struct{}
  135. // for waiting for the next task/node batch update
  136. processUpdatesLock sync.Mutex
  137. processUpdatesCond *sync.Cond
  138. }
  139. // New returns Dispatcher with cluster interface(usually raft.Node).
  140. func New(cluster Cluster, c *Config, dp *drivers.DriverProvider, securityConfig *ca.SecurityConfig) *Dispatcher {
  141. d := &Dispatcher{
  142. dp: dp,
  143. nodes: newNodeStore(c.HeartbeatPeriod, c.HeartbeatEpsilon, c.GracePeriodMultiplier, c.RateLimitPeriod),
  144. downNodes: newNodeStore(defaultNodeDownPeriod, 0, 1, 0),
  145. store: cluster.MemoryStore(),
  146. cluster: cluster,
  147. processUpdatesTrigger: make(chan struct{}, 1),
  148. config: c,
  149. securityConfig: securityConfig,
  150. }
  151. d.processUpdatesCond = sync.NewCond(&d.processUpdatesLock)
  152. return d
  153. }
  154. func getWeightedPeers(cluster Cluster) []*api.WeightedPeer {
  155. members := cluster.GetMemberlist()
  156. var mgrs []*api.WeightedPeer
  157. for _, m := range members {
  158. mgrs = append(mgrs, &api.WeightedPeer{
  159. Peer: &api.Peer{
  160. NodeID: m.NodeID,
  161. Addr: m.Addr,
  162. },
  163. // TODO(stevvooe): Calculate weight of manager selection based on
  164. // cluster-level observations, such as number of connections and
  165. // load.
  166. Weight: remotes.DefaultObservationWeight,
  167. })
  168. }
  169. return mgrs
  170. }
  171. // Run runs dispatcher tasks which should be run on leader dispatcher.
  172. // Dispatcher can be stopped with cancelling ctx or calling Stop().
  173. func (d *Dispatcher) Run(ctx context.Context) error {
  174. ctx = log.WithModule(ctx, "dispatcher")
  175. log.G(ctx).Info("dispatcher starting")
  176. d.taskUpdatesLock.Lock()
  177. d.taskUpdates = make(map[string]*api.TaskStatus)
  178. d.taskUpdatesLock.Unlock()
  179. d.nodeUpdatesLock.Lock()
  180. d.nodeUpdates = make(map[string]nodeUpdate)
  181. d.nodeUpdatesLock.Unlock()
  182. d.mu.Lock()
  183. if d.isRunning() {
  184. d.mu.Unlock()
  185. return errors.New("dispatcher is already running")
  186. }
  187. if err := d.markNodesUnknown(ctx); err != nil {
  188. log.G(ctx).Errorf(`failed to move all nodes to "unknown" state: %v`, err)
  189. }
  190. configWatcher, cancel, err := store.ViewAndWatch(
  191. d.store,
  192. func(readTx store.ReadTx) error {
  193. clusters, err := store.FindClusters(readTx, store.ByName(store.DefaultClusterName))
  194. if err != nil {
  195. return err
  196. }
  197. if err == nil && len(clusters) == 1 {
  198. heartbeatPeriod, err := gogotypes.DurationFromProto(clusters[0].Spec.Dispatcher.HeartbeatPeriod)
  199. if err == nil && heartbeatPeriod > 0 {
  200. d.config.HeartbeatPeriod = heartbeatPeriod
  201. }
  202. if clusters[0].NetworkBootstrapKeys != nil {
  203. d.networkBootstrapKeys = clusters[0].NetworkBootstrapKeys
  204. }
  205. d.lastSeenRootCert = clusters[0].RootCA.CACert
  206. }
  207. return nil
  208. },
  209. api.EventUpdateCluster{},
  210. )
  211. if err != nil {
  212. d.mu.Unlock()
  213. return err
  214. }
  215. // set queue here to guarantee that Close will close it
  216. d.clusterUpdateQueue = watch.NewQueue()
  217. peerWatcher, peerCancel := d.cluster.SubscribePeers()
  218. defer peerCancel()
  219. d.lastSeenManagers = getWeightedPeers(d.cluster)
  220. defer cancel()
  221. d.ctx, d.cancel = context.WithCancel(ctx)
  222. ctx = d.ctx
  223. // If Stop() is called, it should wait
  224. // for Run() to complete.
  225. d.shutdownWait.Add(1)
  226. defer d.shutdownWait.Done()
  227. d.mu.Unlock()
  228. publishManagers := func(peers []*api.Peer) {
  229. var mgrs []*api.WeightedPeer
  230. for _, p := range peers {
  231. mgrs = append(mgrs, &api.WeightedPeer{
  232. Peer: p,
  233. Weight: remotes.DefaultObservationWeight,
  234. })
  235. }
  236. d.mu.Lock()
  237. d.lastSeenManagers = mgrs
  238. d.mu.Unlock()
  239. d.clusterUpdateQueue.Publish(clusterUpdate{managerUpdate: &mgrs})
  240. }
  241. batchTimer := time.NewTimer(maxBatchInterval)
  242. defer batchTimer.Stop()
  243. for {
  244. select {
  245. case ev := <-peerWatcher:
  246. publishManagers(ev.([]*api.Peer))
  247. case <-d.processUpdatesTrigger:
  248. d.processUpdates(ctx)
  249. batchTimer.Reset(maxBatchInterval)
  250. case <-batchTimer.C:
  251. d.processUpdates(ctx)
  252. batchTimer.Reset(maxBatchInterval)
  253. case v := <-configWatcher:
  254. cluster := v.(api.EventUpdateCluster)
  255. d.mu.Lock()
  256. if cluster.Cluster.Spec.Dispatcher.HeartbeatPeriod != nil {
  257. // ignore error, since Spec has passed validation before
  258. heartbeatPeriod, _ := gogotypes.DurationFromProto(cluster.Cluster.Spec.Dispatcher.HeartbeatPeriod)
  259. if heartbeatPeriod != d.config.HeartbeatPeriod {
  260. // only call d.nodes.updatePeriod when heartbeatPeriod changes
  261. d.config.HeartbeatPeriod = heartbeatPeriod
  262. d.nodes.updatePeriod(d.config.HeartbeatPeriod, d.config.HeartbeatEpsilon, d.config.GracePeriodMultiplier)
  263. }
  264. }
  265. d.lastSeenRootCert = cluster.Cluster.RootCA.CACert
  266. d.networkBootstrapKeys = cluster.Cluster.NetworkBootstrapKeys
  267. d.mu.Unlock()
  268. d.clusterUpdateQueue.Publish(clusterUpdate{
  269. bootstrapKeyUpdate: &cluster.Cluster.NetworkBootstrapKeys,
  270. rootCAUpdate: &cluster.Cluster.RootCA.CACert,
  271. })
  272. case <-ctx.Done():
  273. return nil
  274. }
  275. }
  276. }
  277. // Stop stops dispatcher and closes all grpc streams.
  278. func (d *Dispatcher) Stop() error {
  279. d.mu.Lock()
  280. if !d.isRunning() {
  281. d.mu.Unlock()
  282. return errors.New("dispatcher is already stopped")
  283. }
  284. // Cancel dispatcher context.
  285. // This should also close the the streams in Tasks(), Assignments().
  286. d.cancel()
  287. d.mu.Unlock()
  288. // Wait for the RPCs that are in-progress to finish.
  289. d.shutdownWait.Wait()
  290. d.nodes.Clean()
  291. d.processUpdatesLock.Lock()
  292. // In case there are any waiters. There is no chance of any starting
  293. // after this point, because they check if the context is canceled
  294. // before waiting.
  295. d.processUpdatesCond.Broadcast()
  296. d.processUpdatesLock.Unlock()
  297. d.clusterUpdateQueue.Close()
  298. return nil
  299. }
  300. func (d *Dispatcher) isRunningLocked() (context.Context, error) {
  301. d.mu.Lock()
  302. if !d.isRunning() {
  303. d.mu.Unlock()
  304. return nil, status.Errorf(codes.Aborted, "dispatcher is stopped")
  305. }
  306. ctx := d.ctx
  307. d.mu.Unlock()
  308. return ctx, nil
  309. }
  310. func (d *Dispatcher) markNodesUnknown(ctx context.Context) error {
  311. log := log.G(ctx).WithField("method", "(*Dispatcher).markNodesUnknown")
  312. var nodes []*api.Node
  313. var err error
  314. d.store.View(func(tx store.ReadTx) {
  315. nodes, err = store.FindNodes(tx, store.All)
  316. })
  317. if err != nil {
  318. return errors.Wrap(err, "failed to get list of nodes")
  319. }
  320. err = d.store.Batch(func(batch *store.Batch) error {
  321. for _, n := range nodes {
  322. err := batch.Update(func(tx store.Tx) error {
  323. // check if node is still here
  324. node := store.GetNode(tx, n.ID)
  325. if node == nil {
  326. return nil
  327. }
  328. // do not try to resurrect down nodes
  329. if node.Status.State == api.NodeStatus_DOWN {
  330. nodeCopy := node
  331. expireFunc := func() {
  332. log.Infof("moving tasks to orphaned state for node: %s", nodeCopy.ID)
  333. if err := d.moveTasksToOrphaned(nodeCopy.ID); err != nil {
  334. log.WithError(err).Errorf(`failed to move all tasks for node %s to "ORPHANED" state`, node.ID)
  335. }
  336. d.downNodes.Delete(nodeCopy.ID)
  337. }
  338. log.Infof(`node %s was found to be down when marking unknown on dispatcher start`, node.ID)
  339. d.downNodes.Add(nodeCopy, expireFunc)
  340. return nil
  341. }
  342. node.Status.State = api.NodeStatus_UNKNOWN
  343. node.Status.Message = `Node moved to "unknown" state due to leadership change in cluster`
  344. nodeID := node.ID
  345. expireFunc := func() {
  346. log := log.WithField("node", nodeID)
  347. log.Info(`heartbeat expiration for node %s in state "unknown"`, nodeID)
  348. if err := d.markNodeNotReady(nodeID, api.NodeStatus_DOWN, `heartbeat failure for node in "unknown" state`); err != nil {
  349. log.WithError(err).Error(`failed deregistering node after heartbeat expiration for node in "unknown" state`)
  350. }
  351. }
  352. if err := d.nodes.AddUnknown(node, expireFunc); err != nil {
  353. return errors.Wrapf(err, `adding node %s in "unknown" state to node store failed`, nodeID)
  354. }
  355. if err := store.UpdateNode(tx, node); err != nil {
  356. return errors.Wrapf(err, "update for node %s failed", nodeID)
  357. }
  358. return nil
  359. })
  360. if err != nil {
  361. log.WithField("node", n.ID).WithError(err).Error(`failed to move node to "unknown" state`)
  362. }
  363. }
  364. return nil
  365. })
  366. return err
  367. }
  368. func (d *Dispatcher) isRunning() bool {
  369. if d.ctx == nil {
  370. return false
  371. }
  372. select {
  373. case <-d.ctx.Done():
  374. return false
  375. default:
  376. }
  377. return true
  378. }
  379. // markNodeReady updates the description of a node, updates its address, and sets status to READY
  380. // this is used during registration when a new node description is provided
  381. // and during node updates when the node description changes
  382. func (d *Dispatcher) markNodeReady(ctx context.Context, nodeID string, description *api.NodeDescription, addr string) error {
  383. d.nodeUpdatesLock.Lock()
  384. d.nodeUpdates[nodeID] = nodeUpdate{
  385. status: &api.NodeStatus{
  386. State: api.NodeStatus_READY,
  387. Addr: addr,
  388. },
  389. description: description,
  390. }
  391. numUpdates := len(d.nodeUpdates)
  392. d.nodeUpdatesLock.Unlock()
  393. // Node is marked ready. Remove the node from down nodes if it
  394. // is there.
  395. d.downNodes.Delete(nodeID)
  396. if numUpdates >= maxBatchItems {
  397. select {
  398. case d.processUpdatesTrigger <- struct{}{}:
  399. case <-ctx.Done():
  400. return ctx.Err()
  401. }
  402. }
  403. // Wait until the node update batch happens before unblocking register.
  404. d.processUpdatesLock.Lock()
  405. defer d.processUpdatesLock.Unlock()
  406. select {
  407. case <-ctx.Done():
  408. return ctx.Err()
  409. default:
  410. }
  411. d.processUpdatesCond.Wait()
  412. return nil
  413. }
  414. // gets the node IP from the context of a grpc call
  415. func nodeIPFromContext(ctx context.Context) (string, error) {
  416. nodeInfo, err := ca.RemoteNode(ctx)
  417. if err != nil {
  418. return "", err
  419. }
  420. addr, _, err := net.SplitHostPort(nodeInfo.RemoteAddr)
  421. if err != nil {
  422. return "", errors.Wrap(err, "unable to get ip from addr:port")
  423. }
  424. return addr, nil
  425. }
  426. // register is used for registration of node with particular dispatcher.
  427. func (d *Dispatcher) register(ctx context.Context, nodeID string, description *api.NodeDescription) (string, error) {
  428. dctx, err := d.isRunningLocked()
  429. if err != nil {
  430. return "", err
  431. }
  432. logLocal := log.G(ctx).WithField("method", "(*Dispatcher).register")
  433. if err := d.nodes.CheckRateLimit(nodeID); err != nil {
  434. return "", err
  435. }
  436. // TODO(stevvooe): Validate node specification.
  437. var node *api.Node
  438. d.store.View(func(tx store.ReadTx) {
  439. node = store.GetNode(tx, nodeID)
  440. })
  441. if node == nil {
  442. return "", ErrNodeNotFound
  443. }
  444. addr, err := nodeIPFromContext(ctx)
  445. if err != nil {
  446. logLocal.WithError(err).Debug("failed to get remote node IP")
  447. }
  448. if err := d.markNodeReady(dctx, nodeID, description, addr); err != nil {
  449. return "", err
  450. }
  451. expireFunc := func() {
  452. log.G(ctx).Debug("heartbeat expiration for worker %s, setting worker status to NodeStatus_DOWN ", nodeID)
  453. if err := d.markNodeNotReady(nodeID, api.NodeStatus_DOWN, "heartbeat failure"); err != nil {
  454. log.G(ctx).WithError(err).Errorf("failed deregistering node after heartbeat expiration")
  455. }
  456. }
  457. rn := d.nodes.Add(node, expireFunc)
  458. logLocal.Infof("worker %s was successfully registered", nodeID)
  459. // NOTE(stevvooe): We need be a little careful with re-registration. The
  460. // current implementation just matches the node id and then gives away the
  461. // sessionID. If we ever want to use sessionID as a secret, which we may
  462. // want to, this is giving away the keys to the kitchen.
  463. //
  464. // The right behavior is going to be informed by identity. Basically, each
  465. // time a node registers, we invalidate the session and issue a new
  466. // session, once identity is proven. This will cause misbehaved agents to
  467. // be kicked when multiple connections are made.
  468. return rn.SessionID, nil
  469. }
  470. // UpdateTaskStatus updates status of task. Node should send such updates
  471. // on every status change of its tasks.
  472. func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStatusRequest) (*api.UpdateTaskStatusResponse, error) {
  473. // shutdownWait.Add() followed by isRunning() to ensures that
  474. // if this rpc sees the dispatcher running,
  475. // it will already have called Add() on the shutdownWait wait,
  476. // which ensures that Stop() will wait for this rpc to complete.
  477. // Note that Stop() first does Dispatcher.ctx.cancel() followed by
  478. // shutdownWait.Wait() to make sure new rpc's don't start before waiting
  479. // for existing ones to finish.
  480. d.shutdownWait.Add(1)
  481. defer d.shutdownWait.Done()
  482. dctx, err := d.isRunningLocked()
  483. if err != nil {
  484. return nil, err
  485. }
  486. nodeInfo, err := ca.RemoteNode(ctx)
  487. if err != nil {
  488. return nil, err
  489. }
  490. nodeID := nodeInfo.NodeID
  491. fields := logrus.Fields{
  492. "node.id": nodeID,
  493. "node.session": r.SessionID,
  494. "method": "(*Dispatcher).UpdateTaskStatus",
  495. }
  496. if nodeInfo.ForwardedBy != nil {
  497. fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID
  498. }
  499. log := log.G(ctx).WithFields(fields)
  500. if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
  501. return nil, err
  502. }
  503. validTaskUpdates := make([]*api.UpdateTaskStatusRequest_TaskStatusUpdate, 0, len(r.Updates))
  504. // Validate task updates
  505. for _, u := range r.Updates {
  506. if u.Status == nil {
  507. log.WithField("task.id", u.TaskID).Warn("task report has nil status")
  508. continue
  509. }
  510. var t *api.Task
  511. d.store.View(func(tx store.ReadTx) {
  512. t = store.GetTask(tx, u.TaskID)
  513. })
  514. if t == nil {
  515. // Task may have been deleted
  516. log.WithField("task.id", u.TaskID).Debug("cannot find target task in store")
  517. continue
  518. }
  519. if t.NodeID != nodeID {
  520. err := status.Errorf(codes.PermissionDenied, "cannot update a task not assigned this node")
  521. log.WithField("task.id", u.TaskID).Error(err)
  522. return nil, err
  523. }
  524. validTaskUpdates = append(validTaskUpdates, u)
  525. }
  526. d.taskUpdatesLock.Lock()
  527. // Enqueue task updates
  528. for _, u := range validTaskUpdates {
  529. d.taskUpdates[u.TaskID] = u.Status
  530. }
  531. numUpdates := len(d.taskUpdates)
  532. d.taskUpdatesLock.Unlock()
  533. if numUpdates >= maxBatchItems {
  534. select {
  535. case d.processUpdatesTrigger <- struct{}{}:
  536. case <-dctx.Done():
  537. }
  538. }
  539. return nil, nil
  540. }
  541. func (d *Dispatcher) processUpdates(ctx context.Context) {
  542. var (
  543. taskUpdates map[string]*api.TaskStatus
  544. nodeUpdates map[string]nodeUpdate
  545. )
  546. d.taskUpdatesLock.Lock()
  547. if len(d.taskUpdates) != 0 {
  548. taskUpdates = d.taskUpdates
  549. d.taskUpdates = make(map[string]*api.TaskStatus)
  550. }
  551. d.taskUpdatesLock.Unlock()
  552. d.nodeUpdatesLock.Lock()
  553. if len(d.nodeUpdates) != 0 {
  554. nodeUpdates = d.nodeUpdates
  555. d.nodeUpdates = make(map[string]nodeUpdate)
  556. }
  557. d.nodeUpdatesLock.Unlock()
  558. if len(taskUpdates) == 0 && len(nodeUpdates) == 0 {
  559. return
  560. }
  561. log := log.G(ctx).WithFields(logrus.Fields{
  562. "method": "(*Dispatcher).processUpdates",
  563. })
  564. err := d.store.Batch(func(batch *store.Batch) error {
  565. for taskID, status := range taskUpdates {
  566. err := batch.Update(func(tx store.Tx) error {
  567. logger := log.WithField("task.id", taskID)
  568. task := store.GetTask(tx, taskID)
  569. if task == nil {
  570. // Task may have been deleted
  571. logger.Debug("cannot find target task in store")
  572. return nil
  573. }
  574. logger = logger.WithField("state.transition", fmt.Sprintf("%v->%v", task.Status.State, status.State))
  575. if task.Status == *status {
  576. logger.Debug("task status identical, ignoring")
  577. return nil
  578. }
  579. if task.Status.State > status.State {
  580. logger.Debug("task status invalid transition")
  581. return nil
  582. }
  583. // Update scheduling delay metric for running tasks.
  584. // We use the status update time on the leader to calculate the scheduling delay.
  585. // Because of this, the recorded scheduling delay will be an overestimate and include
  586. // the network delay between the worker and the leader.
  587. // This is not ideal, but its a known overestimation, rather than using the status update time
  588. // from the worker node, which may cause unknown incorrect results due to possible clock skew.
  589. if status.State == api.TaskStateRunning {
  590. start := time.Unix(status.AppliedAt.GetSeconds(), int64(status.AppliedAt.GetNanos()))
  591. schedulingDelayTimer.UpdateSince(start)
  592. }
  593. task.Status = *status
  594. task.Status.AppliedBy = d.securityConfig.ClientTLSCreds.NodeID()
  595. task.Status.AppliedAt = ptypes.MustTimestampProto(time.Now())
  596. if err := store.UpdateTask(tx, task); err != nil {
  597. logger.WithError(err).Error("failed to update task status")
  598. return nil
  599. }
  600. logger.Debug("dispatcher committed status update to store")
  601. return nil
  602. })
  603. if err != nil {
  604. log.WithError(err).Error("dispatcher task update transaction failed")
  605. }
  606. }
  607. for nodeID, nodeUpdate := range nodeUpdates {
  608. err := batch.Update(func(tx store.Tx) error {
  609. logger := log.WithField("node.id", nodeID)
  610. node := store.GetNode(tx, nodeID)
  611. if node == nil {
  612. logger.Errorf("node unavailable")
  613. return nil
  614. }
  615. if nodeUpdate.status != nil {
  616. node.Status.State = nodeUpdate.status.State
  617. node.Status.Message = nodeUpdate.status.Message
  618. if nodeUpdate.status.Addr != "" {
  619. node.Status.Addr = nodeUpdate.status.Addr
  620. }
  621. }
  622. if nodeUpdate.description != nil {
  623. node.Description = nodeUpdate.description
  624. }
  625. if err := store.UpdateNode(tx, node); err != nil {
  626. logger.WithError(err).Error("failed to update node status")
  627. return nil
  628. }
  629. logger.Debug("node status updated")
  630. return nil
  631. })
  632. if err != nil {
  633. log.WithError(err).Error("dispatcher node update transaction failed")
  634. }
  635. }
  636. return nil
  637. })
  638. if err != nil {
  639. log.WithError(err).Error("dispatcher batch failed")
  640. }
  641. d.processUpdatesCond.Broadcast()
  642. }
  643. // Tasks is a stream of tasks state for node. Each message contains full list
  644. // of tasks which should be run on node, if task is not present in that list,
  645. // it should be terminated.
  646. func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServer) error {
  647. // shutdownWait.Add() followed by isRunning() to ensures that
  648. // if this rpc sees the dispatcher running,
  649. // it will already have called Add() on the shutdownWait wait,
  650. // which ensures that Stop() will wait for this rpc to complete.
  651. // Note that Stop() first does Dispatcher.ctx.cancel() followed by
  652. // shutdownWait.Wait() to make sure new rpc's don't start before waiting
  653. // for existing ones to finish.
  654. d.shutdownWait.Add(1)
  655. defer d.shutdownWait.Done()
  656. dctx, err := d.isRunningLocked()
  657. if err != nil {
  658. return err
  659. }
  660. nodeInfo, err := ca.RemoteNode(stream.Context())
  661. if err != nil {
  662. return err
  663. }
  664. nodeID := nodeInfo.NodeID
  665. fields := logrus.Fields{
  666. "node.id": nodeID,
  667. "node.session": r.SessionID,
  668. "method": "(*Dispatcher).Tasks",
  669. }
  670. if nodeInfo.ForwardedBy != nil {
  671. fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID
  672. }
  673. log.G(stream.Context()).WithFields(fields).Debug("")
  674. if _, err = d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
  675. return err
  676. }
  677. tasksMap := make(map[string]*api.Task)
  678. nodeTasks, cancel, err := store.ViewAndWatch(
  679. d.store,
  680. func(readTx store.ReadTx) error {
  681. tasks, err := store.FindTasks(readTx, store.ByNodeID(nodeID))
  682. if err != nil {
  683. return err
  684. }
  685. for _, t := range tasks {
  686. tasksMap[t.ID] = t
  687. }
  688. return nil
  689. },
  690. api.EventCreateTask{Task: &api.Task{NodeID: nodeID},
  691. Checks: []api.TaskCheckFunc{api.TaskCheckNodeID}},
  692. api.EventUpdateTask{Task: &api.Task{NodeID: nodeID},
  693. Checks: []api.TaskCheckFunc{api.TaskCheckNodeID}},
  694. api.EventDeleteTask{Task: &api.Task{NodeID: nodeID},
  695. Checks: []api.TaskCheckFunc{api.TaskCheckNodeID}},
  696. )
  697. if err != nil {
  698. return err
  699. }
  700. defer cancel()
  701. for {
  702. if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
  703. return err
  704. }
  705. var tasks []*api.Task
  706. for _, t := range tasksMap {
  707. // dispatcher only sends tasks that have been assigned to a node
  708. if t != nil && t.Status.State >= api.TaskStateAssigned {
  709. tasks = append(tasks, t)
  710. }
  711. }
  712. if err := stream.Send(&api.TasksMessage{Tasks: tasks}); err != nil {
  713. return err
  714. }
  715. // bursty events should be processed in batches and sent out snapshot
  716. var (
  717. modificationCnt int
  718. batchingTimer *time.Timer
  719. batchingTimeout <-chan time.Time
  720. )
  721. batchingLoop:
  722. for modificationCnt < modificationBatchLimit {
  723. select {
  724. case event := <-nodeTasks:
  725. switch v := event.(type) {
  726. case api.EventCreateTask:
  727. tasksMap[v.Task.ID] = v.Task
  728. modificationCnt++
  729. case api.EventUpdateTask:
  730. if oldTask, exists := tasksMap[v.Task.ID]; exists {
  731. // States ASSIGNED and below are set by the orchestrator/scheduler,
  732. // not the agent, so tasks in these states need to be sent to the
  733. // agent even if nothing else has changed.
  734. if equality.TasksEqualStable(oldTask, v.Task) && v.Task.Status.State > api.TaskStateAssigned {
  735. // this update should not trigger action at agent
  736. tasksMap[v.Task.ID] = v.Task
  737. continue
  738. }
  739. }
  740. tasksMap[v.Task.ID] = v.Task
  741. modificationCnt++
  742. case api.EventDeleteTask:
  743. delete(tasksMap, v.Task.ID)
  744. modificationCnt++
  745. }
  746. if batchingTimer != nil {
  747. batchingTimer.Reset(batchingWaitTime)
  748. } else {
  749. batchingTimer = time.NewTimer(batchingWaitTime)
  750. batchingTimeout = batchingTimer.C
  751. }
  752. case <-batchingTimeout:
  753. break batchingLoop
  754. case <-stream.Context().Done():
  755. return stream.Context().Err()
  756. case <-dctx.Done():
  757. return dctx.Err()
  758. }
  759. }
  760. if batchingTimer != nil {
  761. batchingTimer.Stop()
  762. }
  763. }
  764. }
  765. // Assignments is a stream of assignments for a node. Each message contains
  766. // either full list of tasks and secrets for the node, or an incremental update.
  767. func (d *Dispatcher) Assignments(r *api.AssignmentsRequest, stream api.Dispatcher_AssignmentsServer) error {
  768. // shutdownWait.Add() followed by isRunning() to ensures that
  769. // if this rpc sees the dispatcher running,
  770. // it will already have called Add() on the shutdownWait wait,
  771. // which ensures that Stop() will wait for this rpc to complete.
  772. // Note that Stop() first does Dispatcher.ctx.cancel() followed by
  773. // shutdownWait.Wait() to make sure new rpc's don't start before waiting
  774. // for existing ones to finish.
  775. d.shutdownWait.Add(1)
  776. defer d.shutdownWait.Done()
  777. dctx, err := d.isRunningLocked()
  778. if err != nil {
  779. return err
  780. }
  781. nodeInfo, err := ca.RemoteNode(stream.Context())
  782. if err != nil {
  783. return err
  784. }
  785. nodeID := nodeInfo.NodeID
  786. fields := logrus.Fields{
  787. "node.id": nodeID,
  788. "node.session": r.SessionID,
  789. "method": "(*Dispatcher).Assignments",
  790. }
  791. if nodeInfo.ForwardedBy != nil {
  792. fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID
  793. }
  794. log := log.G(stream.Context()).WithFields(fields)
  795. log.Debug("")
  796. if _, err = d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
  797. return err
  798. }
  799. var (
  800. sequence int64
  801. appliesTo string
  802. assignments = newAssignmentSet(log, d.dp)
  803. )
  804. sendMessage := func(msg api.AssignmentsMessage, assignmentType api.AssignmentsMessage_Type) error {
  805. sequence++
  806. msg.AppliesTo = appliesTo
  807. msg.ResultsIn = strconv.FormatInt(sequence, 10)
  808. appliesTo = msg.ResultsIn
  809. msg.Type = assignmentType
  810. return stream.Send(&msg)
  811. }
  812. // TODO(aaronl): Also send node secrets that should be exposed to
  813. // this node.
  814. nodeTasks, cancel, err := store.ViewAndWatch(
  815. d.store,
  816. func(readTx store.ReadTx) error {
  817. tasks, err := store.FindTasks(readTx, store.ByNodeID(nodeID))
  818. if err != nil {
  819. return err
  820. }
  821. for _, t := range tasks {
  822. assignments.addOrUpdateTask(readTx, t)
  823. }
  824. return nil
  825. },
  826. api.EventUpdateTask{Task: &api.Task{NodeID: nodeID},
  827. Checks: []api.TaskCheckFunc{api.TaskCheckNodeID}},
  828. api.EventDeleteTask{Task: &api.Task{NodeID: nodeID},
  829. Checks: []api.TaskCheckFunc{api.TaskCheckNodeID}},
  830. )
  831. if err != nil {
  832. return err
  833. }
  834. defer cancel()
  835. if err := sendMessage(assignments.message(), api.AssignmentsMessage_COMPLETE); err != nil {
  836. return err
  837. }
  838. for {
  839. // Check for session expiration
  840. if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
  841. return err
  842. }
  843. // bursty events should be processed in batches and sent out together
  844. var (
  845. modificationCnt int
  846. batchingTimer *time.Timer
  847. batchingTimeout <-chan time.Time
  848. )
  849. oneModification := func() {
  850. modificationCnt++
  851. if batchingTimer != nil {
  852. batchingTimer.Reset(batchingWaitTime)
  853. } else {
  854. batchingTimer = time.NewTimer(batchingWaitTime)
  855. batchingTimeout = batchingTimer.C
  856. }
  857. }
  858. // The batching loop waits for 50 ms after the most recent
  859. // change, or until modificationBatchLimit is reached. The
  860. // worst case latency is modificationBatchLimit * batchingWaitTime,
  861. // which is 10 seconds.
  862. batchingLoop:
  863. for modificationCnt < modificationBatchLimit {
  864. select {
  865. case event := <-nodeTasks:
  866. switch v := event.(type) {
  867. // We don't monitor EventCreateTask because tasks are
  868. // never created in the ASSIGNED state. First tasks are
  869. // created by the orchestrator, then the scheduler moves
  870. // them to ASSIGNED. If this ever changes, we will need
  871. // to monitor task creations as well.
  872. case api.EventUpdateTask:
  873. d.store.View(func(readTx store.ReadTx) {
  874. if assignments.addOrUpdateTask(readTx, v.Task) {
  875. oneModification()
  876. }
  877. })
  878. case api.EventDeleteTask:
  879. if assignments.removeTask(v.Task) {
  880. oneModification()
  881. }
  882. // TODO(aaronl): For node secrets, we'll need to handle
  883. // EventCreateSecret.
  884. }
  885. case <-batchingTimeout:
  886. break batchingLoop
  887. case <-stream.Context().Done():
  888. return stream.Context().Err()
  889. case <-dctx.Done():
  890. return dctx.Err()
  891. }
  892. }
  893. if batchingTimer != nil {
  894. batchingTimer.Stop()
  895. }
  896. if modificationCnt > 0 {
  897. if err := sendMessage(assignments.message(), api.AssignmentsMessage_INCREMENTAL); err != nil {
  898. return err
  899. }
  900. }
  901. }
  902. }
  903. func (d *Dispatcher) moveTasksToOrphaned(nodeID string) error {
  904. err := d.store.Batch(func(batch *store.Batch) error {
  905. var (
  906. tasks []*api.Task
  907. err error
  908. )
  909. d.store.View(func(tx store.ReadTx) {
  910. tasks, err = store.FindTasks(tx, store.ByNodeID(nodeID))
  911. })
  912. if err != nil {
  913. return err
  914. }
  915. for _, task := range tasks {
  916. // Tasks running on an unreachable node need to be marked as
  917. // orphaned since we have no idea whether the task is still running
  918. // or not.
  919. //
  920. // This only applies for tasks that could have made progress since
  921. // the agent became unreachable (assigned<->running)
  922. //
  923. // Tasks in a final state (e.g. rejected) *cannot* have made
  924. // progress, therefore there's no point in marking them as orphaned
  925. if task.Status.State >= api.TaskStateAssigned && task.Status.State <= api.TaskStateRunning {
  926. task.Status.State = api.TaskStateOrphaned
  927. }
  928. if err := batch.Update(func(tx store.Tx) error {
  929. err := store.UpdateTask(tx, task)
  930. if err != nil {
  931. return err
  932. }
  933. return nil
  934. }); err != nil {
  935. return err
  936. }
  937. }
  938. return nil
  939. })
  940. return err
  941. }
  942. // markNodeNotReady sets the node state to some state other than READY
  943. func (d *Dispatcher) markNodeNotReady(id string, state api.NodeStatus_State, message string) error {
  944. logLocal := log.G(d.ctx).WithField("method", "(*Dispatcher).markNodeNotReady")
  945. dctx, err := d.isRunningLocked()
  946. if err != nil {
  947. return err
  948. }
  949. // Node is down. Add it to down nodes so that we can keep
  950. // track of tasks assigned to the node.
  951. var node *api.Node
  952. d.store.View(func(readTx store.ReadTx) {
  953. node = store.GetNode(readTx, id)
  954. if node == nil {
  955. err = fmt.Errorf("could not find node %s while trying to add to down nodes store", id)
  956. }
  957. })
  958. if err != nil {
  959. return err
  960. }
  961. expireFunc := func() {
  962. log.G(dctx).Debugf(`worker timed-out %s in "down" state, moving all tasks to "ORPHANED" state`, id)
  963. if err := d.moveTasksToOrphaned(id); err != nil {
  964. log.G(dctx).WithError(err).Error(`failed to move all tasks to "ORPHANED" state`)
  965. }
  966. d.downNodes.Delete(id)
  967. }
  968. d.downNodes.Add(node, expireFunc)
  969. logLocal.Debugf("added node %s to down nodes list", node.ID)
  970. status := &api.NodeStatus{
  971. State: state,
  972. Message: message,
  973. }
  974. d.nodeUpdatesLock.Lock()
  975. // pluck the description out of nodeUpdates. this protects against a case
  976. // where a node is marked ready and a description is added, but then the
  977. // node is immediately marked not ready. this preserves that description
  978. d.nodeUpdates[id] = nodeUpdate{status: status, description: d.nodeUpdates[id].description}
  979. numUpdates := len(d.nodeUpdates)
  980. d.nodeUpdatesLock.Unlock()
  981. if numUpdates >= maxBatchItems {
  982. select {
  983. case d.processUpdatesTrigger <- struct{}{}:
  984. case <-dctx.Done():
  985. }
  986. }
  987. if rn := d.nodes.Delete(id); rn == nil {
  988. return errors.Errorf("node %s is not found in local storage", id)
  989. }
  990. logLocal.Debugf("deleted node %s from node store", node.ID)
  991. return nil
  992. }
  993. // Heartbeat is heartbeat method for nodes. It returns new TTL in response.
  994. // Node should send new heartbeat earlier than now + TTL, otherwise it will
  995. // be deregistered from dispatcher and its status will be updated to NodeStatus_DOWN
  996. func (d *Dispatcher) Heartbeat(ctx context.Context, r *api.HeartbeatRequest) (*api.HeartbeatResponse, error) {
  997. // shutdownWait.Add() followed by isRunning() to ensures that
  998. // if this rpc sees the dispatcher running,
  999. // it will already have called Add() on the shutdownWait wait,
  1000. // which ensures that Stop() will wait for this rpc to complete.
  1001. // Note that Stop() first does Dispatcher.ctx.cancel() followed by
  1002. // shutdownWait.Wait() to make sure new rpc's don't start before waiting
  1003. // for existing ones to finish.
  1004. d.shutdownWait.Add(1)
  1005. defer d.shutdownWait.Done()
  1006. // isRunningLocked() is not needed since its OK if
  1007. // the dispatcher context is cancelled while this call is in progress
  1008. // since Stop() which cancels the dispatcher context will wait for
  1009. // Heartbeat() to complete.
  1010. if !d.isRunning() {
  1011. return nil, status.Errorf(codes.Aborted, "dispatcher is stopped")
  1012. }
  1013. nodeInfo, err := ca.RemoteNode(ctx)
  1014. if err != nil {
  1015. return nil, err
  1016. }
  1017. period, err := d.nodes.Heartbeat(nodeInfo.NodeID, r.SessionID)
  1018. log.G(ctx).WithField("method", "(*Dispatcher).Heartbeat").Debugf("received heartbeat from worker %v, expect next heartbeat in %v", nodeInfo, period)
  1019. return &api.HeartbeatResponse{Period: period}, err
  1020. }
  1021. func (d *Dispatcher) getManagers() []*api.WeightedPeer {
  1022. d.mu.Lock()
  1023. defer d.mu.Unlock()
  1024. return d.lastSeenManagers
  1025. }
  1026. func (d *Dispatcher) getNetworkBootstrapKeys() []*api.EncryptionKey {
  1027. d.mu.Lock()
  1028. defer d.mu.Unlock()
  1029. return d.networkBootstrapKeys
  1030. }
  1031. func (d *Dispatcher) getRootCACert() []byte {
  1032. d.mu.Lock()
  1033. defer d.mu.Unlock()
  1034. return d.lastSeenRootCert
  1035. }
  1036. // Session is a stream which controls agent connection.
  1037. // Each message contains list of backup Managers with weights. Also there is
  1038. // a special boolean field Disconnect which if true indicates that node should
  1039. // reconnect to another Manager immediately.
  1040. func (d *Dispatcher) Session(r *api.SessionRequest, stream api.Dispatcher_SessionServer) error {
  1041. // shutdownWait.Add() followed by isRunning() to ensures that
  1042. // if this rpc sees the dispatcher running,
  1043. // it will already have called Add() on the shutdownWait wait,
  1044. // which ensures that Stop() will wait for this rpc to complete.
  1045. // Note that Stop() first does Dispatcher.ctx.cancel() followed by
  1046. // shutdownWait.Wait() to make sure new rpc's don't start before waiting
  1047. // for existing ones to finish.
  1048. d.shutdownWait.Add(1)
  1049. defer d.shutdownWait.Done()
  1050. dctx, err := d.isRunningLocked()
  1051. if err != nil {
  1052. return err
  1053. }
  1054. ctx := stream.Context()
  1055. nodeInfo, err := ca.RemoteNode(ctx)
  1056. if err != nil {
  1057. return err
  1058. }
  1059. nodeID := nodeInfo.NodeID
  1060. var sessionID string
  1061. if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
  1062. // register the node.
  1063. sessionID, err = d.register(ctx, nodeID, r.Description)
  1064. if err != nil {
  1065. return err
  1066. }
  1067. } else {
  1068. sessionID = r.SessionID
  1069. // get the node IP addr
  1070. addr, err := nodeIPFromContext(stream.Context())
  1071. if err != nil {
  1072. log.G(ctx).WithError(err).Debug("failed to get remote node IP")
  1073. }
  1074. // update the node description
  1075. if err := d.markNodeReady(dctx, nodeID, r.Description, addr); err != nil {
  1076. return err
  1077. }
  1078. }
  1079. fields := logrus.Fields{
  1080. "node.id": nodeID,
  1081. "node.session": sessionID,
  1082. "method": "(*Dispatcher).Session",
  1083. }
  1084. if nodeInfo.ForwardedBy != nil {
  1085. fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID
  1086. }
  1087. log := log.G(ctx).WithFields(fields)
  1088. var nodeObj *api.Node
  1089. nodeUpdates, cancel, err := store.ViewAndWatch(d.store, func(readTx store.ReadTx) error {
  1090. nodeObj = store.GetNode(readTx, nodeID)
  1091. return nil
  1092. }, api.EventUpdateNode{Node: &api.Node{ID: nodeID},
  1093. Checks: []api.NodeCheckFunc{api.NodeCheckID}},
  1094. )
  1095. if cancel != nil {
  1096. defer cancel()
  1097. }
  1098. if err != nil {
  1099. log.WithError(err).Error("ViewAndWatch Node failed")
  1100. }
  1101. if _, err = d.nodes.GetWithSession(nodeID, sessionID); err != nil {
  1102. return err
  1103. }
  1104. clusterUpdatesCh, clusterCancel := d.clusterUpdateQueue.Watch()
  1105. defer clusterCancel()
  1106. if err := stream.Send(&api.SessionMessage{
  1107. SessionID: sessionID,
  1108. Node: nodeObj,
  1109. Managers: d.getManagers(),
  1110. NetworkBootstrapKeys: d.getNetworkBootstrapKeys(),
  1111. RootCA: d.getRootCACert(),
  1112. }); err != nil {
  1113. return err
  1114. }
  1115. // disconnectNode is a helper forcibly shutdown connection
  1116. disconnectNode := func() error {
  1117. // force disconnect by shutting down the stream.
  1118. transportStream, ok := transport.StreamFromContext(stream.Context())
  1119. if ok {
  1120. // if we have the transport stream, we can signal a disconnect
  1121. // in the client.
  1122. if err := transportStream.ServerTransport().Close(); err != nil {
  1123. log.WithError(err).Error("session end")
  1124. }
  1125. }
  1126. log.Infof("dispatcher session dropped, marking node %s down", nodeID)
  1127. if err := d.markNodeNotReady(nodeID, api.NodeStatus_DISCONNECTED, "node is currently trying to find new manager"); err != nil {
  1128. log.WithError(err).Error("failed to remove node")
  1129. }
  1130. // still return an abort if the transport closure was ineffective.
  1131. return status.Errorf(codes.Aborted, "node must disconnect")
  1132. }
  1133. for {
  1134. // After each message send, we need to check the nodes sessionID hasn't
  1135. // changed. If it has, we will shut down the stream and make the node
  1136. // re-register.
  1137. node, err := d.nodes.GetWithSession(nodeID, sessionID)
  1138. if err != nil {
  1139. return err
  1140. }
  1141. var (
  1142. disconnect bool
  1143. mgrs []*api.WeightedPeer
  1144. netKeys []*api.EncryptionKey
  1145. rootCert []byte
  1146. )
  1147. select {
  1148. case ev := <-clusterUpdatesCh:
  1149. update := ev.(clusterUpdate)
  1150. if update.managerUpdate != nil {
  1151. mgrs = *update.managerUpdate
  1152. }
  1153. if update.bootstrapKeyUpdate != nil {
  1154. netKeys = *update.bootstrapKeyUpdate
  1155. }
  1156. if update.rootCAUpdate != nil {
  1157. rootCert = *update.rootCAUpdate
  1158. }
  1159. case ev := <-nodeUpdates:
  1160. nodeObj = ev.(api.EventUpdateNode).Node
  1161. case <-stream.Context().Done():
  1162. return stream.Context().Err()
  1163. case <-node.Disconnect:
  1164. disconnect = true
  1165. case <-dctx.Done():
  1166. disconnect = true
  1167. }
  1168. if mgrs == nil {
  1169. mgrs = d.getManagers()
  1170. }
  1171. if netKeys == nil {
  1172. netKeys = d.getNetworkBootstrapKeys()
  1173. }
  1174. if rootCert == nil {
  1175. rootCert = d.getRootCACert()
  1176. }
  1177. if err := stream.Send(&api.SessionMessage{
  1178. SessionID: sessionID,
  1179. Node: nodeObj,
  1180. Managers: mgrs,
  1181. NetworkBootstrapKeys: netKeys,
  1182. RootCA: rootCert,
  1183. }); err != nil {
  1184. return err
  1185. }
  1186. if disconnect {
  1187. return disconnectNode()
  1188. }
  1189. }
  1190. }