diff --git a/vendor.conf b/vendor.conf index 1cafd8252e..10967369eb 100644 --- a/vendor.conf +++ b/vendor.conf @@ -52,7 +52,7 @@ github.com/docker/go-events e31b211e4f1cd09aa76fe4ac2445 github.com/armon/go-radix e39d623f12e8e41c7b5529e9a9dd67a1e2261f80 github.com/armon/go-metrics eb0af217e5e9747e41dd5303755356b62d28e3ec github.com/hashicorp/go-msgpack 71c2886f5a673a35f909803f38ece5810165097b -github.com/hashicorp/memberlist 3d8438da9589e7b608a83ffac1ef8211486bcb7c +github.com/hashicorp/memberlist e6ff9b2d87a3f0f3f04abb5672ada3ac2a640223 # v0.4.0 github.com/sean-/seed e2103e2c35297fb7e17febb81e49b312087a2372 github.com/hashicorp/errwrap 8a6fb523712970c966eefc6b39ed2c5e74880354 # v1.0.0 github.com/hashicorp/go-sockaddr c7188e74f6acae5a989bdc959aa779f8b9f42faf # v1.0.2 diff --git a/vendor/github.com/hashicorp/memberlist/README.md b/vendor/github.com/hashicorp/memberlist/README.md index 0adc075e81..6a2caa30e0 100644 --- a/vendor/github.com/hashicorp/memberlist/README.md +++ b/vendor/github.com/hashicorp/memberlist/README.md @@ -1,4 +1,4 @@ -# memberlist [![GoDoc](https://godoc.org/github.com/hashicorp/memberlist?status.png)](https://godoc.org/github.com/hashicorp/memberlist) +# memberlist [![GoDoc](https://godoc.org/github.com/hashicorp/memberlist?status.png)](https://godoc.org/github.com/hashicorp/memberlist) [![CircleCI](https://circleci.com/gh/hashicorp/memberlist.svg?style=svg)](https://circleci.com/gh/hashicorp/memberlist) memberlist is a [Go](http://www.golang.org) library that manages cluster membership and member failure detection using a gossip based protocol. @@ -23,8 +23,6 @@ Please check your installation with: go version ``` -Run `make deps` to fetch dependencies before building - ## Usage Memberlist is surprisingly simple to use. An example is shown below: @@ -65,7 +63,7 @@ For complete documentation, see the associated [Godoc](http://godoc.org/github.c ## Protocol -memberlist is based on ["SWIM: Scalable Weakly-consistent Infection-style Process Group Membership Protocol"](http://www.cs.cornell.edu/~asdas/research/dsn02-swim.pdf). However, we extend the protocol in a number of ways: +memberlist is based on ["SWIM: Scalable Weakly-consistent Infection-style Process Group Membership Protocol"](http://ieeexplore.ieee.org/document/1028914/). However, we extend the protocol in a number of ways: * Several extensions are made to increase propagation speed and convergence rate. diff --git a/vendor/github.com/hashicorp/memberlist/alive_delegate.go b/vendor/github.com/hashicorp/memberlist/alive_delegate.go index 51a0ba9054..615f4a90a5 100644 --- a/vendor/github.com/hashicorp/memberlist/alive_delegate.go +++ b/vendor/github.com/hashicorp/memberlist/alive_delegate.go @@ -7,8 +7,8 @@ package memberlist // a node out and prevent it from being considered a peer // using application specific logic. type AliveDelegate interface { - // NotifyMerge is invoked when a merge could take place. - // Provides a list of the nodes known by the peer. If - // the return value is non-nil, the merge is canceled. + // NotifyAlive is invoked when a message about a live + // node is received from the network. Returning a non-nil + // error prevents the node from being considered a peer. NotifyAlive(peer *Node) error } diff --git a/vendor/github.com/hashicorp/memberlist/awareness.go b/vendor/github.com/hashicorp/memberlist/awareness.go index ea95c75388..53b1bb3136 100644 --- a/vendor/github.com/hashicorp/memberlist/awareness.go +++ b/vendor/github.com/hashicorp/memberlist/awareness.go @@ -21,13 +21,17 @@ type awareness struct { // score is the current awareness score. Lower values are healthier and // zero is the minimum value. score int + + // metricLabels is the slice of labels to put on all emitted metrics + metricLabels []metrics.Label } // newAwareness returns a new awareness object. -func newAwareness(max int) *awareness { +func newAwareness(max int, metricLabels []metrics.Label) *awareness { return &awareness{ - max: max, - score: 0, + max: max, + score: 0, + metricLabels: metricLabels, } } @@ -47,7 +51,7 @@ func (a *awareness) ApplyDelta(delta int) { a.Unlock() if initial != final { - metrics.SetGauge([]string{"memberlist", "health", "score"}, float32(final)) + metrics.SetGaugeWithLabels([]string{"memberlist", "health", "score"}, float32(final), a.metricLabels) } } diff --git a/vendor/github.com/hashicorp/memberlist/broadcast.go b/vendor/github.com/hashicorp/memberlist/broadcast.go index f7e85a119c..d07d41bb69 100644 --- a/vendor/github.com/hashicorp/memberlist/broadcast.go +++ b/vendor/github.com/hashicorp/memberlist/broadcast.go @@ -29,6 +29,11 @@ func (b *memberlistBroadcast) Invalidates(other Broadcast) bool { return b.node == mb.node } +// memberlist.NamedBroadcast optional interface +func (b *memberlistBroadcast) Name() string { + return b.node +} + func (b *memberlistBroadcast) Message() []byte { return b.msg } diff --git a/vendor/github.com/hashicorp/memberlist/config.go b/vendor/github.com/hashicorp/memberlist/config.go index c85b1657a2..d83a4f3fc1 100644 --- a/vendor/github.com/hashicorp/memberlist/config.go +++ b/vendor/github.com/hashicorp/memberlist/config.go @@ -1,10 +1,16 @@ package memberlist import ( + "fmt" "io" "log" + "net" "os" + "strings" "time" + + "github.com/armon/go-metrics" + multierror "github.com/hashicorp/go-multierror" ) type Config struct { @@ -16,6 +22,17 @@ type Config struct { // make a NetTransport using BindAddr and BindPort from this structure. Transport Transport + // Label is an optional set of bytes to include on the outside of each + // packet and stream. + // + // If gossip encryption is enabled and this is set it is treated as GCM + // authenticated data. + Label string + + // SkipInboundLabelCheck skips the check that inbound packets and gossip + // streams need to be label prefixed. + SkipInboundLabelCheck bool + // Configuration related to what address to bind to and ports to // listen on. The port is used for both UDP and TCP gossip. It is // assumed other nodes are running on this port, but they do not need @@ -116,6 +133,10 @@ type Config struct { // indirect UDP pings. DisableTcpPings bool + // DisableTcpPingsForNode is like DisableTcpPings, but lets you control + // whether to perform TCP pings on a node-by-node basis. + DisableTcpPingsForNode func(nodeName string) bool + // AwarenessMaxMultiplier will increase the probe interval if the node // becomes aware that it might be degraded and not meeting the soft real // time requirements to reliably probe other nodes. @@ -215,6 +236,48 @@ type Config struct { // This is a legacy name for backward compatibility but should really be // called PacketBufferSize now that we have generalized the transport. UDPBufferSize int + + // DeadNodeReclaimTime controls the time before a dead node's name can be + // reclaimed by one with a different address or port. By default, this is 0, + // meaning nodes cannot be reclaimed this way. + DeadNodeReclaimTime time.Duration + + // RequireNodeNames controls if the name of a node is required when sending + // a message to that node. + RequireNodeNames bool + + // CIDRsAllowed If nil, allow any connection (default), otherwise specify all networks + // allowed to connect (you must specify IPv6/IPv4 separately) + // Using [] will block all connections. + CIDRsAllowed []net.IPNet + + // MetricLabels is a map of optional labels to apply to all metrics emitted. + MetricLabels []metrics.Label +} + +// ParseCIDRs return a possible empty list of all Network that have been parsed +// In case of error, it returns succesfully parsed CIDRs and the last error found +func ParseCIDRs(v []string) ([]net.IPNet, error) { + nets := make([]net.IPNet, 0) + if v == nil { + return nets, nil + } + var errs error + hasErrors := false + for _, p := range v { + _, net, err := net.ParseCIDR(strings.TrimSpace(p)) + if err != nil { + err = fmt.Errorf("invalid cidr: %s", p) + errs = multierror.Append(errs, err) + hasErrors = true + } else { + nets = append(nets, *net) + } + } + if !hasErrors { + errs = nil + } + return nets, errs } // DefaultLANConfig returns a sane set of configurations for Memberlist. @@ -258,6 +321,7 @@ func DefaultLANConfig() *Config { HandoffQueueDepth: 1024, UDPBufferSize: 1400, + CIDRsAllowed: nil, // same as allow all } } @@ -277,6 +341,24 @@ func DefaultWANConfig() *Config { return conf } +// IPMustBeChecked return true if IPAllowed must be called +func (c *Config) IPMustBeChecked() bool { + return len(c.CIDRsAllowed) > 0 +} + +// IPAllowed return an error if access to memberlist is denied +func (c *Config) IPAllowed(ip net.IP) error { + if !c.IPMustBeChecked() { + return nil + } + for _, n := range c.CIDRsAllowed { + if n.Contains(ip) { + return nil + } + } + return fmt.Errorf("%s is not allowed", ip) +} + // DefaultLocalConfig works like DefaultConfig, however it returns a configuration // that is optimized for a local loopback environments. The default configuration is // still very conservative and errs on the side of caution. diff --git a/vendor/github.com/hashicorp/memberlist/event_delegate.go b/vendor/github.com/hashicorp/memberlist/event_delegate.go index 35e2a56fdd..352f98b43e 100644 --- a/vendor/github.com/hashicorp/memberlist/event_delegate.go +++ b/vendor/github.com/hashicorp/memberlist/event_delegate.go @@ -49,13 +49,16 @@ type NodeEvent struct { } func (c *ChannelEventDelegate) NotifyJoin(n *Node) { - c.Ch <- NodeEvent{NodeJoin, n} + node := *n + c.Ch <- NodeEvent{NodeJoin, &node} } func (c *ChannelEventDelegate) NotifyLeave(n *Node) { - c.Ch <- NodeEvent{NodeLeave, n} + node := *n + c.Ch <- NodeEvent{NodeLeave, &node} } func (c *ChannelEventDelegate) NotifyUpdate(n *Node) { - c.Ch <- NodeEvent{NodeUpdate, n} + node := *n + c.Ch <- NodeEvent{NodeUpdate, &node} } diff --git a/vendor/github.com/hashicorp/memberlist/go.mod b/vendor/github.com/hashicorp/memberlist/go.mod new file mode 100644 index 0000000000..454def3001 --- /dev/null +++ b/vendor/github.com/hashicorp/memberlist/go.mod @@ -0,0 +1,19 @@ +module github.com/hashicorp/memberlist + +go 1.12 + +require ( + github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c + github.com/hashicorp/go-immutable-radix v1.0.0 // indirect + github.com/hashicorp/go-msgpack v0.5.3 + github.com/hashicorp/go-multierror v1.0.0 + github.com/hashicorp/go-sockaddr v1.0.0 + github.com/miekg/dns v1.1.26 + github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529 + github.com/stretchr/testify v1.2.2 + golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect +) diff --git a/vendor/github.com/hashicorp/memberlist/label.go b/vendor/github.com/hashicorp/memberlist/label.go new file mode 100644 index 0000000000..bbe0163ab6 --- /dev/null +++ b/vendor/github.com/hashicorp/memberlist/label.go @@ -0,0 +1,178 @@ +package memberlist + +import ( + "bufio" + "fmt" + "io" + "net" +) + +// General approach is to prefix all packets and streams with the same structure: +// +// magic type byte (244): uint8 +// length of label name: uint8 (because labels can't be longer than 255 bytes) +// label name: []uint8 + +// LabelMaxSize is the maximum length of a packet or stream label. +const LabelMaxSize = 255 + +// AddLabelHeaderToPacket prefixes outgoing packets with the correct header if +// the label is not empty. +func AddLabelHeaderToPacket(buf []byte, label string) ([]byte, error) { + if label == "" { + return buf, nil + } + if len(label) > LabelMaxSize { + return nil, fmt.Errorf("label %q is too long", label) + } + + return makeLabelHeader(label, buf), nil +} + +// RemoveLabelHeaderFromPacket removes any label header from the provided +// packet and returns it along with the remaining packet contents. +func RemoveLabelHeaderFromPacket(buf []byte) (newBuf []byte, label string, err error) { + if len(buf) == 0 { + return buf, "", nil // can't possibly be labeled + } + + // [type:byte] [size:byte] [size bytes] + + msgType := messageType(buf[0]) + if msgType != hasLabelMsg { + return buf, "", nil + } + + if len(buf) < 2 { + return nil, "", fmt.Errorf("cannot decode label; packet has been truncated") + } + + size := int(buf[1]) + if size < 1 { + return nil, "", fmt.Errorf("label header cannot be empty when present") + } + + if len(buf) < 2+size { + return nil, "", fmt.Errorf("cannot decode label; packet has been truncated") + } + + label = string(buf[2 : 2+size]) + newBuf = buf[2+size:] + + return newBuf, label, nil +} + +// AddLabelHeaderToStream prefixes outgoing streams with the correct header if +// the label is not empty. +func AddLabelHeaderToStream(conn net.Conn, label string) error { + if label == "" { + return nil + } + if len(label) > LabelMaxSize { + return fmt.Errorf("label %q is too long", label) + } + + header := makeLabelHeader(label, nil) + + _, err := conn.Write(header) + return err +} + +// RemoveLabelHeaderFromStream removes any label header from the beginning of +// the stream if present and returns it along with an updated conn with that +// header removed. +// +// Note that on error it is the caller's responsibility to close the +// connection. +func RemoveLabelHeaderFromStream(conn net.Conn) (net.Conn, string, error) { + br := bufio.NewReader(conn) + + // First check for the type byte. + peeked, err := br.Peek(1) + if err != nil { + if err == io.EOF { + // It is safe to return the original net.Conn at this point because + // it never contained any data in the first place so we don't have + // to splice the buffer into the conn because both are empty. + return conn, "", nil + } + return nil, "", err + } + + msgType := messageType(peeked[0]) + if msgType != hasLabelMsg { + conn, err = newPeekedConnFromBufferedReader(conn, br, 0) + return conn, "", err + } + + // We are guaranteed to get a size byte as well. + peeked, err = br.Peek(2) + if err != nil { + if err == io.EOF { + return nil, "", fmt.Errorf("cannot decode label; stream has been truncated") + } + return nil, "", err + } + + size := int(peeked[1]) + if size < 1 { + return nil, "", fmt.Errorf("label header cannot be empty when present") + } + // NOTE: we don't have to check this against LabelMaxSize because a byte + // already has a max value of 255. + + // Once we know the size we can peek the label as well. Note that since we + // are using the default bufio.Reader size of 4096, the entire label header + // fits in the initial buffer fill so this should be free. + peeked, err = br.Peek(2 + size) + if err != nil { + if err == io.EOF { + return nil, "", fmt.Errorf("cannot decode label; stream has been truncated") + } + return nil, "", err + } + + label := string(peeked[2 : 2+size]) + + conn, err = newPeekedConnFromBufferedReader(conn, br, 2+size) + if err != nil { + return nil, "", err + } + + return conn, label, nil +} + +// newPeekedConnFromBufferedReader will splice the buffer contents after the +// offset into the provided net.Conn and return the result so that the rest of +// the buffer contents are returned first when reading from the returned +// peekedConn before moving on to the unbuffered conn contents. +func newPeekedConnFromBufferedReader(conn net.Conn, br *bufio.Reader, offset int) (*peekedConn, error) { + // Extract any of the readahead buffer. + peeked, err := br.Peek(br.Buffered()) + if err != nil { + return nil, err + } + + return &peekedConn{ + Peeked: peeked[offset:], + Conn: conn, + }, nil +} + +func makeLabelHeader(label string, rest []byte) []byte { + newBuf := make([]byte, 2, 2+len(label)+len(rest)) + newBuf[0] = byte(hasLabelMsg) + newBuf[1] = byte(len(label)) + newBuf = append(newBuf, []byte(label)...) + if len(rest) > 0 { + newBuf = append(newBuf, []byte(rest)...) + } + return newBuf +} + +func labelOverhead(label string) int { + if label == "" { + return 0 + } + return 2 + len(label) +} diff --git a/vendor/github.com/hashicorp/memberlist/logging.go b/vendor/github.com/hashicorp/memberlist/logging.go index f31acfb2fa..2ca2bab4e3 100644 --- a/vendor/github.com/hashicorp/memberlist/logging.go +++ b/vendor/github.com/hashicorp/memberlist/logging.go @@ -13,6 +13,14 @@ func LogAddress(addr net.Addr) string { return fmt.Sprintf("from=%s", addr.String()) } +func LogStringAddress(addr string) string { + if addr == "" { + return "from=" + } + + return fmt.Sprintf("from=%s", addr) +} + func LogConn(conn net.Conn) string { if conn == nil { return LogAddress(nil) diff --git a/vendor/github.com/hashicorp/memberlist/memberlist.go b/vendor/github.com/hashicorp/memberlist/memberlist.go index e9084f9fd4..512701dee1 100644 --- a/vendor/github.com/hashicorp/memberlist/memberlist.go +++ b/vendor/github.com/hashicorp/memberlist/memberlist.go @@ -15,6 +15,8 @@ multiple routes. package memberlist import ( + "container/list" + "errors" "fmt" "log" "net" @@ -25,15 +27,23 @@ import ( "sync/atomic" "time" + "github.com/armon/go-metrics" multierror "github.com/hashicorp/go-multierror" sockaddr "github.com/hashicorp/go-sockaddr" "github.com/miekg/dns" ) +var errNodeNamesAreRequired = errors.New("memberlist: node names are required by configuration but one was not provided") + type Memberlist struct { sequenceNum uint32 // Local sequence number incarnation uint32 // Local incarnation number numNodes uint32 // Number of known nodes (estimate) + pushPullReq uint32 // Number of push/pull requests + + advertiseLock sync.RWMutex + advertiseAddr net.IP + advertisePort uint16 config *Config shutdown int32 // Used as an atomic boolean value @@ -44,13 +54,17 @@ type Memberlist struct { shutdownLock sync.Mutex // Serializes calls to Shutdown leaveLock sync.Mutex // Serializes calls to Leave - transport Transport - handoff chan msgHandoff + transport NodeAwareTransport + + handoffCh chan struct{} + highPriorityMsgQueue *list.List + lowPriorityMsgQueue *list.List + msgQueueLock sync.Mutex nodeLock sync.RWMutex nodes []*nodeState // Known nodes - nodeMap map[string]*nodeState // Maps Addr.String() -> NodeState - nodeTimers map[string]*suspicion // Maps Addr.String() -> suspicion timer + nodeMap map[string]*nodeState // Maps Node.Name -> NodeState + nodeTimers map[string]*suspicion // Maps Node.Name -> suspicion timer awareness *awareness tickerLock sync.Mutex @@ -64,6 +78,18 @@ type Memberlist struct { broadcasts *TransmitLimitedQueue logger *log.Logger + + // metricLabels is the slice of labels to put on all emitted metrics + metricLabels []metrics.Label +} + +// BuildVsnArray creates the array of Vsn +func (conf *Config) BuildVsnArray() []uint8 { + return []uint8{ + ProtocolVersionMin, ProtocolVersionMax, conf.ProtocolVersion, + conf.DelegateProtocolMin, conf.DelegateProtocolMax, + conf.DelegateProtocolVersion, + } } // newMemberlist creates the network listeners. @@ -113,9 +139,10 @@ func newMemberlist(conf *Config) (*Memberlist, error) { transport := conf.Transport if transport == nil { nc := &NetTransportConfig{ - BindAddrs: []string{conf.BindAddr}, - BindPort: conf.BindPort, - Logger: logger, + BindAddrs: []string{conf.BindAddr}, + BindPort: conf.BindPort, + Logger: logger, + MetricLabels: conf.MetricLabels, } // See comment below for details about the retry in here. @@ -159,22 +186,50 @@ func newMemberlist(conf *Config) (*Memberlist, error) { transport = nt } + nodeAwareTransport, ok := transport.(NodeAwareTransport) + if !ok { + logger.Printf("[DEBUG] memberlist: configured Transport is not a NodeAwareTransport and some features may not work as desired") + nodeAwareTransport = &shimNodeAwareTransport{transport} + } + + if len(conf.Label) > LabelMaxSize { + return nil, fmt.Errorf("could not use %q as a label: too long", conf.Label) + } + + if conf.Label != "" { + nodeAwareTransport = &labelWrappedTransport{ + label: conf.Label, + NodeAwareTransport: nodeAwareTransport, + } + } + m := &Memberlist{ - config: conf, - shutdownCh: make(chan struct{}), - leaveBroadcast: make(chan struct{}, 1), - transport: transport, - handoff: make(chan msgHandoff, conf.HandoffQueueDepth), - nodeMap: make(map[string]*nodeState), - nodeTimers: make(map[string]*suspicion), - awareness: newAwareness(conf.AwarenessMaxMultiplier), - ackHandlers: make(map[uint32]*ackHandler), - broadcasts: &TransmitLimitedQueue{RetransmitMult: conf.RetransmitMult}, - logger: logger, + config: conf, + shutdownCh: make(chan struct{}), + leaveBroadcast: make(chan struct{}, 1), + transport: nodeAwareTransport, + handoffCh: make(chan struct{}, 1), + highPriorityMsgQueue: list.New(), + lowPriorityMsgQueue: list.New(), + nodeMap: make(map[string]*nodeState), + nodeTimers: make(map[string]*suspicion), + awareness: newAwareness(conf.AwarenessMaxMultiplier, conf.MetricLabels), + ackHandlers: make(map[uint32]*ackHandler), + broadcasts: &TransmitLimitedQueue{RetransmitMult: conf.RetransmitMult}, + logger: logger, + metricLabels: conf.MetricLabels, } m.broadcasts.NumNodes = func() int { return m.estNumNodes() } + + // Get the final advertise address from the transport, which may need + // to see which address we bound to. We'll refresh this each time we + // send out an alive message. + if _, _, err := m.refreshAdvertise(); err != nil { + return nil, err + } + go m.streamListen() go m.packetListen() go m.packetHandler() @@ -222,8 +277,9 @@ func (m *Memberlist) Join(existing []string) (int, error) { for _, addr := range addrs { hp := joinHostPort(addr.ip.String(), addr.port) - if err := m.pushPullNode(hp, true); err != nil { - err = fmt.Errorf("Failed to join %s: %v", addr.ip, err) + a := Address{Addr: hp, Name: addr.nodeName} + if err := m.pushPullNode(a, true); err != nil { + err = fmt.Errorf("Failed to join %s: %v", a.Addr, err) errs = multierror.Append(errs, err) m.logger.Printf("[DEBUG] memberlist: %v", err) continue @@ -240,8 +296,9 @@ func (m *Memberlist) Join(existing []string) (int, error) { // ipPort holds information about a node we want to try to join. type ipPort struct { - ip net.IP - port uint16 + ip net.IP + port uint16 + nodeName string // optional } // tcpLookupIP is a helper to initiate a TCP-based DNS lookup for the given host. @@ -250,7 +307,7 @@ type ipPort struct { // Consul's. By doing the TCP lookup directly, we get the best chance for the // largest list of hosts to join. Since joins are relatively rare events, it's ok // to do this rather expensive operation. -func (m *Memberlist) tcpLookupIP(host string, defaultPort uint16) ([]ipPort, error) { +func (m *Memberlist) tcpLookupIP(host string, defaultPort uint16, nodeName string) ([]ipPort, error) { // Don't attempt any TCP lookups against non-fully qualified domain // names, since those will likely come from the resolv.conf file. if !strings.Contains(host, ".") { @@ -292,9 +349,9 @@ func (m *Memberlist) tcpLookupIP(host string, defaultPort uint16) ([]ipPort, err for _, r := range in.Answer { switch rr := r.(type) { case (*dns.A): - ips = append(ips, ipPort{rr.A, defaultPort}) + ips = append(ips, ipPort{ip: rr.A, port: defaultPort, nodeName: nodeName}) case (*dns.AAAA): - ips = append(ips, ipPort{rr.AAAA, defaultPort}) + ips = append(ips, ipPort{ip: rr.AAAA, port: defaultPort, nodeName: nodeName}) case (*dns.CNAME): m.logger.Printf("[DEBUG] memberlist: Ignoring CNAME RR in TCP-first answer for '%s'", host) } @@ -308,6 +365,16 @@ func (m *Memberlist) tcpLookupIP(host string, defaultPort uint16) ([]ipPort, err // resolveAddr is used to resolve the address into an address, // port, and error. If no port is given, use the default func (m *Memberlist) resolveAddr(hostStr string) ([]ipPort, error) { + // First peel off any leading node name. This is optional. + nodeName := "" + if slashIdx := strings.Index(hostStr, "/"); slashIdx >= 0 { + if slashIdx == 0 { + return nil, fmt.Errorf("empty node name provided") + } + nodeName = hostStr[0:slashIdx] + hostStr = hostStr[slashIdx+1:] + } + // This captures the supplied port, or the default one. hostStr = ensurePort(hostStr, m.config.BindPort) host, sport, err := net.SplitHostPort(hostStr) @@ -324,13 +391,15 @@ func (m *Memberlist) resolveAddr(hostStr string) ([]ipPort, error) { // will make sure the host part is in good shape for parsing, even for // IPv6 addresses. if ip := net.ParseIP(host); ip != nil { - return []ipPort{ipPort{ip, port}}, nil + return []ipPort{ + ipPort{ip: ip, port: port, nodeName: nodeName}, + }, nil } // First try TCP so we have the best chance for the largest list of // hosts to join. If this fails it's not fatal since this isn't a standard // way to query DNS, and we have a fallback below. - ips, err := m.tcpLookupIP(host, port) + ips, err := m.tcpLookupIP(host, port, nodeName) if err != nil { m.logger.Printf("[DEBUG] memberlist: TCP-first lookup failed for '%s', falling back to UDP: %s", hostStr, err) } @@ -347,7 +416,7 @@ func (m *Memberlist) resolveAddr(hostStr string) ([]ipPort, error) { } ips = make([]ipPort, 0, len(ans)) for _, ip := range ans { - ips = append(ips, ipPort{ip, port}) + ips = append(ips, ipPort{ip: ip, port: port, nodeName: nodeName}) } return ips, nil } @@ -358,10 +427,9 @@ func (m *Memberlist) resolveAddr(hostStr string) ([]ipPort, error) { func (m *Memberlist) setAlive() error { // Get the final advertise address from the transport, which may need // to see which address we bound to. - addr, port, err := m.transport.FinalAdvertiseAddr( - m.config.AdvertiseAddr, m.config.AdvertisePort) + addr, port, err := m.refreshAdvertise() if err != nil { - return fmt.Errorf("Failed to get final advertise address: %v", err) + return err } // Check if this is a public address without encryption @@ -394,16 +462,36 @@ func (m *Memberlist) setAlive() error { Addr: addr, Port: uint16(port), Meta: meta, - Vsn: []uint8{ - ProtocolVersionMin, ProtocolVersionMax, m.config.ProtocolVersion, - m.config.DelegateProtocolMin, m.config.DelegateProtocolMax, - m.config.DelegateProtocolVersion, - }, + Vsn: m.config.BuildVsnArray(), } m.aliveNode(&a, nil, true) + return nil } +func (m *Memberlist) getAdvertise() (net.IP, uint16) { + m.advertiseLock.RLock() + defer m.advertiseLock.RUnlock() + return m.advertiseAddr, m.advertisePort +} + +func (m *Memberlist) setAdvertise(addr net.IP, port int) { + m.advertiseLock.Lock() + defer m.advertiseLock.Unlock() + m.advertiseAddr = addr + m.advertisePort = uint16(port) +} + +func (m *Memberlist) refreshAdvertise() (net.IP, int, error) { + addr, port, err := m.transport.FinalAdvertiseAddr( + m.config.AdvertiseAddr, m.config.AdvertisePort) + if err != nil { + return nil, 0, fmt.Errorf("Failed to get final advertise address: %v", err) + } + m.setAdvertise(addr, port) + return addr, port, nil +} + // LocalNode is used to return the local Node func (m *Memberlist) LocalNode() *Node { m.nodeLock.RLock() @@ -439,11 +527,7 @@ func (m *Memberlist) UpdateNode(timeout time.Duration) error { Addr: state.Addr, Port: state.Port, Meta: meta, - Vsn: []uint8{ - ProtocolVersionMin, ProtocolVersionMax, m.config.ProtocolVersion, - m.config.DelegateProtocolMin, m.config.DelegateProtocolMax, - m.config.DelegateProtocolVersion, - }, + Vsn: m.config.BuildVsnArray(), } notifyCh := make(chan struct{}) m.aliveNode(&a, notifyCh, true) @@ -463,24 +547,29 @@ func (m *Memberlist) UpdateNode(timeout time.Duration) error { return nil } -// SendTo is deprecated in favor of SendBestEffort, which requires a node to -// target. +// Deprecated: SendTo is deprecated in favor of SendBestEffort, which requires a node to +// target. If you don't have a node then use SendToAddress. func (m *Memberlist) SendTo(to net.Addr, msg []byte) error { + a := Address{Addr: to.String(), Name: ""} + return m.SendToAddress(a, msg) +} + +func (m *Memberlist) SendToAddress(a Address, msg []byte) error { // Encode as a user message buf := make([]byte, 1, len(msg)+1) buf[0] = byte(userMsg) buf = append(buf, msg...) // Send the message - return m.rawSendMsgPacket(to.String(), nil, buf) + return m.rawSendMsgPacket(a, nil, buf) } -// SendToUDP is deprecated in favor of SendBestEffort. +// Deprecated: SendToUDP is deprecated in favor of SendBestEffort. func (m *Memberlist) SendToUDP(to *Node, msg []byte) error { return m.SendBestEffort(to, msg) } -// SendToTCP is deprecated in favor of SendReliable. +// Deprecated: SendToTCP is deprecated in favor of SendReliable. func (m *Memberlist) SendToTCP(to *Node, msg []byte) error { return m.SendReliable(to, msg) } @@ -496,7 +585,8 @@ func (m *Memberlist) SendBestEffort(to *Node, msg []byte) error { buf = append(buf, msg...) // Send the message - return m.rawSendMsgPacket(to.Address(), to, buf) + a := Address{Addr: to.Address(), Name: to.Name} + return m.rawSendMsgPacket(a, to, buf) } // SendReliable uses the reliable stream-oriented interface of the transport to @@ -504,7 +594,7 @@ func (m *Memberlist) SendBestEffort(to *Node, msg []byte) error { // mechanism). Delivery is guaranteed if no error is returned, and there is no // limit on the size of the message. func (m *Memberlist) SendReliable(to *Node, msg []byte) error { - return m.sendUserMsg(to.Address(), msg) + return m.sendUserMsg(to.FullAddress(), msg) } // Members returns a list of all known live nodes. The node structures @@ -516,7 +606,7 @@ func (m *Memberlist) Members() []*Node { nodes := make([]*Node, 0, len(m.nodes)) for _, n := range m.nodes { - if n.State != stateDead { + if !n.DeadOrLeft() { nodes = append(nodes, &n.Node) } } @@ -533,7 +623,7 @@ func (m *Memberlist) NumMembers() (alive int) { defer m.nodeLock.RUnlock() for _, n := range m.nodes { - if n.State != stateDead { + if !n.DeadOrLeft() { alive++ } } @@ -570,9 +660,14 @@ func (m *Memberlist) Leave(timeout time.Duration) error { return nil } + // This dead message is special, because Node and From are the + // same. This helps other nodes figure out that a node left + // intentionally. When Node equals From, other nodes know for + // sure this node is gone. d := dead{ Incarnation: state.Incarnation, Node: state.Name, + From: state.Name, } m.deadNode(&d) @@ -598,7 +693,7 @@ func (m *Memberlist) anyAlive() bool { m.nodeLock.RLock() defer m.nodeLock.RUnlock() for _, n := range m.nodes { - if n.State != stateDead && n.Name != m.config.Name { + if !n.DeadOrLeft() && n.Name != m.config.Name { return true } } @@ -621,7 +716,7 @@ func (m *Memberlist) ProtocolVersion() uint8 { return m.config.ProtocolVersion } -// Shutdown will stop any background maintanence of network activity +// Shutdown will stop any background maintenance of network activity // for this memberlist, causing it to appear "dead". A leave message // will not be broadcasted prior, so the cluster being left will have // to detect this node's shutdown using probing. If you wish to more @@ -657,3 +752,27 @@ func (m *Memberlist) hasShutdown() bool { func (m *Memberlist) hasLeft() bool { return atomic.LoadInt32(&m.leave) == 1 } + +func (m *Memberlist) getNodeState(addr string) NodeStateType { + m.nodeLock.RLock() + defer m.nodeLock.RUnlock() + + n := m.nodeMap[addr] + return n.State +} + +func (m *Memberlist) getNodeStateChange(addr string) time.Time { + m.nodeLock.RLock() + defer m.nodeLock.RUnlock() + + n := m.nodeMap[addr] + return n.StateChange +} + +func (m *Memberlist) changeNode(addr string, f func(*nodeState)) { + m.nodeLock.Lock() + defer m.nodeLock.Unlock() + + n := m.nodeMap[addr] + f(n) +} diff --git a/vendor/github.com/hashicorp/memberlist/mock_transport.go b/vendor/github.com/hashicorp/memberlist/mock_transport.go index b8bafa8026..0a7d30a277 100644 --- a/vendor/github.com/hashicorp/memberlist/mock_transport.go +++ b/vendor/github.com/hashicorp/memberlist/mock_transport.go @@ -1,7 +1,9 @@ package memberlist import ( + "bytes" "fmt" + "io" "net" "strconv" "time" @@ -10,26 +12,33 @@ import ( // MockNetwork is used as a factory that produces MockTransport instances which // are uniquely addressed and wired up to talk to each other. type MockNetwork struct { - transports map[string]*MockTransport - port int + transportsByAddr map[string]*MockTransport + transportsByName map[string]*MockTransport + port int } // NewTransport returns a new MockTransport with a unique address, wired up to // talk to the other transports in the MockNetwork. -func (n *MockNetwork) NewTransport() *MockTransport { +func (n *MockNetwork) NewTransport(name string) *MockTransport { n.port += 1 addr := fmt.Sprintf("127.0.0.1:%d", n.port) transport := &MockTransport{ net: n, - addr: &MockAddress{addr}, + addr: &MockAddress{addr, name}, packetCh: make(chan *Packet), streamCh: make(chan net.Conn), } - if n.transports == nil { - n.transports = make(map[string]*MockTransport) + if n.transportsByAddr == nil { + n.transportsByAddr = make(map[string]*MockTransport) } - n.transports[addr] = transport + n.transportsByAddr[addr] = transport + + if n.transportsByName == nil { + n.transportsByName = make(map[string]*MockTransport) + } + n.transportsByName[name] = transport + return transport } @@ -37,6 +46,7 @@ func (n *MockNetwork) NewTransport() *MockTransport { // address scheme. type MockAddress struct { addr string + name string } // See net.Addr. @@ -57,6 +67,8 @@ type MockTransport struct { streamCh chan net.Conn } +var _ NodeAwareTransport = (*MockTransport)(nil) + // See Transport. func (t *MockTransport) FinalAdvertiseAddr(string, int) (net.IP, int, error) { host, portStr, err := net.SplitHostPort(t.addr.String()) @@ -79,9 +91,15 @@ func (t *MockTransport) FinalAdvertiseAddr(string, int) (net.IP, int, error) { // See Transport. func (t *MockTransport) WriteTo(b []byte, addr string) (time.Time, error) { - dest, ok := t.net.transports[addr] - if !ok { - return time.Time{}, fmt.Errorf("No route to %q", addr) + a := Address{Addr: addr, Name: ""} + return t.WriteToAddress(b, a) +} + +// See NodeAwareTransport. +func (t *MockTransport) WriteToAddress(b []byte, a Address) (time.Time, error) { + dest, err := t.getPeer(a) + if err != nil { + return time.Time{}, err } now := time.Now() @@ -98,11 +116,45 @@ func (t *MockTransport) PacketCh() <-chan *Packet { return t.packetCh } +// See NodeAwareTransport. +func (t *MockTransport) IngestPacket(conn net.Conn, addr net.Addr, now time.Time, shouldClose bool) error { + if shouldClose { + defer conn.Close() + } + + // Copy everything from the stream into packet buffer. + var buf bytes.Buffer + if _, err := io.Copy(&buf, conn); err != nil { + return fmt.Errorf("failed to read packet: %v", err) + } + + // Check the length - it needs to have at least one byte to be a proper + // message. This is checked elsewhere for writes coming in directly from + // the UDP socket. + if n := buf.Len(); n < 1 { + return fmt.Errorf("packet too short (%d bytes) %s", n, LogAddress(addr)) + } + + // Inject the packet. + t.packetCh <- &Packet{ + Buf: buf.Bytes(), + From: addr, + Timestamp: now, + } + return nil +} + // See Transport. func (t *MockTransport) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) { - dest, ok := t.net.transports[addr] - if !ok { - return nil, fmt.Errorf("No route to %q", addr) + a := Address{Addr: addr, Name: ""} + return t.DialAddressTimeout(a, timeout) +} + +// See NodeAwareTransport. +func (t *MockTransport) DialAddressTimeout(a Address, timeout time.Duration) (net.Conn, error) { + dest, err := t.getPeer(a) + if err != nil { + return nil, err } p1, p2 := net.Pipe() @@ -115,7 +167,29 @@ func (t *MockTransport) StreamCh() <-chan net.Conn { return t.streamCh } +// See NodeAwareTransport. +func (t *MockTransport) IngestStream(conn net.Conn) error { + t.streamCh <- conn + return nil +} + // See Transport. func (t *MockTransport) Shutdown() error { return nil } + +func (t *MockTransport) getPeer(a Address) (*MockTransport, error) { + var ( + dest *MockTransport + ok bool + ) + if a.Name != "" { + dest, ok = t.net.transportsByName[a.Name] + } else { + dest, ok = t.net.transportsByAddr[a.Addr] + } + if !ok { + return nil, fmt.Errorf("No route to %s", a) + } + return dest, nil +} diff --git a/vendor/github.com/hashicorp/memberlist/net.go b/vendor/github.com/hashicorp/memberlist/net.go index a4330c4d20..a8291c4f38 100644 --- a/vendor/github.com/hashicorp/memberlist/net.go +++ b/vendor/github.com/hashicorp/memberlist/net.go @@ -7,10 +7,12 @@ import ( "fmt" "hash/crc32" "io" + "math" "net" + "sync/atomic" "time" - "github.com/armon/go-metrics" + metrics "github.com/armon/go-metrics" "github.com/hashicorp/go-msgpack/codec" ) @@ -41,6 +43,9 @@ const ( type messageType uint8 // The list of available message types. +// +// WARNING: ONLY APPEND TO THIS LIST! The numeric values are part of the +// protocol itself. const ( pingMsg messageType = iota indirectPingMsg @@ -58,6 +63,13 @@ const ( errMsg ) +const ( + // hasLabelMsg has a deliberately high value so that you can disambiguate + // it from the encryptionVersion header which is either 0/1 right now and + // also any of the existing messageTypes + hasLabelMsg messageType = 244 +) + // compressionType is used to specify the compression algorithm type compressionType uint8 @@ -71,7 +83,8 @@ const ( compoundOverhead = 2 // Assumed overhead per entry in compoundHeader userMsgOverhead = 1 blockingWarning = 10 * time.Millisecond // Warn if a UDP packet takes this long to process - maxPushStateBytes = 10 * 1024 * 1024 + maxPushStateBytes = 20 * 1024 * 1024 + maxPushPullRequests = 128 // Maximum number of concurrent push/pull requests ) // ping request sent directly to node @@ -82,15 +95,28 @@ type ping struct { // the intended recipient. This is to protect again an agent // restart with a new name. Node string + + SourceAddr []byte `codec:",omitempty"` // Source address, used for a direct reply + SourcePort uint16 `codec:",omitempty"` // Source port, used for a direct reply + SourceNode string `codec:",omitempty"` // Source name, used for a direct reply } -// indirect ping sent to an indirect ndoe +// indirect ping sent to an indirect node type indirectPingReq struct { SeqNo uint32 Target []byte Port uint16 - Node string - Nack bool // true if we'd like a nack back + + // Node is sent so the target can verify they are + // the intended recipient. This is to protect against an agent + // restart with a new name. + Node string + + Nack bool // true if we'd like a nack back + + SourceAddr []byte `codec:",omitempty"` // Source address, used for a direct reply + SourcePort uint16 `codec:",omitempty"` // Source port, used for a direct reply + SourceNode string `codec:",omitempty"` // Source name, used for a direct reply } // ack response is sent for a ping @@ -161,7 +187,7 @@ type pushNodeState struct { Port uint16 Meta []byte Incarnation uint32 - State nodeStateType + State NodeStateType Vsn []uint8 // Protocol versions } @@ -205,13 +231,38 @@ func (m *Memberlist) streamListen() { // handleConn handles a single incoming stream connection from the transport. func (m *Memberlist) handleConn(conn net.Conn) { + defer conn.Close() m.logger.Printf("[DEBUG] memberlist: Stream connection %s", LogConn(conn)) - defer conn.Close() - metrics.IncrCounter([]string{"memberlist", "tcp", "accept"}, 1) + metrics.IncrCounterWithLabels([]string{"memberlist", "tcp", "accept"}, 1, m.metricLabels) conn.SetDeadline(time.Now().Add(m.config.TCPTimeout)) - msgType, bufConn, dec, err := m.readStream(conn) + + var ( + streamLabel string + err error + ) + conn, streamLabel, err = RemoveLabelHeaderFromStream(conn) + if err != nil { + m.logger.Printf("[ERR] memberlist: failed to receive and remove the stream label header: %s %s", err, LogConn(conn)) + return + } + + if m.config.SkipInboundLabelCheck { + if streamLabel != "" { + m.logger.Printf("[ERR] memberlist: unexpected double stream label header: %s", LogConn(conn)) + return + } + // Set this from config so that the auth data assertions work below. + streamLabel = m.config.Label + } + + if m.config.Label != streamLabel { + m.logger.Printf("[ERR] memberlist: discarding stream with unacceptable label %q: %s", streamLabel, LogConn(conn)) + return + } + + msgType, bufConn, dec, err := m.readStream(conn, streamLabel) if err != nil { if err != io.EOF { m.logger.Printf("[ERR] memberlist: failed to receive: %s %s", err, LogConn(conn)) @@ -223,7 +274,7 @@ func (m *Memberlist) handleConn(conn net.Conn) { return } - err = m.rawSendMsgStream(conn, out.Bytes()) + err = m.rawSendMsgStream(conn, out.Bytes(), streamLabel) if err != nil { m.logger.Printf("[ERR] memberlist: Failed to send error: %s %s", err, LogConn(conn)) return @@ -238,13 +289,23 @@ func (m *Memberlist) handleConn(conn net.Conn) { m.logger.Printf("[ERR] memberlist: Failed to receive user message: %s %s", err, LogConn(conn)) } case pushPullMsg: + // Increment counter of pending push/pulls + numConcurrent := atomic.AddUint32(&m.pushPullReq, 1) + defer atomic.AddUint32(&m.pushPullReq, ^uint32(0)) + + // Check if we have too many open push/pull requests + if numConcurrent >= maxPushPullRequests { + m.logger.Printf("[ERR] memberlist: Too many pending push/pull requests") + return + } + join, remoteNodes, userState, err := m.readRemoteState(bufConn, dec) if err != nil { m.logger.Printf("[ERR] memberlist: Failed to read remote state: %s %s", err, LogConn(conn)) return } - if err := m.sendLocalState(conn, join); err != nil { + if err := m.sendLocalState(conn, join, streamLabel); err != nil { m.logger.Printf("[ERR] memberlist: Failed to push local state: %s %s", err, LogConn(conn)) return } @@ -272,7 +333,7 @@ func (m *Memberlist) handleConn(conn net.Conn) { return } - err = m.rawSendMsgStream(conn, out.Bytes()) + err = m.rawSendMsgStream(conn, out.Bytes(), streamLabel) if err != nil { m.logger.Printf("[ERR] memberlist: Failed to send ack: %s %s", err, LogConn(conn)) return @@ -297,10 +358,35 @@ func (m *Memberlist) packetListen() { } func (m *Memberlist) ingestPacket(buf []byte, from net.Addr, timestamp time.Time) { + var ( + packetLabel string + err error + ) + buf, packetLabel, err = RemoveLabelHeaderFromPacket(buf) + if err != nil { + m.logger.Printf("[ERR] memberlist: %v %s", err, LogAddress(from)) + return + } + + if m.config.SkipInboundLabelCheck { + if packetLabel != "" { + m.logger.Printf("[ERR] memberlist: unexpected double packet label header: %s", LogAddress(from)) + return + } + // Set this from config so that the auth data assertions work below. + packetLabel = m.config.Label + } + + if m.config.Label != packetLabel { + m.logger.Printf("[ERR] memberlist: discarding packet with unacceptable label %q: %s", packetLabel, LogAddress(from)) + return + } + // Check if encryption is enabled if m.config.EncryptionEnabled() { // Decrypt the payload - plain, err := decryptPayload(m.config.Keyring.GetKeys(), buf, nil) + authData := []byte(packetLabel) + plain, err := decryptPayload(m.config.Keyring.GetKeys(), buf, authData) if err != nil { if !m.config.GossipVerifyIncoming { // Treat the message as plaintext @@ -330,6 +416,10 @@ func (m *Memberlist) ingestPacket(buf []byte, from net.Addr, timestamp time.Time } func (m *Memberlist) handleCommand(buf []byte, from net.Addr, timestamp time.Time) { + if len(buf) < 1 { + m.logger.Printf("[ERR] memberlist: missing message type byte %s", LogAddress(from)) + return + } // Decode the message type msgType := messageType(buf[0]) buf = buf[1:] @@ -357,10 +447,25 @@ func (m *Memberlist) handleCommand(buf []byte, from net.Addr, timestamp time.Tim case deadMsg: fallthrough case userMsg: - select { - case m.handoff <- msgHandoff{msgType, buf, from}: - default: + // Determine the message queue, prioritize alive + queue := m.lowPriorityMsgQueue + if msgType == aliveMsg { + queue = m.highPriorityMsgQueue + } + + // Check for overflow and append if not full + m.msgQueueLock.Lock() + if queue.Len() >= m.config.HandoffQueueDepth { m.logger.Printf("[WARN] memberlist: handler queue full, dropping message (%d) %s", msgType, LogAddress(from)) + } else { + queue.PushBack(msgHandoff{msgType, buf, from}) + } + m.msgQueueLock.Unlock() + + // Notify of pending message + select { + case m.handoffCh <- struct{}{}: + default: } default: @@ -368,28 +473,51 @@ func (m *Memberlist) handleCommand(buf []byte, from net.Addr, timestamp time.Tim } } +// getNextMessage returns the next message to process in priority order, using LIFO +func (m *Memberlist) getNextMessage() (msgHandoff, bool) { + m.msgQueueLock.Lock() + defer m.msgQueueLock.Unlock() + + if el := m.highPriorityMsgQueue.Back(); el != nil { + m.highPriorityMsgQueue.Remove(el) + msg := el.Value.(msgHandoff) + return msg, true + } else if el := m.lowPriorityMsgQueue.Back(); el != nil { + m.lowPriorityMsgQueue.Remove(el) + msg := el.Value.(msgHandoff) + return msg, true + } + return msgHandoff{}, false +} + // packetHandler is a long running goroutine that processes messages received // over the packet interface, but is decoupled from the listener to avoid // blocking the listener which may cause ping/ack messages to be delayed. func (m *Memberlist) packetHandler() { for { select { - case msg := <-m.handoff: - msgType := msg.msgType - buf := msg.buf - from := msg.from + case <-m.handoffCh: + for { + msg, ok := m.getNextMessage() + if !ok { + break + } + msgType := msg.msgType + buf := msg.buf + from := msg.from - switch msgType { - case suspectMsg: - m.handleSuspect(buf, from) - case aliveMsg: - m.handleAlive(buf, from) - case deadMsg: - m.handleDead(buf, from) - case userMsg: - m.handleUser(buf, from) - default: - m.logger.Printf("[ERR] memberlist: Message type (%d) not supported %s (packet handler)", msgType, LogAddress(from)) + switch msgType { + case suspectMsg: + m.handleSuspect(buf, from) + case aliveMsg: + m.handleAlive(buf, from) + case deadMsg: + m.handleDead(buf, from) + case userMsg: + m.handleUser(buf, from) + default: + m.logger.Printf("[ERR] memberlist: Message type (%d) not supported %s (packet handler)", msgType, LogAddress(from)) + } } case <-m.shutdownCh: @@ -433,7 +561,19 @@ func (m *Memberlist) handlePing(buf []byte, from net.Addr) { if m.config.Ping != nil { ack.Payload = m.config.Ping.AckPayload() } - if err := m.encodeAndSendMsg(from.String(), ackRespMsg, &ack); err != nil { + + addr := "" + if len(p.SourceAddr) > 0 && p.SourcePort > 0 { + addr = joinHostPort(net.IP(p.SourceAddr).String(), p.SourcePort) + } else { + addr = from.String() + } + + a := Address{ + Addr: addr, + Name: p.SourceNode, + } + if err := m.encodeAndSendMsg(a, ackRespMsg, &ack); err != nil { m.logger.Printf("[ERR] memberlist: Failed to send ack: %s %s", err, LogAddress(from)) } } @@ -453,7 +593,25 @@ func (m *Memberlist) handleIndirectPing(buf []byte, from net.Addr) { // Send a ping to the correct host. localSeqNo := m.nextSeqNo() - ping := ping{SeqNo: localSeqNo, Node: ind.Node} + selfAddr, selfPort := m.getAdvertise() + ping := ping{ + SeqNo: localSeqNo, + Node: ind.Node, + // The outbound message is addressed FROM us. + SourceAddr: selfAddr, + SourcePort: selfPort, + SourceNode: m.config.Name, + } + + // Forward the ack back to the requestor. If the request encodes an origin + // use that otherwise assume that the other end of the UDP socket is + // usable. + indAddr := "" + if len(ind.SourceAddr) > 0 && ind.SourcePort > 0 { + indAddr = joinHostPort(net.IP(ind.SourceAddr).String(), ind.SourcePort) + } else { + indAddr = from.String() + } // Setup a response handler to relay the ack cancelCh := make(chan struct{}) @@ -461,18 +619,25 @@ func (m *Memberlist) handleIndirectPing(buf []byte, from net.Addr) { // Try to prevent the nack if we've caught it in time. close(cancelCh) - // Forward the ack back to the requestor. ack := ackResp{ind.SeqNo, nil} - if err := m.encodeAndSendMsg(from.String(), ackRespMsg, &ack); err != nil { - m.logger.Printf("[ERR] memberlist: Failed to forward ack: %s %s", err, LogAddress(from)) + a := Address{ + Addr: indAddr, + Name: ind.SourceNode, + } + if err := m.encodeAndSendMsg(a, ackRespMsg, &ack); err != nil { + m.logger.Printf("[ERR] memberlist: Failed to forward ack: %s %s", err, LogStringAddress(indAddr)) } } m.setAckHandler(localSeqNo, respHandler, m.config.ProbeTimeout) // Send the ping. addr := joinHostPort(net.IP(ind.Target).String(), ind.Port) - if err := m.encodeAndSendMsg(addr, pingMsg, &ping); err != nil { - m.logger.Printf("[ERR] memberlist: Failed to send ping: %s %s", err, LogAddress(from)) + a := Address{ + Addr: addr, + Name: ind.Node, + } + if err := m.encodeAndSendMsg(a, pingMsg, &ping); err != nil { + m.logger.Printf("[ERR] memberlist: Failed to send indirect ping: %s %s", err, LogStringAddress(indAddr)) } // Setup a timer to fire off a nack if no ack is seen in time. @@ -483,8 +648,12 @@ func (m *Memberlist) handleIndirectPing(buf []byte, from net.Addr) { return case <-time.After(m.config.ProbeTimeout): nack := nackResp{ind.SeqNo} - if err := m.encodeAndSendMsg(from.String(), nackRespMsg, &nack); err != nil { - m.logger.Printf("[ERR] memberlist: Failed to send nack: %s %s", err, LogAddress(from)) + a := Address{ + Addr: indAddr, + Name: ind.SourceNode, + } + if err := m.encodeAndSendMsg(a, nackRespMsg, &nack); err != nil { + m.logger.Printf("[ERR] memberlist: Failed to send nack: %s %s", err, LogStringAddress(indAddr)) } } }() @@ -518,12 +687,47 @@ func (m *Memberlist) handleSuspect(buf []byte, from net.Addr) { m.suspectNode(&sus) } +// ensureCanConnect return the IP from a RemoteAddress +// return error if this client must not connect +func (m *Memberlist) ensureCanConnect(from net.Addr) error { + if !m.config.IPMustBeChecked() { + return nil + } + source := from.String() + if source == "pipe" { + return nil + } + host, _, err := net.SplitHostPort(source) + if err != nil { + return err + } + + ip := net.ParseIP(host) + if ip == nil { + return fmt.Errorf("Cannot parse IP from %s", host) + } + return m.config.IPAllowed(ip) +} + func (m *Memberlist) handleAlive(buf []byte, from net.Addr) { + if err := m.ensureCanConnect(from); err != nil { + m.logger.Printf("[DEBUG] memberlist: Blocked alive message: %s %s", err, LogAddress(from)) + return + } var live alive if err := decode(buf, &live); err != nil { m.logger.Printf("[ERR] memberlist: Failed to decode alive message: %s %s", err, LogAddress(from)) return } + if m.config.IPMustBeChecked() { + innerIP := net.IP(live.Addr) + if innerIP != nil { + if err := m.config.IPAllowed(innerIP); err != nil { + m.logger.Printf("[DEBUG] memberlist: Blocked alive.Addr=%s message from: %s %s", innerIP.String(), err, LogAddress(from)) + return + } + } + } // For proto versions < 2, there is no port provided. Mask old // behavior by using the configured port @@ -565,12 +769,12 @@ func (m *Memberlist) handleCompressed(buf []byte, from net.Addr, timestamp time. } // encodeAndSendMsg is used to combine the encoding and sending steps -func (m *Memberlist) encodeAndSendMsg(addr string, msgType messageType, msg interface{}) error { +func (m *Memberlist) encodeAndSendMsg(a Address, msgType messageType, msg interface{}) error { out, err := encode(msgType, msg) if err != nil { return err } - if err := m.sendMsg(addr, out.Bytes()); err != nil { + if err := m.sendMsg(a, out.Bytes()); err != nil { return err } return nil @@ -578,9 +782,9 @@ func (m *Memberlist) encodeAndSendMsg(addr string, msgType messageType, msg inte // sendMsg is used to send a message via packet to another host. It will // opportunistically create a compoundMsg and piggy back other broadcasts. -func (m *Memberlist) sendMsg(addr string, msg []byte) error { +func (m *Memberlist) sendMsg(a Address, msg []byte) error { // Check if we can piggy back any messages - bytesAvail := m.config.UDPBufferSize - len(msg) - compoundHeaderOverhead + bytesAvail := m.config.UDPBufferSize - len(msg) - compoundHeaderOverhead - labelOverhead(m.config.Label) if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing { bytesAvail -= encryptOverhead(m.encryptionVersion()) } @@ -588,7 +792,7 @@ func (m *Memberlist) sendMsg(addr string, msg []byte) error { // Fast path if nothing to piggypack if len(extra) == 0 { - return m.rawSendMsgPacket(addr, nil, msg) + return m.rawSendMsgPacket(a, nil, msg) } // Join all the messages @@ -600,12 +804,16 @@ func (m *Memberlist) sendMsg(addr string, msg []byte) error { compound := makeCompoundMessage(msgs) // Send the message - return m.rawSendMsgPacket(addr, nil, compound.Bytes()) + return m.rawSendMsgPacket(a, nil, compound.Bytes()) } // rawSendMsgPacket is used to send message via packet to another host without // modification, other than compression or encryption if enabled. -func (m *Memberlist) rawSendMsgPacket(addr string, node *Node, msg []byte) error { +func (m *Memberlist) rawSendMsgPacket(a Address, node *Node, msg []byte) error { + if a.Name == "" && m.config.RequireNodeNames { + return errNodeNamesAreRequired + } + // Check if we have compression enabled if m.config.EnableCompression { buf, err := compressPayload(msg) @@ -619,11 +827,12 @@ func (m *Memberlist) rawSendMsgPacket(addr string, node *Node, msg []byte) error } } - // Try to look up the destination node + // Try to look up the destination node. Note this will only work if the + // bare ip address is used as the node name, which is not guaranteed. if node == nil { - toAddr, _, err := net.SplitHostPort(addr) + toAddr, _, err := net.SplitHostPort(a.Addr) if err != nil { - m.logger.Printf("[ERR] memberlist: Failed to parse address %q: %v", addr, err) + m.logger.Printf("[ERR] memberlist: Failed to parse address %q: %v", a.Addr, err) return err } m.nodeLock.RLock() @@ -647,9 +856,12 @@ func (m *Memberlist) rawSendMsgPacket(addr string, node *Node, msg []byte) error // Check if we have encryption enabled if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing { // Encrypt the payload - var buf bytes.Buffer - primaryKey := m.config.Keyring.GetPrimaryKey() - err := encryptPayload(m.encryptionVersion(), primaryKey, msg, nil, &buf) + var ( + primaryKey = m.config.Keyring.GetPrimaryKey() + packetLabel = []byte(m.config.Label) + buf bytes.Buffer + ) + err := encryptPayload(m.encryptionVersion(), primaryKey, msg, packetLabel, &buf) if err != nil { m.logger.Printf("[ERR] memberlist: Encryption of message failed: %v", err) return err @@ -657,15 +869,15 @@ func (m *Memberlist) rawSendMsgPacket(addr string, node *Node, msg []byte) error msg = buf.Bytes() } - metrics.IncrCounter([]string{"memberlist", "udp", "sent"}, float32(len(msg))) - _, err := m.transport.WriteTo(msg, addr) + metrics.IncrCounterWithLabels([]string{"memberlist", "udp", "sent"}, float32(len(msg)), m.metricLabels) + _, err := m.transport.WriteToAddress(msg, a) return err } // rawSendMsgStream is used to stream a message to another host without // modification, other than applying compression and encryption if enabled. -func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte) error { - // Check if compresion is enabled +func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte, streamLabel string) error { + // Check if compression is enabled if m.config.EnableCompression { compBuf, err := compressPayload(sendBuf) if err != nil { @@ -677,7 +889,7 @@ func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte) error { // Check if encryption is enabled if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing { - crypt, err := m.encryptLocalState(sendBuf) + crypt, err := m.encryptLocalState(sendBuf, streamLabel) if err != nil { m.logger.Printf("[ERROR] memberlist: Failed to encrypt local state: %v", err) return err @@ -686,7 +898,7 @@ func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte) error { } // Write out the entire send buffer - metrics.IncrCounter([]string{"memberlist", "tcp", "sent"}, float32(len(sendBuf))) + metrics.IncrCounterWithLabels([]string{"memberlist", "tcp", "sent"}, float32(len(sendBuf)), m.metricLabels) if n, err := conn.Write(sendBuf); err != nil { return err @@ -698,8 +910,12 @@ func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte) error { } // sendUserMsg is used to stream a user message to another host. -func (m *Memberlist) sendUserMsg(addr string, sendBuf []byte) error { - conn, err := m.transport.DialTimeout(addr, m.config.TCPTimeout) +func (m *Memberlist) sendUserMsg(a Address, sendBuf []byte) error { + if a.Name == "" && m.config.RequireNodeNames { + return errNodeNamesAreRequired + } + + conn, err := m.transport.DialAddressTimeout(a, m.config.TCPTimeout) if err != nil { return err } @@ -719,28 +935,33 @@ func (m *Memberlist) sendUserMsg(addr string, sendBuf []byte) error { if _, err := bufConn.Write(sendBuf); err != nil { return err } - return m.rawSendMsgStream(conn, bufConn.Bytes()) + + return m.rawSendMsgStream(conn, bufConn.Bytes(), m.config.Label) } // sendAndReceiveState is used to initiate a push/pull over a stream with a // remote host. -func (m *Memberlist) sendAndReceiveState(addr string, join bool) ([]pushNodeState, []byte, error) { +func (m *Memberlist) sendAndReceiveState(a Address, join bool) ([]pushNodeState, []byte, error) { + if a.Name == "" && m.config.RequireNodeNames { + return nil, nil, errNodeNamesAreRequired + } + // Attempt to connect - conn, err := m.transport.DialTimeout(addr, m.config.TCPTimeout) + conn, err := m.transport.DialAddressTimeout(a, m.config.TCPTimeout) if err != nil { return nil, nil, err } defer conn.Close() - m.logger.Printf("[DEBUG] memberlist: Initiating push/pull sync with: %s", conn.RemoteAddr()) - metrics.IncrCounter([]string{"memberlist", "tcp", "connect"}, 1) + m.logger.Printf("[DEBUG] memberlist: Initiating push/pull sync with: %s %s", a.Name, conn.RemoteAddr()) + metrics.IncrCounterWithLabels([]string{"memberlist", "tcp", "connect"}, 1, m.metricLabels) // Send our state - if err := m.sendLocalState(conn, join); err != nil { + if err := m.sendLocalState(conn, join, m.config.Label); err != nil { return nil, nil, err } conn.SetDeadline(time.Now().Add(m.config.TCPTimeout)) - msgType, bufConn, dec, err := m.readStream(conn) + msgType, bufConn, dec, err := m.readStream(conn, m.config.Label) if err != nil { return nil, nil, err } @@ -765,7 +986,7 @@ func (m *Memberlist) sendAndReceiveState(addr string, join bool) ([]pushNodeStat } // sendLocalState is invoked to send our local state over a stream connection. -func (m *Memberlist) sendLocalState(conn net.Conn, join bool) error { +func (m *Memberlist) sendLocalState(conn net.Conn, join bool, streamLabel string) error { // Setup a deadline conn.SetDeadline(time.Now().Add(m.config.TCPTimeout)) @@ -822,11 +1043,11 @@ func (m *Memberlist) sendLocalState(conn net.Conn, join bool) error { } // Get the send buffer - return m.rawSendMsgStream(conn, bufConn.Bytes()) + return m.rawSendMsgStream(conn, bufConn.Bytes(), streamLabel) } // encryptLocalState is used to help encrypt local state before sending -func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) { +func (m *Memberlist) encryptLocalState(sendBuf []byte, streamLabel string) ([]byte, error) { var buf bytes.Buffer // Write the encryptMsg byte @@ -839,9 +1060,15 @@ func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) { binary.BigEndian.PutUint32(sizeBuf, uint32(encLen)) buf.Write(sizeBuf) + // Authenticated Data is: + // + // [messageType; byte] [messageLength; uint32] [stream_label; optional] + // + dataBytes := appendBytes(buf.Bytes()[:5], []byte(streamLabel)) + // Write the encrypted cipher text to the buffer key := m.config.Keyring.GetPrimaryKey() - err := encryptPayload(encVsn, key, sendBuf, buf.Bytes()[:5], &buf) + err := encryptPayload(encVsn, key, sendBuf, dataBytes, &buf) if err != nil { return nil, err } @@ -849,7 +1076,7 @@ func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) { } // decryptRemoteState is used to help decrypt the remote state -func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) { +func (m *Memberlist) decryptRemoteState(bufConn io.Reader, streamLabel string) ([]byte, error) { // Read in enough to determine message length cipherText := bytes.NewBuffer(nil) cipherText.WriteByte(byte(encryptMsg)) @@ -863,6 +1090,12 @@ func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) { moreBytes := binary.BigEndian.Uint32(cipherText.Bytes()[1:5]) if moreBytes > maxPushStateBytes { return nil, fmt.Errorf("Remote node state is larger than limit (%d)", moreBytes) + + } + + //Start reporting the size before you cross the limit + if moreBytes > uint32(math.Floor(.6*maxPushStateBytes)) { + m.logger.Printf("[WARN] memberlist: Remote node state size is (%d) limit is (%d)", moreBytes, maxPushStateBytes) } // Read in the rest of the payload @@ -871,8 +1104,13 @@ func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) { return nil, err } - // Decrypt the cipherText - dataBytes := cipherText.Bytes()[:5] + // Decrypt the cipherText with some authenticated data + // + // Authenticated Data is: + // + // [messageType; byte] [messageLength; uint32] [label_data; optional] + // + dataBytes := appendBytes(cipherText.Bytes()[:5], []byte(streamLabel)) cipherBytes := cipherText.Bytes()[5:] // Decrypt the payload @@ -880,15 +1118,18 @@ func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) { return decryptPayload(keys, cipherBytes, dataBytes) } -// readStream is used to read from a stream connection, decrypting and +// readStream is used to read messages from a stream connection, decrypting and // decompressing the stream if necessary. -func (m *Memberlist) readStream(conn net.Conn) (messageType, io.Reader, *codec.Decoder, error) { +// +// The provided streamLabel if present will be authenticated during decryption +// of each message. +func (m *Memberlist) readStream(conn net.Conn, streamLabel string) (messageType, io.Reader, *codec.Decoder, error) { // Created a buffered reader var bufConn io.Reader = bufio.NewReader(conn) // Read the message type buf := [1]byte{0} - if _, err := bufConn.Read(buf[:]); err != nil { + if _, err := io.ReadFull(bufConn, buf[:]); err != nil { return 0, nil, nil, err } msgType := messageType(buf[0]) @@ -900,7 +1141,7 @@ func (m *Memberlist) readStream(conn net.Conn) (messageType, io.Reader, *codec.D fmt.Errorf("Remote state is encrypted and encryption is not configured") } - plain, err := m.decryptRemoteState(bufConn) + plain, err := m.decryptRemoteState(bufConn, streamLabel) if err != nil { return 0, nil, nil, err } @@ -996,16 +1237,17 @@ func (m *Memberlist) mergeRemoteState(join bool, remoteNodes []pushNodeState, us nodes := make([]*Node, len(remoteNodes)) for idx, n := range remoteNodes { nodes[idx] = &Node{ - Name: n.Name, - Addr: n.Addr, - Port: n.Port, - Meta: n.Meta, - PMin: n.Vsn[0], - PMax: n.Vsn[1], - PCur: n.Vsn[2], - DMin: n.Vsn[3], - DMax: n.Vsn[4], - DCur: n.Vsn[5], + Name: n.Name, + Addr: n.Addr, + Port: n.Port, + Meta: n.Meta, + State: n.State, + PMin: n.Vsn[0], + PMax: n.Vsn[1], + PCur: n.Vsn[2], + DMin: n.Vsn[3], + DMax: n.Vsn[4], + DCur: n.Vsn[5], } } if err := m.config.Merge.NotifyMerge(nodes); err != nil { @@ -1058,8 +1300,12 @@ func (m *Memberlist) readUserMsg(bufConn io.Reader, dec *codec.Decoder) error { // a ping, and waits for an ack. All of this is done as a series of blocking // operations, given the deadline. The bool return parameter is true if we // we able to round trip a ping to the other node. -func (m *Memberlist) sendPingAndWaitForAck(addr string, ping ping, deadline time.Time) (bool, error) { - conn, err := m.transport.DialTimeout(addr, deadline.Sub(time.Now())) +func (m *Memberlist) sendPingAndWaitForAck(a Address, ping ping, deadline time.Time) (bool, error) { + if a.Name == "" && m.config.RequireNodeNames { + return false, errNodeNamesAreRequired + } + + conn, err := m.transport.DialAddressTimeout(a, deadline.Sub(time.Now())) if err != nil { // If the node is actually dead we expect this to fail, so we // shouldn't spam the logs with it. After this point, errors @@ -1075,11 +1321,11 @@ func (m *Memberlist) sendPingAndWaitForAck(addr string, ping ping, deadline time return false, err } - if err = m.rawSendMsgStream(conn, out.Bytes()); err != nil { + if err = m.rawSendMsgStream(conn, out.Bytes(), m.config.Label); err != nil { return false, err } - msgType, _, dec, err := m.readStream(conn) + msgType, _, dec, err := m.readStream(conn, m.config.Label) if err != nil { return false, err } @@ -1094,7 +1340,7 @@ func (m *Memberlist) sendPingAndWaitForAck(addr string, ping ping, deadline time } if ack.SeqNo != ping.SeqNo { - return false, fmt.Errorf("Sequence number from ack (%d) doesn't match ping (%d)", ack.SeqNo, ping.SeqNo, LogConn(conn)) + return false, fmt.Errorf("Sequence number from ack (%d) doesn't match ping (%d)", ack.SeqNo, ping.SeqNo) } return true, nil diff --git a/vendor/github.com/hashicorp/memberlist/net_transport.go b/vendor/github.com/hashicorp/memberlist/net_transport.go index e7b88b01f6..a379c855c2 100644 --- a/vendor/github.com/hashicorp/memberlist/net_transport.go +++ b/vendor/github.com/hashicorp/memberlist/net_transport.go @@ -1,7 +1,9 @@ package memberlist import ( + "bytes" "fmt" + "io" "log" "net" "sync" @@ -33,6 +35,10 @@ type NetTransportConfig struct { // Logger is a logger for operator messages. Logger *log.Logger + + // MetricLabels is a map of optional labels to apply to all metrics + // emitted by this transport. + MetricLabels []metrics.Label } // NetTransport is a Transport implementation that uses connectionless UDP for @@ -46,8 +52,12 @@ type NetTransport struct { tcpListeners []*net.TCPListener udpListeners []*net.UDPConn shutdown int32 + + metricLabels []metrics.Label } +var _ NodeAwareTransport = (*NetTransport)(nil) + // NewNetTransport returns a net transport with the given configuration. On // success all the network listeners will be created and listening. func NewNetTransport(config *NetTransportConfig) (*NetTransport, error) { @@ -60,10 +70,11 @@ func NewNetTransport(config *NetTransportConfig) (*NetTransport, error) { // Build out the new transport. var ok bool t := NetTransport{ - config: config, - packetCh: make(chan *Packet), - streamCh: make(chan net.Conn), - logger: config.Logger, + config: config, + packetCh: make(chan *Packet), + streamCh: make(chan net.Conn), + logger: config.Logger, + metricLabels: config.MetricLabels, } // Clean up listeners if there's an error. @@ -170,6 +181,14 @@ func (t *NetTransport) FinalAdvertiseAddr(ip string, port int) (net.IP, int, err // See Transport. func (t *NetTransport) WriteTo(b []byte, addr string) (time.Time, error) { + a := Address{Addr: addr, Name: ""} + return t.WriteToAddress(b, a) +} + +// See NodeAwareTransport. +func (t *NetTransport) WriteToAddress(b []byte, a Address) (time.Time, error) { + addr := a.Addr + udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return time.Time{}, err @@ -188,8 +207,44 @@ func (t *NetTransport) PacketCh() <-chan *Packet { return t.packetCh } +// See IngestionAwareTransport. +func (t *NetTransport) IngestPacket(conn net.Conn, addr net.Addr, now time.Time, shouldClose bool) error { + if shouldClose { + defer conn.Close() + } + + // Copy everything from the stream into packet buffer. + var buf bytes.Buffer + if _, err := io.Copy(&buf, conn); err != nil { + return fmt.Errorf("failed to read packet: %v", err) + } + + // Check the length - it needs to have at least one byte to be a proper + // message. This is checked elsewhere for writes coming in directly from + // the UDP socket. + if n := buf.Len(); n < 1 { + return fmt.Errorf("packet too short (%d bytes) %s", n, LogAddress(addr)) + } + + // Inject the packet. + t.packetCh <- &Packet{ + Buf: buf.Bytes(), + From: addr, + Timestamp: now, + } + return nil +} + // See Transport. func (t *NetTransport) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) { + a := Address{Addr: addr, Name: ""} + return t.DialAddressTimeout(a, timeout) +} + +// See NodeAwareTransport. +func (t *NetTransport) DialAddressTimeout(a Address, timeout time.Duration) (net.Conn, error) { + addr := a.Addr + dialer := net.Dialer{Timeout: timeout} return dialer.Dial("tcp", addr) } @@ -199,6 +254,12 @@ func (t *NetTransport) StreamCh() <-chan net.Conn { return t.streamCh } +// See IngestionAwareTransport. +func (t *NetTransport) IngestStream(conn net.Conn) error { + t.streamCh <- conn + return nil +} + // See Transport. func (t *NetTransport) Shutdown() error { // This will avoid log spam about errors when we shut down. @@ -221,6 +282,16 @@ func (t *NetTransport) Shutdown() error { // and hands them off to the stream channel. func (t *NetTransport) tcpListen(tcpLn *net.TCPListener) { defer t.wg.Done() + + // baseDelay is the initial delay after an AcceptTCP() error before attempting again + const baseDelay = 5 * time.Millisecond + + // maxDelay is the maximum delay after an AcceptTCP() error before attempting again. + // In the case that tcpListen() is error-looping, it will delay the shutdown check. + // Therefore, changes to maxDelay may have an effect on the latency of shutdown. + const maxDelay = 1 * time.Second + + var loopDelay time.Duration for { conn, err := tcpLn.AcceptTCP() if err != nil { @@ -228,9 +299,22 @@ func (t *NetTransport) tcpListen(tcpLn *net.TCPListener) { break } + if loopDelay == 0 { + loopDelay = baseDelay + } else { + loopDelay *= 2 + } + + if loopDelay > maxDelay { + loopDelay = maxDelay + } + t.logger.Printf("[ERR] memberlist: Error accepting TCP connection: %v", err) + time.Sleep(loopDelay) continue } + // No error, reset loop delay + loopDelay = 0 t.streamCh <- conn } @@ -264,7 +348,7 @@ func (t *NetTransport) udpListen(udpLn *net.UDPConn) { } // Ingest the packet. - metrics.IncrCounter([]string{"memberlist", "udp", "received"}, float32(n)) + metrics.IncrCounterWithLabels([]string{"memberlist", "udp", "received"}, float32(n), t.metricLabels) t.packetCh <- &Packet{ Buf: buf[:n], From: addr, diff --git a/vendor/github.com/hashicorp/memberlist/peeked_conn.go b/vendor/github.com/hashicorp/memberlist/peeked_conn.go new file mode 100644 index 0000000000..3181d90cec --- /dev/null +++ b/vendor/github.com/hashicorp/memberlist/peeked_conn.go @@ -0,0 +1,48 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Originally from: https://github.com/google/tcpproxy/blob/master/tcpproxy.go +// at f5c09fbedceb69e4b238dec52cdf9f2fe9a815e2 + +package memberlist + +import "net" + +// peekedConn is an incoming connection that has had some bytes read from it +// to determine how to route the connection. The Read method stitches +// the peeked bytes and unread bytes back together. +type peekedConn struct { + // Peeked are the bytes that have been read from Conn for the + // purposes of route matching, but have not yet been consumed + // by Read calls. It set to nil by Read when fully consumed. + Peeked []byte + + // Conn is the underlying connection. + // It can be type asserted against *net.TCPConn or other types + // as needed. It should not be read from directly unless + // Peeked is nil. + net.Conn +} + +func (c *peekedConn) Read(p []byte) (n int, err error) { + if len(c.Peeked) > 0 { + n = copy(p, c.Peeked) + c.Peeked = c.Peeked[n:] + if len(c.Peeked) == 0 { + c.Peeked = nil + } + return n, nil + } + return c.Conn.Read(p) +} diff --git a/vendor/github.com/hashicorp/memberlist/queue.go b/vendor/github.com/hashicorp/memberlist/queue.go index 994b90ff10..c970176e18 100644 --- a/vendor/github.com/hashicorp/memberlist/queue.go +++ b/vendor/github.com/hashicorp/memberlist/queue.go @@ -1,8 +1,10 @@ package memberlist import ( - "sort" + "math" "sync" + + "github.com/google/btree" ) // TransmitLimitedQueue is used to queue messages to broadcast to @@ -19,15 +21,93 @@ type TransmitLimitedQueue struct { // number of retransmissions attempted. RetransmitMult int - sync.Mutex - bcQueue limitedBroadcasts + mu sync.Mutex + tq *btree.BTree // stores *limitedBroadcast as btree.Item + tm map[string]*limitedBroadcast + idGen int64 } type limitedBroadcast struct { - transmits int // Number of transmissions attempted. + transmits int // btree-key[0]: Number of transmissions attempted. + msgLen int64 // btree-key[1]: copied from len(b.Message()) + id int64 // btree-key[2]: unique incrementing id stamped at submission time b Broadcast + + name string // set if Broadcast is a NamedBroadcast +} + +// Less tests whether the current item is less than the given argument. +// +// This must provide a strict weak ordering. +// If !a.Less(b) && !b.Less(a), we treat this to mean a == b (i.e. we can only +// hold one of either a or b in the tree). +// +// default ordering is +// - [transmits=0, ..., transmits=inf] +// - [transmits=0:len=999, ..., transmits=0:len=2, ...] +// - [transmits=0:len=999,id=999, ..., transmits=0:len=999:id=1, ...] +func (b *limitedBroadcast) Less(than btree.Item) bool { + o := than.(*limitedBroadcast) + if b.transmits < o.transmits { + return true + } else if b.transmits > o.transmits { + return false + } + if b.msgLen > o.msgLen { + return true + } else if b.msgLen < o.msgLen { + return false + } + return b.id > o.id +} + +// for testing; emits in transmit order if reverse=false +func (q *TransmitLimitedQueue) orderedView(reverse bool) []*limitedBroadcast { + q.mu.Lock() + defer q.mu.Unlock() + + out := make([]*limitedBroadcast, 0, q.lenLocked()) + q.walkReadOnlyLocked(reverse, func(cur *limitedBroadcast) bool { + out = append(out, cur) + return true + }) + + return out +} + +// walkReadOnlyLocked calls f for each item in the queue traversing it in +// natural order (by Less) when reverse=false and the opposite when true. You +// must hold the mutex. +// +// This method panics if you attempt to mutate the item during traversal. The +// underlying btree should also not be mutated during traversal. +func (q *TransmitLimitedQueue) walkReadOnlyLocked(reverse bool, f func(*limitedBroadcast) bool) { + if q.lenLocked() == 0 { + return + } + + iter := func(item btree.Item) bool { + cur := item.(*limitedBroadcast) + + prevTransmits := cur.transmits + prevMsgLen := cur.msgLen + prevID := cur.id + + keepGoing := f(cur) + + if prevTransmits != cur.transmits || prevMsgLen != cur.msgLen || prevID != cur.id { + panic("edited queue while walking read only") + } + + return keepGoing + } + + if reverse { + q.tq.Descend(iter) // end with transmit 0 + } else { + q.tq.Ascend(iter) // start with transmit 0 + } } -type limitedBroadcasts []*limitedBroadcast // Broadcast is something that can be broadcasted via gossip to // the memberlist cluster. @@ -45,123 +125,298 @@ type Broadcast interface { Finished() } +// NamedBroadcast is an optional extension of the Broadcast interface that +// gives each message a unique string name, and that is used to optimize +// +// You shoud ensure that Invalidates() checks the same uniqueness as the +// example below: +// +// func (b *foo) Invalidates(other Broadcast) bool { +// nb, ok := other.(NamedBroadcast) +// if !ok { +// return false +// } +// return b.Name() == nb.Name() +// } +// +// Invalidates() isn't currently used for NamedBroadcasts, but that may change +// in the future. +type NamedBroadcast interface { + Broadcast + // The unique identity of this broadcast message. + Name() string +} + +// UniqueBroadcast is an optional interface that indicates that each message is +// intrinsically unique and there is no need to scan the broadcast queue for +// duplicates. +// +// You should ensure that Invalidates() always returns false if implementing +// this interface. Invalidates() isn't currently used for UniqueBroadcasts, but +// that may change in the future. +type UniqueBroadcast interface { + Broadcast + // UniqueBroadcast is just a marker method for this interface. + UniqueBroadcast() +} + // QueueBroadcast is used to enqueue a broadcast func (q *TransmitLimitedQueue) QueueBroadcast(b Broadcast) { - q.Lock() - defer q.Unlock() + q.queueBroadcast(b, 0) +} - // Check if this message invalidates another - n := len(q.bcQueue) - for i := 0; i < n; i++ { - if b.Invalidates(q.bcQueue[i].b) { - q.bcQueue[i].b.Finished() - copy(q.bcQueue[i:], q.bcQueue[i+1:]) - q.bcQueue[n-1] = nil - q.bcQueue = q.bcQueue[:n-1] - n-- +// lazyInit initializes internal data structures the first time they are +// needed. You must already hold the mutex. +func (q *TransmitLimitedQueue) lazyInit() { + if q.tq == nil { + q.tq = btree.New(32) + } + if q.tm == nil { + q.tm = make(map[string]*limitedBroadcast) + } +} + +// queueBroadcast is like QueueBroadcast but you can use a nonzero value for +// the initial transmit tier assigned to the message. This is meant to be used +// for unit testing. +func (q *TransmitLimitedQueue) queueBroadcast(b Broadcast, initialTransmits int) { + q.mu.Lock() + defer q.mu.Unlock() + + q.lazyInit() + + if q.idGen == math.MaxInt64 { + // it's super duper unlikely to wrap around within the retransmit limit + q.idGen = 1 + } else { + q.idGen++ + } + id := q.idGen + + lb := &limitedBroadcast{ + transmits: initialTransmits, + msgLen: int64(len(b.Message())), + id: id, + b: b, + } + unique := false + if nb, ok := b.(NamedBroadcast); ok { + lb.name = nb.Name() + } else if _, ok := b.(UniqueBroadcast); ok { + unique = true + } + + // Check if this message invalidates another. + if lb.name != "" { + if old, ok := q.tm[lb.name]; ok { + old.b.Finished() + q.deleteItem(old) + } + } else if !unique { + // Slow path, hopefully nothing hot hits this. + var remove []*limitedBroadcast + q.tq.Ascend(func(item btree.Item) bool { + cur := item.(*limitedBroadcast) + + // Special Broadcasts can only invalidate each other. + switch cur.b.(type) { + case NamedBroadcast: + // noop + case UniqueBroadcast: + // noop + default: + if b.Invalidates(cur.b) { + cur.b.Finished() + remove = append(remove, cur) + } + } + return true + }) + for _, cur := range remove { + q.deleteItem(cur) } } - // Append to the queue - q.bcQueue = append(q.bcQueue, &limitedBroadcast{0, b}) + // Append to the relevant queue. + q.addItem(lb) +} + +// deleteItem removes the given item from the overall datastructure. You +// must already hold the mutex. +func (q *TransmitLimitedQueue) deleteItem(cur *limitedBroadcast) { + _ = q.tq.Delete(cur) + if cur.name != "" { + delete(q.tm, cur.name) + } + + if q.tq.Len() == 0 { + // At idle there's no reason to let the id generator keep going + // indefinitely. + q.idGen = 0 + } +} + +// addItem adds the given item into the overall datastructure. You must already +// hold the mutex. +func (q *TransmitLimitedQueue) addItem(cur *limitedBroadcast) { + _ = q.tq.ReplaceOrInsert(cur) + if cur.name != "" { + q.tm[cur.name] = cur + } +} + +// getTransmitRange returns a pair of min/max values for transmit values +// represented by the current queue contents. Both values represent actual +// transmit values on the interval [0, len). You must already hold the mutex. +func (q *TransmitLimitedQueue) getTransmitRange() (minTransmit, maxTransmit int) { + if q.lenLocked() == 0 { + return 0, 0 + } + minItem, maxItem := q.tq.Min(), q.tq.Max() + if minItem == nil || maxItem == nil { + return 0, 0 + } + + min := minItem.(*limitedBroadcast).transmits + max := maxItem.(*limitedBroadcast).transmits + + return min, max } // GetBroadcasts is used to get a number of broadcasts, up to a byte limit // and applying a per-message overhead as provided. func (q *TransmitLimitedQueue) GetBroadcasts(overhead, limit int) [][]byte { - q.Lock() - defer q.Unlock() + q.mu.Lock() + defer q.mu.Unlock() // Fast path the default case - if len(q.bcQueue) == 0 { + if q.lenLocked() == 0 { return nil } transmitLimit := retransmitLimit(q.RetransmitMult, q.NumNodes()) - bytesUsed := 0 - var toSend [][]byte - for i := len(q.bcQueue) - 1; i >= 0; i-- { - // Check if this is within our limits - b := q.bcQueue[i] - msg := b.b.Message() - if bytesUsed+overhead+len(msg) > limit { + var ( + bytesUsed int + toSend [][]byte + reinsert []*limitedBroadcast + ) + + // Visit fresher items first, but only look at stuff that will fit. + // We'll go tier by tier, grabbing the largest items first. + minTr, maxTr := q.getTransmitRange() + for transmits := minTr; transmits <= maxTr; /*do not advance automatically*/ { + free := int64(limit - bytesUsed - overhead) + if free <= 0 { + break // bail out early + } + + // Search for the least element on a given tier (by transmit count) as + // defined in the limitedBroadcast.Less function that will fit into our + // remaining space. + greaterOrEqual := &limitedBroadcast{ + transmits: transmits, + msgLen: free, + id: math.MaxInt64, + } + lessThan := &limitedBroadcast{ + transmits: transmits + 1, + msgLen: math.MaxInt64, + id: math.MaxInt64, + } + var keep *limitedBroadcast + q.tq.AscendRange(greaterOrEqual, lessThan, func(item btree.Item) bool { + cur := item.(*limitedBroadcast) + // Check if this is within our limits + if int64(len(cur.b.Message())) > free { + // If this happens it's a bug in the datastructure or + // surrounding use doing something like having len(Message()) + // change over time. There's enough going on here that it's + // probably sane to just skip it and move on for now. + return true + } + keep = cur + return false + }) + if keep == nil { + // No more items of an appropriate size in the tier. + transmits++ continue } + msg := keep.b.Message() + // Add to slice to send bytesUsed += overhead + len(msg) toSend = append(toSend, msg) // Check if we should stop transmission - b.transmits++ - if b.transmits >= transmitLimit { - b.b.Finished() - n := len(q.bcQueue) - q.bcQueue[i], q.bcQueue[n-1] = q.bcQueue[n-1], nil - q.bcQueue = q.bcQueue[:n-1] + q.deleteItem(keep) + if keep.transmits+1 >= transmitLimit { + keep.b.Finished() + } else { + // We need to bump this item down to another transmit tier, but + // because it would be in the same direction that we're walking the + // tiers, we will have to delay the reinsertion until we are + // finished our search. Otherwise we'll possibly re-add the message + // when we ascend to the next tier. + keep.transmits++ + reinsert = append(reinsert, keep) } } - // If we are sending anything, we need to re-sort to deal - // with adjusted transmit counts - if len(toSend) > 0 { - q.bcQueue.Sort() + for _, cur := range reinsert { + q.addItem(cur) } + return toSend } // NumQueued returns the number of queued messages func (q *TransmitLimitedQueue) NumQueued() int { - q.Lock() - defer q.Unlock() - return len(q.bcQueue) + q.mu.Lock() + defer q.mu.Unlock() + return q.lenLocked() } -// Reset clears all the queued messages -func (q *TransmitLimitedQueue) Reset() { - q.Lock() - defer q.Unlock() - for _, b := range q.bcQueue { - b.b.Finished() +// lenLocked returns the length of the overall queue datastructure. You must +// hold the mutex. +func (q *TransmitLimitedQueue) lenLocked() int { + if q.tq == nil { + return 0 } - q.bcQueue = nil + return q.tq.Len() +} + +// Reset clears all the queued messages. Should only be used for tests. +func (q *TransmitLimitedQueue) Reset() { + q.mu.Lock() + defer q.mu.Unlock() + + q.walkReadOnlyLocked(false, func(cur *limitedBroadcast) bool { + cur.b.Finished() + return true + }) + + q.tq = nil + q.tm = nil + q.idGen = 0 } // Prune will retain the maxRetain latest messages, and the rest // will be discarded. This can be used to prevent unbounded queue sizes func (q *TransmitLimitedQueue) Prune(maxRetain int) { - q.Lock() - defer q.Unlock() + q.mu.Lock() + defer q.mu.Unlock() // Do nothing if queue size is less than the limit - n := len(q.bcQueue) - if n < maxRetain { - return + for q.tq.Len() > maxRetain { + item := q.tq.Max() + if item == nil { + break + } + cur := item.(*limitedBroadcast) + cur.b.Finished() + q.deleteItem(cur) } - - // Invalidate the messages we will be removing - for i := 0; i < n-maxRetain; i++ { - q.bcQueue[i].b.Finished() - } - - // Move the messages, and retain only the last maxRetain - copy(q.bcQueue[0:], q.bcQueue[n-maxRetain:]) - q.bcQueue = q.bcQueue[:maxRetain] -} - -func (b limitedBroadcasts) Len() int { - return len(b) -} - -func (b limitedBroadcasts) Less(i, j int) bool { - return b[i].transmits < b[j].transmits -} - -func (b limitedBroadcasts) Swap(i, j int) { - b[i], b[j] = b[j], b[i] -} - -func (b limitedBroadcasts) Sort() { - sort.Sort(sort.Reverse(b)) } diff --git a/vendor/github.com/hashicorp/memberlist/security.go b/vendor/github.com/hashicorp/memberlist/security.go index d90114eb0c..6831be3bc6 100644 --- a/vendor/github.com/hashicorp/memberlist/security.go +++ b/vendor/github.com/hashicorp/memberlist/security.go @@ -106,7 +106,10 @@ func encryptPayload(vsn encryptionVersion, key []byte, msg []byte, data []byte, dst.WriteByte(byte(vsn)) // Add a random nonce - io.CopyN(dst, rand.Reader, nonceSize) + _, err = io.CopyN(dst, rand.Reader, nonceSize) + if err != nil { + return err + } afterNonce := dst.Len() // Ensure we are correctly padded (only version 0) @@ -196,3 +199,22 @@ func decryptPayload(keys [][]byte, msg []byte, data []byte) ([]byte, error) { return nil, fmt.Errorf("No installed keys could decrypt the message") } + +func appendBytes(first []byte, second []byte) []byte { + hasFirst := len(first) > 0 + hasSecond := len(second) > 0 + + switch { + case hasFirst && hasSecond: + out := make([]byte, 0, len(first)+len(second)) + out = append(out, first...) + out = append(out, second...) + return out + case hasFirst: + return first + case hasSecond: + return second + default: + return nil + } +} diff --git a/vendor/github.com/hashicorp/memberlist/state.go b/vendor/github.com/hashicorp/memberlist/state.go index f51692de0a..a9ee889960 100644 --- a/vendor/github.com/hashicorp/memberlist/state.go +++ b/vendor/github.com/hashicorp/memberlist/state.go @@ -6,32 +6,35 @@ import ( "math" "math/rand" "net" + "strings" "sync/atomic" "time" - "github.com/armon/go-metrics" + metrics "github.com/armon/go-metrics" ) -type nodeStateType int +type NodeStateType int const ( - stateAlive nodeStateType = iota - stateSuspect - stateDead + StateAlive NodeStateType = iota + StateSuspect + StateDead + StateLeft ) // Node represents a node in the cluster. type Node struct { - Name string - Addr net.IP - Port uint16 - Meta []byte // Metadata from the delegate for this node. - PMin uint8 // Minimum protocol version this understands - PMax uint8 // Maximum protocol version this understands - PCur uint8 // Current version node is speaking - DMin uint8 // Min protocol version for the delegate to understand - DMax uint8 // Max protocol version for the delegate to understand - DCur uint8 // Current version delegate is speaking + Name string + Addr net.IP + Port uint16 + Meta []byte // Metadata from the delegate for this node. + State NodeStateType // State of the node. + PMin uint8 // Minimum protocol version this understands + PMax uint8 // Maximum protocol version this understands + PCur uint8 // Current version node is speaking + DMin uint8 // Min protocol version for the delegate to understand + DMax uint8 // Max protocol version for the delegate to understand + DCur uint8 // Current version delegate is speaking } // Address returns the host:port form of a node's address, suitable for use @@ -40,6 +43,15 @@ func (n *Node) Address() string { return joinHostPort(n.Addr.String(), n.Port) } +// FullAddress returns the node name and host:port form of a node's address, +// suitable for use with a transport. +func (n *Node) FullAddress() Address { + return Address{ + Addr: joinHostPort(n.Addr.String(), n.Port), + Name: n.Name, + } +} + // String returns the node name func (n *Node) String() string { return n.Name @@ -49,7 +61,7 @@ func (n *Node) String() string { type nodeState struct { Node Incarnation uint32 // Last known incarnation number - State nodeStateType // Current state + State NodeStateType // Current state StateChange time.Time // Time last state change happened } @@ -59,6 +71,16 @@ func (n *nodeState) Address() string { return n.Node.Address() } +// FullAddress returns the node name and host:port form of a node's address, +// suitable for use with a transport. +func (n *nodeState) FullAddress() Address { + return n.Node.FullAddress() +} + +func (n *nodeState) DeadOrLeft() bool { + return n.State == StateDead || n.State == StateLeft +} + // ackHandler is used to register handlers for incoming acks and nacks. type ackHandler struct { ackFn func([]byte, time.Time) @@ -217,7 +239,7 @@ START: node = *m.nodes[m.probeIndex] if node.Name == m.config.Name { skip = true - } else if node.State == stateDead { + } else if node.DeadOrLeft() { skip = true } @@ -233,20 +255,56 @@ START: m.probeNode(&node) } +// probeNodeByAddr just safely calls probeNode given only the address of the node (for tests) +func (m *Memberlist) probeNodeByAddr(addr string) { + m.nodeLock.RLock() + n := m.nodeMap[addr] + m.nodeLock.RUnlock() + + m.probeNode(n) +} + +// failedRemote checks the error and decides if it indicates a failure on the +// other end. +func failedRemote(err error) bool { + switch t := err.(type) { + case *net.OpError: + if strings.HasPrefix(t.Net, "tcp") { + switch t.Op { + case "dial", "read", "write": + return true + } + } else if strings.HasPrefix(t.Net, "udp") { + switch t.Op { + case "write": + return true + } + } + } + return false +} + // probeNode handles a single round of failure checking on a node. func (m *Memberlist) probeNode(node *nodeState) { - defer metrics.MeasureSince([]string{"memberlist", "probeNode"}, time.Now()) + defer metrics.MeasureSinceWithLabels([]string{"memberlist", "probeNode"}, time.Now(), m.metricLabels) // We use our health awareness to scale the overall probe interval, so we // slow down if we detect problems. The ticker that calls us can handle // us running over the base interval, and will skip missed ticks. probeInterval := m.awareness.ScaleTimeout(m.config.ProbeInterval) if probeInterval > m.config.ProbeInterval { - metrics.IncrCounter([]string{"memberlist", "degraded", "probe"}, 1) + metrics.IncrCounterWithLabels([]string{"memberlist", "degraded", "probe"}, 1, m.metricLabels) } // Prepare a ping message and setup an ack handler. - ping := ping{SeqNo: m.nextSeqNo(), Node: node.Name} + selfAddr, selfPort := m.getAdvertise() + ping := ping{ + SeqNo: m.nextSeqNo(), + Node: node.Name, + SourceAddr: selfAddr, + SourcePort: selfPort, + SourceNode: m.config.Name, + } ackCh := make(chan ackMessage, m.config.IndirectChecks+1) nackCh := make(chan struct{}, m.config.IndirectChecks+1) m.setProbeChannels(ping.SeqNo, ackCh, nackCh, probeInterval) @@ -263,15 +321,25 @@ func (m *Memberlist) probeNode(node *nodeState) { // soon as possible. deadline := sent.Add(probeInterval) addr := node.Address() - if node.State == stateAlive { - if err := m.encodeAndSendMsg(addr, pingMsg, &ping); err != nil { - m.logger.Printf("[ERR] memberlist: Failed to send ping: %s", err) - return + + // Arrange for our self-awareness to get updated. + var awarenessDelta int + defer func() { + m.awareness.ApplyDelta(awarenessDelta) + }() + if node.State == StateAlive { + if err := m.encodeAndSendMsg(node.FullAddress(), pingMsg, &ping); err != nil { + m.logger.Printf("[ERR] memberlist: Failed to send UDP ping: %s", err) + if failedRemote(err) { + goto HANDLE_REMOTE_FAILURE + } else { + return + } } } else { var msgs [][]byte if buf, err := encode(pingMsg, &ping); err != nil { - m.logger.Printf("[ERR] memberlist: Failed to encode ping message: %s", err) + m.logger.Printf("[ERR] memberlist: Failed to encode UDP ping message: %s", err) return } else { msgs = append(msgs, buf.Bytes()) @@ -285,9 +353,13 @@ func (m *Memberlist) probeNode(node *nodeState) { } compound := makeCompoundMessage(msgs) - if err := m.rawSendMsgPacket(addr, &node.Node, compound.Bytes()); err != nil { - m.logger.Printf("[ERR] memberlist: Failed to send compound ping and suspect message to %s: %s", addr, err) - return + if err := m.rawSendMsgPacket(node.FullAddress(), &node.Node, compound.Bytes()); err != nil { + m.logger.Printf("[ERR] memberlist: Failed to send UDP compound ping and suspect message to %s: %s", addr, err) + if failedRemote(err) { + goto HANDLE_REMOTE_FAILURE + } else { + return + } } } @@ -296,10 +368,7 @@ func (m *Memberlist) probeNode(node *nodeState) { // which will improve our health until we get to the failure scenarios // at the end of this function, which will alter this delta variable // accordingly. - awarenessDelta := -1 - defer func() { - m.awareness.ApplyDelta(awarenessDelta) - }() + awarenessDelta = -1 // Wait for response or round-trip-time. select { @@ -324,21 +393,31 @@ func (m *Memberlist) probeNode(node *nodeState) { // probe interval it will give the TCP fallback more time, which // is more active in dealing with lost packets, and it gives more // time to wait for indirect acks/nacks. - m.logger.Printf("[DEBUG] memberlist: Failed ping: %v (timeout reached)", node.Name) + m.logger.Printf("[DEBUG] memberlist: Failed UDP ping: %s (timeout reached)", node.Name) } +HANDLE_REMOTE_FAILURE: // Get some random live nodes. m.nodeLock.RLock() kNodes := kRandomNodes(m.config.IndirectChecks, m.nodes, func(n *nodeState) bool { return n.Name == m.config.Name || n.Name == node.Name || - n.State != stateAlive + n.State != StateAlive }) m.nodeLock.RUnlock() // Attempt an indirect ping. expectedNacks := 0 - ind := indirectPingReq{SeqNo: ping.SeqNo, Target: node.Addr, Port: node.Port, Node: node.Name} + selfAddr, selfPort = m.getAdvertise() + ind := indirectPingReq{ + SeqNo: ping.SeqNo, + Target: node.Addr, + Port: node.Port, + Node: node.Name, + SourceAddr: selfAddr, + SourcePort: selfPort, + SourceNode: m.config.Name, + } for _, peer := range kNodes { // We only expect nack to be sent from peers who understand // version 4 of the protocol. @@ -346,8 +425,8 @@ func (m *Memberlist) probeNode(node *nodeState) { expectedNacks++ } - if err := m.encodeAndSendMsg(peer.Address(), indirectPingMsg, &ind); err != nil { - m.logger.Printf("[ERR] memberlist: Failed to send indirect ping: %s", err) + if err := m.encodeAndSendMsg(peer.FullAddress(), indirectPingMsg, &ind); err != nil { + m.logger.Printf("[ERR] memberlist: Failed to send indirect UDP ping: %s", err) } } @@ -362,12 +441,19 @@ func (m *Memberlist) probeNode(node *nodeState) { // which protocol version we are speaking. That's why we've included a // config option to turn this off if desired. fallbackCh := make(chan bool, 1) - if (!m.config.DisableTcpPings) && (node.PMax >= 3) { + + disableTcpPings := m.config.DisableTcpPings || + (m.config.DisableTcpPingsForNode != nil && m.config.DisableTcpPingsForNode(node.Name)) + if (!disableTcpPings) && (node.PMax >= 3) { go func() { defer close(fallbackCh) - didContact, err := m.sendPingAndWaitForAck(node.Address(), ping, deadline) + didContact, err := m.sendPingAndWaitForAck(node.FullAddress(), ping, deadline) if err != nil { - m.logger.Printf("[ERR] memberlist: Failed fallback ping: %s", err) + var to string + if ne, ok := err.(net.Error); ok && ne.Timeout() { + to = fmt.Sprintf("timeout %s: ", probeInterval) + } + m.logger.Printf("[ERR] memberlist: Failed fallback TCP ping: %s%s", to, err) } else { fallbackCh <- didContact } @@ -392,7 +478,7 @@ func (m *Memberlist) probeNode(node *nodeState) { // any additional time here. for didContact := range fallbackCh { if didContact { - m.logger.Printf("[WARN] memberlist: Was able to connect to %s but other probes failed, network may be misconfigured", node.Name) + m.logger.Printf("[WARN] memberlist: Was able to connect to %s over TCP but UDP probes failed, network may be misconfigured", node.Name) return } } @@ -422,12 +508,21 @@ func (m *Memberlist) probeNode(node *nodeState) { // Ping initiates a ping to the node with the specified name. func (m *Memberlist) Ping(node string, addr net.Addr) (time.Duration, error) { // Prepare a ping message and setup an ack handler. - ping := ping{SeqNo: m.nextSeqNo(), Node: node} + selfAddr, selfPort := m.getAdvertise() + ping := ping{ + SeqNo: m.nextSeqNo(), + Node: node, + SourceAddr: selfAddr, + SourcePort: selfPort, + SourceNode: m.config.Name, + } ackCh := make(chan ackMessage, m.config.IndirectChecks+1) m.setProbeChannels(ping.SeqNo, ackCh, nil, m.config.ProbeInterval) + a := Address{Addr: addr.String(), Name: node} + // Send a ping to the node. - if err := m.encodeAndSendMsg(addr.String(), pingMsg, &ping); err != nil { + if err := m.encodeAndSendMsg(a, pingMsg, &ping); err != nil { return 0, err } @@ -478,7 +573,7 @@ func (m *Memberlist) resetNodes() { // gossip is invoked every GossipInterval period to broadcast our gossip // messages to a few random nodes. func (m *Memberlist) gossip() { - defer metrics.MeasureSince([]string{"memberlist", "gossip"}, time.Now()) + defer metrics.MeasureSinceWithLabels([]string{"memberlist", "gossip"}, time.Now(), m.metricLabels) // Get some random live, suspect, or recently dead nodes m.nodeLock.RLock() @@ -488,10 +583,10 @@ func (m *Memberlist) gossip() { } switch n.State { - case stateAlive, stateSuspect: + case StateAlive, StateSuspect: return false - case stateDead: + case StateDead: return time.Since(n.StateChange) > m.config.GossipToTheDeadTime default: @@ -501,7 +596,7 @@ func (m *Memberlist) gossip() { m.nodeLock.RUnlock() // Compute the bytes available - bytesAvail := m.config.UDPBufferSize - compoundHeaderOverhead + bytesAvail := m.config.UDPBufferSize - compoundHeaderOverhead - labelOverhead(m.config.Label) if m.config.EncryptionEnabled() { bytesAvail -= encryptOverhead(m.encryptionVersion()) } @@ -516,14 +611,16 @@ func (m *Memberlist) gossip() { addr := node.Address() if len(msgs) == 1 { // Send single message as is - if err := m.rawSendMsgPacket(addr, &node.Node, msgs[0]); err != nil { + if err := m.rawSendMsgPacket(node.FullAddress(), &node, msgs[0]); err != nil { m.logger.Printf("[ERR] memberlist: Failed to send gossip to %s: %s", addr, err) } } else { - // Otherwise create and send a compound message - compound := makeCompoundMessage(msgs) - if err := m.rawSendMsgPacket(addr, &node.Node, compound.Bytes()); err != nil { - m.logger.Printf("[ERR] memberlist: Failed to send gossip to %s: %s", addr, err) + // Otherwise create and send one or more compound messages + compounds := makeCompoundMessages(msgs) + for _, compound := range compounds { + if err := m.rawSendMsgPacket(node.FullAddress(), &node, compound.Bytes()); err != nil { + m.logger.Printf("[ERR] memberlist: Failed to send gossip to %s: %s", addr, err) + } } } } @@ -538,7 +635,7 @@ func (m *Memberlist) pushPull() { m.nodeLock.RLock() nodes := kRandomNodes(1, m.nodes, func(n *nodeState) bool { return n.Name == m.config.Name || - n.State != stateAlive + n.State != StateAlive }) m.nodeLock.RUnlock() @@ -549,17 +646,17 @@ func (m *Memberlist) pushPull() { node := nodes[0] // Attempt a push pull - if err := m.pushPullNode(node.Address(), false); err != nil { + if err := m.pushPullNode(node.FullAddress(), false); err != nil { m.logger.Printf("[ERR] memberlist: Push/Pull with %s failed: %s", node.Name, err) } } // pushPullNode does a complete state exchange with a specific node. -func (m *Memberlist) pushPullNode(addr string, join bool) error { - defer metrics.MeasureSince([]string{"memberlist", "pushPullNode"}, time.Now()) +func (m *Memberlist) pushPullNode(a Address, join bool) error { + defer metrics.MeasureSinceWithLabels([]string{"memberlist", "pushPullNode"}, time.Now(), m.metricLabels) // Attempt to send and receive with the node - remote, userState, err := m.sendAndReceiveState(addr, join) + remote, userState, err := m.sendAndReceiveState(a, join) if err != nil { return err } @@ -596,7 +693,7 @@ func (m *Memberlist) verifyProtocol(remote []pushNodeState) error { for _, rn := range remote { // If the node isn't alive, then skip it - if rn.State != stateAlive { + if rn.State != StateAlive { continue } @@ -625,7 +722,7 @@ func (m *Memberlist) verifyProtocol(remote []pushNodeState) error { for _, n := range m.nodes { // Ignore non-alive nodes - if n.State != stateAlive { + if n.State != StateAlive { continue } @@ -841,11 +938,26 @@ func (m *Memberlist) aliveNode(a *alive, notify chan struct{}, bootstrap bool) { return } + if len(a.Vsn) >= 3 { + pMin := a.Vsn[0] + pMax := a.Vsn[1] + pCur := a.Vsn[2] + if pMin == 0 || pMax == 0 || pMin > pMax { + m.logger.Printf("[WARN] memberlist: Ignoring an alive message for '%s' (%v:%d) because protocol version(s) are wrong: %d <= %d <= %d should be >0", a.Node, net.IP(a.Addr), a.Port, pMin, pCur, pMax) + return + } + } + // Invoke the Alive delegate if any. This can be used to filter out // alive messages based on custom logic. For example, using a cluster name. // Using a merge delegate is not enough, as it is possible for passive // cluster merging to still occur. if m.config.Alive != nil { + if len(a.Vsn) < 6 { + m.logger.Printf("[WARN] memberlist: ignoring alive message for '%s' (%v:%d) because Vsn is not present", + a.Node, net.IP(a.Addr), a.Port) + return + } node := &Node{ Name: a.Node, Addr: a.Addr, @@ -867,7 +979,13 @@ func (m *Memberlist) aliveNode(a *alive, notify chan struct{}, bootstrap bool) { // Check if we've never seen this node before, and if not, then // store this node in our node map. + var updatesNode bool if !ok { + errCon := m.config.IPAllowed(a.Addr) + if errCon != nil { + m.logger.Printf("[WARN] memberlist: Rejected node %s (%v): %s", a.Node, net.IP(a.Addr), errCon) + return + } state = &nodeState{ Node: Node{ Name: a.Node, @@ -875,7 +993,15 @@ func (m *Memberlist) aliveNode(a *alive, notify chan struct{}, bootstrap bool) { Port: a.Port, Meta: a.Meta, }, - State: stateDead, + State: StateDead, + } + if len(a.Vsn) > 5 { + state.PMin = a.Vsn[0] + state.PMax = a.Vsn[1] + state.PCur = a.Vsn[2] + state.DMin = a.Vsn[3] + state.DMax = a.Vsn[4] + state.DCur = a.Vsn[5] } // Add to map @@ -894,29 +1020,45 @@ func (m *Memberlist) aliveNode(a *alive, notify chan struct{}, bootstrap bool) { // Update numNodes after we've added a new node atomic.AddUint32(&m.numNodes, 1) - } - - // Check if this address is different than the existing node - if !bytes.Equal([]byte(state.Addr), a.Addr) || state.Port != a.Port { - m.logger.Printf("[ERR] memberlist: Conflicting address for %s. Mine: %v:%d Theirs: %v:%d", - state.Name, state.Addr, state.Port, net.IP(a.Addr), a.Port) - - // Inform the conflict delegate if provided - if m.config.Conflict != nil { - other := Node{ - Name: a.Node, - Addr: a.Addr, - Port: a.Port, - Meta: a.Meta, + } else { + // Check if this address is different than the existing node unless the old node is dead. + if !bytes.Equal([]byte(state.Addr), a.Addr) || state.Port != a.Port { + errCon := m.config.IPAllowed(a.Addr) + if errCon != nil { + m.logger.Printf("[WARN] memberlist: Rejected IP update from %v to %v for node %s: %s", a.Node, state.Addr, net.IP(a.Addr), errCon) + return + } + // If DeadNodeReclaimTime is configured, check if enough time has elapsed since the node died. + canReclaim := (m.config.DeadNodeReclaimTime > 0 && + time.Since(state.StateChange) > m.config.DeadNodeReclaimTime) + + // Allow the address to be updated if a dead node is being replaced. + if state.State == StateLeft || (state.State == StateDead && canReclaim) { + m.logger.Printf("[INFO] memberlist: Updating address for left or failed node %s from %v:%d to %v:%d", + state.Name, state.Addr, state.Port, net.IP(a.Addr), a.Port) + updatesNode = true + } else { + m.logger.Printf("[ERR] memberlist: Conflicting address for %s. Mine: %v:%d Theirs: %v:%d Old state: %v", + state.Name, state.Addr, state.Port, net.IP(a.Addr), a.Port, state.State) + + // Inform the conflict delegate if provided + if m.config.Conflict != nil { + other := Node{ + Name: a.Node, + Addr: a.Addr, + Port: a.Port, + Meta: a.Meta, + } + m.config.Conflict.NotifyConflict(&state.Node, &other) + } + return } - m.config.Conflict.NotifyConflict(&state.Node, &other) } - return } // Bail if the incarnation number is older, and this is not about us isLocalNode := state.Name == m.config.Name - if a.Incarnation <= state.Incarnation && !isLocalNode { + if a.Incarnation <= state.Incarnation && !isLocalNode && !updatesNode { return } @@ -956,9 +1098,8 @@ func (m *Memberlist) aliveNode(a *alive, notify chan struct{}, bootstrap bool) { bytes.Equal(a.Vsn, versions) { return } - m.refute(state, a.Incarnation) - m.logger.Printf("[WARN] memberlist: Refuting an alive message") + m.logger.Printf("[WARN] memberlist: Refuting an alive message for '%s' (%v:%d) meta:(%v VS %v), vsn:(%v VS %v)", a.Node, net.IP(a.Addr), a.Port, a.Meta, state.Meta, a.Vsn, versions) } else { m.encodeBroadcastNotify(a.Node, aliveMsg, a, notify) @@ -975,19 +1116,21 @@ func (m *Memberlist) aliveNode(a *alive, notify chan struct{}, bootstrap bool) { // Update the state and incarnation number state.Incarnation = a.Incarnation state.Meta = a.Meta - if state.State != stateAlive { - state.State = stateAlive + state.Addr = a.Addr + state.Port = a.Port + if state.State != StateAlive { + state.State = StateAlive state.StateChange = time.Now() } } // Update metrics - metrics.IncrCounter([]string{"memberlist", "msg", "alive"}, 1) + metrics.IncrCounterWithLabels([]string{"memberlist", "msg", "alive"}, 1, m.metricLabels) // Notify the delegate of any relevant updates if m.config.Events != nil { - if oldState == stateDead { - // if Dead -> Alive, notify of join + if oldState == StateDead || oldState == StateLeft { + // if Dead/Left -> Alive, notify of join m.config.Events.NotifyJoin(&state.Node) } else if !bytes.Equal(oldMeta, state.Meta) { @@ -1026,7 +1169,7 @@ func (m *Memberlist) suspectNode(s *suspect) { } // Ignore non-alive nodes - if state.State != stateAlive { + if state.State != StateAlive { return } @@ -1040,11 +1183,11 @@ func (m *Memberlist) suspectNode(s *suspect) { } // Update metrics - metrics.IncrCounter([]string{"memberlist", "msg", "suspect"}, 1) + metrics.IncrCounterWithLabels([]string{"memberlist", "msg", "suspect"}, 1, m.metricLabels) // Update the state state.Incarnation = s.Incarnation - state.State = stateSuspect + state.State = StateSuspect changeTime := time.Now() state.StateChange = changeTime @@ -1066,20 +1209,25 @@ func (m *Memberlist) suspectNode(s *suspect) { min := suspicionTimeout(m.config.SuspicionMult, n, m.config.ProbeInterval) max := time.Duration(m.config.SuspicionMaxTimeoutMult) * min fn := func(numConfirmations int) { + var d *dead + m.nodeLock.Lock() state, ok := m.nodeMap[s.Node] - timeout := ok && state.State == stateSuspect && state.StateChange == changeTime + timeout := ok && state.State == StateSuspect && state.StateChange == changeTime + if timeout { + d = &dead{Incarnation: state.Incarnation, Node: state.Name, From: m.config.Name} + } m.nodeLock.Unlock() if timeout { if k > 0 && numConfirmations < k { - metrics.IncrCounter([]string{"memberlist", "degraded", "timeout"}, 1) + metrics.IncrCounterWithLabels([]string{"memberlist", "degraded", "timeout"}, 1, m.metricLabels) } m.logger.Printf("[INFO] memberlist: Marking %s as failed, suspect timeout reached (%d peer confirmations)", state.Name, numConfirmations) - d := dead{Incarnation: state.Incarnation, Node: state.Name, From: m.config.Name} - m.deadNode(&d) + + m.deadNode(d) } } m.nodeTimers[s.Node] = newSuspicion(s.From, k, min, max, fn) @@ -1106,7 +1254,7 @@ func (m *Memberlist) deadNode(d *dead) { delete(m.nodeTimers, d.Node) // Ignore if node is already dead - if state.State == stateDead { + if state.DeadOrLeft() { return } @@ -1126,11 +1274,18 @@ func (m *Memberlist) deadNode(d *dead) { } // Update metrics - metrics.IncrCounter([]string{"memberlist", "msg", "dead"}, 1) + metrics.IncrCounterWithLabels([]string{"memberlist", "msg", "dead"}, 1, m.metricLabels) // Update the state state.Incarnation = d.Incarnation - state.State = stateDead + + // If the dead message was send by the node itself, mark it is left + // instead of dead. + if d.Node == d.From { + state.State = StateLeft + } else { + state.State = StateDead + } state.StateChange = time.Now() // Notify of death @@ -1144,7 +1299,7 @@ func (m *Memberlist) deadNode(d *dead) { func (m *Memberlist) mergeState(remote []pushNodeState) { for _, r := range remote { switch r.State { - case stateAlive: + case StateAlive: a := alive{ Incarnation: r.Incarnation, Node: r.Name, @@ -1155,11 +1310,14 @@ func (m *Memberlist) mergeState(remote []pushNodeState) { } m.aliveNode(&a, nil, false) - case stateDead: + case StateLeft: + d := dead{Incarnation: r.Incarnation, Node: r.Name, From: r.Name} + m.deadNode(&d) + case StateDead: // If the remote node believes a node is dead, we prefer to // suspect that node instead of declaring it dead instantly fallthrough - case stateSuspect: + case StateSuspect: s := suspect{Incarnation: r.Incarnation, Node: r.Name, From: m.config.Name} m.suspectNode(&s) } diff --git a/vendor/github.com/hashicorp/memberlist/transport.go b/vendor/github.com/hashicorp/memberlist/transport.go index 6ce55ea47f..f3d05364d7 100644 --- a/vendor/github.com/hashicorp/memberlist/transport.go +++ b/vendor/github.com/hashicorp/memberlist/transport.go @@ -1,6 +1,7 @@ package memberlist import ( + "fmt" "net" "time" ) @@ -63,3 +64,97 @@ type Transport interface { // transport a chance to clean up any listeners. Shutdown() error } + +type Address struct { + // Addr is a network address as a string, similar to Dial. This usually is + // in the form of "host:port". This is required. + Addr string + + // Name is the name of the node being addressed. This is optional but + // transports may require it. + Name string +} + +func (a *Address) String() string { + if a.Name != "" { + return fmt.Sprintf("%s (%s)", a.Name, a.Addr) + } + return a.Addr +} + +// IngestionAwareTransport is not used. +// +// Deprecated: IngestionAwareTransport is not used and may be removed in a future +// version. Define the interface locally instead of referencing this exported +// interface. +type IngestionAwareTransport interface { + IngestPacket(conn net.Conn, addr net.Addr, now time.Time, shouldClose bool) error + IngestStream(conn net.Conn) error +} + +type NodeAwareTransport interface { + Transport + WriteToAddress(b []byte, addr Address) (time.Time, error) + DialAddressTimeout(addr Address, timeout time.Duration) (net.Conn, error) +} + +type shimNodeAwareTransport struct { + Transport +} + +var _ NodeAwareTransport = (*shimNodeAwareTransport)(nil) + +func (t *shimNodeAwareTransport) WriteToAddress(b []byte, addr Address) (time.Time, error) { + return t.WriteTo(b, addr.Addr) +} + +func (t *shimNodeAwareTransport) DialAddressTimeout(addr Address, timeout time.Duration) (net.Conn, error) { + return t.DialTimeout(addr.Addr, timeout) +} + +type labelWrappedTransport struct { + label string + NodeAwareTransport +} + +var _ NodeAwareTransport = (*labelWrappedTransport)(nil) + +func (t *labelWrappedTransport) WriteToAddress(buf []byte, addr Address) (time.Time, error) { + var err error + buf, err = AddLabelHeaderToPacket(buf, t.label) + if err != nil { + return time.Time{}, fmt.Errorf("failed to add label header to packet: %w", err) + } + return t.NodeAwareTransport.WriteToAddress(buf, addr) +} + +func (t *labelWrappedTransport) WriteTo(buf []byte, addr string) (time.Time, error) { + var err error + buf, err = AddLabelHeaderToPacket(buf, t.label) + if err != nil { + return time.Time{}, err + } + return t.NodeAwareTransport.WriteTo(buf, addr) +} + +func (t *labelWrappedTransport) DialAddressTimeout(addr Address, timeout time.Duration) (net.Conn, error) { + conn, err := t.NodeAwareTransport.DialAddressTimeout(addr, timeout) + if err != nil { + return nil, err + } + if err := AddLabelHeaderToStream(conn, t.label); err != nil { + return nil, fmt.Errorf("failed to add label header to stream: %w", err) + } + return conn, nil +} + +func (t *labelWrappedTransport) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) { + conn, err := t.NodeAwareTransport.DialTimeout(addr, timeout) + if err != nil { + return nil, err + } + if err := AddLabelHeaderToStream(conn, t.label); err != nil { + return nil, fmt.Errorf("failed to add label header to stream: %w", err) + } + return conn, nil +} diff --git a/vendor/github.com/hashicorp/memberlist/util.go b/vendor/github.com/hashicorp/memberlist/util.go index e2381a6986..24112210df 100644 --- a/vendor/github.com/hashicorp/memberlist/util.go +++ b/vendor/github.com/hashicorp/memberlist/util.go @@ -78,10 +78,9 @@ func retransmitLimit(retransmitMult, n int) int { // shuffleNodes randomly shuffles the input nodes using the Fisher-Yates shuffle func shuffleNodes(nodes []*nodeState) { n := len(nodes) - for i := n - 1; i > 0; i-- { - j := rand.Intn(i + 1) + rand.Shuffle(n, func(i, j int) { nodes[i], nodes[j] = nodes[j], nodes[i] - } + }) } // pushPushScale is used to scale the time interval at which push/pull @@ -97,13 +96,13 @@ func pushPullScale(interval time.Duration, n int) time.Duration { return time.Duration(multiplier) * interval } -// moveDeadNodes moves nodes that are dead and beyond the gossip to the dead interval +// moveDeadNodes moves dead and left nodes that that have not changed during the gossipToTheDeadTime interval // to the end of the slice and returns the index of the first moved node. func moveDeadNodes(nodes []*nodeState, gossipToTheDeadTime time.Duration) int { numDead := 0 n := len(nodes) for i := 0; i < n-numDead; i++ { - if nodes[i].State != stateDead { + if !nodes[i].DeadOrLeft() { continue } @@ -120,39 +119,56 @@ func moveDeadNodes(nodes []*nodeState, gossipToTheDeadTime time.Duration) int { return n - numDead } -// kRandomNodes is used to select up to k random nodes, excluding any nodes where -// the filter function returns true. It is possible that less than k nodes are +// kRandomNodes is used to select up to k random Nodes, excluding any nodes where +// the exclude function returns true. It is possible that less than k nodes are // returned. -func kRandomNodes(k int, nodes []*nodeState, filterFn func(*nodeState) bool) []*nodeState { +func kRandomNodes(k int, nodes []*nodeState, exclude func(*nodeState) bool) []Node { n := len(nodes) - kNodes := make([]*nodeState, 0, k) + kNodes := make([]Node, 0, k) OUTER: // Probe up to 3*n times, with large n this is not necessary // since k << n, but with small n we want search to be // exhaustive for i := 0; i < 3*n && len(kNodes) < k; i++ { - // Get random node + // Get random nodeState idx := randomOffset(n) - node := nodes[idx] + state := nodes[idx] // Give the filter a shot at it. - if filterFn != nil && filterFn(node) { + if exclude != nil && exclude(state) { continue OUTER } // Check if we have this node already for j := 0; j < len(kNodes); j++ { - if node == kNodes[j] { + if state.Node.Name == kNodes[j].Name { continue OUTER } } // Append the node - kNodes = append(kNodes, node) + kNodes = append(kNodes, state.Node) } return kNodes } +// makeCompoundMessages takes a list of messages and packs +// them into one or multiple messages based on the limitations +// of compound messages (255 messages each). +func makeCompoundMessages(msgs [][]byte) []*bytes.Buffer { + const maxMsgs = 255 + bufs := make([]*bytes.Buffer, 0, (len(msgs)+(maxMsgs-1))/maxMsgs) + + for ; len(msgs) > maxMsgs; msgs = msgs[maxMsgs:] { + bufs = append(bufs, makeCompoundMessage(msgs[:maxMsgs])) + } + if len(msgs) > 0 { + bufs = append(bufs, makeCompoundMessage(msgs)) + } + + return bufs +} + // makeCompoundMessage takes a list of messages and generates // a single compound message containing all of them func makeCompoundMessage(msgs [][]byte) *bytes.Buffer { @@ -186,18 +202,18 @@ func decodeCompoundMessage(buf []byte) (trunc int, parts [][]byte, err error) { err = fmt.Errorf("missing compound length byte") return } - numParts := uint8(buf[0]) + numParts := int(buf[0]) buf = buf[1:] // Check we have enough bytes - if len(buf) < int(numParts*2) { + if len(buf) < numParts*2 { err = fmt.Errorf("truncated len slice") return } // Decode the lengths lengths := make([]uint16, numParts) - for i := 0; i < int(numParts); i++ { + for i := 0; i < numParts; i++ { lengths[i] = binary.BigEndian.Uint16(buf[i*2 : i*2+2]) } buf = buf[numParts*2:] @@ -205,7 +221,7 @@ func decodeCompoundMessage(buf []byte) (trunc int, parts [][]byte, err error) { // Split each message for idx, msgLen := range lengths { if len(buf) < int(msgLen) { - trunc = int(numParts) - idx + trunc = numParts - idx return }