vendor: github.com/hasicorp/memberlist v0.4.0

Signed-off-by: Bjorn Neergaard <bneergaard@mirantis.com>
This commit is contained in:
Bjorn Neergaard 2022-09-14 12:15:12 -06:00
parent cc3aa33f54
commit c2755f40cd
20 changed files with 1787 additions and 373 deletions

View file

@ -52,7 +52,7 @@ github.com/docker/go-events e31b211e4f1cd09aa76fe4ac2445
github.com/armon/go-radix e39d623f12e8e41c7b5529e9a9dd67a1e2261f80
github.com/armon/go-metrics eb0af217e5e9747e41dd5303755356b62d28e3ec
github.com/hashicorp/go-msgpack 71c2886f5a673a35f909803f38ece5810165097b
github.com/hashicorp/memberlist 3d8438da9589e7b608a83ffac1ef8211486bcb7c
github.com/hashicorp/memberlist e6ff9b2d87a3f0f3f04abb5672ada3ac2a640223 # v0.4.0
github.com/sean-/seed e2103e2c35297fb7e17febb81e49b312087a2372
github.com/hashicorp/errwrap 8a6fb523712970c966eefc6b39ed2c5e74880354 # v1.0.0
github.com/hashicorp/go-sockaddr c7188e74f6acae5a989bdc959aa779f8b9f42faf # v1.0.2

View file

@ -1,4 +1,4 @@
# memberlist [![GoDoc](https://godoc.org/github.com/hashicorp/memberlist?status.png)](https://godoc.org/github.com/hashicorp/memberlist)
# memberlist [![GoDoc](https://godoc.org/github.com/hashicorp/memberlist?status.png)](https://godoc.org/github.com/hashicorp/memberlist) [![CircleCI](https://circleci.com/gh/hashicorp/memberlist.svg?style=svg)](https://circleci.com/gh/hashicorp/memberlist)
memberlist is a [Go](http://www.golang.org) library that manages cluster
membership and member failure detection using a gossip based protocol.
@ -23,8 +23,6 @@ Please check your installation with:
go version
```
Run `make deps` to fetch dependencies before building
## Usage
Memberlist is surprisingly simple to use. An example is shown below:
@ -65,7 +63,7 @@ For complete documentation, see the associated [Godoc](http://godoc.org/github.c
## Protocol
memberlist is based on ["SWIM: Scalable Weakly-consistent Infection-style Process Group Membership Protocol"](http://www.cs.cornell.edu/~asdas/research/dsn02-swim.pdf). However, we extend the protocol in a number of ways:
memberlist is based on ["SWIM: Scalable Weakly-consistent Infection-style Process Group Membership Protocol"](http://ieeexplore.ieee.org/document/1028914/). However, we extend the protocol in a number of ways:
* Several extensions are made to increase propagation speed and
convergence rate.

View file

@ -7,8 +7,8 @@ package memberlist
// a node out and prevent it from being considered a peer
// using application specific logic.
type AliveDelegate interface {
// NotifyMerge is invoked when a merge could take place.
// Provides a list of the nodes known by the peer. If
// the return value is non-nil, the merge is canceled.
// NotifyAlive is invoked when a message about a live
// node is received from the network. Returning a non-nil
// error prevents the node from being considered a peer.
NotifyAlive(peer *Node) error
}

View file

@ -21,13 +21,17 @@ type awareness struct {
// score is the current awareness score. Lower values are healthier and
// zero is the minimum value.
score int
// metricLabels is the slice of labels to put on all emitted metrics
metricLabels []metrics.Label
}
// newAwareness returns a new awareness object.
func newAwareness(max int) *awareness {
func newAwareness(max int, metricLabels []metrics.Label) *awareness {
return &awareness{
max: max,
score: 0,
max: max,
score: 0,
metricLabels: metricLabels,
}
}
@ -47,7 +51,7 @@ func (a *awareness) ApplyDelta(delta int) {
a.Unlock()
if initial != final {
metrics.SetGauge([]string{"memberlist", "health", "score"}, float32(final))
metrics.SetGaugeWithLabels([]string{"memberlist", "health", "score"}, float32(final), a.metricLabels)
}
}

View file

@ -29,6 +29,11 @@ func (b *memberlistBroadcast) Invalidates(other Broadcast) bool {
return b.node == mb.node
}
// memberlist.NamedBroadcast optional interface
func (b *memberlistBroadcast) Name() string {
return b.node
}
func (b *memberlistBroadcast) Message() []byte {
return b.msg
}

View file

@ -1,10 +1,16 @@
package memberlist
import (
"fmt"
"io"
"log"
"net"
"os"
"strings"
"time"
"github.com/armon/go-metrics"
multierror "github.com/hashicorp/go-multierror"
)
type Config struct {
@ -16,6 +22,17 @@ type Config struct {
// make a NetTransport using BindAddr and BindPort from this structure.
Transport Transport
// Label is an optional set of bytes to include on the outside of each
// packet and stream.
//
// If gossip encryption is enabled and this is set it is treated as GCM
// authenticated data.
Label string
// SkipInboundLabelCheck skips the check that inbound packets and gossip
// streams need to be label prefixed.
SkipInboundLabelCheck bool
// Configuration related to what address to bind to and ports to
// listen on. The port is used for both UDP and TCP gossip. It is
// assumed other nodes are running on this port, but they do not need
@ -116,6 +133,10 @@ type Config struct {
// indirect UDP pings.
DisableTcpPings bool
// DisableTcpPingsForNode is like DisableTcpPings, but lets you control
// whether to perform TCP pings on a node-by-node basis.
DisableTcpPingsForNode func(nodeName string) bool
// AwarenessMaxMultiplier will increase the probe interval if the node
// becomes aware that it might be degraded and not meeting the soft real
// time requirements to reliably probe other nodes.
@ -215,6 +236,48 @@ type Config struct {
// This is a legacy name for backward compatibility but should really be
// called PacketBufferSize now that we have generalized the transport.
UDPBufferSize int
// DeadNodeReclaimTime controls the time before a dead node's name can be
// reclaimed by one with a different address or port. By default, this is 0,
// meaning nodes cannot be reclaimed this way.
DeadNodeReclaimTime time.Duration
// RequireNodeNames controls if the name of a node is required when sending
// a message to that node.
RequireNodeNames bool
// CIDRsAllowed If nil, allow any connection (default), otherwise specify all networks
// allowed to connect (you must specify IPv6/IPv4 separately)
// Using [] will block all connections.
CIDRsAllowed []net.IPNet
// MetricLabels is a map of optional labels to apply to all metrics emitted.
MetricLabels []metrics.Label
}
// ParseCIDRs return a possible empty list of all Network that have been parsed
// In case of error, it returns succesfully parsed CIDRs and the last error found
func ParseCIDRs(v []string) ([]net.IPNet, error) {
nets := make([]net.IPNet, 0)
if v == nil {
return nets, nil
}
var errs error
hasErrors := false
for _, p := range v {
_, net, err := net.ParseCIDR(strings.TrimSpace(p))
if err != nil {
err = fmt.Errorf("invalid cidr: %s", p)
errs = multierror.Append(errs, err)
hasErrors = true
} else {
nets = append(nets, *net)
}
}
if !hasErrors {
errs = nil
}
return nets, errs
}
// DefaultLANConfig returns a sane set of configurations for Memberlist.
@ -258,6 +321,7 @@ func DefaultLANConfig() *Config {
HandoffQueueDepth: 1024,
UDPBufferSize: 1400,
CIDRsAllowed: nil, // same as allow all
}
}
@ -277,6 +341,24 @@ func DefaultWANConfig() *Config {
return conf
}
// IPMustBeChecked return true if IPAllowed must be called
func (c *Config) IPMustBeChecked() bool {
return len(c.CIDRsAllowed) > 0
}
// IPAllowed return an error if access to memberlist is denied
func (c *Config) IPAllowed(ip net.IP) error {
if !c.IPMustBeChecked() {
return nil
}
for _, n := range c.CIDRsAllowed {
if n.Contains(ip) {
return nil
}
}
return fmt.Errorf("%s is not allowed", ip)
}
// DefaultLocalConfig works like DefaultConfig, however it returns a configuration
// that is optimized for a local loopback environments. The default configuration is
// still very conservative and errs on the side of caution.

View file

@ -49,13 +49,16 @@ type NodeEvent struct {
}
func (c *ChannelEventDelegate) NotifyJoin(n *Node) {
c.Ch <- NodeEvent{NodeJoin, n}
node := *n
c.Ch <- NodeEvent{NodeJoin, &node}
}
func (c *ChannelEventDelegate) NotifyLeave(n *Node) {
c.Ch <- NodeEvent{NodeLeave, n}
node := *n
c.Ch <- NodeEvent{NodeLeave, &node}
}
func (c *ChannelEventDelegate) NotifyUpdate(n *Node) {
c.Ch <- NodeEvent{NodeUpdate, n}
node := *n
c.Ch <- NodeEvent{NodeUpdate, &node}
}

19
vendor/github.com/hashicorp/memberlist/go.mod generated vendored Normal file
View file

@ -0,0 +1,19 @@
module github.com/hashicorp/memberlist
go 1.12
require (
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c
github.com/hashicorp/go-immutable-radix v1.0.0 // indirect
github.com/hashicorp/go-msgpack v0.5.3
github.com/hashicorp/go-multierror v1.0.0
github.com/hashicorp/go-sockaddr v1.0.0
github.com/miekg/dns v1.1.26
github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529
github.com/stretchr/testify v1.2.2
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect
)

178
vendor/github.com/hashicorp/memberlist/label.go generated vendored Normal file
View file

@ -0,0 +1,178 @@
package memberlist
import (
"bufio"
"fmt"
"io"
"net"
)
// General approach is to prefix all packets and streams with the same structure:
//
// magic type byte (244): uint8
// length of label name: uint8 (because labels can't be longer than 255 bytes)
// label name: []uint8
// LabelMaxSize is the maximum length of a packet or stream label.
const LabelMaxSize = 255
// AddLabelHeaderToPacket prefixes outgoing packets with the correct header if
// the label is not empty.
func AddLabelHeaderToPacket(buf []byte, label string) ([]byte, error) {
if label == "" {
return buf, nil
}
if len(label) > LabelMaxSize {
return nil, fmt.Errorf("label %q is too long", label)
}
return makeLabelHeader(label, buf), nil
}
// RemoveLabelHeaderFromPacket removes any label header from the provided
// packet and returns it along with the remaining packet contents.
func RemoveLabelHeaderFromPacket(buf []byte) (newBuf []byte, label string, err error) {
if len(buf) == 0 {
return buf, "", nil // can't possibly be labeled
}
// [type:byte] [size:byte] [size bytes]
msgType := messageType(buf[0])
if msgType != hasLabelMsg {
return buf, "", nil
}
if len(buf) < 2 {
return nil, "", fmt.Errorf("cannot decode label; packet has been truncated")
}
size := int(buf[1])
if size < 1 {
return nil, "", fmt.Errorf("label header cannot be empty when present")
}
if len(buf) < 2+size {
return nil, "", fmt.Errorf("cannot decode label; packet has been truncated")
}
label = string(buf[2 : 2+size])
newBuf = buf[2+size:]
return newBuf, label, nil
}
// AddLabelHeaderToStream prefixes outgoing streams with the correct header if
// the label is not empty.
func AddLabelHeaderToStream(conn net.Conn, label string) error {
if label == "" {
return nil
}
if len(label) > LabelMaxSize {
return fmt.Errorf("label %q is too long", label)
}
header := makeLabelHeader(label, nil)
_, err := conn.Write(header)
return err
}
// RemoveLabelHeaderFromStream removes any label header from the beginning of
// the stream if present and returns it along with an updated conn with that
// header removed.
//
// Note that on error it is the caller's responsibility to close the
// connection.
func RemoveLabelHeaderFromStream(conn net.Conn) (net.Conn, string, error) {
br := bufio.NewReader(conn)
// First check for the type byte.
peeked, err := br.Peek(1)
if err != nil {
if err == io.EOF {
// It is safe to return the original net.Conn at this point because
// it never contained any data in the first place so we don't have
// to splice the buffer into the conn because both are empty.
return conn, "", nil
}
return nil, "", err
}
msgType := messageType(peeked[0])
if msgType != hasLabelMsg {
conn, err = newPeekedConnFromBufferedReader(conn, br, 0)
return conn, "", err
}
// We are guaranteed to get a size byte as well.
peeked, err = br.Peek(2)
if err != nil {
if err == io.EOF {
return nil, "", fmt.Errorf("cannot decode label; stream has been truncated")
}
return nil, "", err
}
size := int(peeked[1])
if size < 1 {
return nil, "", fmt.Errorf("label header cannot be empty when present")
}
// NOTE: we don't have to check this against LabelMaxSize because a byte
// already has a max value of 255.
// Once we know the size we can peek the label as well. Note that since we
// are using the default bufio.Reader size of 4096, the entire label header
// fits in the initial buffer fill so this should be free.
peeked, err = br.Peek(2 + size)
if err != nil {
if err == io.EOF {
return nil, "", fmt.Errorf("cannot decode label; stream has been truncated")
}
return nil, "", err
}
label := string(peeked[2 : 2+size])
conn, err = newPeekedConnFromBufferedReader(conn, br, 2+size)
if err != nil {
return nil, "", err
}
return conn, label, nil
}
// newPeekedConnFromBufferedReader will splice the buffer contents after the
// offset into the provided net.Conn and return the result so that the rest of
// the buffer contents are returned first when reading from the returned
// peekedConn before moving on to the unbuffered conn contents.
func newPeekedConnFromBufferedReader(conn net.Conn, br *bufio.Reader, offset int) (*peekedConn, error) {
// Extract any of the readahead buffer.
peeked, err := br.Peek(br.Buffered())
if err != nil {
return nil, err
}
return &peekedConn{
Peeked: peeked[offset:],
Conn: conn,
}, nil
}
func makeLabelHeader(label string, rest []byte) []byte {
newBuf := make([]byte, 2, 2+len(label)+len(rest))
newBuf[0] = byte(hasLabelMsg)
newBuf[1] = byte(len(label))
newBuf = append(newBuf, []byte(label)...)
if len(rest) > 0 {
newBuf = append(newBuf, []byte(rest)...)
}
return newBuf
}
func labelOverhead(label string) int {
if label == "" {
return 0
}
return 2 + len(label)
}

View file

@ -13,6 +13,14 @@ func LogAddress(addr net.Addr) string {
return fmt.Sprintf("from=%s", addr.String())
}
func LogStringAddress(addr string) string {
if addr == "" {
return "from=<unknown address>"
}
return fmt.Sprintf("from=%s", addr)
}
func LogConn(conn net.Conn) string {
if conn == nil {
return LogAddress(nil)

View file

@ -15,6 +15,8 @@ multiple routes.
package memberlist
import (
"container/list"
"errors"
"fmt"
"log"
"net"
@ -25,15 +27,23 @@ import (
"sync/atomic"
"time"
"github.com/armon/go-metrics"
multierror "github.com/hashicorp/go-multierror"
sockaddr "github.com/hashicorp/go-sockaddr"
"github.com/miekg/dns"
)
var errNodeNamesAreRequired = errors.New("memberlist: node names are required by configuration but one was not provided")
type Memberlist struct {
sequenceNum uint32 // Local sequence number
incarnation uint32 // Local incarnation number
numNodes uint32 // Number of known nodes (estimate)
pushPullReq uint32 // Number of push/pull requests
advertiseLock sync.RWMutex
advertiseAddr net.IP
advertisePort uint16
config *Config
shutdown int32 // Used as an atomic boolean value
@ -44,13 +54,17 @@ type Memberlist struct {
shutdownLock sync.Mutex // Serializes calls to Shutdown
leaveLock sync.Mutex // Serializes calls to Leave
transport Transport
handoff chan msgHandoff
transport NodeAwareTransport
handoffCh chan struct{}
highPriorityMsgQueue *list.List
lowPriorityMsgQueue *list.List
msgQueueLock sync.Mutex
nodeLock sync.RWMutex
nodes []*nodeState // Known nodes
nodeMap map[string]*nodeState // Maps Addr.String() -> NodeState
nodeTimers map[string]*suspicion // Maps Addr.String() -> suspicion timer
nodeMap map[string]*nodeState // Maps Node.Name -> NodeState
nodeTimers map[string]*suspicion // Maps Node.Name -> suspicion timer
awareness *awareness
tickerLock sync.Mutex
@ -64,6 +78,18 @@ type Memberlist struct {
broadcasts *TransmitLimitedQueue
logger *log.Logger
// metricLabels is the slice of labels to put on all emitted metrics
metricLabels []metrics.Label
}
// BuildVsnArray creates the array of Vsn
func (conf *Config) BuildVsnArray() []uint8 {
return []uint8{
ProtocolVersionMin, ProtocolVersionMax, conf.ProtocolVersion,
conf.DelegateProtocolMin, conf.DelegateProtocolMax,
conf.DelegateProtocolVersion,
}
}
// newMemberlist creates the network listeners.
@ -113,9 +139,10 @@ func newMemberlist(conf *Config) (*Memberlist, error) {
transport := conf.Transport
if transport == nil {
nc := &NetTransportConfig{
BindAddrs: []string{conf.BindAddr},
BindPort: conf.BindPort,
Logger: logger,
BindAddrs: []string{conf.BindAddr},
BindPort: conf.BindPort,
Logger: logger,
MetricLabels: conf.MetricLabels,
}
// See comment below for details about the retry in here.
@ -159,22 +186,50 @@ func newMemberlist(conf *Config) (*Memberlist, error) {
transport = nt
}
nodeAwareTransport, ok := transport.(NodeAwareTransport)
if !ok {
logger.Printf("[DEBUG] memberlist: configured Transport is not a NodeAwareTransport and some features may not work as desired")
nodeAwareTransport = &shimNodeAwareTransport{transport}
}
if len(conf.Label) > LabelMaxSize {
return nil, fmt.Errorf("could not use %q as a label: too long", conf.Label)
}
if conf.Label != "" {
nodeAwareTransport = &labelWrappedTransport{
label: conf.Label,
NodeAwareTransport: nodeAwareTransport,
}
}
m := &Memberlist{
config: conf,
shutdownCh: make(chan struct{}),
leaveBroadcast: make(chan struct{}, 1),
transport: transport,
handoff: make(chan msgHandoff, conf.HandoffQueueDepth),
nodeMap: make(map[string]*nodeState),
nodeTimers: make(map[string]*suspicion),
awareness: newAwareness(conf.AwarenessMaxMultiplier),
ackHandlers: make(map[uint32]*ackHandler),
broadcasts: &TransmitLimitedQueue{RetransmitMult: conf.RetransmitMult},
logger: logger,
config: conf,
shutdownCh: make(chan struct{}),
leaveBroadcast: make(chan struct{}, 1),
transport: nodeAwareTransport,
handoffCh: make(chan struct{}, 1),
highPriorityMsgQueue: list.New(),
lowPriorityMsgQueue: list.New(),
nodeMap: make(map[string]*nodeState),
nodeTimers: make(map[string]*suspicion),
awareness: newAwareness(conf.AwarenessMaxMultiplier, conf.MetricLabels),
ackHandlers: make(map[uint32]*ackHandler),
broadcasts: &TransmitLimitedQueue{RetransmitMult: conf.RetransmitMult},
logger: logger,
metricLabels: conf.MetricLabels,
}
m.broadcasts.NumNodes = func() int {
return m.estNumNodes()
}
// Get the final advertise address from the transport, which may need
// to see which address we bound to. We'll refresh this each time we
// send out an alive message.
if _, _, err := m.refreshAdvertise(); err != nil {
return nil, err
}
go m.streamListen()
go m.packetListen()
go m.packetHandler()
@ -222,8 +277,9 @@ func (m *Memberlist) Join(existing []string) (int, error) {
for _, addr := range addrs {
hp := joinHostPort(addr.ip.String(), addr.port)
if err := m.pushPullNode(hp, true); err != nil {
err = fmt.Errorf("Failed to join %s: %v", addr.ip, err)
a := Address{Addr: hp, Name: addr.nodeName}
if err := m.pushPullNode(a, true); err != nil {
err = fmt.Errorf("Failed to join %s: %v", a.Addr, err)
errs = multierror.Append(errs, err)
m.logger.Printf("[DEBUG] memberlist: %v", err)
continue
@ -240,8 +296,9 @@ func (m *Memberlist) Join(existing []string) (int, error) {
// ipPort holds information about a node we want to try to join.
type ipPort struct {
ip net.IP
port uint16
ip net.IP
port uint16
nodeName string // optional
}
// tcpLookupIP is a helper to initiate a TCP-based DNS lookup for the given host.
@ -250,7 +307,7 @@ type ipPort struct {
// Consul's. By doing the TCP lookup directly, we get the best chance for the
// largest list of hosts to join. Since joins are relatively rare events, it's ok
// to do this rather expensive operation.
func (m *Memberlist) tcpLookupIP(host string, defaultPort uint16) ([]ipPort, error) {
func (m *Memberlist) tcpLookupIP(host string, defaultPort uint16, nodeName string) ([]ipPort, error) {
// Don't attempt any TCP lookups against non-fully qualified domain
// names, since those will likely come from the resolv.conf file.
if !strings.Contains(host, ".") {
@ -292,9 +349,9 @@ func (m *Memberlist) tcpLookupIP(host string, defaultPort uint16) ([]ipPort, err
for _, r := range in.Answer {
switch rr := r.(type) {
case (*dns.A):
ips = append(ips, ipPort{rr.A, defaultPort})
ips = append(ips, ipPort{ip: rr.A, port: defaultPort, nodeName: nodeName})
case (*dns.AAAA):
ips = append(ips, ipPort{rr.AAAA, defaultPort})
ips = append(ips, ipPort{ip: rr.AAAA, port: defaultPort, nodeName: nodeName})
case (*dns.CNAME):
m.logger.Printf("[DEBUG] memberlist: Ignoring CNAME RR in TCP-first answer for '%s'", host)
}
@ -308,6 +365,16 @@ func (m *Memberlist) tcpLookupIP(host string, defaultPort uint16) ([]ipPort, err
// resolveAddr is used to resolve the address into an address,
// port, and error. If no port is given, use the default
func (m *Memberlist) resolveAddr(hostStr string) ([]ipPort, error) {
// First peel off any leading node name. This is optional.
nodeName := ""
if slashIdx := strings.Index(hostStr, "/"); slashIdx >= 0 {
if slashIdx == 0 {
return nil, fmt.Errorf("empty node name provided")
}
nodeName = hostStr[0:slashIdx]
hostStr = hostStr[slashIdx+1:]
}
// This captures the supplied port, or the default one.
hostStr = ensurePort(hostStr, m.config.BindPort)
host, sport, err := net.SplitHostPort(hostStr)
@ -324,13 +391,15 @@ func (m *Memberlist) resolveAddr(hostStr string) ([]ipPort, error) {
// will make sure the host part is in good shape for parsing, even for
// IPv6 addresses.
if ip := net.ParseIP(host); ip != nil {
return []ipPort{ipPort{ip, port}}, nil
return []ipPort{
ipPort{ip: ip, port: port, nodeName: nodeName},
}, nil
}
// First try TCP so we have the best chance for the largest list of
// hosts to join. If this fails it's not fatal since this isn't a standard
// way to query DNS, and we have a fallback below.
ips, err := m.tcpLookupIP(host, port)
ips, err := m.tcpLookupIP(host, port, nodeName)
if err != nil {
m.logger.Printf("[DEBUG] memberlist: TCP-first lookup failed for '%s', falling back to UDP: %s", hostStr, err)
}
@ -347,7 +416,7 @@ func (m *Memberlist) resolveAddr(hostStr string) ([]ipPort, error) {
}
ips = make([]ipPort, 0, len(ans))
for _, ip := range ans {
ips = append(ips, ipPort{ip, port})
ips = append(ips, ipPort{ip: ip, port: port, nodeName: nodeName})
}
return ips, nil
}
@ -358,10 +427,9 @@ func (m *Memberlist) resolveAddr(hostStr string) ([]ipPort, error) {
func (m *Memberlist) setAlive() error {
// Get the final advertise address from the transport, which may need
// to see which address we bound to.
addr, port, err := m.transport.FinalAdvertiseAddr(
m.config.AdvertiseAddr, m.config.AdvertisePort)
addr, port, err := m.refreshAdvertise()
if err != nil {
return fmt.Errorf("Failed to get final advertise address: %v", err)
return err
}
// Check if this is a public address without encryption
@ -394,16 +462,36 @@ func (m *Memberlist) setAlive() error {
Addr: addr,
Port: uint16(port),
Meta: meta,
Vsn: []uint8{
ProtocolVersionMin, ProtocolVersionMax, m.config.ProtocolVersion,
m.config.DelegateProtocolMin, m.config.DelegateProtocolMax,
m.config.DelegateProtocolVersion,
},
Vsn: m.config.BuildVsnArray(),
}
m.aliveNode(&a, nil, true)
return nil
}
func (m *Memberlist) getAdvertise() (net.IP, uint16) {
m.advertiseLock.RLock()
defer m.advertiseLock.RUnlock()
return m.advertiseAddr, m.advertisePort
}
func (m *Memberlist) setAdvertise(addr net.IP, port int) {
m.advertiseLock.Lock()
defer m.advertiseLock.Unlock()
m.advertiseAddr = addr
m.advertisePort = uint16(port)
}
func (m *Memberlist) refreshAdvertise() (net.IP, int, error) {
addr, port, err := m.transport.FinalAdvertiseAddr(
m.config.AdvertiseAddr, m.config.AdvertisePort)
if err != nil {
return nil, 0, fmt.Errorf("Failed to get final advertise address: %v", err)
}
m.setAdvertise(addr, port)
return addr, port, nil
}
// LocalNode is used to return the local Node
func (m *Memberlist) LocalNode() *Node {
m.nodeLock.RLock()
@ -439,11 +527,7 @@ func (m *Memberlist) UpdateNode(timeout time.Duration) error {
Addr: state.Addr,
Port: state.Port,
Meta: meta,
Vsn: []uint8{
ProtocolVersionMin, ProtocolVersionMax, m.config.ProtocolVersion,
m.config.DelegateProtocolMin, m.config.DelegateProtocolMax,
m.config.DelegateProtocolVersion,
},
Vsn: m.config.BuildVsnArray(),
}
notifyCh := make(chan struct{})
m.aliveNode(&a, notifyCh, true)
@ -463,24 +547,29 @@ func (m *Memberlist) UpdateNode(timeout time.Duration) error {
return nil
}
// SendTo is deprecated in favor of SendBestEffort, which requires a node to
// target.
// Deprecated: SendTo is deprecated in favor of SendBestEffort, which requires a node to
// target. If you don't have a node then use SendToAddress.
func (m *Memberlist) SendTo(to net.Addr, msg []byte) error {
a := Address{Addr: to.String(), Name: ""}
return m.SendToAddress(a, msg)
}
func (m *Memberlist) SendToAddress(a Address, msg []byte) error {
// Encode as a user message
buf := make([]byte, 1, len(msg)+1)
buf[0] = byte(userMsg)
buf = append(buf, msg...)
// Send the message
return m.rawSendMsgPacket(to.String(), nil, buf)
return m.rawSendMsgPacket(a, nil, buf)
}
// SendToUDP is deprecated in favor of SendBestEffort.
// Deprecated: SendToUDP is deprecated in favor of SendBestEffort.
func (m *Memberlist) SendToUDP(to *Node, msg []byte) error {
return m.SendBestEffort(to, msg)
}
// SendToTCP is deprecated in favor of SendReliable.
// Deprecated: SendToTCP is deprecated in favor of SendReliable.
func (m *Memberlist) SendToTCP(to *Node, msg []byte) error {
return m.SendReliable(to, msg)
}
@ -496,7 +585,8 @@ func (m *Memberlist) SendBestEffort(to *Node, msg []byte) error {
buf = append(buf, msg...)
// Send the message
return m.rawSendMsgPacket(to.Address(), to, buf)
a := Address{Addr: to.Address(), Name: to.Name}
return m.rawSendMsgPacket(a, to, buf)
}
// SendReliable uses the reliable stream-oriented interface of the transport to
@ -504,7 +594,7 @@ func (m *Memberlist) SendBestEffort(to *Node, msg []byte) error {
// mechanism). Delivery is guaranteed if no error is returned, and there is no
// limit on the size of the message.
func (m *Memberlist) SendReliable(to *Node, msg []byte) error {
return m.sendUserMsg(to.Address(), msg)
return m.sendUserMsg(to.FullAddress(), msg)
}
// Members returns a list of all known live nodes. The node structures
@ -516,7 +606,7 @@ func (m *Memberlist) Members() []*Node {
nodes := make([]*Node, 0, len(m.nodes))
for _, n := range m.nodes {
if n.State != stateDead {
if !n.DeadOrLeft() {
nodes = append(nodes, &n.Node)
}
}
@ -533,7 +623,7 @@ func (m *Memberlist) NumMembers() (alive int) {
defer m.nodeLock.RUnlock()
for _, n := range m.nodes {
if n.State != stateDead {
if !n.DeadOrLeft() {
alive++
}
}
@ -570,9 +660,14 @@ func (m *Memberlist) Leave(timeout time.Duration) error {
return nil
}
// This dead message is special, because Node and From are the
// same. This helps other nodes figure out that a node left
// intentionally. When Node equals From, other nodes know for
// sure this node is gone.
d := dead{
Incarnation: state.Incarnation,
Node: state.Name,
From: state.Name,
}
m.deadNode(&d)
@ -598,7 +693,7 @@ func (m *Memberlist) anyAlive() bool {
m.nodeLock.RLock()
defer m.nodeLock.RUnlock()
for _, n := range m.nodes {
if n.State != stateDead && n.Name != m.config.Name {
if !n.DeadOrLeft() && n.Name != m.config.Name {
return true
}
}
@ -621,7 +716,7 @@ func (m *Memberlist) ProtocolVersion() uint8 {
return m.config.ProtocolVersion
}
// Shutdown will stop any background maintanence of network activity
// Shutdown will stop any background maintenance of network activity
// for this memberlist, causing it to appear "dead". A leave message
// will not be broadcasted prior, so the cluster being left will have
// to detect this node's shutdown using probing. If you wish to more
@ -657,3 +752,27 @@ func (m *Memberlist) hasShutdown() bool {
func (m *Memberlist) hasLeft() bool {
return atomic.LoadInt32(&m.leave) == 1
}
func (m *Memberlist) getNodeState(addr string) NodeStateType {
m.nodeLock.RLock()
defer m.nodeLock.RUnlock()
n := m.nodeMap[addr]
return n.State
}
func (m *Memberlist) getNodeStateChange(addr string) time.Time {
m.nodeLock.RLock()
defer m.nodeLock.RUnlock()
n := m.nodeMap[addr]
return n.StateChange
}
func (m *Memberlist) changeNode(addr string, f func(*nodeState)) {
m.nodeLock.Lock()
defer m.nodeLock.Unlock()
n := m.nodeMap[addr]
f(n)
}

View file

@ -1,7 +1,9 @@
package memberlist
import (
"bytes"
"fmt"
"io"
"net"
"strconv"
"time"
@ -10,26 +12,33 @@ import (
// MockNetwork is used as a factory that produces MockTransport instances which
// are uniquely addressed and wired up to talk to each other.
type MockNetwork struct {
transports map[string]*MockTransport
port int
transportsByAddr map[string]*MockTransport
transportsByName map[string]*MockTransport
port int
}
// NewTransport returns a new MockTransport with a unique address, wired up to
// talk to the other transports in the MockNetwork.
func (n *MockNetwork) NewTransport() *MockTransport {
func (n *MockNetwork) NewTransport(name string) *MockTransport {
n.port += 1
addr := fmt.Sprintf("127.0.0.1:%d", n.port)
transport := &MockTransport{
net: n,
addr: &MockAddress{addr},
addr: &MockAddress{addr, name},
packetCh: make(chan *Packet),
streamCh: make(chan net.Conn),
}
if n.transports == nil {
n.transports = make(map[string]*MockTransport)
if n.transportsByAddr == nil {
n.transportsByAddr = make(map[string]*MockTransport)
}
n.transports[addr] = transport
n.transportsByAddr[addr] = transport
if n.transportsByName == nil {
n.transportsByName = make(map[string]*MockTransport)
}
n.transportsByName[name] = transport
return transport
}
@ -37,6 +46,7 @@ func (n *MockNetwork) NewTransport() *MockTransport {
// address scheme.
type MockAddress struct {
addr string
name string
}
// See net.Addr.
@ -57,6 +67,8 @@ type MockTransport struct {
streamCh chan net.Conn
}
var _ NodeAwareTransport = (*MockTransport)(nil)
// See Transport.
func (t *MockTransport) FinalAdvertiseAddr(string, int) (net.IP, int, error) {
host, portStr, err := net.SplitHostPort(t.addr.String())
@ -79,9 +91,15 @@ func (t *MockTransport) FinalAdvertiseAddr(string, int) (net.IP, int, error) {
// See Transport.
func (t *MockTransport) WriteTo(b []byte, addr string) (time.Time, error) {
dest, ok := t.net.transports[addr]
if !ok {
return time.Time{}, fmt.Errorf("No route to %q", addr)
a := Address{Addr: addr, Name: ""}
return t.WriteToAddress(b, a)
}
// See NodeAwareTransport.
func (t *MockTransport) WriteToAddress(b []byte, a Address) (time.Time, error) {
dest, err := t.getPeer(a)
if err != nil {
return time.Time{}, err
}
now := time.Now()
@ -98,11 +116,45 @@ func (t *MockTransport) PacketCh() <-chan *Packet {
return t.packetCh
}
// See NodeAwareTransport.
func (t *MockTransport) IngestPacket(conn net.Conn, addr net.Addr, now time.Time, shouldClose bool) error {
if shouldClose {
defer conn.Close()
}
// Copy everything from the stream into packet buffer.
var buf bytes.Buffer
if _, err := io.Copy(&buf, conn); err != nil {
return fmt.Errorf("failed to read packet: %v", err)
}
// Check the length - it needs to have at least one byte to be a proper
// message. This is checked elsewhere for writes coming in directly from
// the UDP socket.
if n := buf.Len(); n < 1 {
return fmt.Errorf("packet too short (%d bytes) %s", n, LogAddress(addr))
}
// Inject the packet.
t.packetCh <- &Packet{
Buf: buf.Bytes(),
From: addr,
Timestamp: now,
}
return nil
}
// See Transport.
func (t *MockTransport) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
dest, ok := t.net.transports[addr]
if !ok {
return nil, fmt.Errorf("No route to %q", addr)
a := Address{Addr: addr, Name: ""}
return t.DialAddressTimeout(a, timeout)
}
// See NodeAwareTransport.
func (t *MockTransport) DialAddressTimeout(a Address, timeout time.Duration) (net.Conn, error) {
dest, err := t.getPeer(a)
if err != nil {
return nil, err
}
p1, p2 := net.Pipe()
@ -115,7 +167,29 @@ func (t *MockTransport) StreamCh() <-chan net.Conn {
return t.streamCh
}
// See NodeAwareTransport.
func (t *MockTransport) IngestStream(conn net.Conn) error {
t.streamCh <- conn
return nil
}
// See Transport.
func (t *MockTransport) Shutdown() error {
return nil
}
func (t *MockTransport) getPeer(a Address) (*MockTransport, error) {
var (
dest *MockTransport
ok bool
)
if a.Name != "" {
dest, ok = t.net.transportsByName[a.Name]
} else {
dest, ok = t.net.transportsByAddr[a.Addr]
}
if !ok {
return nil, fmt.Errorf("No route to %s", a)
}
return dest, nil
}

View file

@ -7,10 +7,12 @@ import (
"fmt"
"hash/crc32"
"io"
"math"
"net"
"sync/atomic"
"time"
"github.com/armon/go-metrics"
metrics "github.com/armon/go-metrics"
"github.com/hashicorp/go-msgpack/codec"
)
@ -41,6 +43,9 @@ const (
type messageType uint8
// The list of available message types.
//
// WARNING: ONLY APPEND TO THIS LIST! The numeric values are part of the
// protocol itself.
const (
pingMsg messageType = iota
indirectPingMsg
@ -58,6 +63,13 @@ const (
errMsg
)
const (
// hasLabelMsg has a deliberately high value so that you can disambiguate
// it from the encryptionVersion header which is either 0/1 right now and
// also any of the existing messageTypes
hasLabelMsg messageType = 244
)
// compressionType is used to specify the compression algorithm
type compressionType uint8
@ -71,7 +83,8 @@ const (
compoundOverhead = 2 // Assumed overhead per entry in compoundHeader
userMsgOverhead = 1
blockingWarning = 10 * time.Millisecond // Warn if a UDP packet takes this long to process
maxPushStateBytes = 10 * 1024 * 1024
maxPushStateBytes = 20 * 1024 * 1024
maxPushPullRequests = 128 // Maximum number of concurrent push/pull requests
)
// ping request sent directly to node
@ -82,15 +95,28 @@ type ping struct {
// the intended recipient. This is to protect again an agent
// restart with a new name.
Node string
SourceAddr []byte `codec:",omitempty"` // Source address, used for a direct reply
SourcePort uint16 `codec:",omitempty"` // Source port, used for a direct reply
SourceNode string `codec:",omitempty"` // Source name, used for a direct reply
}
// indirect ping sent to an indirect ndoe
// indirect ping sent to an indirect node
type indirectPingReq struct {
SeqNo uint32
Target []byte
Port uint16
Node string
Nack bool // true if we'd like a nack back
// Node is sent so the target can verify they are
// the intended recipient. This is to protect against an agent
// restart with a new name.
Node string
Nack bool // true if we'd like a nack back
SourceAddr []byte `codec:",omitempty"` // Source address, used for a direct reply
SourcePort uint16 `codec:",omitempty"` // Source port, used for a direct reply
SourceNode string `codec:",omitempty"` // Source name, used for a direct reply
}
// ack response is sent for a ping
@ -161,7 +187,7 @@ type pushNodeState struct {
Port uint16
Meta []byte
Incarnation uint32
State nodeStateType
State NodeStateType
Vsn []uint8 // Protocol versions
}
@ -205,13 +231,38 @@ func (m *Memberlist) streamListen() {
// handleConn handles a single incoming stream connection from the transport.
func (m *Memberlist) handleConn(conn net.Conn) {
defer conn.Close()
m.logger.Printf("[DEBUG] memberlist: Stream connection %s", LogConn(conn))
defer conn.Close()
metrics.IncrCounter([]string{"memberlist", "tcp", "accept"}, 1)
metrics.IncrCounterWithLabels([]string{"memberlist", "tcp", "accept"}, 1, m.metricLabels)
conn.SetDeadline(time.Now().Add(m.config.TCPTimeout))
msgType, bufConn, dec, err := m.readStream(conn)
var (
streamLabel string
err error
)
conn, streamLabel, err = RemoveLabelHeaderFromStream(conn)
if err != nil {
m.logger.Printf("[ERR] memberlist: failed to receive and remove the stream label header: %s %s", err, LogConn(conn))
return
}
if m.config.SkipInboundLabelCheck {
if streamLabel != "" {
m.logger.Printf("[ERR] memberlist: unexpected double stream label header: %s", LogConn(conn))
return
}
// Set this from config so that the auth data assertions work below.
streamLabel = m.config.Label
}
if m.config.Label != streamLabel {
m.logger.Printf("[ERR] memberlist: discarding stream with unacceptable label %q: %s", streamLabel, LogConn(conn))
return
}
msgType, bufConn, dec, err := m.readStream(conn, streamLabel)
if err != nil {
if err != io.EOF {
m.logger.Printf("[ERR] memberlist: failed to receive: %s %s", err, LogConn(conn))
@ -223,7 +274,7 @@ func (m *Memberlist) handleConn(conn net.Conn) {
return
}
err = m.rawSendMsgStream(conn, out.Bytes())
err = m.rawSendMsgStream(conn, out.Bytes(), streamLabel)
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send error: %s %s", err, LogConn(conn))
return
@ -238,13 +289,23 @@ func (m *Memberlist) handleConn(conn net.Conn) {
m.logger.Printf("[ERR] memberlist: Failed to receive user message: %s %s", err, LogConn(conn))
}
case pushPullMsg:
// Increment counter of pending push/pulls
numConcurrent := atomic.AddUint32(&m.pushPullReq, 1)
defer atomic.AddUint32(&m.pushPullReq, ^uint32(0))
// Check if we have too many open push/pull requests
if numConcurrent >= maxPushPullRequests {
m.logger.Printf("[ERR] memberlist: Too many pending push/pull requests")
return
}
join, remoteNodes, userState, err := m.readRemoteState(bufConn, dec)
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to read remote state: %s %s", err, LogConn(conn))
return
}
if err := m.sendLocalState(conn, join); err != nil {
if err := m.sendLocalState(conn, join, streamLabel); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to push local state: %s %s", err, LogConn(conn))
return
}
@ -272,7 +333,7 @@ func (m *Memberlist) handleConn(conn net.Conn) {
return
}
err = m.rawSendMsgStream(conn, out.Bytes())
err = m.rawSendMsgStream(conn, out.Bytes(), streamLabel)
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send ack: %s %s", err, LogConn(conn))
return
@ -297,10 +358,35 @@ func (m *Memberlist) packetListen() {
}
func (m *Memberlist) ingestPacket(buf []byte, from net.Addr, timestamp time.Time) {
var (
packetLabel string
err error
)
buf, packetLabel, err = RemoveLabelHeaderFromPacket(buf)
if err != nil {
m.logger.Printf("[ERR] memberlist: %v %s", err, LogAddress(from))
return
}
if m.config.SkipInboundLabelCheck {
if packetLabel != "" {
m.logger.Printf("[ERR] memberlist: unexpected double packet label header: %s", LogAddress(from))
return
}
// Set this from config so that the auth data assertions work below.
packetLabel = m.config.Label
}
if m.config.Label != packetLabel {
m.logger.Printf("[ERR] memberlist: discarding packet with unacceptable label %q: %s", packetLabel, LogAddress(from))
return
}
// Check if encryption is enabled
if m.config.EncryptionEnabled() {
// Decrypt the payload
plain, err := decryptPayload(m.config.Keyring.GetKeys(), buf, nil)
authData := []byte(packetLabel)
plain, err := decryptPayload(m.config.Keyring.GetKeys(), buf, authData)
if err != nil {
if !m.config.GossipVerifyIncoming {
// Treat the message as plaintext
@ -330,6 +416,10 @@ func (m *Memberlist) ingestPacket(buf []byte, from net.Addr, timestamp time.Time
}
func (m *Memberlist) handleCommand(buf []byte, from net.Addr, timestamp time.Time) {
if len(buf) < 1 {
m.logger.Printf("[ERR] memberlist: missing message type byte %s", LogAddress(from))
return
}
// Decode the message type
msgType := messageType(buf[0])
buf = buf[1:]
@ -357,10 +447,25 @@ func (m *Memberlist) handleCommand(buf []byte, from net.Addr, timestamp time.Tim
case deadMsg:
fallthrough
case userMsg:
select {
case m.handoff <- msgHandoff{msgType, buf, from}:
default:
// Determine the message queue, prioritize alive
queue := m.lowPriorityMsgQueue
if msgType == aliveMsg {
queue = m.highPriorityMsgQueue
}
// Check for overflow and append if not full
m.msgQueueLock.Lock()
if queue.Len() >= m.config.HandoffQueueDepth {
m.logger.Printf("[WARN] memberlist: handler queue full, dropping message (%d) %s", msgType, LogAddress(from))
} else {
queue.PushBack(msgHandoff{msgType, buf, from})
}
m.msgQueueLock.Unlock()
// Notify of pending message
select {
case m.handoffCh <- struct{}{}:
default:
}
default:
@ -368,28 +473,51 @@ func (m *Memberlist) handleCommand(buf []byte, from net.Addr, timestamp time.Tim
}
}
// getNextMessage returns the next message to process in priority order, using LIFO
func (m *Memberlist) getNextMessage() (msgHandoff, bool) {
m.msgQueueLock.Lock()
defer m.msgQueueLock.Unlock()
if el := m.highPriorityMsgQueue.Back(); el != nil {
m.highPriorityMsgQueue.Remove(el)
msg := el.Value.(msgHandoff)
return msg, true
} else if el := m.lowPriorityMsgQueue.Back(); el != nil {
m.lowPriorityMsgQueue.Remove(el)
msg := el.Value.(msgHandoff)
return msg, true
}
return msgHandoff{}, false
}
// packetHandler is a long running goroutine that processes messages received
// over the packet interface, but is decoupled from the listener to avoid
// blocking the listener which may cause ping/ack messages to be delayed.
func (m *Memberlist) packetHandler() {
for {
select {
case msg := <-m.handoff:
msgType := msg.msgType
buf := msg.buf
from := msg.from
case <-m.handoffCh:
for {
msg, ok := m.getNextMessage()
if !ok {
break
}
msgType := msg.msgType
buf := msg.buf
from := msg.from
switch msgType {
case suspectMsg:
m.handleSuspect(buf, from)
case aliveMsg:
m.handleAlive(buf, from)
case deadMsg:
m.handleDead(buf, from)
case userMsg:
m.handleUser(buf, from)
default:
m.logger.Printf("[ERR] memberlist: Message type (%d) not supported %s (packet handler)", msgType, LogAddress(from))
switch msgType {
case suspectMsg:
m.handleSuspect(buf, from)
case aliveMsg:
m.handleAlive(buf, from)
case deadMsg:
m.handleDead(buf, from)
case userMsg:
m.handleUser(buf, from)
default:
m.logger.Printf("[ERR] memberlist: Message type (%d) not supported %s (packet handler)", msgType, LogAddress(from))
}
}
case <-m.shutdownCh:
@ -433,7 +561,19 @@ func (m *Memberlist) handlePing(buf []byte, from net.Addr) {
if m.config.Ping != nil {
ack.Payload = m.config.Ping.AckPayload()
}
if err := m.encodeAndSendMsg(from.String(), ackRespMsg, &ack); err != nil {
addr := ""
if len(p.SourceAddr) > 0 && p.SourcePort > 0 {
addr = joinHostPort(net.IP(p.SourceAddr).String(), p.SourcePort)
} else {
addr = from.String()
}
a := Address{
Addr: addr,
Name: p.SourceNode,
}
if err := m.encodeAndSendMsg(a, ackRespMsg, &ack); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send ack: %s %s", err, LogAddress(from))
}
}
@ -453,7 +593,25 @@ func (m *Memberlist) handleIndirectPing(buf []byte, from net.Addr) {
// Send a ping to the correct host.
localSeqNo := m.nextSeqNo()
ping := ping{SeqNo: localSeqNo, Node: ind.Node}
selfAddr, selfPort := m.getAdvertise()
ping := ping{
SeqNo: localSeqNo,
Node: ind.Node,
// The outbound message is addressed FROM us.
SourceAddr: selfAddr,
SourcePort: selfPort,
SourceNode: m.config.Name,
}
// Forward the ack back to the requestor. If the request encodes an origin
// use that otherwise assume that the other end of the UDP socket is
// usable.
indAddr := ""
if len(ind.SourceAddr) > 0 && ind.SourcePort > 0 {
indAddr = joinHostPort(net.IP(ind.SourceAddr).String(), ind.SourcePort)
} else {
indAddr = from.String()
}
// Setup a response handler to relay the ack
cancelCh := make(chan struct{})
@ -461,18 +619,25 @@ func (m *Memberlist) handleIndirectPing(buf []byte, from net.Addr) {
// Try to prevent the nack if we've caught it in time.
close(cancelCh)
// Forward the ack back to the requestor.
ack := ackResp{ind.SeqNo, nil}
if err := m.encodeAndSendMsg(from.String(), ackRespMsg, &ack); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to forward ack: %s %s", err, LogAddress(from))
a := Address{
Addr: indAddr,
Name: ind.SourceNode,
}
if err := m.encodeAndSendMsg(a, ackRespMsg, &ack); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to forward ack: %s %s", err, LogStringAddress(indAddr))
}
}
m.setAckHandler(localSeqNo, respHandler, m.config.ProbeTimeout)
// Send the ping.
addr := joinHostPort(net.IP(ind.Target).String(), ind.Port)
if err := m.encodeAndSendMsg(addr, pingMsg, &ping); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send ping: %s %s", err, LogAddress(from))
a := Address{
Addr: addr,
Name: ind.Node,
}
if err := m.encodeAndSendMsg(a, pingMsg, &ping); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send indirect ping: %s %s", err, LogStringAddress(indAddr))
}
// Setup a timer to fire off a nack if no ack is seen in time.
@ -483,8 +648,12 @@ func (m *Memberlist) handleIndirectPing(buf []byte, from net.Addr) {
return
case <-time.After(m.config.ProbeTimeout):
nack := nackResp{ind.SeqNo}
if err := m.encodeAndSendMsg(from.String(), nackRespMsg, &nack); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send nack: %s %s", err, LogAddress(from))
a := Address{
Addr: indAddr,
Name: ind.SourceNode,
}
if err := m.encodeAndSendMsg(a, nackRespMsg, &nack); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send nack: %s %s", err, LogStringAddress(indAddr))
}
}
}()
@ -518,12 +687,47 @@ func (m *Memberlist) handleSuspect(buf []byte, from net.Addr) {
m.suspectNode(&sus)
}
// ensureCanConnect return the IP from a RemoteAddress
// return error if this client must not connect
func (m *Memberlist) ensureCanConnect(from net.Addr) error {
if !m.config.IPMustBeChecked() {
return nil
}
source := from.String()
if source == "pipe" {
return nil
}
host, _, err := net.SplitHostPort(source)
if err != nil {
return err
}
ip := net.ParseIP(host)
if ip == nil {
return fmt.Errorf("Cannot parse IP from %s", host)
}
return m.config.IPAllowed(ip)
}
func (m *Memberlist) handleAlive(buf []byte, from net.Addr) {
if err := m.ensureCanConnect(from); err != nil {
m.logger.Printf("[DEBUG] memberlist: Blocked alive message: %s %s", err, LogAddress(from))
return
}
var live alive
if err := decode(buf, &live); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode alive message: %s %s", err, LogAddress(from))
return
}
if m.config.IPMustBeChecked() {
innerIP := net.IP(live.Addr)
if innerIP != nil {
if err := m.config.IPAllowed(innerIP); err != nil {
m.logger.Printf("[DEBUG] memberlist: Blocked alive.Addr=%s message from: %s %s", innerIP.String(), err, LogAddress(from))
return
}
}
}
// For proto versions < 2, there is no port provided. Mask old
// behavior by using the configured port
@ -565,12 +769,12 @@ func (m *Memberlist) handleCompressed(buf []byte, from net.Addr, timestamp time.
}
// encodeAndSendMsg is used to combine the encoding and sending steps
func (m *Memberlist) encodeAndSendMsg(addr string, msgType messageType, msg interface{}) error {
func (m *Memberlist) encodeAndSendMsg(a Address, msgType messageType, msg interface{}) error {
out, err := encode(msgType, msg)
if err != nil {
return err
}
if err := m.sendMsg(addr, out.Bytes()); err != nil {
if err := m.sendMsg(a, out.Bytes()); err != nil {
return err
}
return nil
@ -578,9 +782,9 @@ func (m *Memberlist) encodeAndSendMsg(addr string, msgType messageType, msg inte
// sendMsg is used to send a message via packet to another host. It will
// opportunistically create a compoundMsg and piggy back other broadcasts.
func (m *Memberlist) sendMsg(addr string, msg []byte) error {
func (m *Memberlist) sendMsg(a Address, msg []byte) error {
// Check if we can piggy back any messages
bytesAvail := m.config.UDPBufferSize - len(msg) - compoundHeaderOverhead
bytesAvail := m.config.UDPBufferSize - len(msg) - compoundHeaderOverhead - labelOverhead(m.config.Label)
if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing {
bytesAvail -= encryptOverhead(m.encryptionVersion())
}
@ -588,7 +792,7 @@ func (m *Memberlist) sendMsg(addr string, msg []byte) error {
// Fast path if nothing to piggypack
if len(extra) == 0 {
return m.rawSendMsgPacket(addr, nil, msg)
return m.rawSendMsgPacket(a, nil, msg)
}
// Join all the messages
@ -600,12 +804,16 @@ func (m *Memberlist) sendMsg(addr string, msg []byte) error {
compound := makeCompoundMessage(msgs)
// Send the message
return m.rawSendMsgPacket(addr, nil, compound.Bytes())
return m.rawSendMsgPacket(a, nil, compound.Bytes())
}
// rawSendMsgPacket is used to send message via packet to another host without
// modification, other than compression or encryption if enabled.
func (m *Memberlist) rawSendMsgPacket(addr string, node *Node, msg []byte) error {
func (m *Memberlist) rawSendMsgPacket(a Address, node *Node, msg []byte) error {
if a.Name == "" && m.config.RequireNodeNames {
return errNodeNamesAreRequired
}
// Check if we have compression enabled
if m.config.EnableCompression {
buf, err := compressPayload(msg)
@ -619,11 +827,12 @@ func (m *Memberlist) rawSendMsgPacket(addr string, node *Node, msg []byte) error
}
}
// Try to look up the destination node
// Try to look up the destination node. Note this will only work if the
// bare ip address is used as the node name, which is not guaranteed.
if node == nil {
toAddr, _, err := net.SplitHostPort(addr)
toAddr, _, err := net.SplitHostPort(a.Addr)
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to parse address %q: %v", addr, err)
m.logger.Printf("[ERR] memberlist: Failed to parse address %q: %v", a.Addr, err)
return err
}
m.nodeLock.RLock()
@ -647,9 +856,12 @@ func (m *Memberlist) rawSendMsgPacket(addr string, node *Node, msg []byte) error
// Check if we have encryption enabled
if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing {
// Encrypt the payload
var buf bytes.Buffer
primaryKey := m.config.Keyring.GetPrimaryKey()
err := encryptPayload(m.encryptionVersion(), primaryKey, msg, nil, &buf)
var (
primaryKey = m.config.Keyring.GetPrimaryKey()
packetLabel = []byte(m.config.Label)
buf bytes.Buffer
)
err := encryptPayload(m.encryptionVersion(), primaryKey, msg, packetLabel, &buf)
if err != nil {
m.logger.Printf("[ERR] memberlist: Encryption of message failed: %v", err)
return err
@ -657,15 +869,15 @@ func (m *Memberlist) rawSendMsgPacket(addr string, node *Node, msg []byte) error
msg = buf.Bytes()
}
metrics.IncrCounter([]string{"memberlist", "udp", "sent"}, float32(len(msg)))
_, err := m.transport.WriteTo(msg, addr)
metrics.IncrCounterWithLabels([]string{"memberlist", "udp", "sent"}, float32(len(msg)), m.metricLabels)
_, err := m.transport.WriteToAddress(msg, a)
return err
}
// rawSendMsgStream is used to stream a message to another host without
// modification, other than applying compression and encryption if enabled.
func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte) error {
// Check if compresion is enabled
func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte, streamLabel string) error {
// Check if compression is enabled
if m.config.EnableCompression {
compBuf, err := compressPayload(sendBuf)
if err != nil {
@ -677,7 +889,7 @@ func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte) error {
// Check if encryption is enabled
if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing {
crypt, err := m.encryptLocalState(sendBuf)
crypt, err := m.encryptLocalState(sendBuf, streamLabel)
if err != nil {
m.logger.Printf("[ERROR] memberlist: Failed to encrypt local state: %v", err)
return err
@ -686,7 +898,7 @@ func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte) error {
}
// Write out the entire send buffer
metrics.IncrCounter([]string{"memberlist", "tcp", "sent"}, float32(len(sendBuf)))
metrics.IncrCounterWithLabels([]string{"memberlist", "tcp", "sent"}, float32(len(sendBuf)), m.metricLabels)
if n, err := conn.Write(sendBuf); err != nil {
return err
@ -698,8 +910,12 @@ func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte) error {
}
// sendUserMsg is used to stream a user message to another host.
func (m *Memberlist) sendUserMsg(addr string, sendBuf []byte) error {
conn, err := m.transport.DialTimeout(addr, m.config.TCPTimeout)
func (m *Memberlist) sendUserMsg(a Address, sendBuf []byte) error {
if a.Name == "" && m.config.RequireNodeNames {
return errNodeNamesAreRequired
}
conn, err := m.transport.DialAddressTimeout(a, m.config.TCPTimeout)
if err != nil {
return err
}
@ -719,28 +935,33 @@ func (m *Memberlist) sendUserMsg(addr string, sendBuf []byte) error {
if _, err := bufConn.Write(sendBuf); err != nil {
return err
}
return m.rawSendMsgStream(conn, bufConn.Bytes())
return m.rawSendMsgStream(conn, bufConn.Bytes(), m.config.Label)
}
// sendAndReceiveState is used to initiate a push/pull over a stream with a
// remote host.
func (m *Memberlist) sendAndReceiveState(addr string, join bool) ([]pushNodeState, []byte, error) {
func (m *Memberlist) sendAndReceiveState(a Address, join bool) ([]pushNodeState, []byte, error) {
if a.Name == "" && m.config.RequireNodeNames {
return nil, nil, errNodeNamesAreRequired
}
// Attempt to connect
conn, err := m.transport.DialTimeout(addr, m.config.TCPTimeout)
conn, err := m.transport.DialAddressTimeout(a, m.config.TCPTimeout)
if err != nil {
return nil, nil, err
}
defer conn.Close()
m.logger.Printf("[DEBUG] memberlist: Initiating push/pull sync with: %s", conn.RemoteAddr())
metrics.IncrCounter([]string{"memberlist", "tcp", "connect"}, 1)
m.logger.Printf("[DEBUG] memberlist: Initiating push/pull sync with: %s %s", a.Name, conn.RemoteAddr())
metrics.IncrCounterWithLabels([]string{"memberlist", "tcp", "connect"}, 1, m.metricLabels)
// Send our state
if err := m.sendLocalState(conn, join); err != nil {
if err := m.sendLocalState(conn, join, m.config.Label); err != nil {
return nil, nil, err
}
conn.SetDeadline(time.Now().Add(m.config.TCPTimeout))
msgType, bufConn, dec, err := m.readStream(conn)
msgType, bufConn, dec, err := m.readStream(conn, m.config.Label)
if err != nil {
return nil, nil, err
}
@ -765,7 +986,7 @@ func (m *Memberlist) sendAndReceiveState(addr string, join bool) ([]pushNodeStat
}
// sendLocalState is invoked to send our local state over a stream connection.
func (m *Memberlist) sendLocalState(conn net.Conn, join bool) error {
func (m *Memberlist) sendLocalState(conn net.Conn, join bool, streamLabel string) error {
// Setup a deadline
conn.SetDeadline(time.Now().Add(m.config.TCPTimeout))
@ -822,11 +1043,11 @@ func (m *Memberlist) sendLocalState(conn net.Conn, join bool) error {
}
// Get the send buffer
return m.rawSendMsgStream(conn, bufConn.Bytes())
return m.rawSendMsgStream(conn, bufConn.Bytes(), streamLabel)
}
// encryptLocalState is used to help encrypt local state before sending
func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) {
func (m *Memberlist) encryptLocalState(sendBuf []byte, streamLabel string) ([]byte, error) {
var buf bytes.Buffer
// Write the encryptMsg byte
@ -839,9 +1060,15 @@ func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) {
binary.BigEndian.PutUint32(sizeBuf, uint32(encLen))
buf.Write(sizeBuf)
// Authenticated Data is:
//
// [messageType; byte] [messageLength; uint32] [stream_label; optional]
//
dataBytes := appendBytes(buf.Bytes()[:5], []byte(streamLabel))
// Write the encrypted cipher text to the buffer
key := m.config.Keyring.GetPrimaryKey()
err := encryptPayload(encVsn, key, sendBuf, buf.Bytes()[:5], &buf)
err := encryptPayload(encVsn, key, sendBuf, dataBytes, &buf)
if err != nil {
return nil, err
}
@ -849,7 +1076,7 @@ func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) {
}
// decryptRemoteState is used to help decrypt the remote state
func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) {
func (m *Memberlist) decryptRemoteState(bufConn io.Reader, streamLabel string) ([]byte, error) {
// Read in enough to determine message length
cipherText := bytes.NewBuffer(nil)
cipherText.WriteByte(byte(encryptMsg))
@ -863,6 +1090,12 @@ func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) {
moreBytes := binary.BigEndian.Uint32(cipherText.Bytes()[1:5])
if moreBytes > maxPushStateBytes {
return nil, fmt.Errorf("Remote node state is larger than limit (%d)", moreBytes)
}
//Start reporting the size before you cross the limit
if moreBytes > uint32(math.Floor(.6*maxPushStateBytes)) {
m.logger.Printf("[WARN] memberlist: Remote node state size is (%d) limit is (%d)", moreBytes, maxPushStateBytes)
}
// Read in the rest of the payload
@ -871,8 +1104,13 @@ func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) {
return nil, err
}
// Decrypt the cipherText
dataBytes := cipherText.Bytes()[:5]
// Decrypt the cipherText with some authenticated data
//
// Authenticated Data is:
//
// [messageType; byte] [messageLength; uint32] [label_data; optional]
//
dataBytes := appendBytes(cipherText.Bytes()[:5], []byte(streamLabel))
cipherBytes := cipherText.Bytes()[5:]
// Decrypt the payload
@ -880,15 +1118,18 @@ func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) {
return decryptPayload(keys, cipherBytes, dataBytes)
}
// readStream is used to read from a stream connection, decrypting and
// readStream is used to read messages from a stream connection, decrypting and
// decompressing the stream if necessary.
func (m *Memberlist) readStream(conn net.Conn) (messageType, io.Reader, *codec.Decoder, error) {
//
// The provided streamLabel if present will be authenticated during decryption
// of each message.
func (m *Memberlist) readStream(conn net.Conn, streamLabel string) (messageType, io.Reader, *codec.Decoder, error) {
// Created a buffered reader
var bufConn io.Reader = bufio.NewReader(conn)
// Read the message type
buf := [1]byte{0}
if _, err := bufConn.Read(buf[:]); err != nil {
if _, err := io.ReadFull(bufConn, buf[:]); err != nil {
return 0, nil, nil, err
}
msgType := messageType(buf[0])
@ -900,7 +1141,7 @@ func (m *Memberlist) readStream(conn net.Conn) (messageType, io.Reader, *codec.D
fmt.Errorf("Remote state is encrypted and encryption is not configured")
}
plain, err := m.decryptRemoteState(bufConn)
plain, err := m.decryptRemoteState(bufConn, streamLabel)
if err != nil {
return 0, nil, nil, err
}
@ -996,16 +1237,17 @@ func (m *Memberlist) mergeRemoteState(join bool, remoteNodes []pushNodeState, us
nodes := make([]*Node, len(remoteNodes))
for idx, n := range remoteNodes {
nodes[idx] = &Node{
Name: n.Name,
Addr: n.Addr,
Port: n.Port,
Meta: n.Meta,
PMin: n.Vsn[0],
PMax: n.Vsn[1],
PCur: n.Vsn[2],
DMin: n.Vsn[3],
DMax: n.Vsn[4],
DCur: n.Vsn[5],
Name: n.Name,
Addr: n.Addr,
Port: n.Port,
Meta: n.Meta,
State: n.State,
PMin: n.Vsn[0],
PMax: n.Vsn[1],
PCur: n.Vsn[2],
DMin: n.Vsn[3],
DMax: n.Vsn[4],
DCur: n.Vsn[5],
}
}
if err := m.config.Merge.NotifyMerge(nodes); err != nil {
@ -1058,8 +1300,12 @@ func (m *Memberlist) readUserMsg(bufConn io.Reader, dec *codec.Decoder) error {
// a ping, and waits for an ack. All of this is done as a series of blocking
// operations, given the deadline. The bool return parameter is true if we
// we able to round trip a ping to the other node.
func (m *Memberlist) sendPingAndWaitForAck(addr string, ping ping, deadline time.Time) (bool, error) {
conn, err := m.transport.DialTimeout(addr, deadline.Sub(time.Now()))
func (m *Memberlist) sendPingAndWaitForAck(a Address, ping ping, deadline time.Time) (bool, error) {
if a.Name == "" && m.config.RequireNodeNames {
return false, errNodeNamesAreRequired
}
conn, err := m.transport.DialAddressTimeout(a, deadline.Sub(time.Now()))
if err != nil {
// If the node is actually dead we expect this to fail, so we
// shouldn't spam the logs with it. After this point, errors
@ -1075,11 +1321,11 @@ func (m *Memberlist) sendPingAndWaitForAck(addr string, ping ping, deadline time
return false, err
}
if err = m.rawSendMsgStream(conn, out.Bytes()); err != nil {
if err = m.rawSendMsgStream(conn, out.Bytes(), m.config.Label); err != nil {
return false, err
}
msgType, _, dec, err := m.readStream(conn)
msgType, _, dec, err := m.readStream(conn, m.config.Label)
if err != nil {
return false, err
}
@ -1094,7 +1340,7 @@ func (m *Memberlist) sendPingAndWaitForAck(addr string, ping ping, deadline time
}
if ack.SeqNo != ping.SeqNo {
return false, fmt.Errorf("Sequence number from ack (%d) doesn't match ping (%d)", ack.SeqNo, ping.SeqNo, LogConn(conn))
return false, fmt.Errorf("Sequence number from ack (%d) doesn't match ping (%d)", ack.SeqNo, ping.SeqNo)
}
return true, nil

View file

@ -1,7 +1,9 @@
package memberlist
import (
"bytes"
"fmt"
"io"
"log"
"net"
"sync"
@ -33,6 +35,10 @@ type NetTransportConfig struct {
// Logger is a logger for operator messages.
Logger *log.Logger
// MetricLabels is a map of optional labels to apply to all metrics
// emitted by this transport.
MetricLabels []metrics.Label
}
// NetTransport is a Transport implementation that uses connectionless UDP for
@ -46,8 +52,12 @@ type NetTransport struct {
tcpListeners []*net.TCPListener
udpListeners []*net.UDPConn
shutdown int32
metricLabels []metrics.Label
}
var _ NodeAwareTransport = (*NetTransport)(nil)
// NewNetTransport returns a net transport with the given configuration. On
// success all the network listeners will be created and listening.
func NewNetTransport(config *NetTransportConfig) (*NetTransport, error) {
@ -60,10 +70,11 @@ func NewNetTransport(config *NetTransportConfig) (*NetTransport, error) {
// Build out the new transport.
var ok bool
t := NetTransport{
config: config,
packetCh: make(chan *Packet),
streamCh: make(chan net.Conn),
logger: config.Logger,
config: config,
packetCh: make(chan *Packet),
streamCh: make(chan net.Conn),
logger: config.Logger,
metricLabels: config.MetricLabels,
}
// Clean up listeners if there's an error.
@ -170,6 +181,14 @@ func (t *NetTransport) FinalAdvertiseAddr(ip string, port int) (net.IP, int, err
// See Transport.
func (t *NetTransport) WriteTo(b []byte, addr string) (time.Time, error) {
a := Address{Addr: addr, Name: ""}
return t.WriteToAddress(b, a)
}
// See NodeAwareTransport.
func (t *NetTransport) WriteToAddress(b []byte, a Address) (time.Time, error) {
addr := a.Addr
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return time.Time{}, err
@ -188,8 +207,44 @@ func (t *NetTransport) PacketCh() <-chan *Packet {
return t.packetCh
}
// See IngestionAwareTransport.
func (t *NetTransport) IngestPacket(conn net.Conn, addr net.Addr, now time.Time, shouldClose bool) error {
if shouldClose {
defer conn.Close()
}
// Copy everything from the stream into packet buffer.
var buf bytes.Buffer
if _, err := io.Copy(&buf, conn); err != nil {
return fmt.Errorf("failed to read packet: %v", err)
}
// Check the length - it needs to have at least one byte to be a proper
// message. This is checked elsewhere for writes coming in directly from
// the UDP socket.
if n := buf.Len(); n < 1 {
return fmt.Errorf("packet too short (%d bytes) %s", n, LogAddress(addr))
}
// Inject the packet.
t.packetCh <- &Packet{
Buf: buf.Bytes(),
From: addr,
Timestamp: now,
}
return nil
}
// See Transport.
func (t *NetTransport) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
a := Address{Addr: addr, Name: ""}
return t.DialAddressTimeout(a, timeout)
}
// See NodeAwareTransport.
func (t *NetTransport) DialAddressTimeout(a Address, timeout time.Duration) (net.Conn, error) {
addr := a.Addr
dialer := net.Dialer{Timeout: timeout}
return dialer.Dial("tcp", addr)
}
@ -199,6 +254,12 @@ func (t *NetTransport) StreamCh() <-chan net.Conn {
return t.streamCh
}
// See IngestionAwareTransport.
func (t *NetTransport) IngestStream(conn net.Conn) error {
t.streamCh <- conn
return nil
}
// See Transport.
func (t *NetTransport) Shutdown() error {
// This will avoid log spam about errors when we shut down.
@ -221,6 +282,16 @@ func (t *NetTransport) Shutdown() error {
// and hands them off to the stream channel.
func (t *NetTransport) tcpListen(tcpLn *net.TCPListener) {
defer t.wg.Done()
// baseDelay is the initial delay after an AcceptTCP() error before attempting again
const baseDelay = 5 * time.Millisecond
// maxDelay is the maximum delay after an AcceptTCP() error before attempting again.
// In the case that tcpListen() is error-looping, it will delay the shutdown check.
// Therefore, changes to maxDelay may have an effect on the latency of shutdown.
const maxDelay = 1 * time.Second
var loopDelay time.Duration
for {
conn, err := tcpLn.AcceptTCP()
if err != nil {
@ -228,9 +299,22 @@ func (t *NetTransport) tcpListen(tcpLn *net.TCPListener) {
break
}
if loopDelay == 0 {
loopDelay = baseDelay
} else {
loopDelay *= 2
}
if loopDelay > maxDelay {
loopDelay = maxDelay
}
t.logger.Printf("[ERR] memberlist: Error accepting TCP connection: %v", err)
time.Sleep(loopDelay)
continue
}
// No error, reset loop delay
loopDelay = 0
t.streamCh <- conn
}
@ -264,7 +348,7 @@ func (t *NetTransport) udpListen(udpLn *net.UDPConn) {
}
// Ingest the packet.
metrics.IncrCounter([]string{"memberlist", "udp", "received"}, float32(n))
metrics.IncrCounterWithLabels([]string{"memberlist", "udp", "received"}, float32(n), t.metricLabels)
t.packetCh <- &Packet{
Buf: buf[:n],
From: addr,

48
vendor/github.com/hashicorp/memberlist/peeked_conn.go generated vendored Normal file
View file

@ -0,0 +1,48 @@
// Copyright 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Originally from: https://github.com/google/tcpproxy/blob/master/tcpproxy.go
// at f5c09fbedceb69e4b238dec52cdf9f2fe9a815e2
package memberlist
import "net"
// peekedConn is an incoming connection that has had some bytes read from it
// to determine how to route the connection. The Read method stitches
// the peeked bytes and unread bytes back together.
type peekedConn struct {
// Peeked are the bytes that have been read from Conn for the
// purposes of route matching, but have not yet been consumed
// by Read calls. It set to nil by Read when fully consumed.
Peeked []byte
// Conn is the underlying connection.
// It can be type asserted against *net.TCPConn or other types
// as needed. It should not be read from directly unless
// Peeked is nil.
net.Conn
}
func (c *peekedConn) Read(p []byte) (n int, err error) {
if len(c.Peeked) > 0 {
n = copy(p, c.Peeked)
c.Peeked = c.Peeked[n:]
if len(c.Peeked) == 0 {
c.Peeked = nil
}
return n, nil
}
return c.Conn.Read(p)
}

View file

@ -1,8 +1,10 @@
package memberlist
import (
"sort"
"math"
"sync"
"github.com/google/btree"
)
// TransmitLimitedQueue is used to queue messages to broadcast to
@ -19,15 +21,93 @@ type TransmitLimitedQueue struct {
// number of retransmissions attempted.
RetransmitMult int
sync.Mutex
bcQueue limitedBroadcasts
mu sync.Mutex
tq *btree.BTree // stores *limitedBroadcast as btree.Item
tm map[string]*limitedBroadcast
idGen int64
}
type limitedBroadcast struct {
transmits int // Number of transmissions attempted.
transmits int // btree-key[0]: Number of transmissions attempted.
msgLen int64 // btree-key[1]: copied from len(b.Message())
id int64 // btree-key[2]: unique incrementing id stamped at submission time
b Broadcast
name string // set if Broadcast is a NamedBroadcast
}
// Less tests whether the current item is less than the given argument.
//
// This must provide a strict weak ordering.
// If !a.Less(b) && !b.Less(a), we treat this to mean a == b (i.e. we can only
// hold one of either a or b in the tree).
//
// default ordering is
// - [transmits=0, ..., transmits=inf]
// - [transmits=0:len=999, ..., transmits=0:len=2, ...]
// - [transmits=0:len=999,id=999, ..., transmits=0:len=999:id=1, ...]
func (b *limitedBroadcast) Less(than btree.Item) bool {
o := than.(*limitedBroadcast)
if b.transmits < o.transmits {
return true
} else if b.transmits > o.transmits {
return false
}
if b.msgLen > o.msgLen {
return true
} else if b.msgLen < o.msgLen {
return false
}
return b.id > o.id
}
// for testing; emits in transmit order if reverse=false
func (q *TransmitLimitedQueue) orderedView(reverse bool) []*limitedBroadcast {
q.mu.Lock()
defer q.mu.Unlock()
out := make([]*limitedBroadcast, 0, q.lenLocked())
q.walkReadOnlyLocked(reverse, func(cur *limitedBroadcast) bool {
out = append(out, cur)
return true
})
return out
}
// walkReadOnlyLocked calls f for each item in the queue traversing it in
// natural order (by Less) when reverse=false and the opposite when true. You
// must hold the mutex.
//
// This method panics if you attempt to mutate the item during traversal. The
// underlying btree should also not be mutated during traversal.
func (q *TransmitLimitedQueue) walkReadOnlyLocked(reverse bool, f func(*limitedBroadcast) bool) {
if q.lenLocked() == 0 {
return
}
iter := func(item btree.Item) bool {
cur := item.(*limitedBroadcast)
prevTransmits := cur.transmits
prevMsgLen := cur.msgLen
prevID := cur.id
keepGoing := f(cur)
if prevTransmits != cur.transmits || prevMsgLen != cur.msgLen || prevID != cur.id {
panic("edited queue while walking read only")
}
return keepGoing
}
if reverse {
q.tq.Descend(iter) // end with transmit 0
} else {
q.tq.Ascend(iter) // start with transmit 0
}
}
type limitedBroadcasts []*limitedBroadcast
// Broadcast is something that can be broadcasted via gossip to
// the memberlist cluster.
@ -45,123 +125,298 @@ type Broadcast interface {
Finished()
}
// NamedBroadcast is an optional extension of the Broadcast interface that
// gives each message a unique string name, and that is used to optimize
//
// You shoud ensure that Invalidates() checks the same uniqueness as the
// example below:
//
// func (b *foo) Invalidates(other Broadcast) bool {
// nb, ok := other.(NamedBroadcast)
// if !ok {
// return false
// }
// return b.Name() == nb.Name()
// }
//
// Invalidates() isn't currently used for NamedBroadcasts, but that may change
// in the future.
type NamedBroadcast interface {
Broadcast
// The unique identity of this broadcast message.
Name() string
}
// UniqueBroadcast is an optional interface that indicates that each message is
// intrinsically unique and there is no need to scan the broadcast queue for
// duplicates.
//
// You should ensure that Invalidates() always returns false if implementing
// this interface. Invalidates() isn't currently used for UniqueBroadcasts, but
// that may change in the future.
type UniqueBroadcast interface {
Broadcast
// UniqueBroadcast is just a marker method for this interface.
UniqueBroadcast()
}
// QueueBroadcast is used to enqueue a broadcast
func (q *TransmitLimitedQueue) QueueBroadcast(b Broadcast) {
q.Lock()
defer q.Unlock()
q.queueBroadcast(b, 0)
}
// Check if this message invalidates another
n := len(q.bcQueue)
for i := 0; i < n; i++ {
if b.Invalidates(q.bcQueue[i].b) {
q.bcQueue[i].b.Finished()
copy(q.bcQueue[i:], q.bcQueue[i+1:])
q.bcQueue[n-1] = nil
q.bcQueue = q.bcQueue[:n-1]
n--
// lazyInit initializes internal data structures the first time they are
// needed. You must already hold the mutex.
func (q *TransmitLimitedQueue) lazyInit() {
if q.tq == nil {
q.tq = btree.New(32)
}
if q.tm == nil {
q.tm = make(map[string]*limitedBroadcast)
}
}
// queueBroadcast is like QueueBroadcast but you can use a nonzero value for
// the initial transmit tier assigned to the message. This is meant to be used
// for unit testing.
func (q *TransmitLimitedQueue) queueBroadcast(b Broadcast, initialTransmits int) {
q.mu.Lock()
defer q.mu.Unlock()
q.lazyInit()
if q.idGen == math.MaxInt64 {
// it's super duper unlikely to wrap around within the retransmit limit
q.idGen = 1
} else {
q.idGen++
}
id := q.idGen
lb := &limitedBroadcast{
transmits: initialTransmits,
msgLen: int64(len(b.Message())),
id: id,
b: b,
}
unique := false
if nb, ok := b.(NamedBroadcast); ok {
lb.name = nb.Name()
} else if _, ok := b.(UniqueBroadcast); ok {
unique = true
}
// Check if this message invalidates another.
if lb.name != "" {
if old, ok := q.tm[lb.name]; ok {
old.b.Finished()
q.deleteItem(old)
}
} else if !unique {
// Slow path, hopefully nothing hot hits this.
var remove []*limitedBroadcast
q.tq.Ascend(func(item btree.Item) bool {
cur := item.(*limitedBroadcast)
// Special Broadcasts can only invalidate each other.
switch cur.b.(type) {
case NamedBroadcast:
// noop
case UniqueBroadcast:
// noop
default:
if b.Invalidates(cur.b) {
cur.b.Finished()
remove = append(remove, cur)
}
}
return true
})
for _, cur := range remove {
q.deleteItem(cur)
}
}
// Append to the queue
q.bcQueue = append(q.bcQueue, &limitedBroadcast{0, b})
// Append to the relevant queue.
q.addItem(lb)
}
// deleteItem removes the given item from the overall datastructure. You
// must already hold the mutex.
func (q *TransmitLimitedQueue) deleteItem(cur *limitedBroadcast) {
_ = q.tq.Delete(cur)
if cur.name != "" {
delete(q.tm, cur.name)
}
if q.tq.Len() == 0 {
// At idle there's no reason to let the id generator keep going
// indefinitely.
q.idGen = 0
}
}
// addItem adds the given item into the overall datastructure. You must already
// hold the mutex.
func (q *TransmitLimitedQueue) addItem(cur *limitedBroadcast) {
_ = q.tq.ReplaceOrInsert(cur)
if cur.name != "" {
q.tm[cur.name] = cur
}
}
// getTransmitRange returns a pair of min/max values for transmit values
// represented by the current queue contents. Both values represent actual
// transmit values on the interval [0, len). You must already hold the mutex.
func (q *TransmitLimitedQueue) getTransmitRange() (minTransmit, maxTransmit int) {
if q.lenLocked() == 0 {
return 0, 0
}
minItem, maxItem := q.tq.Min(), q.tq.Max()
if minItem == nil || maxItem == nil {
return 0, 0
}
min := minItem.(*limitedBroadcast).transmits
max := maxItem.(*limitedBroadcast).transmits
return min, max
}
// GetBroadcasts is used to get a number of broadcasts, up to a byte limit
// and applying a per-message overhead as provided.
func (q *TransmitLimitedQueue) GetBroadcasts(overhead, limit int) [][]byte {
q.Lock()
defer q.Unlock()
q.mu.Lock()
defer q.mu.Unlock()
// Fast path the default case
if len(q.bcQueue) == 0 {
if q.lenLocked() == 0 {
return nil
}
transmitLimit := retransmitLimit(q.RetransmitMult, q.NumNodes())
bytesUsed := 0
var toSend [][]byte
for i := len(q.bcQueue) - 1; i >= 0; i-- {
// Check if this is within our limits
b := q.bcQueue[i]
msg := b.b.Message()
if bytesUsed+overhead+len(msg) > limit {
var (
bytesUsed int
toSend [][]byte
reinsert []*limitedBroadcast
)
// Visit fresher items first, but only look at stuff that will fit.
// We'll go tier by tier, grabbing the largest items first.
minTr, maxTr := q.getTransmitRange()
for transmits := minTr; transmits <= maxTr; /*do not advance automatically*/ {
free := int64(limit - bytesUsed - overhead)
if free <= 0 {
break // bail out early
}
// Search for the least element on a given tier (by transmit count) as
// defined in the limitedBroadcast.Less function that will fit into our
// remaining space.
greaterOrEqual := &limitedBroadcast{
transmits: transmits,
msgLen: free,
id: math.MaxInt64,
}
lessThan := &limitedBroadcast{
transmits: transmits + 1,
msgLen: math.MaxInt64,
id: math.MaxInt64,
}
var keep *limitedBroadcast
q.tq.AscendRange(greaterOrEqual, lessThan, func(item btree.Item) bool {
cur := item.(*limitedBroadcast)
// Check if this is within our limits
if int64(len(cur.b.Message())) > free {
// If this happens it's a bug in the datastructure or
// surrounding use doing something like having len(Message())
// change over time. There's enough going on here that it's
// probably sane to just skip it and move on for now.
return true
}
keep = cur
return false
})
if keep == nil {
// No more items of an appropriate size in the tier.
transmits++
continue
}
msg := keep.b.Message()
// Add to slice to send
bytesUsed += overhead + len(msg)
toSend = append(toSend, msg)
// Check if we should stop transmission
b.transmits++
if b.transmits >= transmitLimit {
b.b.Finished()
n := len(q.bcQueue)
q.bcQueue[i], q.bcQueue[n-1] = q.bcQueue[n-1], nil
q.bcQueue = q.bcQueue[:n-1]
q.deleteItem(keep)
if keep.transmits+1 >= transmitLimit {
keep.b.Finished()
} else {
// We need to bump this item down to another transmit tier, but
// because it would be in the same direction that we're walking the
// tiers, we will have to delay the reinsertion until we are
// finished our search. Otherwise we'll possibly re-add the message
// when we ascend to the next tier.
keep.transmits++
reinsert = append(reinsert, keep)
}
}
// If we are sending anything, we need to re-sort to deal
// with adjusted transmit counts
if len(toSend) > 0 {
q.bcQueue.Sort()
for _, cur := range reinsert {
q.addItem(cur)
}
return toSend
}
// NumQueued returns the number of queued messages
func (q *TransmitLimitedQueue) NumQueued() int {
q.Lock()
defer q.Unlock()
return len(q.bcQueue)
q.mu.Lock()
defer q.mu.Unlock()
return q.lenLocked()
}
// Reset clears all the queued messages
func (q *TransmitLimitedQueue) Reset() {
q.Lock()
defer q.Unlock()
for _, b := range q.bcQueue {
b.b.Finished()
// lenLocked returns the length of the overall queue datastructure. You must
// hold the mutex.
func (q *TransmitLimitedQueue) lenLocked() int {
if q.tq == nil {
return 0
}
q.bcQueue = nil
return q.tq.Len()
}
// Reset clears all the queued messages. Should only be used for tests.
func (q *TransmitLimitedQueue) Reset() {
q.mu.Lock()
defer q.mu.Unlock()
q.walkReadOnlyLocked(false, func(cur *limitedBroadcast) bool {
cur.b.Finished()
return true
})
q.tq = nil
q.tm = nil
q.idGen = 0
}
// Prune will retain the maxRetain latest messages, and the rest
// will be discarded. This can be used to prevent unbounded queue sizes
func (q *TransmitLimitedQueue) Prune(maxRetain int) {
q.Lock()
defer q.Unlock()
q.mu.Lock()
defer q.mu.Unlock()
// Do nothing if queue size is less than the limit
n := len(q.bcQueue)
if n < maxRetain {
return
for q.tq.Len() > maxRetain {
item := q.tq.Max()
if item == nil {
break
}
cur := item.(*limitedBroadcast)
cur.b.Finished()
q.deleteItem(cur)
}
// Invalidate the messages we will be removing
for i := 0; i < n-maxRetain; i++ {
q.bcQueue[i].b.Finished()
}
// Move the messages, and retain only the last maxRetain
copy(q.bcQueue[0:], q.bcQueue[n-maxRetain:])
q.bcQueue = q.bcQueue[:maxRetain]
}
func (b limitedBroadcasts) Len() int {
return len(b)
}
func (b limitedBroadcasts) Less(i, j int) bool {
return b[i].transmits < b[j].transmits
}
func (b limitedBroadcasts) Swap(i, j int) {
b[i], b[j] = b[j], b[i]
}
func (b limitedBroadcasts) Sort() {
sort.Sort(sort.Reverse(b))
}

View file

@ -106,7 +106,10 @@ func encryptPayload(vsn encryptionVersion, key []byte, msg []byte, data []byte,
dst.WriteByte(byte(vsn))
// Add a random nonce
io.CopyN(dst, rand.Reader, nonceSize)
_, err = io.CopyN(dst, rand.Reader, nonceSize)
if err != nil {
return err
}
afterNonce := dst.Len()
// Ensure we are correctly padded (only version 0)
@ -196,3 +199,22 @@ func decryptPayload(keys [][]byte, msg []byte, data []byte) ([]byte, error) {
return nil, fmt.Errorf("No installed keys could decrypt the message")
}
func appendBytes(first []byte, second []byte) []byte {
hasFirst := len(first) > 0
hasSecond := len(second) > 0
switch {
case hasFirst && hasSecond:
out := make([]byte, 0, len(first)+len(second))
out = append(out, first...)
out = append(out, second...)
return out
case hasFirst:
return first
case hasSecond:
return second
default:
return nil
}
}

View file

@ -6,32 +6,35 @@ import (
"math"
"math/rand"
"net"
"strings"
"sync/atomic"
"time"
"github.com/armon/go-metrics"
metrics "github.com/armon/go-metrics"
)
type nodeStateType int
type NodeStateType int
const (
stateAlive nodeStateType = iota
stateSuspect
stateDead
StateAlive NodeStateType = iota
StateSuspect
StateDead
StateLeft
)
// Node represents a node in the cluster.
type Node struct {
Name string
Addr net.IP
Port uint16
Meta []byte // Metadata from the delegate for this node.
PMin uint8 // Minimum protocol version this understands
PMax uint8 // Maximum protocol version this understands
PCur uint8 // Current version node is speaking
DMin uint8 // Min protocol version for the delegate to understand
DMax uint8 // Max protocol version for the delegate to understand
DCur uint8 // Current version delegate is speaking
Name string
Addr net.IP
Port uint16
Meta []byte // Metadata from the delegate for this node.
State NodeStateType // State of the node.
PMin uint8 // Minimum protocol version this understands
PMax uint8 // Maximum protocol version this understands
PCur uint8 // Current version node is speaking
DMin uint8 // Min protocol version for the delegate to understand
DMax uint8 // Max protocol version for the delegate to understand
DCur uint8 // Current version delegate is speaking
}
// Address returns the host:port form of a node's address, suitable for use
@ -40,6 +43,15 @@ func (n *Node) Address() string {
return joinHostPort(n.Addr.String(), n.Port)
}
// FullAddress returns the node name and host:port form of a node's address,
// suitable for use with a transport.
func (n *Node) FullAddress() Address {
return Address{
Addr: joinHostPort(n.Addr.String(), n.Port),
Name: n.Name,
}
}
// String returns the node name
func (n *Node) String() string {
return n.Name
@ -49,7 +61,7 @@ func (n *Node) String() string {
type nodeState struct {
Node
Incarnation uint32 // Last known incarnation number
State nodeStateType // Current state
State NodeStateType // Current state
StateChange time.Time // Time last state change happened
}
@ -59,6 +71,16 @@ func (n *nodeState) Address() string {
return n.Node.Address()
}
// FullAddress returns the node name and host:port form of a node's address,
// suitable for use with a transport.
func (n *nodeState) FullAddress() Address {
return n.Node.FullAddress()
}
func (n *nodeState) DeadOrLeft() bool {
return n.State == StateDead || n.State == StateLeft
}
// ackHandler is used to register handlers for incoming acks and nacks.
type ackHandler struct {
ackFn func([]byte, time.Time)
@ -217,7 +239,7 @@ START:
node = *m.nodes[m.probeIndex]
if node.Name == m.config.Name {
skip = true
} else if node.State == stateDead {
} else if node.DeadOrLeft() {
skip = true
}
@ -233,20 +255,56 @@ START:
m.probeNode(&node)
}
// probeNodeByAddr just safely calls probeNode given only the address of the node (for tests)
func (m *Memberlist) probeNodeByAddr(addr string) {
m.nodeLock.RLock()
n := m.nodeMap[addr]
m.nodeLock.RUnlock()
m.probeNode(n)
}
// failedRemote checks the error and decides if it indicates a failure on the
// other end.
func failedRemote(err error) bool {
switch t := err.(type) {
case *net.OpError:
if strings.HasPrefix(t.Net, "tcp") {
switch t.Op {
case "dial", "read", "write":
return true
}
} else if strings.HasPrefix(t.Net, "udp") {
switch t.Op {
case "write":
return true
}
}
}
return false
}
// probeNode handles a single round of failure checking on a node.
func (m *Memberlist) probeNode(node *nodeState) {
defer metrics.MeasureSince([]string{"memberlist", "probeNode"}, time.Now())
defer metrics.MeasureSinceWithLabels([]string{"memberlist", "probeNode"}, time.Now(), m.metricLabels)
// We use our health awareness to scale the overall probe interval, so we
// slow down if we detect problems. The ticker that calls us can handle
// us running over the base interval, and will skip missed ticks.
probeInterval := m.awareness.ScaleTimeout(m.config.ProbeInterval)
if probeInterval > m.config.ProbeInterval {
metrics.IncrCounter([]string{"memberlist", "degraded", "probe"}, 1)
metrics.IncrCounterWithLabels([]string{"memberlist", "degraded", "probe"}, 1, m.metricLabels)
}
// Prepare a ping message and setup an ack handler.
ping := ping{SeqNo: m.nextSeqNo(), Node: node.Name}
selfAddr, selfPort := m.getAdvertise()
ping := ping{
SeqNo: m.nextSeqNo(),
Node: node.Name,
SourceAddr: selfAddr,
SourcePort: selfPort,
SourceNode: m.config.Name,
}
ackCh := make(chan ackMessage, m.config.IndirectChecks+1)
nackCh := make(chan struct{}, m.config.IndirectChecks+1)
m.setProbeChannels(ping.SeqNo, ackCh, nackCh, probeInterval)
@ -263,15 +321,25 @@ func (m *Memberlist) probeNode(node *nodeState) {
// soon as possible.
deadline := sent.Add(probeInterval)
addr := node.Address()
if node.State == stateAlive {
if err := m.encodeAndSendMsg(addr, pingMsg, &ping); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send ping: %s", err)
return
// Arrange for our self-awareness to get updated.
var awarenessDelta int
defer func() {
m.awareness.ApplyDelta(awarenessDelta)
}()
if node.State == StateAlive {
if err := m.encodeAndSendMsg(node.FullAddress(), pingMsg, &ping); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send UDP ping: %s", err)
if failedRemote(err) {
goto HANDLE_REMOTE_FAILURE
} else {
return
}
}
} else {
var msgs [][]byte
if buf, err := encode(pingMsg, &ping); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to encode ping message: %s", err)
m.logger.Printf("[ERR] memberlist: Failed to encode UDP ping message: %s", err)
return
} else {
msgs = append(msgs, buf.Bytes())
@ -285,9 +353,13 @@ func (m *Memberlist) probeNode(node *nodeState) {
}
compound := makeCompoundMessage(msgs)
if err := m.rawSendMsgPacket(addr, &node.Node, compound.Bytes()); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send compound ping and suspect message to %s: %s", addr, err)
return
if err := m.rawSendMsgPacket(node.FullAddress(), &node.Node, compound.Bytes()); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send UDP compound ping and suspect message to %s: %s", addr, err)
if failedRemote(err) {
goto HANDLE_REMOTE_FAILURE
} else {
return
}
}
}
@ -296,10 +368,7 @@ func (m *Memberlist) probeNode(node *nodeState) {
// which will improve our health until we get to the failure scenarios
// at the end of this function, which will alter this delta variable
// accordingly.
awarenessDelta := -1
defer func() {
m.awareness.ApplyDelta(awarenessDelta)
}()
awarenessDelta = -1
// Wait for response or round-trip-time.
select {
@ -324,21 +393,31 @@ func (m *Memberlist) probeNode(node *nodeState) {
// probe interval it will give the TCP fallback more time, which
// is more active in dealing with lost packets, and it gives more
// time to wait for indirect acks/nacks.
m.logger.Printf("[DEBUG] memberlist: Failed ping: %v (timeout reached)", node.Name)
m.logger.Printf("[DEBUG] memberlist: Failed UDP ping: %s (timeout reached)", node.Name)
}
HANDLE_REMOTE_FAILURE:
// Get some random live nodes.
m.nodeLock.RLock()
kNodes := kRandomNodes(m.config.IndirectChecks, m.nodes, func(n *nodeState) bool {
return n.Name == m.config.Name ||
n.Name == node.Name ||
n.State != stateAlive
n.State != StateAlive
})
m.nodeLock.RUnlock()
// Attempt an indirect ping.
expectedNacks := 0
ind := indirectPingReq{SeqNo: ping.SeqNo, Target: node.Addr, Port: node.Port, Node: node.Name}
selfAddr, selfPort = m.getAdvertise()
ind := indirectPingReq{
SeqNo: ping.SeqNo,
Target: node.Addr,
Port: node.Port,
Node: node.Name,
SourceAddr: selfAddr,
SourcePort: selfPort,
SourceNode: m.config.Name,
}
for _, peer := range kNodes {
// We only expect nack to be sent from peers who understand
// version 4 of the protocol.
@ -346,8 +425,8 @@ func (m *Memberlist) probeNode(node *nodeState) {
expectedNacks++
}
if err := m.encodeAndSendMsg(peer.Address(), indirectPingMsg, &ind); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send indirect ping: %s", err)
if err := m.encodeAndSendMsg(peer.FullAddress(), indirectPingMsg, &ind); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send indirect UDP ping: %s", err)
}
}
@ -362,12 +441,19 @@ func (m *Memberlist) probeNode(node *nodeState) {
// which protocol version we are speaking. That's why we've included a
// config option to turn this off if desired.
fallbackCh := make(chan bool, 1)
if (!m.config.DisableTcpPings) && (node.PMax >= 3) {
disableTcpPings := m.config.DisableTcpPings ||
(m.config.DisableTcpPingsForNode != nil && m.config.DisableTcpPingsForNode(node.Name))
if (!disableTcpPings) && (node.PMax >= 3) {
go func() {
defer close(fallbackCh)
didContact, err := m.sendPingAndWaitForAck(node.Address(), ping, deadline)
didContact, err := m.sendPingAndWaitForAck(node.FullAddress(), ping, deadline)
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed fallback ping: %s", err)
var to string
if ne, ok := err.(net.Error); ok && ne.Timeout() {
to = fmt.Sprintf("timeout %s: ", probeInterval)
}
m.logger.Printf("[ERR] memberlist: Failed fallback TCP ping: %s%s", to, err)
} else {
fallbackCh <- didContact
}
@ -392,7 +478,7 @@ func (m *Memberlist) probeNode(node *nodeState) {
// any additional time here.
for didContact := range fallbackCh {
if didContact {
m.logger.Printf("[WARN] memberlist: Was able to connect to %s but other probes failed, network may be misconfigured", node.Name)
m.logger.Printf("[WARN] memberlist: Was able to connect to %s over TCP but UDP probes failed, network may be misconfigured", node.Name)
return
}
}
@ -422,12 +508,21 @@ func (m *Memberlist) probeNode(node *nodeState) {
// Ping initiates a ping to the node with the specified name.
func (m *Memberlist) Ping(node string, addr net.Addr) (time.Duration, error) {
// Prepare a ping message and setup an ack handler.
ping := ping{SeqNo: m.nextSeqNo(), Node: node}
selfAddr, selfPort := m.getAdvertise()
ping := ping{
SeqNo: m.nextSeqNo(),
Node: node,
SourceAddr: selfAddr,
SourcePort: selfPort,
SourceNode: m.config.Name,
}
ackCh := make(chan ackMessage, m.config.IndirectChecks+1)
m.setProbeChannels(ping.SeqNo, ackCh, nil, m.config.ProbeInterval)
a := Address{Addr: addr.String(), Name: node}
// Send a ping to the node.
if err := m.encodeAndSendMsg(addr.String(), pingMsg, &ping); err != nil {
if err := m.encodeAndSendMsg(a, pingMsg, &ping); err != nil {
return 0, err
}
@ -478,7 +573,7 @@ func (m *Memberlist) resetNodes() {
// gossip is invoked every GossipInterval period to broadcast our gossip
// messages to a few random nodes.
func (m *Memberlist) gossip() {
defer metrics.MeasureSince([]string{"memberlist", "gossip"}, time.Now())
defer metrics.MeasureSinceWithLabels([]string{"memberlist", "gossip"}, time.Now(), m.metricLabels)
// Get some random live, suspect, or recently dead nodes
m.nodeLock.RLock()
@ -488,10 +583,10 @@ func (m *Memberlist) gossip() {
}
switch n.State {
case stateAlive, stateSuspect:
case StateAlive, StateSuspect:
return false
case stateDead:
case StateDead:
return time.Since(n.StateChange) > m.config.GossipToTheDeadTime
default:
@ -501,7 +596,7 @@ func (m *Memberlist) gossip() {
m.nodeLock.RUnlock()
// Compute the bytes available
bytesAvail := m.config.UDPBufferSize - compoundHeaderOverhead
bytesAvail := m.config.UDPBufferSize - compoundHeaderOverhead - labelOverhead(m.config.Label)
if m.config.EncryptionEnabled() {
bytesAvail -= encryptOverhead(m.encryptionVersion())
}
@ -516,14 +611,16 @@ func (m *Memberlist) gossip() {
addr := node.Address()
if len(msgs) == 1 {
// Send single message as is
if err := m.rawSendMsgPacket(addr, &node.Node, msgs[0]); err != nil {
if err := m.rawSendMsgPacket(node.FullAddress(), &node, msgs[0]); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send gossip to %s: %s", addr, err)
}
} else {
// Otherwise create and send a compound message
compound := makeCompoundMessage(msgs)
if err := m.rawSendMsgPacket(addr, &node.Node, compound.Bytes()); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send gossip to %s: %s", addr, err)
// Otherwise create and send one or more compound messages
compounds := makeCompoundMessages(msgs)
for _, compound := range compounds {
if err := m.rawSendMsgPacket(node.FullAddress(), &node, compound.Bytes()); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send gossip to %s: %s", addr, err)
}
}
}
}
@ -538,7 +635,7 @@ func (m *Memberlist) pushPull() {
m.nodeLock.RLock()
nodes := kRandomNodes(1, m.nodes, func(n *nodeState) bool {
return n.Name == m.config.Name ||
n.State != stateAlive
n.State != StateAlive
})
m.nodeLock.RUnlock()
@ -549,17 +646,17 @@ func (m *Memberlist) pushPull() {
node := nodes[0]
// Attempt a push pull
if err := m.pushPullNode(node.Address(), false); err != nil {
if err := m.pushPullNode(node.FullAddress(), false); err != nil {
m.logger.Printf("[ERR] memberlist: Push/Pull with %s failed: %s", node.Name, err)
}
}
// pushPullNode does a complete state exchange with a specific node.
func (m *Memberlist) pushPullNode(addr string, join bool) error {
defer metrics.MeasureSince([]string{"memberlist", "pushPullNode"}, time.Now())
func (m *Memberlist) pushPullNode(a Address, join bool) error {
defer metrics.MeasureSinceWithLabels([]string{"memberlist", "pushPullNode"}, time.Now(), m.metricLabels)
// Attempt to send and receive with the node
remote, userState, err := m.sendAndReceiveState(addr, join)
remote, userState, err := m.sendAndReceiveState(a, join)
if err != nil {
return err
}
@ -596,7 +693,7 @@ func (m *Memberlist) verifyProtocol(remote []pushNodeState) error {
for _, rn := range remote {
// If the node isn't alive, then skip it
if rn.State != stateAlive {
if rn.State != StateAlive {
continue
}
@ -625,7 +722,7 @@ func (m *Memberlist) verifyProtocol(remote []pushNodeState) error {
for _, n := range m.nodes {
// Ignore non-alive nodes
if n.State != stateAlive {
if n.State != StateAlive {
continue
}
@ -841,11 +938,26 @@ func (m *Memberlist) aliveNode(a *alive, notify chan struct{}, bootstrap bool) {
return
}
if len(a.Vsn) >= 3 {
pMin := a.Vsn[0]
pMax := a.Vsn[1]
pCur := a.Vsn[2]
if pMin == 0 || pMax == 0 || pMin > pMax {
m.logger.Printf("[WARN] memberlist: Ignoring an alive message for '%s' (%v:%d) because protocol version(s) are wrong: %d <= %d <= %d should be >0", a.Node, net.IP(a.Addr), a.Port, pMin, pCur, pMax)
return
}
}
// Invoke the Alive delegate if any. This can be used to filter out
// alive messages based on custom logic. For example, using a cluster name.
// Using a merge delegate is not enough, as it is possible for passive
// cluster merging to still occur.
if m.config.Alive != nil {
if len(a.Vsn) < 6 {
m.logger.Printf("[WARN] memberlist: ignoring alive message for '%s' (%v:%d) because Vsn is not present",
a.Node, net.IP(a.Addr), a.Port)
return
}
node := &Node{
Name: a.Node,
Addr: a.Addr,
@ -867,7 +979,13 @@ func (m *Memberlist) aliveNode(a *alive, notify chan struct{}, bootstrap bool) {
// Check if we've never seen this node before, and if not, then
// store this node in our node map.
var updatesNode bool
if !ok {
errCon := m.config.IPAllowed(a.Addr)
if errCon != nil {
m.logger.Printf("[WARN] memberlist: Rejected node %s (%v): %s", a.Node, net.IP(a.Addr), errCon)
return
}
state = &nodeState{
Node: Node{
Name: a.Node,
@ -875,7 +993,15 @@ func (m *Memberlist) aliveNode(a *alive, notify chan struct{}, bootstrap bool) {
Port: a.Port,
Meta: a.Meta,
},
State: stateDead,
State: StateDead,
}
if len(a.Vsn) > 5 {
state.PMin = a.Vsn[0]
state.PMax = a.Vsn[1]
state.PCur = a.Vsn[2]
state.DMin = a.Vsn[3]
state.DMax = a.Vsn[4]
state.DCur = a.Vsn[5]
}
// Add to map
@ -894,29 +1020,45 @@ func (m *Memberlist) aliveNode(a *alive, notify chan struct{}, bootstrap bool) {
// Update numNodes after we've added a new node
atomic.AddUint32(&m.numNodes, 1)
}
// Check if this address is different than the existing node
if !bytes.Equal([]byte(state.Addr), a.Addr) || state.Port != a.Port {
m.logger.Printf("[ERR] memberlist: Conflicting address for %s. Mine: %v:%d Theirs: %v:%d",
state.Name, state.Addr, state.Port, net.IP(a.Addr), a.Port)
// Inform the conflict delegate if provided
if m.config.Conflict != nil {
other := Node{
Name: a.Node,
Addr: a.Addr,
Port: a.Port,
Meta: a.Meta,
} else {
// Check if this address is different than the existing node unless the old node is dead.
if !bytes.Equal([]byte(state.Addr), a.Addr) || state.Port != a.Port {
errCon := m.config.IPAllowed(a.Addr)
if errCon != nil {
m.logger.Printf("[WARN] memberlist: Rejected IP update from %v to %v for node %s: %s", a.Node, state.Addr, net.IP(a.Addr), errCon)
return
}
// If DeadNodeReclaimTime is configured, check if enough time has elapsed since the node died.
canReclaim := (m.config.DeadNodeReclaimTime > 0 &&
time.Since(state.StateChange) > m.config.DeadNodeReclaimTime)
// Allow the address to be updated if a dead node is being replaced.
if state.State == StateLeft || (state.State == StateDead && canReclaim) {
m.logger.Printf("[INFO] memberlist: Updating address for left or failed node %s from %v:%d to %v:%d",
state.Name, state.Addr, state.Port, net.IP(a.Addr), a.Port)
updatesNode = true
} else {
m.logger.Printf("[ERR] memberlist: Conflicting address for %s. Mine: %v:%d Theirs: %v:%d Old state: %v",
state.Name, state.Addr, state.Port, net.IP(a.Addr), a.Port, state.State)
// Inform the conflict delegate if provided
if m.config.Conflict != nil {
other := Node{
Name: a.Node,
Addr: a.Addr,
Port: a.Port,
Meta: a.Meta,
}
m.config.Conflict.NotifyConflict(&state.Node, &other)
}
return
}
m.config.Conflict.NotifyConflict(&state.Node, &other)
}
return
}
// Bail if the incarnation number is older, and this is not about us
isLocalNode := state.Name == m.config.Name
if a.Incarnation <= state.Incarnation && !isLocalNode {
if a.Incarnation <= state.Incarnation && !isLocalNode && !updatesNode {
return
}
@ -956,9 +1098,8 @@ func (m *Memberlist) aliveNode(a *alive, notify chan struct{}, bootstrap bool) {
bytes.Equal(a.Vsn, versions) {
return
}
m.refute(state, a.Incarnation)
m.logger.Printf("[WARN] memberlist: Refuting an alive message")
m.logger.Printf("[WARN] memberlist: Refuting an alive message for '%s' (%v:%d) meta:(%v VS %v), vsn:(%v VS %v)", a.Node, net.IP(a.Addr), a.Port, a.Meta, state.Meta, a.Vsn, versions)
} else {
m.encodeBroadcastNotify(a.Node, aliveMsg, a, notify)
@ -975,19 +1116,21 @@ func (m *Memberlist) aliveNode(a *alive, notify chan struct{}, bootstrap bool) {
// Update the state and incarnation number
state.Incarnation = a.Incarnation
state.Meta = a.Meta
if state.State != stateAlive {
state.State = stateAlive
state.Addr = a.Addr
state.Port = a.Port
if state.State != StateAlive {
state.State = StateAlive
state.StateChange = time.Now()
}
}
// Update metrics
metrics.IncrCounter([]string{"memberlist", "msg", "alive"}, 1)
metrics.IncrCounterWithLabels([]string{"memberlist", "msg", "alive"}, 1, m.metricLabels)
// Notify the delegate of any relevant updates
if m.config.Events != nil {
if oldState == stateDead {
// if Dead -> Alive, notify of join
if oldState == StateDead || oldState == StateLeft {
// if Dead/Left -> Alive, notify of join
m.config.Events.NotifyJoin(&state.Node)
} else if !bytes.Equal(oldMeta, state.Meta) {
@ -1026,7 +1169,7 @@ func (m *Memberlist) suspectNode(s *suspect) {
}
// Ignore non-alive nodes
if state.State != stateAlive {
if state.State != StateAlive {
return
}
@ -1040,11 +1183,11 @@ func (m *Memberlist) suspectNode(s *suspect) {
}
// Update metrics
metrics.IncrCounter([]string{"memberlist", "msg", "suspect"}, 1)
metrics.IncrCounterWithLabels([]string{"memberlist", "msg", "suspect"}, 1, m.metricLabels)
// Update the state
state.Incarnation = s.Incarnation
state.State = stateSuspect
state.State = StateSuspect
changeTime := time.Now()
state.StateChange = changeTime
@ -1066,20 +1209,25 @@ func (m *Memberlist) suspectNode(s *suspect) {
min := suspicionTimeout(m.config.SuspicionMult, n, m.config.ProbeInterval)
max := time.Duration(m.config.SuspicionMaxTimeoutMult) * min
fn := func(numConfirmations int) {
var d *dead
m.nodeLock.Lock()
state, ok := m.nodeMap[s.Node]
timeout := ok && state.State == stateSuspect && state.StateChange == changeTime
timeout := ok && state.State == StateSuspect && state.StateChange == changeTime
if timeout {
d = &dead{Incarnation: state.Incarnation, Node: state.Name, From: m.config.Name}
}
m.nodeLock.Unlock()
if timeout {
if k > 0 && numConfirmations < k {
metrics.IncrCounter([]string{"memberlist", "degraded", "timeout"}, 1)
metrics.IncrCounterWithLabels([]string{"memberlist", "degraded", "timeout"}, 1, m.metricLabels)
}
m.logger.Printf("[INFO] memberlist: Marking %s as failed, suspect timeout reached (%d peer confirmations)",
state.Name, numConfirmations)
d := dead{Incarnation: state.Incarnation, Node: state.Name, From: m.config.Name}
m.deadNode(&d)
m.deadNode(d)
}
}
m.nodeTimers[s.Node] = newSuspicion(s.From, k, min, max, fn)
@ -1106,7 +1254,7 @@ func (m *Memberlist) deadNode(d *dead) {
delete(m.nodeTimers, d.Node)
// Ignore if node is already dead
if state.State == stateDead {
if state.DeadOrLeft() {
return
}
@ -1126,11 +1274,18 @@ func (m *Memberlist) deadNode(d *dead) {
}
// Update metrics
metrics.IncrCounter([]string{"memberlist", "msg", "dead"}, 1)
metrics.IncrCounterWithLabels([]string{"memberlist", "msg", "dead"}, 1, m.metricLabels)
// Update the state
state.Incarnation = d.Incarnation
state.State = stateDead
// If the dead message was send by the node itself, mark it is left
// instead of dead.
if d.Node == d.From {
state.State = StateLeft
} else {
state.State = StateDead
}
state.StateChange = time.Now()
// Notify of death
@ -1144,7 +1299,7 @@ func (m *Memberlist) deadNode(d *dead) {
func (m *Memberlist) mergeState(remote []pushNodeState) {
for _, r := range remote {
switch r.State {
case stateAlive:
case StateAlive:
a := alive{
Incarnation: r.Incarnation,
Node: r.Name,
@ -1155,11 +1310,14 @@ func (m *Memberlist) mergeState(remote []pushNodeState) {
}
m.aliveNode(&a, nil, false)
case stateDead:
case StateLeft:
d := dead{Incarnation: r.Incarnation, Node: r.Name, From: r.Name}
m.deadNode(&d)
case StateDead:
// If the remote node believes a node is dead, we prefer to
// suspect that node instead of declaring it dead instantly
fallthrough
case stateSuspect:
case StateSuspect:
s := suspect{Incarnation: r.Incarnation, Node: r.Name, From: m.config.Name}
m.suspectNode(&s)
}

View file

@ -1,6 +1,7 @@
package memberlist
import (
"fmt"
"net"
"time"
)
@ -63,3 +64,97 @@ type Transport interface {
// transport a chance to clean up any listeners.
Shutdown() error
}
type Address struct {
// Addr is a network address as a string, similar to Dial. This usually is
// in the form of "host:port". This is required.
Addr string
// Name is the name of the node being addressed. This is optional but
// transports may require it.
Name string
}
func (a *Address) String() string {
if a.Name != "" {
return fmt.Sprintf("%s (%s)", a.Name, a.Addr)
}
return a.Addr
}
// IngestionAwareTransport is not used.
//
// Deprecated: IngestionAwareTransport is not used and may be removed in a future
// version. Define the interface locally instead of referencing this exported
// interface.
type IngestionAwareTransport interface {
IngestPacket(conn net.Conn, addr net.Addr, now time.Time, shouldClose bool) error
IngestStream(conn net.Conn) error
}
type NodeAwareTransport interface {
Transport
WriteToAddress(b []byte, addr Address) (time.Time, error)
DialAddressTimeout(addr Address, timeout time.Duration) (net.Conn, error)
}
type shimNodeAwareTransport struct {
Transport
}
var _ NodeAwareTransport = (*shimNodeAwareTransport)(nil)
func (t *shimNodeAwareTransport) WriteToAddress(b []byte, addr Address) (time.Time, error) {
return t.WriteTo(b, addr.Addr)
}
func (t *shimNodeAwareTransport) DialAddressTimeout(addr Address, timeout time.Duration) (net.Conn, error) {
return t.DialTimeout(addr.Addr, timeout)
}
type labelWrappedTransport struct {
label string
NodeAwareTransport
}
var _ NodeAwareTransport = (*labelWrappedTransport)(nil)
func (t *labelWrappedTransport) WriteToAddress(buf []byte, addr Address) (time.Time, error) {
var err error
buf, err = AddLabelHeaderToPacket(buf, t.label)
if err != nil {
return time.Time{}, fmt.Errorf("failed to add label header to packet: %w", err)
}
return t.NodeAwareTransport.WriteToAddress(buf, addr)
}
func (t *labelWrappedTransport) WriteTo(buf []byte, addr string) (time.Time, error) {
var err error
buf, err = AddLabelHeaderToPacket(buf, t.label)
if err != nil {
return time.Time{}, err
}
return t.NodeAwareTransport.WriteTo(buf, addr)
}
func (t *labelWrappedTransport) DialAddressTimeout(addr Address, timeout time.Duration) (net.Conn, error) {
conn, err := t.NodeAwareTransport.DialAddressTimeout(addr, timeout)
if err != nil {
return nil, err
}
if err := AddLabelHeaderToStream(conn, t.label); err != nil {
return nil, fmt.Errorf("failed to add label header to stream: %w", err)
}
return conn, nil
}
func (t *labelWrappedTransport) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
conn, err := t.NodeAwareTransport.DialTimeout(addr, timeout)
if err != nil {
return nil, err
}
if err := AddLabelHeaderToStream(conn, t.label); err != nil {
return nil, fmt.Errorf("failed to add label header to stream: %w", err)
}
return conn, nil
}

View file

@ -78,10 +78,9 @@ func retransmitLimit(retransmitMult, n int) int {
// shuffleNodes randomly shuffles the input nodes using the Fisher-Yates shuffle
func shuffleNodes(nodes []*nodeState) {
n := len(nodes)
for i := n - 1; i > 0; i-- {
j := rand.Intn(i + 1)
rand.Shuffle(n, func(i, j int) {
nodes[i], nodes[j] = nodes[j], nodes[i]
}
})
}
// pushPushScale is used to scale the time interval at which push/pull
@ -97,13 +96,13 @@ func pushPullScale(interval time.Duration, n int) time.Duration {
return time.Duration(multiplier) * interval
}
// moveDeadNodes moves nodes that are dead and beyond the gossip to the dead interval
// moveDeadNodes moves dead and left nodes that that have not changed during the gossipToTheDeadTime interval
// to the end of the slice and returns the index of the first moved node.
func moveDeadNodes(nodes []*nodeState, gossipToTheDeadTime time.Duration) int {
numDead := 0
n := len(nodes)
for i := 0; i < n-numDead; i++ {
if nodes[i].State != stateDead {
if !nodes[i].DeadOrLeft() {
continue
}
@ -120,39 +119,56 @@ func moveDeadNodes(nodes []*nodeState, gossipToTheDeadTime time.Duration) int {
return n - numDead
}
// kRandomNodes is used to select up to k random nodes, excluding any nodes where
// the filter function returns true. It is possible that less than k nodes are
// kRandomNodes is used to select up to k random Nodes, excluding any nodes where
// the exclude function returns true. It is possible that less than k nodes are
// returned.
func kRandomNodes(k int, nodes []*nodeState, filterFn func(*nodeState) bool) []*nodeState {
func kRandomNodes(k int, nodes []*nodeState, exclude func(*nodeState) bool) []Node {
n := len(nodes)
kNodes := make([]*nodeState, 0, k)
kNodes := make([]Node, 0, k)
OUTER:
// Probe up to 3*n times, with large n this is not necessary
// since k << n, but with small n we want search to be
// exhaustive
for i := 0; i < 3*n && len(kNodes) < k; i++ {
// Get random node
// Get random nodeState
idx := randomOffset(n)
node := nodes[idx]
state := nodes[idx]
// Give the filter a shot at it.
if filterFn != nil && filterFn(node) {
if exclude != nil && exclude(state) {
continue OUTER
}
// Check if we have this node already
for j := 0; j < len(kNodes); j++ {
if node == kNodes[j] {
if state.Node.Name == kNodes[j].Name {
continue OUTER
}
}
// Append the node
kNodes = append(kNodes, node)
kNodes = append(kNodes, state.Node)
}
return kNodes
}
// makeCompoundMessages takes a list of messages and packs
// them into one or multiple messages based on the limitations
// of compound messages (255 messages each).
func makeCompoundMessages(msgs [][]byte) []*bytes.Buffer {
const maxMsgs = 255
bufs := make([]*bytes.Buffer, 0, (len(msgs)+(maxMsgs-1))/maxMsgs)
for ; len(msgs) > maxMsgs; msgs = msgs[maxMsgs:] {
bufs = append(bufs, makeCompoundMessage(msgs[:maxMsgs]))
}
if len(msgs) > 0 {
bufs = append(bufs, makeCompoundMessage(msgs))
}
return bufs
}
// makeCompoundMessage takes a list of messages and generates
// a single compound message containing all of them
func makeCompoundMessage(msgs [][]byte) *bytes.Buffer {
@ -186,18 +202,18 @@ func decodeCompoundMessage(buf []byte) (trunc int, parts [][]byte, err error) {
err = fmt.Errorf("missing compound length byte")
return
}
numParts := uint8(buf[0])
numParts := int(buf[0])
buf = buf[1:]
// Check we have enough bytes
if len(buf) < int(numParts*2) {
if len(buf) < numParts*2 {
err = fmt.Errorf("truncated len slice")
return
}
// Decode the lengths
lengths := make([]uint16, numParts)
for i := 0; i < int(numParts); i++ {
for i := 0; i < numParts; i++ {
lengths[i] = binary.BigEndian.Uint16(buf[i*2 : i*2+2])
}
buf = buf[numParts*2:]
@ -205,7 +221,7 @@ func decodeCompoundMessage(buf []byte) (trunc int, parts [][]byte, err error) {
// Split each message
for idx, msgLen := range lengths {
if len(buf) < int(msgLen) {
trunc = int(numParts) - idx
trunc = numParts - idx
return
}