Merge pull request #14071 from mavenugo/mhnet

Vendoring libnetwork 83743db8ceb2bdbfa0960d9a54ed2f98df4ea846
This commit is contained in:
David Calavera 2015-06-22 08:35:18 -07:00
commit f39b9a0b0f
102 changed files with 17526 additions and 209 deletions

View file

@ -569,7 +569,6 @@ func (b *Builder) create() (*daemon.Container, error) {
CgroupParent: b.cgroupParent,
Memory: b.memory,
MemorySwap: b.memorySwap,
NetworkMode: "bridge",
}
config := *b.Config

View file

@ -33,6 +33,7 @@ type CommonConfig struct {
Root string
TrustKeyPath string
DefaultNetwork string
NetworkKVStore string
}
// InstallCommonFlags adds command-line options to the top-level flag parser for
@ -51,7 +52,6 @@ func (config *Config) InstallCommonFlags() {
flag.IntVar(&config.Mtu, []string{"#mtu", "-mtu"}, 0, "Set the containers network MTU")
flag.BoolVar(&config.EnableCors, []string{"#api-enable-cors", "#-api-enable-cors"}, false, "Enable CORS headers in the remote API, this is deprecated by --api-cors-header")
flag.StringVar(&config.CorsHeaders, []string{"-api-cors-header"}, "", "Set CORS headers in the remote API")
flag.StringVar(&config.DefaultNetwork, []string{"-default-network"}, "", "Set default network")
// FIXME: why the inconsistency between "hosts" and "sockets"?
opts.IPListVar(&config.Dns, []string{"#dns", "-dns"}, "DNS server to use")
opts.DnsSearchListVar(&config.DnsSearch, []string{"-dns-search"}, "DNS search domains to use")

View file

@ -0,0 +1,10 @@
// +build experimental
package daemon
import flag "github.com/docker/docker/pkg/mflag"
func (config *Config) attachExperimentalFlags() {
flag.StringVar(&config.DefaultNetwork, []string{"-default-network"}, "", "Set default network")
flag.StringVar(&config.NetworkKVStore, []string{"-kv-store"}, "", "Set KV Store configuration")
}

View file

@ -73,4 +73,5 @@ func (config *Config) InstallFlags() {
flag.BoolVar(&config.Bridge.InterContainerCommunication, []string{"#icc", "-icc"}, true, "Enable inter-container communication")
opts.IPVar(&config.Bridge.DefaultIP, []string{"#ip", "-ip"}, "0.0.0.0", "Default IP when binding container ports")
flag.BoolVar(&config.Bridge.EnableUserlandProxy, []string{"-userland-proxy"}, true, "Use userland proxy for loopback traffic")
config.attachExperimentalFlags()
}

6
daemon/config_stub.go Normal file
View file

@ -0,0 +1,6 @@
// +build !experimental
package daemon
func (config *Config) attachExperimentalFlags() {
}

View file

@ -768,6 +768,20 @@ func createNetwork(controller libnetwork.NetworkController, dnet string, driver
return controller.NewNetwork(driver, dnet, createOptions...)
}
func (container *Container) secondaryNetworkRequired(primaryNetworkType string) bool {
switch primaryNetworkType {
case "bridge", "none", "host", "container":
return false
}
if container.Config.ExposedPorts != nil && len(container.Config.ExposedPorts) > 0 {
return true
}
if container.hostConfig.PortBindings != nil && len(container.hostConfig.PortBindings) > 0 {
return true
}
return false
}
func (container *Container) AllocateNetwork() error {
mode := container.hostConfig.NetworkMode
controller := container.daemon.netController
@ -775,7 +789,7 @@ func (container *Container) AllocateNetwork() error {
return nil
}
var networkDriver string
networkDriver := string(mode)
service := container.Config.PublishService
networkName := mode.NetworkName()
if mode.IsDefault() {
@ -790,15 +804,32 @@ func (container *Container) AllocateNetwork() error {
}
if service == "" {
// dot character "." has a special meaning to support SERVICE[.NETWORK] format.
// For backward compatiblity, replacing "." with "-", instead of failing
service = strings.Replace(container.Name, ".", "-", -1)
// Service names dont like "/" in them. removing it instead of failing for backward compatibility
service = strings.Replace(service, "/", "", -1)
}
var err error
if container.secondaryNetworkRequired(networkDriver) {
// Configure Bridge as secondary network for port binding purposes
if err := container.configureNetwork("bridge", service, "bridge", false); err != nil {
return err
}
}
if err := container.configureNetwork(networkName, service, networkDriver, mode.IsDefault()); err != nil {
return err
}
return container.WriteHostConfig()
}
func (container *Container) configureNetwork(networkName, service, networkDriver string, canCreateNetwork bool) error {
controller := container.daemon.netController
n, err := controller.NetworkByName(networkName)
if err != nil {
// Create Network automatically only in default mode
if _, ok := err.(libnetwork.ErrNoSuchNetwork); !ok || !mode.IsDefault() {
if _, ok := err.(libnetwork.ErrNoSuchNetwork); !ok || !canCreateNetwork {
return err
}
@ -841,10 +872,6 @@ func (container *Container) AllocateNetwork() error {
return fmt.Errorf("Updating join info failed: %v", err)
}
if err := container.WriteHostConfig(); err != nil {
return err
}
return nil
}
@ -976,33 +1003,33 @@ func (container *Container) ReleaseNetwork() {
return
}
// If the container is not attached to any network do not try
// to release network and generate spurious error messages.
if container.NetworkSettings.NetworkID == "" {
return
}
n, err := container.daemon.netController.NetworkByID(container.NetworkSettings.NetworkID)
err := container.daemon.netController.LeaveAll(container.ID)
if err != nil {
logrus.Errorf("error locating network id %s: %v", container.NetworkSettings.NetworkID, err)
logrus.Errorf("Leave all failed for %s: %v", container.ID, err)
return
}
ep, err := n.EndpointByID(container.NetworkSettings.EndpointID)
if err != nil {
logrus.Errorf("error locating endpoint id %s: %v", container.NetworkSettings.EndpointID, err)
return
}
if err := ep.Leave(container.ID); err != nil {
logrus.Errorf("leaving endpoint failed: %v", err)
}
if err := ep.Delete(); err != nil {
logrus.Errorf("deleting endpoint failed: %v", err)
}
eid := container.NetworkSettings.EndpointID
nid := container.NetworkSettings.NetworkID
container.NetworkSettings = &network.Settings{}
// In addition to leaving all endpoints, delete implicitly created endpoint
if container.Config.PublishService == "" && eid != "" && nid != "" {
n, err := container.daemon.netController.NetworkByID(nid)
if err != nil {
logrus.Errorf("error locating network id %s: %v", nid, err)
return
}
ep, err := n.EndpointByID(eid)
if err != nil {
logrus.Errorf("error locating endpoint id %s: %v", eid, err)
return
}
if err := ep.Delete(); err != nil {
logrus.Errorf("deleting endpoint failed: %v", err)
}
}
}
func disableAllActiveLinks(container *Container) {

View file

@ -282,6 +282,16 @@ func networkOptions(dconfig *Config) ([]nwconfig.Option, error) {
options = append(options, nwconfig.OptionDefaultDriver(string(dd)))
options = append(options, nwconfig.OptionDefaultNetwork(dn))
}
if strings.TrimSpace(dconfig.NetworkKVStore) != "" {
kv := strings.Split(dconfig.NetworkKVStore, ":")
if len(kv) < 2 {
return nil, fmt.Errorf("kv store daemon config must be of the form KV-PROVIDER:KV-URL")
}
options = append(options, nwconfig.OptionKVProvider(kv[0]))
options = append(options, nwconfig.OptionKVProviderURL(strings.Join(kv[1:], ":")))
}
options = append(options, nwconfig.OptionLabels(dconfig.Labels))
return options, nil
}

View file

@ -45,7 +45,8 @@ Unlike the regular Docker binary, the experimental channels is built and updated
## Current experimental features
* [Support for Docker plugins](plugins.md)
* [Volume plugins](plugins_volume.md)
* [Volume plugins](plugins_volume.md)
* [Native Multi-host networking](networking.md)
## How to comment on an experimental feature

View file

@ -3,9 +3,13 @@
In this feature:
- `network` and `service` become a first class objects in the Docker UI
- You can create networks and attach containers to them
- We introduce the concept of `services`
- This is an entry-point in to a given network that is also published via Service Discovery
- one can now create networks, publish services on that network and attach containers to the services
- Natice multi-host networking
- `network` and `service` objects are globally significant and provides multi-host container connectivity natively
- Inbuilt simple Service Discovery
- With multi-host networking and top-level `service` object, Docker now provides out of the box simple Service Discovery for containers running in a network
- Batteries included but removable
- Docker provides inbuilt native multi-host networking by default & can be swapped by any remote driver provided by external plugins.
This is an experimental feature. For information on installing and using experimental features, see [the experimental feature overview](experimental.md).
@ -62,11 +66,14 @@ If you no longer have need of a network, you can delete it with `docker network
bd61375b6993 host host
cc455abccfeb bridge bridge
Docker daemon supports a configuration flag `--default-network` which takes configuration value of format `NETWORK:DRIVER`, where,
`NETWORK` is the name of the network created using the `docker network create` command and
## User-Defined default network
Docker daemon supports a configuration flag `--default-network` which takes configuration value of format `DRIVER:NETWORK`, where,
`DRIVER` represents the in-built drivers such as bridge, overlay, container, host and none. or Remote drivers via Network Plugins.
`NETWORK` is the name of the network created using the `docker network create` command
When a container is created and if the network mode (`--net`) is not specified, then this default network will be used to connect
the container. If `--default-network` is not specified, the default network will be the `bridge` driver.
Example : `docker -d --default-network=overlay:multihost`
## Using Services
@ -109,6 +116,20 @@ To remove the a service:
$ docker service detach a0ebc12d3e48 my-service.foo
$ docker service unpublish my-service.foo
Send us feedback and comments on [#](https://github.com/docker/docker/issues/?),
## Native Multi-host networking
There is a lot to talk about the native multi-host networking and the `overlay` driver that makes it happen. The technical details are documented under https://github.com/docker/libnetwork/blob/master/docs/overlay.md.
Using the above experimental UI `docker network`, `docker service` and `--publish-service`, the user can exercise the power of multi-host networking.
Since `network` and `service` objects are globally significant, this feature requires distributed states provided by the `libkv` project.
Using `libkv`, the user can plug any of the supported Key-Value store (such as consul, etcd or zookeeper).
User can specify the Key-Value store of choice using the `--kv-store` daemon flag, which takes configuration value of format `PROVIDER:URL`, where
`PROVIDER` is the name of the Key-Value store (such as consul, etcd or zookeeper) and
`URL` is the url to reach the Key-Value store.
Example : `docker -d --kv-store=consul:localhost:8500`
Send us feedback and comments on [#14083](https://github.com/docker/docker/issues/14083)
or on the usual Google Groups (docker-user, docker-dev) and IRC channels.

View file

@ -18,7 +18,11 @@ clone git golang.org/x/net 3cffabab72adf04f8e3b01c5baf775361837b5fe https://gith
clone hg code.google.com/p/gosqlite 74691fb6f837
#get libnetwork packages
clone git github.com/docker/libnetwork fc7abaa93fd33a77cc37845adbbc4adf03676dd5
clone git github.com/docker/libnetwork 1aaf1047fd48345619a875184538a0eb6c6cfb2a
clone git github.com/armon/go-metrics eb0af217e5e9747e41dd5303755356b62d28e3ec
clone git github.com/hashicorp/go-msgpack 71c2886f5a673a35f909803f38ece5810165097b
clone git github.com/hashicorp/memberlist 9a1e242e454d2443df330bdd51a436d5a9058fc4
clone git github.com/hashicorp/serf 7151adcef72687bf95f451a2e0ba15cb19412bf2
clone git github.com/docker/libkv e8cde779d58273d240c1eff065352a6cd67027dd
clone git github.com/vishvananda/netns 5478c060110032f972e86a1f844fdb9a2f008f2c
clone git github.com/vishvananda/netlink 8eb64238879fed52fd51c5b30ad20b928fb4c36c

View file

@ -0,0 +1,39 @@
// +build daemon,experimental
package main
import (
"os/exec"
"strings"
"github.com/go-check/check"
)
func assertNetwork(c *check.C, d *Daemon, name string) {
out, err := d.Cmd("network", "ls")
c.Assert(err, check.IsNil)
lines := strings.Split(out, "\n")
for i := 1; i < len(lines)-1; i++ {
if strings.Contains(lines[i], name) {
return
}
}
c.Fatalf("Network %s not found in network ls o/p", name)
}
func (s *DockerDaemonSuite) TestDaemonDefaultNetwork(c *check.C) {
d := s.d
networkName := "testdefault"
err := d.StartWithBusybox("--default-network", "bridge:"+networkName)
c.Assert(err, check.IsNil)
_, err = d.Cmd("run", "busybox", "true")
c.Assert(err, check.IsNil)
assertNetwork(c, d, networkName)
ifconfigCmd := exec.Command("ifconfig", networkName)
_, _, _, err = runCommandWithStdoutStderr(ifconfigCmd)
c.Assert(err, check.IsNil)
}

View file

@ -0,0 +1,22 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe

View file

@ -0,0 +1,20 @@
The MIT License (MIT)
Copyright (c) 2013 Armon Dadgar
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -0,0 +1,68 @@
go-metrics
==========
This library provides a `metrics` package which can be used to instrument code,
expose application metrics, and profile runtime performance in a flexible manner.
Sinks
=====
The `metrics` package makes use of a `MetricSink` interface to support delivery
to any type of backend. Currently the following sinks are provided:
* StatsiteSink : Sinks to a statsite instance (TCP)
* StatsdSink: Sinks to a statsd / statsite instance (UDP)
* InmemSink : Provides in-memory aggregation, can be used to export stats
* FanoutSink : Sinks to multiple sinks. Enables writing to multiple statsite instances for example.
* BlackholeSink : Sinks to nowhere
In addition to the sinks, the `InmemSignal` can be used to catch a signal,
and dump a formatted output of recent metrics. For example, when a process gets
a SIGUSR1, it can dump to stderr recent performance metrics for debugging.
Examples
========
Here is an example of using the package:
func SlowMethod() {
// Profiling the runtime of a method
defer metrics.MeasureSince([]string{"SlowMethod"}, time.Now())
}
// Configure a statsite sink as the global metrics sink
sink, _ := metrics.NewStatsiteSink("statsite:8125")
metrics.NewGlobal(metrics.DefaultConfig("service-name"), sink)
// Emit a Key/Value pair
metrics.EmitKey([]string{"questions", "meaning of life"}, 42)
Here is an example of setting up an signal handler:
// Setup the inmem sink and signal handler
inm := NewInmemSink(10*time.Second, time.Minute)
sig := DefaultInmemSignal(inm)
metrics.NewGlobal(metrics.DefaultConfig("service-name"), inm)
// Run some code
inm.SetGauge([]string{"foo"}, 42)
inm.EmitKey([]string{"bar"}, 30)
inm.IncrCounter([]string{"baz"}, 42)
inm.IncrCounter([]string{"baz"}, 1)
inm.IncrCounter([]string{"baz"}, 80)
inm.AddSample([]string{"method", "wow"}, 42)
inm.AddSample([]string{"method", "wow"}, 100)
inm.AddSample([]string{"method", "wow"}, 22)
....
When a signal comes in, output like the following will be dumped to stderr:
[2014-01-28 14:57:33.04 -0800 PST][G] 'foo': 42.000
[2014-01-28 14:57:33.04 -0800 PST][P] 'bar': 30.000
[2014-01-28 14:57:33.04 -0800 PST][C] 'baz': Count: 3 Min: 1.000 Mean: 41.000 Max: 80.000 Stddev: 39.509
[2014-01-28 14:57:33.04 -0800 PST][S] 'method.wow': Count: 3 Min: 22.000 Mean: 54.667 Max: 100.000 Stddev: 40.513

View file

@ -0,0 +1,12 @@
// +build !windows
package metrics
import (
"syscall"
)
const (
// DefaultSignal is used with DefaultInmemSignal
DefaultSignal = syscall.SIGUSR1
)

View file

@ -0,0 +1,13 @@
// +build windows
package metrics
import (
"syscall"
)
const (
// DefaultSignal is used with DefaultInmemSignal
// Windows has no SIGUSR1, use SIGBREAK
DefaultSignal = syscall.Signal(21)
)

View file

@ -0,0 +1,239 @@
package metrics
import (
"fmt"
"math"
"strings"
"sync"
"time"
)
// InmemSink provides a MetricSink that does in-memory aggregation
// without sending metrics over a network. It can be embedded within
// an application to provide profiling information.
type InmemSink struct {
// How long is each aggregation interval
interval time.Duration
// Retain controls how many metrics interval we keep
retain time.Duration
// maxIntervals is the maximum length of intervals.
// It is retain / interval.
maxIntervals int
// intervals is a slice of the retained intervals
intervals []*IntervalMetrics
intervalLock sync.RWMutex
}
// IntervalMetrics stores the aggregated metrics
// for a specific interval
type IntervalMetrics struct {
sync.RWMutex
// The start time of the interval
Interval time.Time
// Gauges maps the key to the last set value
Gauges map[string]float32
// Points maps the string to the list of emitted values
// from EmitKey
Points map[string][]float32
// Counters maps the string key to a sum of the counter
// values
Counters map[string]*AggregateSample
// Samples maps the key to an AggregateSample,
// which has the rolled up view of a sample
Samples map[string]*AggregateSample
}
// NewIntervalMetrics creates a new IntervalMetrics for a given interval
func NewIntervalMetrics(intv time.Time) *IntervalMetrics {
return &IntervalMetrics{
Interval: intv,
Gauges: make(map[string]float32),
Points: make(map[string][]float32),
Counters: make(map[string]*AggregateSample),
Samples: make(map[string]*AggregateSample),
}
}
// AggregateSample is used to hold aggregate metrics
// about a sample
type AggregateSample struct {
Count int // The count of emitted pairs
Sum float64 // The sum of values
SumSq float64 // The sum of squared values
Min float64 // Minimum value
Max float64 // Maximum value
}
// Computes a Stddev of the values
func (a *AggregateSample) Stddev() float64 {
num := (float64(a.Count) * a.SumSq) - math.Pow(a.Sum, 2)
div := float64(a.Count * (a.Count - 1))
if div == 0 {
return 0
}
return math.Sqrt(num / div)
}
// Computes a mean of the values
func (a *AggregateSample) Mean() float64 {
if a.Count == 0 {
return 0
}
return a.Sum / float64(a.Count)
}
// Ingest is used to update a sample
func (a *AggregateSample) Ingest(v float64) {
a.Count++
a.Sum += v
a.SumSq += (v * v)
if v < a.Min || a.Count == 1 {
a.Min = v
}
if v > a.Max || a.Count == 1 {
a.Max = v
}
}
func (a *AggregateSample) String() string {
if a.Count == 0 {
return "Count: 0"
} else if a.Stddev() == 0 {
return fmt.Sprintf("Count: %d Sum: %0.3f", a.Count, a.Sum)
} else {
return fmt.Sprintf("Count: %d Min: %0.3f Mean: %0.3f Max: %0.3f Stddev: %0.3f Sum: %0.3f",
a.Count, a.Min, a.Mean(), a.Max, a.Stddev(), a.Sum)
}
}
// NewInmemSink is used to construct a new in-memory sink.
// Uses an aggregation interval and maximum retention period.
func NewInmemSink(interval, retain time.Duration) *InmemSink {
i := &InmemSink{
interval: interval,
retain: retain,
maxIntervals: int(retain / interval),
}
i.intervals = make([]*IntervalMetrics, 0, i.maxIntervals)
return i
}
func (i *InmemSink) SetGauge(key []string, val float32) {
k := i.flattenKey(key)
intv := i.getInterval()
intv.Lock()
defer intv.Unlock()
intv.Gauges[k] = val
}
func (i *InmemSink) EmitKey(key []string, val float32) {
k := i.flattenKey(key)
intv := i.getInterval()
intv.Lock()
defer intv.Unlock()
vals := intv.Points[k]
intv.Points[k] = append(vals, val)
}
func (i *InmemSink) IncrCounter(key []string, val float32) {
k := i.flattenKey(key)
intv := i.getInterval()
intv.Lock()
defer intv.Unlock()
agg := intv.Counters[k]
if agg == nil {
agg = &AggregateSample{}
intv.Counters[k] = agg
}
agg.Ingest(float64(val))
}
func (i *InmemSink) AddSample(key []string, val float32) {
k := i.flattenKey(key)
intv := i.getInterval()
intv.Lock()
defer intv.Unlock()
agg := intv.Samples[k]
if agg == nil {
agg = &AggregateSample{}
intv.Samples[k] = agg
}
agg.Ingest(float64(val))
}
// Data is used to retrieve all the aggregated metrics
// Intervals may be in use, and a read lock should be acquired
func (i *InmemSink) Data() []*IntervalMetrics {
// Get the current interval, forces creation
i.getInterval()
i.intervalLock.RLock()
defer i.intervalLock.RUnlock()
intervals := make([]*IntervalMetrics, len(i.intervals))
copy(intervals, i.intervals)
return intervals
}
func (i *InmemSink) getExistingInterval(intv time.Time) *IntervalMetrics {
i.intervalLock.RLock()
defer i.intervalLock.RUnlock()
n := len(i.intervals)
if n > 0 && i.intervals[n-1].Interval == intv {
return i.intervals[n-1]
}
return nil
}
func (i *InmemSink) createInterval(intv time.Time) *IntervalMetrics {
i.intervalLock.Lock()
defer i.intervalLock.Unlock()
// Check for an existing interval
n := len(i.intervals)
if n > 0 && i.intervals[n-1].Interval == intv {
return i.intervals[n-1]
}
// Add the current interval
current := NewIntervalMetrics(intv)
i.intervals = append(i.intervals, current)
n++
// Truncate the intervals if they are too long
if n >= i.maxIntervals {
copy(i.intervals[0:], i.intervals[n-i.maxIntervals:])
i.intervals = i.intervals[:i.maxIntervals]
}
return current
}
// getInterval returns the current interval to write to
func (i *InmemSink) getInterval() *IntervalMetrics {
intv := time.Now().Truncate(i.interval)
if m := i.getExistingInterval(intv); m != nil {
return m
}
return i.createInterval(intv)
}
// Flattens the key for formatting, removes spaces
func (i *InmemSink) flattenKey(parts []string) string {
joined := strings.Join(parts, ".")
return strings.Replace(joined, " ", "_", -1)
}

View file

@ -0,0 +1,100 @@
package metrics
import (
"bytes"
"fmt"
"io"
"os"
"os/signal"
"sync"
"syscall"
)
// InmemSignal is used to listen for a given signal, and when received,
// to dump the current metrics from the InmemSink to an io.Writer
type InmemSignal struct {
signal syscall.Signal
inm *InmemSink
w io.Writer
sigCh chan os.Signal
stop bool
stopCh chan struct{}
stopLock sync.Mutex
}
// NewInmemSignal creates a new InmemSignal which listens for a given signal,
// and dumps the current metrics out to a writer
func NewInmemSignal(inmem *InmemSink, sig syscall.Signal, w io.Writer) *InmemSignal {
i := &InmemSignal{
signal: sig,
inm: inmem,
w: w,
sigCh: make(chan os.Signal, 1),
stopCh: make(chan struct{}),
}
signal.Notify(i.sigCh, sig)
go i.run()
return i
}
// DefaultInmemSignal returns a new InmemSignal that responds to SIGUSR1
// and writes output to stderr. Windows uses SIGBREAK
func DefaultInmemSignal(inmem *InmemSink) *InmemSignal {
return NewInmemSignal(inmem, DefaultSignal, os.Stderr)
}
// Stop is used to stop the InmemSignal from listening
func (i *InmemSignal) Stop() {
i.stopLock.Lock()
defer i.stopLock.Unlock()
if i.stop {
return
}
i.stop = true
close(i.stopCh)
signal.Stop(i.sigCh)
}
// run is a long running routine that handles signals
func (i *InmemSignal) run() {
for {
select {
case <-i.sigCh:
i.dumpStats()
case <-i.stopCh:
return
}
}
}
// dumpStats is used to dump the data to output writer
func (i *InmemSignal) dumpStats() {
buf := bytes.NewBuffer(nil)
data := i.inm.Data()
// Skip the last period which is still being aggregated
for i := 0; i < len(data)-1; i++ {
intv := data[i]
intv.RLock()
for name, val := range intv.Gauges {
fmt.Fprintf(buf, "[%v][G] '%s': %0.3f\n", intv.Interval, name, val)
}
for name, vals := range intv.Points {
for _, val := range vals {
fmt.Fprintf(buf, "[%v][P] '%s': %0.3f\n", intv.Interval, name, val)
}
}
for name, agg := range intv.Counters {
fmt.Fprintf(buf, "[%v][C] '%s': %s\n", intv.Interval, name, agg)
}
for name, agg := range intv.Samples {
fmt.Fprintf(buf, "[%v][S] '%s': %s\n", intv.Interval, name, agg)
}
intv.RUnlock()
}
// Write out the bytes
i.w.Write(buf.Bytes())
}

View file

@ -0,0 +1,115 @@
package metrics
import (
"runtime"
"time"
)
func (m *Metrics) SetGauge(key []string, val float32) {
if m.HostName != "" && m.EnableHostname {
key = insert(0, m.HostName, key)
}
if m.EnableTypePrefix {
key = insert(0, "gauge", key)
}
if m.ServiceName != "" {
key = insert(0, m.ServiceName, key)
}
m.sink.SetGauge(key, val)
}
func (m *Metrics) EmitKey(key []string, val float32) {
if m.EnableTypePrefix {
key = insert(0, "kv", key)
}
if m.ServiceName != "" {
key = insert(0, m.ServiceName, key)
}
m.sink.EmitKey(key, val)
}
func (m *Metrics) IncrCounter(key []string, val float32) {
if m.EnableTypePrefix {
key = insert(0, "counter", key)
}
if m.ServiceName != "" {
key = insert(0, m.ServiceName, key)
}
m.sink.IncrCounter(key, val)
}
func (m *Metrics) AddSample(key []string, val float32) {
if m.EnableTypePrefix {
key = insert(0, "sample", key)
}
if m.ServiceName != "" {
key = insert(0, m.ServiceName, key)
}
m.sink.AddSample(key, val)
}
func (m *Metrics) MeasureSince(key []string, start time.Time) {
if m.EnableTypePrefix {
key = insert(0, "timer", key)
}
if m.ServiceName != "" {
key = insert(0, m.ServiceName, key)
}
now := time.Now()
elapsed := now.Sub(start)
msec := float32(elapsed.Nanoseconds()) / float32(m.TimerGranularity)
m.sink.AddSample(key, msec)
}
// Periodically collects runtime stats to publish
func (m *Metrics) collectStats() {
for {
time.Sleep(m.ProfileInterval)
m.emitRuntimeStats()
}
}
// Emits various runtime statsitics
func (m *Metrics) emitRuntimeStats() {
// Export number of Goroutines
numRoutines := runtime.NumGoroutine()
m.SetGauge([]string{"runtime", "num_goroutines"}, float32(numRoutines))
// Export memory stats
var stats runtime.MemStats
runtime.ReadMemStats(&stats)
m.SetGauge([]string{"runtime", "alloc_bytes"}, float32(stats.Alloc))
m.SetGauge([]string{"runtime", "sys_bytes"}, float32(stats.Sys))
m.SetGauge([]string{"runtime", "malloc_count"}, float32(stats.Mallocs))
m.SetGauge([]string{"runtime", "free_count"}, float32(stats.Frees))
m.SetGauge([]string{"runtime", "heap_objects"}, float32(stats.HeapObjects))
m.SetGauge([]string{"runtime", "total_gc_pause_ns"}, float32(stats.PauseTotalNs))
m.SetGauge([]string{"runtime", "total_gc_runs"}, float32(stats.NumGC))
// Export info about the last few GC runs
num := stats.NumGC
// Handle wrap around
if num < m.lastNumGC {
m.lastNumGC = 0
}
// Ensure we don't scan more than 256
if num-m.lastNumGC >= 256 {
m.lastNumGC = num - 255
}
for i := m.lastNumGC; i < num; i++ {
pause := stats.PauseNs[i%256]
m.AddSample([]string{"runtime", "gc_pause_ns"}, float32(pause))
}
m.lastNumGC = num
}
// Inserts a string value at an index into the slice
func insert(i int, v string, s []string) []string {
s = append(s, "")
copy(s[i+1:], s[i:])
s[i] = v
return s
}

View file

@ -0,0 +1,52 @@
package metrics
// The MetricSink interface is used to transmit metrics information
// to an external system
type MetricSink interface {
// A Gauge should retain the last value it is set to
SetGauge(key []string, val float32)
// Should emit a Key/Value pair for each call
EmitKey(key []string, val float32)
// Counters should accumulate values
IncrCounter(key []string, val float32)
// Samples are for timing information, where quantiles are used
AddSample(key []string, val float32)
}
// BlackholeSink is used to just blackhole messages
type BlackholeSink struct{}
func (*BlackholeSink) SetGauge(key []string, val float32) {}
func (*BlackholeSink) EmitKey(key []string, val float32) {}
func (*BlackholeSink) IncrCounter(key []string, val float32) {}
func (*BlackholeSink) AddSample(key []string, val float32) {}
// FanoutSink is used to sink to fanout values to multiple sinks
type FanoutSink []MetricSink
func (fh FanoutSink) SetGauge(key []string, val float32) {
for _, s := range fh {
s.SetGauge(key, val)
}
}
func (fh FanoutSink) EmitKey(key []string, val float32) {
for _, s := range fh {
s.EmitKey(key, val)
}
}
func (fh FanoutSink) IncrCounter(key []string, val float32) {
for _, s := range fh {
s.IncrCounter(key, val)
}
}
func (fh FanoutSink) AddSample(key []string, val float32) {
for _, s := range fh {
s.AddSample(key, val)
}
}

View file

@ -0,0 +1,95 @@
package metrics
import (
"os"
"time"
)
// Config is used to configure metrics settings
type Config struct {
ServiceName string // Prefixed with keys to seperate services
HostName string // Hostname to use. If not provided and EnableHostname, it will be os.Hostname
EnableHostname bool // Enable prefixing gauge values with hostname
EnableRuntimeMetrics bool // Enables profiling of runtime metrics (GC, Goroutines, Memory)
EnableTypePrefix bool // Prefixes key with a type ("counter", "gauge", "timer")
TimerGranularity time.Duration // Granularity of timers.
ProfileInterval time.Duration // Interval to profile runtime metrics
}
// Metrics represents an instance of a metrics sink that can
// be used to emit
type Metrics struct {
Config
lastNumGC uint32
sink MetricSink
}
// Shared global metrics instance
var globalMetrics *Metrics
func init() {
// Initialize to a blackhole sink to avoid errors
globalMetrics = &Metrics{sink: &BlackholeSink{}}
}
// DefaultConfig provides a sane default configuration
func DefaultConfig(serviceName string) *Config {
c := &Config{
ServiceName: serviceName, // Use client provided service
HostName: "",
EnableHostname: true, // Enable hostname prefix
EnableRuntimeMetrics: true, // Enable runtime profiling
EnableTypePrefix: false, // Disable type prefix
TimerGranularity: time.Millisecond, // Timers are in milliseconds
ProfileInterval: time.Second, // Poll runtime every second
}
// Try to get the hostname
name, _ := os.Hostname()
c.HostName = name
return c
}
// New is used to create a new instance of Metrics
func New(conf *Config, sink MetricSink) (*Metrics, error) {
met := &Metrics{}
met.Config = *conf
met.sink = sink
// Start the runtime collector
if conf.EnableRuntimeMetrics {
go met.collectStats()
}
return met, nil
}
// NewGlobal is the same as New, but it assigns the metrics object to be
// used globally as well as returning it.
func NewGlobal(conf *Config, sink MetricSink) (*Metrics, error) {
metrics, err := New(conf, sink)
if err == nil {
globalMetrics = metrics
}
return metrics, err
}
// Proxy all the methods to the globalMetrics instance
func SetGauge(key []string, val float32) {
globalMetrics.SetGauge(key, val)
}
func EmitKey(key []string, val float32) {
globalMetrics.EmitKey(key, val)
}
func IncrCounter(key []string, val float32) {
globalMetrics.IncrCounter(key, val)
}
func AddSample(key []string, val float32) {
globalMetrics.AddSample(key, val)
}
func MeasureSince(key []string, start time.Time) {
globalMetrics.MeasureSince(key, start)
}

View file

@ -0,0 +1,154 @@
package metrics
import (
"bytes"
"fmt"
"log"
"net"
"strings"
"time"
)
const (
// statsdMaxLen is the maximum size of a packet
// to send to statsd
statsdMaxLen = 1400
)
// StatsdSink provides a MetricSink that can be used
// with a statsite or statsd metrics server. It uses
// only UDP packets, while StatsiteSink uses TCP.
type StatsdSink struct {
addr string
metricQueue chan string
}
// NewStatsdSink is used to create a new StatsdSink
func NewStatsdSink(addr string) (*StatsdSink, error) {
s := &StatsdSink{
addr: addr,
metricQueue: make(chan string, 4096),
}
go s.flushMetrics()
return s, nil
}
// Close is used to stop flushing to statsd
func (s *StatsdSink) Shutdown() {
close(s.metricQueue)
}
func (s *StatsdSink) SetGauge(key []string, val float32) {
flatKey := s.flattenKey(key)
s.pushMetric(fmt.Sprintf("%s:%f|g\n", flatKey, val))
}
func (s *StatsdSink) EmitKey(key []string, val float32) {
flatKey := s.flattenKey(key)
s.pushMetric(fmt.Sprintf("%s:%f|kv\n", flatKey, val))
}
func (s *StatsdSink) IncrCounter(key []string, val float32) {
flatKey := s.flattenKey(key)
s.pushMetric(fmt.Sprintf("%s:%f|c\n", flatKey, val))
}
func (s *StatsdSink) AddSample(key []string, val float32) {
flatKey := s.flattenKey(key)
s.pushMetric(fmt.Sprintf("%s:%f|ms\n", flatKey, val))
}
// Flattens the key for formatting, removes spaces
func (s *StatsdSink) flattenKey(parts []string) string {
joined := strings.Join(parts, ".")
return strings.Map(func(r rune) rune {
switch r {
case ':':
fallthrough
case ' ':
return '_'
default:
return r
}
}, joined)
}
// Does a non-blocking push to the metrics queue
func (s *StatsdSink) pushMetric(m string) {
select {
case s.metricQueue <- m:
default:
}
}
// Flushes metrics
func (s *StatsdSink) flushMetrics() {
var sock net.Conn
var err error
var wait <-chan time.Time
ticker := time.NewTicker(flushInterval)
defer ticker.Stop()
CONNECT:
// Create a buffer
buf := bytes.NewBuffer(nil)
// Attempt to connect
sock, err = net.Dial("udp", s.addr)
if err != nil {
log.Printf("[ERR] Error connecting to statsd! Err: %s", err)
goto WAIT
}
for {
select {
case metric, ok := <-s.metricQueue:
// Get a metric from the queue
if !ok {
goto QUIT
}
// Check if this would overflow the packet size
if len(metric)+buf.Len() > statsdMaxLen {
_, err := sock.Write(buf.Bytes())
buf.Reset()
if err != nil {
log.Printf("[ERR] Error writing to statsd! Err: %s", err)
goto WAIT
}
}
// Append to the buffer
buf.WriteString(metric)
case <-ticker.C:
if buf.Len() == 0 {
continue
}
_, err := sock.Write(buf.Bytes())
buf.Reset()
if err != nil {
log.Printf("[ERR] Error flushing to statsd! Err: %s", err)
goto WAIT
}
}
}
WAIT:
// Wait for a while
wait = time.After(time.Duration(5) * time.Second)
for {
select {
// Dequeue the messages to avoid backlog
case _, ok := <-s.metricQueue:
if !ok {
goto QUIT
}
case <-wait:
goto CONNECT
}
}
QUIT:
s.metricQueue = nil
}

View file

@ -0,0 +1,142 @@
package metrics
import (
"bufio"
"fmt"
"log"
"net"
"strings"
"time"
)
const (
// We force flush the statsite metrics after this period of
// inactivity. Prevents stats from getting stuck in a buffer
// forever.
flushInterval = 100 * time.Millisecond
)
// StatsiteSink provides a MetricSink that can be used with a
// statsite metrics server
type StatsiteSink struct {
addr string
metricQueue chan string
}
// NewStatsiteSink is used to create a new StatsiteSink
func NewStatsiteSink(addr string) (*StatsiteSink, error) {
s := &StatsiteSink{
addr: addr,
metricQueue: make(chan string, 4096),
}
go s.flushMetrics()
return s, nil
}
// Close is used to stop flushing to statsite
func (s *StatsiteSink) Shutdown() {
close(s.metricQueue)
}
func (s *StatsiteSink) SetGauge(key []string, val float32) {
flatKey := s.flattenKey(key)
s.pushMetric(fmt.Sprintf("%s:%f|g\n", flatKey, val))
}
func (s *StatsiteSink) EmitKey(key []string, val float32) {
flatKey := s.flattenKey(key)
s.pushMetric(fmt.Sprintf("%s:%f|kv\n", flatKey, val))
}
func (s *StatsiteSink) IncrCounter(key []string, val float32) {
flatKey := s.flattenKey(key)
s.pushMetric(fmt.Sprintf("%s:%f|c\n", flatKey, val))
}
func (s *StatsiteSink) AddSample(key []string, val float32) {
flatKey := s.flattenKey(key)
s.pushMetric(fmt.Sprintf("%s:%f|ms\n", flatKey, val))
}
// Flattens the key for formatting, removes spaces
func (s *StatsiteSink) flattenKey(parts []string) string {
joined := strings.Join(parts, ".")
return strings.Map(func(r rune) rune {
switch r {
case ':':
fallthrough
case ' ':
return '_'
default:
return r
}
}, joined)
}
// Does a non-blocking push to the metrics queue
func (s *StatsiteSink) pushMetric(m string) {
select {
case s.metricQueue <- m:
default:
}
}
// Flushes metrics
func (s *StatsiteSink) flushMetrics() {
var sock net.Conn
var err error
var wait <-chan time.Time
var buffered *bufio.Writer
ticker := time.NewTicker(flushInterval)
defer ticker.Stop()
CONNECT:
// Attempt to connect
sock, err = net.Dial("tcp", s.addr)
if err != nil {
log.Printf("[ERR] Error connecting to statsite! Err: %s", err)
goto WAIT
}
// Create a buffered writer
buffered = bufio.NewWriter(sock)
for {
select {
case metric, ok := <-s.metricQueue:
// Get a metric from the queue
if !ok {
goto QUIT
}
// Try to send to statsite
_, err := buffered.Write([]byte(metric))
if err != nil {
log.Printf("[ERR] Error writing to statsite! Err: %s", err)
goto WAIT
}
case <-ticker.C:
if err := buffered.Flush(); err != nil {
log.Printf("[ERR] Error flushing to statsite! Err: %s", err)
goto WAIT
}
}
}
WAIT:
// Wait for a while
wait = time.After(time.Duration(5) * time.Second)
for {
select {
// Dequeue the messages to avoid backlog
case _, ok := <-s.metricQueue:
if !ok {
goto QUIT
}
case <-wait:
goto CONNECT
}
}
QUIT:
s.metricQueue = nil
}

View file

@ -0,0 +1,459 @@
// Package bitseq provides a structure and utilities for representing long bitmask
// as sequence of run-lenght encoded blocks. It operates direclty on the encoded
// representation, it does not decode/encode.
package bitseq
import (
"fmt"
"sync"
"github.com/docker/libnetwork/datastore"
"github.com/docker/libnetwork/netutils"
)
// Block Sequence constants
// If needed we can think of making these configurable
const (
blockLen = 32
blockBytes = blockLen / 8
blockMAX = 1<<blockLen - 1
blockFirstBit = 1 << (blockLen - 1)
)
// Handle contains the sequece representing the bitmask and its identifier
type Handle struct {
bits uint32
unselected uint32
head *Sequence
app string
id string
dbIndex uint64
store datastore.DataStore
sync.Mutex
}
// NewHandle returns a thread-safe instance of the bitmask handler
func NewHandle(app string, ds datastore.DataStore, id string, numElements uint32) (*Handle, error) {
h := &Handle{
app: app,
id: id,
store: ds,
bits: numElements,
unselected: numElements,
head: &Sequence{
Block: 0x0,
Count: getNumBlocks(numElements),
},
}
if h.store == nil {
return h, nil
}
// Register for status changes
h.watchForChanges()
// Get the initial status from the ds if present.
// We will be getting an instance without a dbIndex
// (GetObject() does not set it): It is ok for now,
// it will only cause the first allocation on this
// node to go through a retry.
var bah []byte
if err := h.store.GetObject(datastore.Key(h.Key()...), &bah); err != nil {
if err != datastore.ErrKeyNotFound {
return nil, err
}
return h, nil
}
err := h.FromByteArray(bah)
return h, err
}
// Sequence reresents a recurring sequence of 32 bits long bitmasks
type Sequence struct {
Block uint32 // block representing 4 byte long allocation bitmask
Count uint32 // number of consecutive blocks
Next *Sequence // next sequence
}
// NewSequence returns a sequence initialized to represent a bitmaks of numElements bits
func NewSequence(numElements uint32) *Sequence {
return &Sequence{Block: 0x0, Count: getNumBlocks(numElements), Next: nil}
}
// String returns a string representation of the block sequence starting from this block
func (s *Sequence) String() string {
var nextBlock string
if s.Next == nil {
nextBlock = "end"
} else {
nextBlock = s.Next.String()
}
return fmt.Sprintf("(0x%x, %d)->%s", s.Block, s.Count, nextBlock)
}
// GetAvailableBit returns the position of the first unset bit in the bitmask represented by this sequence
func (s *Sequence) GetAvailableBit() (bytePos, bitPos int) {
if s.Block == blockMAX || s.Count == 0 {
return -1, -1
}
bits := 0
bitSel := uint32(blockFirstBit)
for bitSel > 0 && s.Block&bitSel != 0 {
bitSel >>= 1
bits++
}
return bits / 8, bits % 8
}
// GetCopy returns a copy of the linked list rooted at this node
func (s *Sequence) GetCopy() *Sequence {
n := &Sequence{Block: s.Block, Count: s.Count}
pn := n
ps := s.Next
for ps != nil {
pn.Next = &Sequence{Block: ps.Block, Count: ps.Count}
pn = pn.Next
ps = ps.Next
}
return n
}
// Equal checks if this sequence is equal to the passed one
func (s *Sequence) Equal(o *Sequence) bool {
this := s
other := o
for this != nil {
if other == nil {
return false
}
if this.Block != other.Block || this.Count != other.Count {
return false
}
this = this.Next
other = other.Next
}
// Check if other is longer than this
if other != nil {
return false
}
return true
}
// ToByteArray converts the sequence into a byte array
// TODO (aboch): manage network/host order stuff
func (s *Sequence) ToByteArray() ([]byte, error) {
var bb []byte
p := s
for p != nil {
bb = append(bb, netutils.U32ToA(p.Block)...)
bb = append(bb, netutils.U32ToA(p.Count)...)
p = p.Next
}
return bb, nil
}
// FromByteArray construct the sequence from the byte array
// TODO (aboch): manage network/host order stuff
func (s *Sequence) FromByteArray(data []byte) error {
l := len(data)
if l%8 != 0 {
return fmt.Errorf("cannot deserialize byte sequence of lenght %d (%v)", l, data)
}
p := s
i := 0
for {
p.Block = netutils.ATo32(data[i : i+4])
p.Count = netutils.ATo32(data[i+4 : i+8])
i += 8
if i == l {
break
}
p.Next = &Sequence{}
p = p.Next
}
return nil
}
// GetFirstAvailable returns the byte and bit position of the first unset bit
func (h *Handle) GetFirstAvailable() (int, int, error) {
h.Lock()
defer h.Unlock()
return GetFirstAvailable(h.head)
}
// CheckIfAvailable checks if the bit correspondent to the specified ordinal is unset
// If the ordinal is beyond the Sequence limits, a negative response is returned
func (h *Handle) CheckIfAvailable(ordinal int) (int, int, error) {
h.Lock()
defer h.Unlock()
return CheckIfAvailable(h.head, ordinal)
}
// PushReservation pushes the bit reservation inside the bitmask.
func (h *Handle) PushReservation(bytePos, bitPos int, release bool) error {
// Create a copy of the current handler
h.Lock()
nh := &Handle{app: h.app, id: h.id, store: h.store, dbIndex: h.dbIndex, head: h.head.GetCopy()}
h.Unlock()
nh.head = PushReservation(bytePos, bitPos, nh.head, release)
err := nh.writeToStore()
if err == nil {
// Commit went through, save locally
h.Lock()
h.head = nh.head
if release {
h.unselected++
} else {
h.unselected--
}
h.dbIndex = nh.dbIndex
h.Unlock()
}
return err
}
// Destroy removes from the datastore the data belonging to this handle
func (h *Handle) Destroy() {
h.deleteFromStore()
}
// ToByteArray converts this handle's data into a byte array
func (h *Handle) ToByteArray() ([]byte, error) {
ba := make([]byte, 8)
h.Lock()
defer h.Unlock()
copy(ba[0:4], netutils.U32ToA(h.bits))
copy(ba[4:8], netutils.U32ToA(h.unselected))
bm, err := h.head.ToByteArray()
if err != nil {
return nil, fmt.Errorf("failed to serialize head: %s", err.Error())
}
ba = append(ba, bm...)
return ba, nil
}
// FromByteArray reads his handle's data from a byte array
func (h *Handle) FromByteArray(ba []byte) error {
if ba == nil {
return fmt.Errorf("nil byte array")
}
nh := &Sequence{}
err := nh.FromByteArray(ba[8:])
if err != nil {
return fmt.Errorf("failed to deserialize head: %s", err.Error())
}
h.Lock()
h.head = nh
h.bits = netutils.ATo32(ba[0:4])
h.unselected = netutils.ATo32(ba[4:8])
h.Unlock()
return nil
}
// Bits returns the length of the bit sequence
func (h *Handle) Bits() uint32 {
return h.bits
}
// Unselected returns the number of bits which are not selected
func (h *Handle) Unselected() uint32 {
h.Lock()
defer h.Unlock()
return h.unselected
}
func (h *Handle) getDBIndex() uint64 {
h.Lock()
defer h.Unlock()
return h.dbIndex
}
// GetFirstAvailable looks for the first unset bit in passed mask
func GetFirstAvailable(head *Sequence) (int, int, error) {
byteIndex := 0
current := head
for current != nil {
if current.Block != blockMAX {
bytePos, bitPos := current.GetAvailableBit()
return byteIndex + bytePos, bitPos, nil
}
byteIndex += int(current.Count * blockBytes)
current = current.Next
}
return -1, -1, fmt.Errorf("no bit available")
}
// CheckIfAvailable checks if the bit correspondent to the specified ordinal is unset
// If the ordinal is beyond the Sequence limits, a negative response is returned
func CheckIfAvailable(head *Sequence, ordinal int) (int, int, error) {
bytePos := ordinal / 8
bitPos := ordinal % 8
// Find the Sequence containing this byte
current, _, _, inBlockBytePos := findSequence(head, bytePos)
if current != nil {
// Check whether the bit corresponding to the ordinal address is unset
bitSel := uint32(blockFirstBit >> uint(inBlockBytePos*8+bitPos))
if current.Block&bitSel == 0 {
return bytePos, bitPos, nil
}
}
return -1, -1, fmt.Errorf("requested bit is not available")
}
// Given the byte position and the sequences list head, return the pointer to the
// sequence containing the byte (current), the pointer to the previous sequence,
// the number of blocks preceding the block containing the byte inside the current sequence.
// If bytePos is outside of the list, function will return (nil, nil, 0, -1)
func findSequence(head *Sequence, bytePos int) (*Sequence, *Sequence, uint32, int) {
// Find the Sequence containing this byte
previous := head
current := head
n := bytePos
for current.Next != nil && n >= int(current.Count*blockBytes) { // Nil check for less than 32 addresses masks
n -= int(current.Count * blockBytes)
previous = current
current = current.Next
}
// If byte is outside of the list, let caller know
if n >= int(current.Count*blockBytes) {
return nil, nil, 0, -1
}
// Find the byte position inside the block and the number of blocks
// preceding the block containing the byte inside this sequence
precBlocks := uint32(n / blockBytes)
inBlockBytePos := bytePos % blockBytes
return current, previous, precBlocks, inBlockBytePos
}
// PushReservation pushes the bit reservation inside the bitmask.
// Given byte and bit positions, identify the sequence (current) which holds the block containing the affected bit.
// Create a new block with the modified bit according to the operation (allocate/release).
// Create a new Sequence containing the new Block and insert it in the proper position.
// Remove current sequence if empty.
// Check if new Sequence can be merged with neighbour (previous/Next) sequences.
//
//
// Identify "current" Sequence containing block:
// [prev seq] [current seq] [Next seq]
//
// Based on block position, resulting list of sequences can be any of three forms:
//
// Block position Resulting list of sequences
// A) Block is first in current: [prev seq] [new] [modified current seq] [Next seq]
// B) Block is last in current: [prev seq] [modified current seq] [new] [Next seq]
// C) Block is in the middle of current: [prev seq] [curr pre] [new] [curr post] [Next seq]
func PushReservation(bytePos, bitPos int, head *Sequence, release bool) *Sequence {
// Store list's head
newHead := head
// Find the Sequence containing this byte
current, previous, precBlocks, inBlockBytePos := findSequence(head, bytePos)
if current == nil {
return newHead
}
// Construct updated block
bitSel := uint32(blockFirstBit >> uint(inBlockBytePos*8+bitPos))
newBlock := current.Block
if release {
newBlock &^= bitSel
} else {
newBlock |= bitSel
}
// Quit if it was a redundant request
if current.Block == newBlock {
return newHead
}
// Current Sequence inevitably looses one block, upadate Count
current.Count--
// Create new sequence
newSequence := &Sequence{Block: newBlock, Count: 1}
// Insert the new sequence in the list based on block position
if precBlocks == 0 { // First in sequence (A)
newSequence.Next = current
if current == head {
newHead = newSequence
previous = newHead
} else {
previous.Next = newSequence
}
removeCurrentIfEmpty(&newHead, newSequence, current)
mergeSequences(previous)
} else if precBlocks == current.Count-2 { // Last in sequence (B)
newSequence.Next = current.Next
current.Next = newSequence
mergeSequences(current)
} else { // In between the sequence (C)
currPre := &Sequence{Block: current.Block, Count: precBlocks, Next: newSequence}
currPost := current
currPost.Count -= precBlocks
newSequence.Next = currPost
if currPost == head {
newHead = currPre
} else {
previous.Next = currPre
}
// No merging or empty current possible here
}
return newHead
}
// Removes the current sequence from the list if empty, adjusting the head pointer if needed
func removeCurrentIfEmpty(head **Sequence, previous, current *Sequence) {
if current.Count == 0 {
if current == *head {
*head = current.Next
} else {
previous.Next = current.Next
current = current.Next
}
}
}
// Given a pointer to a Sequence, it checks if it can be merged with any following sequences
// It stops when no more merging is possible.
// TODO: Optimization: only attempt merge from start to end sequence, no need to scan till the end of the list
func mergeSequences(seq *Sequence) {
if seq != nil {
// Merge all what possible from seq
for seq.Next != nil && seq.Block == seq.Next.Block {
seq.Count += seq.Next.Count
seq.Next = seq.Next.Next
}
// Move to Next
mergeSequences(seq.Next)
}
}
func getNumBlocks(numBits uint32) uint32 {
numBlocks := numBits / blockLen
if numBits%blockLen != 0 {
numBlocks++
}
return numBlocks
}

View file

@ -0,0 +1,122 @@
package bitseq
import (
"encoding/json"
"fmt"
log "github.com/Sirupsen/logrus"
"github.com/docker/libnetwork/datastore"
"github.com/docker/libnetwork/types"
)
// Key provides the Key to be used in KV Store
func (h *Handle) Key() []string {
h.Lock()
defer h.Unlock()
return []string{h.app, h.id}
}
// KeyPrefix returns the immediate parent key that can be used for tree walk
func (h *Handle) KeyPrefix() []string {
h.Lock()
defer h.Unlock()
return []string{h.app}
}
// Value marshals the data to be stored in the KV store
func (h *Handle) Value() []byte {
b, err := h.ToByteArray()
if err != nil {
log.Warnf("Failed to serialize Handle: %v", err)
b = []byte{}
}
jv, err := json.Marshal(b)
if err != nil {
log.Warnf("Failed to json encode bitseq handler byte array: %v", err)
return []byte{}
}
return jv
}
// Index returns the latest DB Index as seen by this object
func (h *Handle) Index() uint64 {
h.Lock()
defer h.Unlock()
return h.dbIndex
}
// SetIndex method allows the datastore to store the latest DB Index into this object
func (h *Handle) SetIndex(index uint64) {
h.Lock()
h.dbIndex = index
h.Unlock()
}
func (h *Handle) watchForChanges() error {
h.Lock()
store := h.store
h.Unlock()
if store == nil {
return nil
}
kvpChan, err := store.KVStore().Watch(datastore.Key(h.Key()...), nil)
if err != nil {
return err
}
go func() {
for {
select {
case kvPair := <-kvpChan:
// Only process remote update
if kvPair != nil && (kvPair.LastIndex != h.getDBIndex()) {
err := h.fromDsValue(kvPair.Value)
if err != nil {
log.Warnf("Failed to reconstruct bitseq handle from ds watch: %s", err.Error())
} else {
h.Lock()
h.dbIndex = kvPair.LastIndex
h.Unlock()
}
}
}
}
}()
return nil
}
func (h *Handle) fromDsValue(value []byte) error {
var ba []byte
if err := json.Unmarshal(value, &ba); err != nil {
return fmt.Errorf("failed to decode json: %s", err.Error())
}
if err := h.FromByteArray(ba); err != nil {
return fmt.Errorf("failed to decode handle: %s", err.Error())
}
return nil
}
func (h *Handle) writeToStore() error {
h.Lock()
store := h.store
h.Unlock()
if store == nil {
return nil
}
err := store.PutObjectAtomic(h)
if err == datastore.ErrKeyModified {
return types.RetryErrorf("failed to perform atomic write (%v). retry might fix the error", err)
}
return err
}
func (h *Handle) deleteFromStore() error {
h.Lock()
store := h.store
h.Unlock()
if store == nil {
return nil
}
return store.DeleteObjectAtomic(h)
}

View file

@ -2,51 +2,52 @@
Package libnetwork provides the basic functionality and extension points to
create network namespaces and allocate interfaces for containers to use.
// Create a new controller instance
controller, _err := libnetwork.New(nil)
// Create a new controller instance
controller, _err := libnetwork.New(nil)
// Select and configure the network driver
networkType := "bridge"
// Select and configure the network driver
networkType := "bridge"
driverOptions := options.Generic{}
genericOption := make(map[string]interface{})
genericOption[netlabel.GenericData] = driverOptions
err := controller.ConfigureNetworkDriver(networkType, genericOption)
if err != nil {
return
}
driverOptions := options.Generic{}
genericOption := make(map[string]interface{})
genericOption[netlabel.GenericData] = driverOptions
err := controller.ConfigureNetworkDriver(networkType, genericOption)
if err != nil {
return
}
// Create a network for containers to join.
// NewNetwork accepts Variadic optional arguments that libnetwork and Drivers can make of
network, err := controller.NewNetwork(networkType, "network1")
if err != nil {
return
}
// Create a network for containers to join.
// NewNetwork accepts Variadic optional arguments that libnetwork and Drivers can make of
network, err := controller.NewNetwork(networkType, "network1")
if err != nil {
return
}
// For each new container: allocate IP and interfaces. The returned network
// settings will be used for container infos (inspect and such), as well as
// iptables rules for port publishing. This info is contained or accessible
// from the returned endpoint.
ep, err := network.CreateEndpoint("Endpoint1")
if err != nil {
return
}
// For each new container: allocate IP and interfaces. The returned network
// settings will be used for container infos (inspect and such), as well as
// iptables rules for port publishing. This info is contained or accessible
// from the returned endpoint.
ep, err := network.CreateEndpoint("Endpoint1")
if err != nil {
return
}
// A container can join the endpoint by providing the container ID to the join
// api.
// Join acceps Variadic arguments which will be made use of by libnetwork and Drivers
err = ep.Join("container1",
libnetwork.JoinOptionHostname("test"),
libnetwork.JoinOptionDomainname("docker.io"))
if err != nil {
return
}
// A container can join the endpoint by providing the container ID to the join
// api.
// Join acceps Variadic arguments which will be made use of by libnetwork and Drivers
err = ep.Join("container1",
libnetwork.JoinOptionHostname("test"),
libnetwork.JoinOptionDomainname("docker.io"))
if err != nil {
return
}
*/
package libnetwork
import (
"fmt"
"net"
"strings"
"sync"
log "github.com/Sirupsen/logrus"
@ -56,6 +57,7 @@ import (
"github.com/docker/libnetwork/datastore"
"github.com/docker/libnetwork/driverapi"
"github.com/docker/libnetwork/hostdiscovery"
"github.com/docker/libnetwork/netlabel"
"github.com/docker/libnetwork/sandbox"
"github.com/docker/libnetwork/types"
)
@ -85,6 +87,9 @@ type NetworkController interface {
// NetworkByID returns the Network which has the passed id. If not found, the error ErrNoSuchNetwork is returned.
NetworkByID(id string) (Network, error)
// LeaveAll accepts a container id and attempts to leave all endpoints that the container has joined
LeaveAll(id string) error
// GC triggers immediate garbage collection of resources which are garbage collected.
GC()
}
@ -187,14 +192,41 @@ func (c *controller) ConfigureNetworkDriver(networkType string, options map[stri
func (c *controller) RegisterDriver(networkType string, driver driverapi.Driver, capability driverapi.Capability) error {
c.Lock()
defer c.Unlock()
if !config.IsValidName(networkType) {
c.Unlock()
return ErrInvalidName(networkType)
}
if _, ok := c.drivers[networkType]; ok {
c.Unlock()
return driverapi.ErrActiveRegistration(networkType)
}
c.drivers[networkType] = &driverData{driver, capability}
if c.cfg == nil {
c.Unlock()
return nil
}
opt := make(map[string]interface{})
for _, label := range c.cfg.Daemon.Labels {
if strings.HasPrefix(label, netlabel.DriverPrefix+"."+networkType) {
opt[netlabel.Key(label)] = netlabel.Value(label)
}
}
if capability.Scope == driverapi.GlobalScope && c.validateDatastoreConfig() {
opt[netlabel.KVProvider] = c.cfg.Datastore.Client.Provider
opt[netlabel.KVProviderURL] = c.cfg.Datastore.Client.Address
}
c.Unlock()
if len(opt) != 0 {
if err := driver.Config(opt); err != nil {
return err
}
}
return nil
}
@ -255,6 +287,7 @@ func (c *controller) addNetwork(n *network) error {
}
n.Lock()
n.svcRecords = svcMap{}
n.driver = dd.driver
d := n.driver
n.Unlock()
@ -263,6 +296,9 @@ func (c *controller) addNetwork(n *network) error {
if err := d.CreateNetwork(n.id, n.generic); err != nil {
return err
}
if err := n.watchEndpoints(); err != nil {
return err
}
c.Lock()
c.networks[n.id] = n
c.Unlock()

View file

@ -15,6 +15,8 @@ import (
type DataStore interface {
// GetObject gets data from datastore and unmarshals to the specified object
GetObject(key string, o interface{}) error
// GetUpdatedObject gets data from datastore along with its index and unmarshals to the specified object
GetUpdatedObject(key string, o interface{}) (uint64, error)
// PutObject adds a new Record based on an object into the datastore
PutObject(kvObject KV) error
// PutObjectAtomic provides an atomic add and update operation for a Record
@ -30,7 +32,10 @@ type DataStore interface {
}
// ErrKeyModified is raised for an atomic update when the update is working on a stale state
var ErrKeyModified = store.ErrKeyModified
var (
ErrKeyModified = store.ErrKeyModified
ErrKeyNotFound = store.ErrKeyNotFound
)
type datastore struct {
store store.Store
@ -152,6 +157,18 @@ func (ds *datastore) GetObject(key string, o interface{}) error {
return json.Unmarshal(kvPair.Value, o)
}
// GetUpdateObject returns a record matching the key
func (ds *datastore) GetUpdatedObject(key string, o interface{}) (uint64, error) {
kvPair, err := ds.store.Get(key)
if err != nil {
return 0, err
}
if err := json.Unmarshal(kvPair.Value, o); err != nil {
return 0, err
}
return kvPair.LastIndex, nil
}
// DeleteObject unconditionally deletes a record from the store
func (ds *datastore) DeleteObject(kvObject KV) error {
return ds.store.Delete(Key(kvObject.Key()...))

View file

@ -10,6 +10,7 @@ import (
"github.com/Sirupsen/logrus"
"github.com/docker/libnetwork/driverapi"
"github.com/docker/libnetwork/ipallocator"
"github.com/docker/libnetwork/iptables"
"github.com/docker/libnetwork/netlabel"
"github.com/docker/libnetwork/netutils"
"github.com/docker/libnetwork/options"
@ -110,6 +111,10 @@ func Init(dc driverapi.DriverCallback) error {
if out, err := exec.Command("modprobe", "-va", "bridge", "nf_nat", "br_netfilter").Output(); err != nil {
logrus.Warnf("Running modprobe bridge nf_nat failed with message: %s, error: %v", out, err)
}
if err := iptables.RemoveExistingChain(DockerChain, iptables.Nat); err != nil {
logrus.Warnf("Failed to remove existing iptables entries in %s : %v", DockerChain, err)
}
c := driverapi.Capability{
Scope: driverapi.LocalScope,
}

View file

@ -0,0 +1,99 @@
package overlay
import (
"fmt"
"github.com/docker/libnetwork/driverapi"
"github.com/docker/libnetwork/types"
"github.com/vishvananda/netlink"
)
// Join method is invoked when a Sandbox is attached to an endpoint.
func (d *driver) Join(nid, eid types.UUID, sboxKey string, jinfo driverapi.JoinInfo, options map[string]interface{}) error {
if err := validateID(nid, eid); err != nil {
return err
}
n := d.network(nid)
if n == nil {
return fmt.Errorf("could not find network with id %s", nid)
}
ep := n.endpoint(eid)
if ep == nil {
return fmt.Errorf("could not find endpoint with id %s", eid)
}
if err := n.joinSandbox(); err != nil {
return fmt.Errorf("network sandbox join failed: %v",
err)
}
sbox := n.sandbox()
name1, name2, err := createVethPair()
if err != nil {
return err
}
if err := sbox.AddInterface(name1, "veth",
sbox.InterfaceOptions().Master("bridge1")); err != nil {
return fmt.Errorf("could not add veth pair inside the network sandbox: %v", err)
}
veth, err := netlink.LinkByName(name2)
if err != nil {
return fmt.Errorf("could not find link by name %s: %v", name2, err)
}
if err := netlink.LinkSetHardwareAddr(veth, ep.mac); err != nil {
return fmt.Errorf("could not set mac address to the container interface: %v", err)
}
for _, iNames := range jinfo.InterfaceNames() {
// Make sure to set names on the correct interface ID.
if iNames.ID() == 1 {
err = iNames.SetNames(name2, "eth")
if err != nil {
return err
}
}
}
err = jinfo.SetGateway(bridgeIP.IP)
if err != nil {
return err
}
d.peerDbAdd(nid, eid, ep.addr.IP, ep.mac,
d.serfInstance.LocalMember().Addr, true)
d.notifyCh <- ovNotify{
action: "join",
nid: nid,
eid: eid,
}
return nil
}
// Leave method is invoked when a Sandbox detaches from an endpoint.
func (d *driver) Leave(nid, eid types.UUID) error {
if err := validateID(nid, eid); err != nil {
return err
}
n := d.network(nid)
if n == nil {
return fmt.Errorf("could not find network with id %s", nid)
}
d.notifyCh <- ovNotify{
action: "leave",
nid: nid,
eid: eid,
}
n.leaveSandbox()
return nil
}

View file

@ -0,0 +1,110 @@
package overlay
import (
"encoding/binary"
"fmt"
"net"
"github.com/docker/libnetwork/driverapi"
"github.com/docker/libnetwork/netutils"
"github.com/docker/libnetwork/types"
)
type endpointTable map[types.UUID]*endpoint
type endpoint struct {
id types.UUID
mac net.HardwareAddr
addr *net.IPNet
}
func (n *network) endpoint(eid types.UUID) *endpoint {
n.Lock()
defer n.Unlock()
return n.endpoints[eid]
}
func (n *network) addEndpoint(ep *endpoint) {
n.Lock()
n.endpoints[ep.id] = ep
n.Unlock()
}
func (n *network) deleteEndpoint(eid types.UUID) {
n.Lock()
delete(n.endpoints, eid)
n.Unlock()
}
func (d *driver) CreateEndpoint(nid, eid types.UUID, epInfo driverapi.EndpointInfo,
epOptions map[string]interface{}) error {
if err := validateID(nid, eid); err != nil {
return err
}
n := d.network(nid)
if n == nil {
return fmt.Errorf("network id %q not found", nid)
}
ep := &endpoint{
id: eid,
}
if epInfo != nil && (len(epInfo.Interfaces()) > 0) {
addr := epInfo.Interfaces()[0].Address()
ep.addr = &addr
ep.mac = epInfo.Interfaces()[0].MacAddress()
n.addEndpoint(ep)
return nil
}
ipID, err := d.ipAllocator.GetID()
if err != nil {
return fmt.Errorf("could not allocate ip from subnet %s: %v",
bridgeSubnet.String(), err)
}
ep.addr = &net.IPNet{
Mask: bridgeSubnet.Mask,
}
ep.addr.IP = make([]byte, 4)
binary.BigEndian.PutUint32(ep.addr.IP, bridgeSubnetInt+ipID)
ep.mac = netutils.GenerateRandomMAC()
err = epInfo.AddInterface(1, ep.mac, *ep.addr, net.IPNet{})
if err != nil {
return fmt.Errorf("could not add interface to endpoint info: %v", err)
}
n.addEndpoint(ep)
return nil
}
func (d *driver) DeleteEndpoint(nid, eid types.UUID) error {
if err := validateID(nid, eid); err != nil {
return err
}
n := d.network(nid)
if n == nil {
return fmt.Errorf("network id %q not found", nid)
}
ep := n.endpoint(eid)
if ep == nil {
return fmt.Errorf("endpoint id %q not found", eid)
}
d.ipAllocator.Release(binary.BigEndian.Uint32(ep.addr.IP) - bridgeSubnetInt)
n.deleteEndpoint(eid)
return nil
}
func (d *driver) EndpointOperInfo(nid, eid types.UUID) (map[string]interface{}, error) {
return make(map[string]interface{}, 0), nil
}

View file

@ -0,0 +1,325 @@
package overlay
import (
"encoding/json"
"fmt"
"net"
"sync"
"syscall"
"github.com/Sirupsen/logrus"
"github.com/docker/libnetwork/datastore"
"github.com/docker/libnetwork/ipallocator"
"github.com/docker/libnetwork/sandbox"
"github.com/docker/libnetwork/types"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netlink/nl"
)
type networkTable map[types.UUID]*network
type network struct {
id types.UUID
vni uint32
dbIndex uint64
sbox sandbox.Sandbox
endpoints endpointTable
ipAllocator *ipallocator.IPAllocator
gw net.IP
vxlanName string
driver *driver
joinCnt int
sync.Mutex
}
func (d *driver) CreateNetwork(id types.UUID, option map[string]interface{}) error {
if id == "" {
return fmt.Errorf("invalid network id")
}
n := &network{
id: id,
driver: d,
endpoints: endpointTable{},
}
n.gw = bridgeIP.IP
d.addNetwork(n)
if err := n.obtainVxlanID(); err != nil {
return err
}
return nil
}
func (d *driver) DeleteNetwork(nid types.UUID) error {
if nid == "" {
return fmt.Errorf("invalid network id")
}
n := d.network(nid)
if n == nil {
return fmt.Errorf("could not find network with id %s", nid)
}
d.deleteNetwork(nid)
return n.releaseVxlanID()
}
func (n *network) joinSandbox() error {
n.Lock()
if n.joinCnt != 0 {
n.joinCnt++
n.Unlock()
return nil
}
n.joinCnt++
n.Unlock()
return n.initSandbox()
}
func (n *network) leaveSandbox() {
n.Lock()
n.joinCnt--
if n.joinCnt != 0 {
n.Unlock()
return
}
n.Unlock()
n.destroySandbox()
}
func (n *network) destroySandbox() {
sbox := n.sandbox()
if sbox != nil {
for _, iface := range sbox.Info().Interfaces() {
iface.Remove()
}
if err := deleteVxlan(n.vxlanName); err != nil {
logrus.Warnf("could not cleanup sandbox properly: %v", err)
}
sbox.Destroy()
}
}
func (n *network) initSandbox() error {
sbox, err := sandbox.NewSandbox(sandbox.GenerateKey(string(n.id)), true)
if err != nil {
return fmt.Errorf("could not create network sandbox: %v", err)
}
// Add a bridge inside the namespace
if err := sbox.AddInterface("bridge1", "br",
sbox.InterfaceOptions().Address(bridgeIP),
sbox.InterfaceOptions().Bridge(true)); err != nil {
return fmt.Errorf("could not create bridge inside the network sandbox: %v", err)
}
vxlanName, err := createVxlan(n.vxlanID())
if err != nil {
return err
}
if err := sbox.AddInterface(vxlanName, "vxlan",
sbox.InterfaceOptions().Master("bridge1")); err != nil {
return fmt.Errorf("could not add vxlan interface inside the network sandbox: %v",
err)
}
n.vxlanName = vxlanName
n.setSandbox(sbox)
n.driver.peerDbUpdateSandbox(n.id)
var nlSock *nl.NetlinkSocket
sbox.InvokeFunc(func() {
nlSock, err = nl.Subscribe(syscall.NETLINK_ROUTE, syscall.RTNLGRP_NEIGH)
if err != nil {
err = fmt.Errorf("failed to subscribe to neighbor group netlink messages")
}
})
go n.watchMiss(nlSock)
return nil
}
func (n *network) watchMiss(nlSock *nl.NetlinkSocket) {
for {
msgs, err := nlSock.Recieve()
if err != nil {
logrus.Errorf("Failed to receive from netlink: %v ", err)
continue
}
for _, msg := range msgs {
if msg.Header.Type != syscall.RTM_GETNEIGH && msg.Header.Type != syscall.RTM_NEWNEIGH {
continue
}
neigh, err := netlink.NeighDeserialize(msg.Data)
if err != nil {
logrus.Errorf("Failed to deserialize netlink ndmsg: %v", err)
continue
}
if neigh.IP.To16() != nil {
continue
}
if neigh.State&(netlink.NUD_STALE|netlink.NUD_INCOMPLETE) == 0 {
continue
}
mac, vtep, err := n.driver.resolvePeer(n.id, neigh.IP)
if err != nil {
logrus.Errorf("could not resolve peer %q: %v", neigh.IP, err)
continue
}
if err := n.driver.peerAdd(n.id, types.UUID("dummy"), neigh.IP, mac, vtep, true); err != nil {
logrus.Errorf("could not add neighbor entry for missed peer: %v", err)
}
}
}
}
func (d *driver) addNetwork(n *network) {
d.Lock()
d.networks[n.id] = n
d.Unlock()
}
func (d *driver) deleteNetwork(nid types.UUID) {
d.Lock()
delete(d.networks, nid)
d.Unlock()
}
func (d *driver) network(nid types.UUID) *network {
d.Lock()
defer d.Unlock()
return d.networks[nid]
}
func (n *network) sandbox() sandbox.Sandbox {
n.Lock()
defer n.Unlock()
return n.sbox
}
func (n *network) setSandbox(sbox sandbox.Sandbox) {
n.Lock()
n.sbox = sbox
n.Unlock()
}
func (n *network) vxlanID() uint32 {
n.Lock()
defer n.Unlock()
return n.vni
}
func (n *network) setVxlanID(vni uint32) {
n.Lock()
n.vni = vni
n.Unlock()
}
func (n *network) Key() []string {
return []string{"overlay", "network", string(n.id)}
}
func (n *network) KeyPrefix() []string {
return []string{"overlay", "network"}
}
func (n *network) Value() []byte {
b, err := json.Marshal(n.vxlanID())
if err != nil {
return []byte{}
}
return b
}
func (n *network) Index() uint64 {
return n.dbIndex
}
func (n *network) SetIndex(index uint64) {
n.dbIndex = index
}
func (n *network) writeToStore() error {
return n.driver.store.PutObjectAtomic(n)
}
func (n *network) releaseVxlanID() error {
if n.driver.store == nil {
return fmt.Errorf("no datastore configured. cannot release vxlan id")
}
if n.vxlanID() == 0 {
return nil
}
if err := n.driver.store.DeleteObjectAtomic(n); err != nil {
if err == datastore.ErrKeyModified || err == datastore.ErrKeyNotFound {
// In both the above cases we can safely assume that the key has been removed by some other
// instance and so simply get out of here
return nil
}
return fmt.Errorf("failed to delete network to vxlan id map: %v", err)
}
n.driver.vxlanIdm.Release(n.vxlanID())
n.setVxlanID(0)
return nil
}
func (n *network) obtainVxlanID() error {
if n.driver.store == nil {
return fmt.Errorf("no datastore configured. cannot obtain vxlan id")
}
for {
var vxlanID uint32
if err := n.driver.store.GetObject(datastore.Key(n.Key()...),
&vxlanID); err != nil {
if err == datastore.ErrKeyNotFound {
vxlanID, err = n.driver.vxlanIdm.GetID()
if err != nil {
return fmt.Errorf("failed to allocate vxlan id: %v", err)
}
n.setVxlanID(vxlanID)
if err := n.writeToStore(); err != nil {
n.driver.vxlanIdm.Release(n.vxlanID())
n.setVxlanID(0)
if err == datastore.ErrKeyModified {
continue
}
return fmt.Errorf("failed to update data store with vxlan id: %v", err)
}
return nil
}
return fmt.Errorf("failed to obtain vxlan id from data store: %v", err)
}
n.setVxlanID(vxlanID)
return nil
}
}

View file

@ -0,0 +1,249 @@
package overlay
import (
"fmt"
"net"
"strings"
"time"
"github.com/Sirupsen/logrus"
"github.com/docker/libnetwork/types"
"github.com/hashicorp/serf/serf"
)
type ovNotify struct {
action string
eid types.UUID
nid types.UUID
}
type logWriter struct{}
func (l *logWriter) Write(p []byte) (int, error) {
str := string(p)
switch {
case strings.Contains(str, "[WARN]"):
logrus.Warn(str)
case strings.Contains(str, "[DEBUG]"):
logrus.Debug(str)
case strings.Contains(str, "[INFO]"):
logrus.Info(str)
case strings.Contains(str, "[ERR]"):
logrus.Error(str)
}
return len(p), nil
}
func getBindAddr(ifaceName string) (string, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {
return "", fmt.Errorf("failed to find interface %s: %v", ifaceName, err)
}
addrs, err := iface.Addrs()
if err != nil {
return "", fmt.Errorf("failed to get interface addresses: %v", err)
}
for _, a := range addrs {
addr, ok := a.(*net.IPNet)
if !ok {
continue
}
addrIP := addr.IP
if addrIP.IsLinkLocalUnicast() {
continue
}
return addrIP.String(), nil
}
return "", fmt.Errorf("failed to get bind address")
}
func (d *driver) serfInit() error {
var err error
config := serf.DefaultConfig()
config.Init()
if d.ifaceName != "" {
bindAddr, err := getBindAddr(d.ifaceName)
if err != nil {
return fmt.Errorf("getBindAddr error: %v", err)
}
config.MemberlistConfig.BindAddr = bindAddr
}
d.eventCh = make(chan serf.Event, 4)
config.EventCh = d.eventCh
config.UserCoalescePeriod = 1 * time.Second
config.UserQuiescentPeriod = 50 * time.Millisecond
config.LogOutput = logrus.StandardLogger().Out
s, err := serf.Create(config)
if err != nil {
return fmt.Errorf("failed to create cluster node: %v", err)
}
defer func() {
if err != nil {
s.Shutdown()
}
}()
if d.neighIP != "" {
if _, err = s.Join([]string{d.neighIP}, false); err != nil {
return fmt.Errorf("Failed to join the cluster at neigh IP %s: %v",
d.neighIP, err)
}
}
d.serfInstance = s
d.notifyCh = make(chan ovNotify)
d.exitCh = make(chan chan struct{})
go d.startSerfLoop(d.eventCh, d.notifyCh, d.exitCh)
return nil
}
func (d *driver) notifyEvent(event ovNotify) {
n := d.network(event.nid)
ep := n.endpoint(event.eid)
ePayload := fmt.Sprintf("%s %s %s", event.action, ep.addr.IP.String(), ep.mac.String())
eName := fmt.Sprintf("jl %s %s %s", d.serfInstance.LocalMember().Addr.String(),
event.nid, event.eid)
if err := d.serfInstance.UserEvent(eName, []byte(ePayload), true); err != nil {
fmt.Printf("Sending user event failed: %v\n", err)
}
}
func (d *driver) processEvent(u serf.UserEvent) {
fmt.Printf("Received user event name:%s, payload:%s\n", u.Name,
string(u.Payload))
var dummy, action, vtepStr, nid, eid, ipStr, macStr string
if _, err := fmt.Sscan(u.Name, &dummy, &vtepStr, &nid, &eid); err != nil {
fmt.Printf("Failed to scan name string: %v\n", err)
}
if _, err := fmt.Sscan(string(u.Payload), &action,
&ipStr, &macStr); err != nil {
fmt.Printf("Failed to scan value string: %v\n", err)
}
fmt.Printf("Parsed data = %s/%s/%s/%s/%s\n", nid, eid, vtepStr, ipStr, macStr)
mac, err := net.ParseMAC(macStr)
if err != nil {
fmt.Printf("Failed to parse mac: %v\n", err)
}
if d.serfInstance.LocalMember().Addr.String() == vtepStr {
return
}
switch action {
case "join":
if err := d.peerAdd(types.UUID(nid), types.UUID(eid), net.ParseIP(ipStr), mac,
net.ParseIP(vtepStr), true); err != nil {
fmt.Printf("Peer add failed in the driver: %v\n", err)
}
case "leave":
if err := d.peerDelete(types.UUID(nid), types.UUID(eid), net.ParseIP(ipStr), mac,
net.ParseIP(vtepStr), true); err != nil {
fmt.Printf("Peer delete failed in the driver: %v\n", err)
}
}
}
func (d *driver) processQuery(q *serf.Query) {
fmt.Printf("Received query name:%s, payload:%s\n", q.Name,
string(q.Payload))
var nid, ipStr string
if _, err := fmt.Sscan(string(q.Payload), &nid, &ipStr); err != nil {
fmt.Printf("Failed to scan query payload string: %v\n", err)
}
peerMac, vtep, err := d.peerDbSearch(types.UUID(nid), net.ParseIP(ipStr))
if err != nil {
return
}
q.Respond([]byte(fmt.Sprintf("%s %s", peerMac.String(), vtep.String())))
}
func (d *driver) resolvePeer(nid types.UUID, peerIP net.IP) (net.HardwareAddr, net.IP, error) {
qPayload := fmt.Sprintf("%s %s", string(nid), peerIP.String())
resp, err := d.serfInstance.Query("peerlookup", []byte(qPayload), nil)
if err != nil {
return nil, nil, fmt.Errorf("resolving peer by querying the cluster failed: %v", err)
}
respCh := resp.ResponseCh()
select {
case r := <-respCh:
var macStr, vtepStr string
if _, err := fmt.Sscan(string(r.Payload), &macStr, &vtepStr); err != nil {
return nil, nil, fmt.Errorf("bad response %q for the resolve query: %v", string(r.Payload), err)
}
mac, err := net.ParseMAC(macStr)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse mac: %v", err)
}
return mac, net.ParseIP(vtepStr), nil
case <-time.After(time.Second):
return nil, nil, fmt.Errorf("timed out resolving peer by querying the cluster")
}
}
func (d *driver) startSerfLoop(eventCh chan serf.Event, notifyCh chan ovNotify,
exitCh chan chan struct{}) {
for {
select {
case notify, ok := <-notifyCh:
if !ok {
break
}
d.notifyEvent(notify)
case ch, ok := <-exitCh:
if !ok {
break
}
if err := d.serfInstance.Leave(); err != nil {
fmt.Printf("failed leaving the cluster: %v\n", err)
}
d.serfInstance.Shutdown()
close(ch)
return
case e, ok := <-eventCh:
if !ok {
break
}
if e.EventType() == serf.EventQuery {
d.processQuery(e.(*serf.Query))
break
}
u, ok := e.(serf.UserEvent)
if !ok {
break
}
d.processEvent(u)
}
}
}

View file

@ -0,0 +1,80 @@
package overlay
import (
"fmt"
"github.com/docker/libnetwork/netutils"
"github.com/docker/libnetwork/types"
"github.com/vishvananda/netlink"
)
func validateID(nid, eid types.UUID) error {
if nid == "" {
return fmt.Errorf("invalid network id")
}
if eid == "" {
return fmt.Errorf("invalid endpoint id")
}
return nil
}
func createVethPair() (string, string, error) {
// Generate a name for what will be the host side pipe interface
name1, err := netutils.GenerateIfaceName(vethPrefix, vethLen)
if err != nil {
return "", "", fmt.Errorf("error generating veth name1: %v", err)
}
// Generate a name for what will be the sandbox side pipe interface
name2, err := netutils.GenerateIfaceName(vethPrefix, vethLen)
if err != nil {
return "", "", fmt.Errorf("error generating veth name2: %v", err)
}
// Generate and add the interface pipe host <-> sandbox
veth := &netlink.Veth{
LinkAttrs: netlink.LinkAttrs{Name: name1, TxQLen: 0},
PeerName: name2}
if err := netlink.LinkAdd(veth); err != nil {
return "", "", fmt.Errorf("error creating veth pair: %v", err)
}
return name1, name2, nil
}
func createVxlan(vni uint32) (string, error) {
name, err := netutils.GenerateIfaceName("vxlan", 7)
if err != nil {
return "", fmt.Errorf("error generating vxlan name: %v", err)
}
vxlan := &netlink.Vxlan{
LinkAttrs: netlink.LinkAttrs{Name: name},
VxlanId: int(vni),
Learning: true,
Proxy: true,
L3miss: true,
L2miss: true,
}
if err := netlink.LinkAdd(vxlan); err != nil {
return "", fmt.Errorf("error creating vxlan interface: %v", err)
}
return name, nil
}
func deleteVxlan(name string) error {
link, err := netlink.LinkByName(name)
if err != nil {
return fmt.Errorf("failed to find vxlan interface with name %s: %v", name, err)
}
if err := netlink.LinkDel(link); err != nil {
return fmt.Errorf("error deleting vxlan interface: %v", err)
}
return nil
}

View file

@ -0,0 +1,157 @@
package overlay
import (
"encoding/binary"
"fmt"
"net"
"sync"
"github.com/docker/libnetwork/config"
"github.com/docker/libnetwork/datastore"
"github.com/docker/libnetwork/driverapi"
"github.com/docker/libnetwork/idm"
"github.com/docker/libnetwork/netlabel"
"github.com/docker/libnetwork/types"
"github.com/hashicorp/serf/serf"
)
const (
networkType = "overlay"
vethPrefix = "veth"
vethLen = 7
vxlanIDStart = 256
vxlanIDEnd = 1000
)
type driver struct {
eventCh chan serf.Event
notifyCh chan ovNotify
exitCh chan chan struct{}
ifaceName string
neighIP string
peerDb peerNetworkMap
serfInstance *serf.Serf
networks networkTable
store datastore.DataStore
ipAllocator *idm.Idm
vxlanIdm *idm.Idm
sync.Once
sync.Mutex
}
var (
bridgeSubnet, bridgeIP *net.IPNet
once sync.Once
bridgeSubnetInt uint32
)
func onceInit() {
var err error
_, bridgeSubnet, err = net.ParseCIDR("172.21.0.0/16")
if err != nil {
panic("could not parse cid 172.21.0.0/16")
}
bridgeSubnetInt = binary.BigEndian.Uint32(bridgeSubnet.IP.To4())
ip, subnet, err := net.ParseCIDR("172.21.255.254/16")
if err != nil {
panic("could not parse cid 172.21.255.254/16")
}
bridgeIP = &net.IPNet{
IP: ip,
Mask: subnet.Mask,
}
}
// Init registers a new instance of overlay driver
func Init(dc driverapi.DriverCallback) error {
once.Do(onceInit)
c := driverapi.Capability{
Scope: driverapi.GlobalScope,
}
return dc.RegisterDriver(networkType, &driver{
networks: networkTable{},
peerDb: peerNetworkMap{
mp: map[types.UUID]peerMap{},
},
}, c)
}
// Fini cleans up the driver resources
func Fini(drv driverapi.Driver) {
d := drv.(*driver)
if d.exitCh != nil {
waitCh := make(chan struct{})
d.exitCh <- waitCh
<-waitCh
}
}
func (d *driver) Config(option map[string]interface{}) error {
var onceDone bool
var err error
d.Do(func() {
onceDone = true
if ifaceName, ok := option[netlabel.OverlayBindInterface]; ok {
d.ifaceName = ifaceName.(string)
}
if neighIP, ok := option[netlabel.OverlayNeighborIP]; ok {
d.neighIP = neighIP.(string)
}
provider, provOk := option[netlabel.KVProvider]
provURL, urlOk := option[netlabel.KVProviderURL]
if provOk && urlOk {
cfg := &config.DatastoreCfg{
Client: config.DatastoreClientCfg{
Provider: provider.(string),
Address: provURL.(string),
},
}
d.store, err = datastore.NewDataStore(cfg)
if err != nil {
err = fmt.Errorf("failed to initialize data store: %v", err)
return
}
}
d.vxlanIdm, err = idm.New(d.store, "vxlan-id", vxlanIDStart, vxlanIDEnd)
if err != nil {
err = fmt.Errorf("failed to initialize vxlan id manager: %v", err)
return
}
d.ipAllocator, err = idm.New(d.store, "ipam-id", 1, 0xFFFF-2)
if err != nil {
err = fmt.Errorf("failed to initalize ipam id manager: %v", err)
return
}
err = d.serfInit()
if err != nil {
err = fmt.Errorf("initializing serf instance failed: %v", err)
}
})
if !onceDone {
return fmt.Errorf("config already applied to driver")
}
return err
}
func (d *driver) Type() string {
return networkType
}

View file

@ -0,0 +1,280 @@
package overlay
import (
"fmt"
"net"
"sync"
"syscall"
"github.com/docker/libnetwork/types"
)
type peerKey struct {
peerIP net.IP
peerMac net.HardwareAddr
}
type peerEntry struct {
eid types.UUID
vtep net.IP
inSandbox bool
isLocal bool
}
type peerMap struct {
mp map[string]peerEntry
sync.Mutex
}
type peerNetworkMap struct {
mp map[types.UUID]peerMap
sync.Mutex
}
func (pKey peerKey) String() string {
return fmt.Sprintf("%s %s", pKey.peerIP, pKey.peerMac)
}
func (pKey *peerKey) Scan(state fmt.ScanState, verb rune) error {
ipB, err := state.Token(true, nil)
if err != nil {
return err
}
pKey.peerIP = net.ParseIP(string(ipB))
macB, err := state.Token(true, nil)
if err != nil {
return err
}
pKey.peerMac, err = net.ParseMAC(string(macB))
if err != nil {
return err
}
return nil
}
var peerDbWg sync.WaitGroup
func (d *driver) peerDbWalk(nid types.UUID, f func(*peerKey, *peerEntry) bool) error {
d.peerDb.Lock()
pMap, ok := d.peerDb.mp[nid]
if !ok {
d.peerDb.Unlock()
return nil
}
d.peerDb.Unlock()
pMap.Lock()
for pKeyStr, pEntry := range pMap.mp {
var pKey peerKey
if _, err := fmt.Sscan(pKeyStr, &pKey); err != nil {
fmt.Printf("peer key scan failed: %v", err)
}
if f(&pKey, &pEntry) {
pMap.Unlock()
return nil
}
}
pMap.Unlock()
return nil
}
func (d *driver) peerDbSearch(nid types.UUID, peerIP net.IP) (net.HardwareAddr, net.IP, error) {
var (
peerMac net.HardwareAddr
vtep net.IP
found bool
)
err := d.peerDbWalk(nid, func(pKey *peerKey, pEntry *peerEntry) bool {
if pKey.peerIP.Equal(peerIP) {
peerMac = pKey.peerMac
vtep = pEntry.vtep
found = true
return found
}
return found
})
if err != nil {
return nil, nil, fmt.Errorf("peerdb search for peer ip %q failed: %v", peerIP, err)
}
if !found {
return nil, nil, fmt.Errorf("peer ip %q not found in peerdb", peerIP)
}
return peerMac, vtep, nil
}
func (d *driver) peerDbAdd(nid, eid types.UUID, peerIP net.IP,
peerMac net.HardwareAddr, vtep net.IP, isLocal bool) {
peerDbWg.Wait()
d.peerDb.Lock()
pMap, ok := d.peerDb.mp[nid]
if !ok {
d.peerDb.mp[nid] = peerMap{
mp: make(map[string]peerEntry),
}
pMap = d.peerDb.mp[nid]
}
d.peerDb.Unlock()
pKey := peerKey{
peerIP: peerIP,
peerMac: peerMac,
}
pEntry := peerEntry{
eid: eid,
vtep: vtep,
isLocal: isLocal,
}
pMap.Lock()
pMap.mp[pKey.String()] = pEntry
pMap.Unlock()
}
func (d *driver) peerDbDelete(nid, eid types.UUID, peerIP net.IP,
peerMac net.HardwareAddr, vtep net.IP) {
peerDbWg.Wait()
d.peerDb.Lock()
pMap, ok := d.peerDb.mp[nid]
if !ok {
d.peerDb.Unlock()
return
}
d.peerDb.Unlock()
pKey := peerKey{
peerIP: peerIP,
peerMac: peerMac,
}
pMap.Lock()
delete(pMap.mp, pKey.String())
pMap.Unlock()
}
func (d *driver) peerDbUpdateSandbox(nid types.UUID) {
d.peerDb.Lock()
pMap, ok := d.peerDb.mp[nid]
if !ok {
d.peerDb.Unlock()
return
}
d.peerDb.Unlock()
peerDbWg.Add(1)
var peerOps []func()
pMap.Lock()
for pKeyStr, pEntry := range pMap.mp {
var pKey peerKey
if _, err := fmt.Sscan(pKeyStr, &pKey); err != nil {
fmt.Printf("peer key scan failed: %v", err)
}
if pEntry.isLocal {
continue
}
op := func() {
if err := d.peerAdd(nid, pEntry.eid, pKey.peerIP,
pKey.peerMac, pEntry.vtep,
false); err != nil {
fmt.Printf("peerdbupdate in sandbox failed for ip %s and mac %s: %v",
pKey.peerIP, pKey.peerMac, err)
}
}
peerOps = append(peerOps, op)
}
pMap.Unlock()
for _, op := range peerOps {
op()
}
peerDbWg.Done()
}
func (d *driver) peerAdd(nid, eid types.UUID, peerIP net.IP,
peerMac net.HardwareAddr, vtep net.IP, updateDb bool) error {
if err := validateID(nid, eid); err != nil {
return err
}
if updateDb {
d.peerDbAdd(nid, eid, peerIP, peerMac, vtep, false)
}
n := d.network(nid)
if n == nil {
return nil
}
sbox := n.sandbox()
if sbox == nil {
return nil
}
// Add neighbor entry for the peer IP
if err := sbox.AddNeighbor(peerIP, peerMac, sbox.NeighborOptions().LinkName(n.vxlanName)); err != nil {
return fmt.Errorf("could not add neigbor entry into the sandbox: %v", err)
}
// Add fdb entry to the bridge for the peer mac
if err := sbox.AddNeighbor(vtep, peerMac, sbox.NeighborOptions().LinkName(n.vxlanName),
sbox.NeighborOptions().Family(syscall.AF_BRIDGE)); err != nil {
return fmt.Errorf("could not add fdb entry into the sandbox: %v", err)
}
return nil
}
func (d *driver) peerDelete(nid, eid types.UUID, peerIP net.IP,
peerMac net.HardwareAddr, vtep net.IP, updateDb bool) error {
if err := validateID(nid, eid); err != nil {
return err
}
if updateDb {
d.peerDbDelete(nid, eid, peerIP, peerMac, vtep)
}
n := d.network(nid)
if n == nil {
return nil
}
sbox := n.sandbox()
if sbox == nil {
return nil
}
// Delete fdb entry to the bridge for the peer mac
if err := sbox.DeleteNeighbor(vtep, peerMac); err != nil {
return fmt.Errorf("could not delete fdb entry into the sandbox: %v", err)
}
// Delete neighbor entry for the peer IP
if err := sbox.DeleteNeighbor(peerIP, peerMac); err != nil {
return fmt.Errorf("could not delete neigbor entry into the sandbox: %v", err)
}
return nil
}

View file

@ -5,6 +5,7 @@ import (
"github.com/docker/libnetwork/drivers/bridge"
"github.com/docker/libnetwork/drivers/host"
"github.com/docker/libnetwork/drivers/null"
o "github.com/docker/libnetwork/drivers/overlay"
"github.com/docker/libnetwork/drivers/remote"
)
@ -14,6 +15,7 @@ func initDrivers(dc driverapi.DriverCallback) error {
host.Init,
null.Init,
remote.Init,
o.Init,
} {
if err := fn(dc); err != nil {
return err

View file

@ -349,11 +349,6 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) error {
ep.joinLeaveStart()
defer func() {
ep.joinLeaveEnd()
if err != nil {
if e := ep.Leave(containerID, options...); e != nil {
log.Warnf("couldnt leave endpoint : %v", ep.name, err)
}
}
}()
ep.Lock()
@ -403,6 +398,13 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) error {
if err != nil {
return err
}
defer func() {
if err != nil {
if err = driver.Leave(nid, epid); err != nil {
log.Warnf("driver leave failed while rolling back join: %v", err)
}
}
}()
err = ep.buildHostsFiles()
if err != nil {
@ -421,7 +423,7 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) error {
sb, err := ctrlr.sandboxAdd(sboxKey, !container.config.useDefaultSandBox, ep)
if err != nil {
return err
return fmt.Errorf("failed sandbox add: %v", err)
}
defer func() {
if err != nil {
@ -554,7 +556,7 @@ func (ep *endpoint) deleteEndpoint() error {
_, ok := n.endpoints[epid]
if !ok {
n.Unlock()
return &UnknownEndpointError{name: name, id: string(epid)}
return nil
}
nid := n.id
@ -571,9 +573,39 @@ func (ep *endpoint) deleteEndpoint() error {
}
log.Warnf("driver error deleting endpoint %s : %v", name, err)
}
n.updateSvcRecord(ep, false)
return nil
}
func (ep *endpoint) addHostEntries(recs []etchosts.Record) {
ep.Lock()
container := ep.container
ep.Unlock()
if container == nil {
return
}
if err := etchosts.Add(container.config.hostsPath, recs); err != nil {
log.Warnf("Failed adding service host entries to the running container: %v", err)
}
}
func (ep *endpoint) deleteHostEntries(recs []etchosts.Record) {
ep.Lock()
container := ep.container
ep.Unlock()
if container == nil {
return
}
if err := etchosts.Delete(container.config.hostsPath, recs); err != nil {
log.Warnf("Failed deleting service host entries to the running container: %v", err)
}
}
func (ep *endpoint) buildHostsFiles() error {
var extraContent []etchosts.Record
@ -581,6 +613,7 @@ func (ep *endpoint) buildHostsFiles() error {
container := ep.container
joinInfo := ep.joinInfo
ifaces := ep.iFaces
n := ep.network
ep.Unlock()
if container == nil {
@ -613,6 +646,8 @@ func (ep *endpoint) buildHostsFiles() error {
etchosts.Record{Hosts: extraHost.name, IP: extraHost.IP})
}
extraContent = append(extraContent, n.getSvcRecords()...)
IP := ""
if len(ifaces) != 0 && ifaces[0] != nil {
IP = ifaces[0].addr.IP.String()

View file

@ -84,13 +84,15 @@ func (epi *endpointInterface) UnmarshalJSON(b []byte) (err error) {
mac, _ := net.ParseMAC(epMap["mac"].(string))
epi.mac = mac
_, ipnet, _ := net.ParseCIDR(epMap["addr"].(string))
ip, ipnet, _ := net.ParseCIDR(epMap["addr"].(string))
if ipnet != nil {
ipnet.IP = ip
epi.addr = *ipnet
}
_, ipnet, _ = net.ParseCIDR(epMap["addrv6"].(string))
ip, ipnet, _ = net.ParseCIDR(epMap["addrv6"].(string))
if ipnet != nil {
ipnet.IP = ip
epi.addrv6 = *ipnet
}
@ -102,8 +104,9 @@ func (epi *endpointInterface) UnmarshalJSON(b []byte) (err error) {
json.Unmarshal(rb, &routes)
epi.routes = make([]*net.IPNet, 0)
for _, route := range routes {
_, ipr, err := net.ParseCIDR(route)
ip, ipr, err := net.ParseCIDR(route)
if err == nil {
ipr.IP = ip
epi.routes = append(epi.routes, ipr)
}
}

View file

@ -51,7 +51,7 @@ func (ij ErrInvalidJoin) BadRequest() {}
type ErrNoContainer struct{}
func (nc ErrNoContainer) Error() string {
return "a container has already joined the endpoint"
return "no container is attached to the endpoint"
}
// Maskable denotes the type of this error

View file

@ -5,6 +5,7 @@ import (
"fmt"
"io"
"io/ioutil"
"os"
"regexp"
)
@ -65,6 +66,45 @@ func Build(path, IP, hostname, domainname string, extraContent []Record) error {
return ioutil.WriteFile(path, content.Bytes(), 0644)
}
// Add adds an arbitrary number of Records to an already existing /etc/hosts file
func Add(path string, recs []Record) error {
f, err := os.Open(path)
if err != nil {
return err
}
content := bytes.NewBuffer(nil)
_, err = content.ReadFrom(f)
if err != nil {
return err
}
for _, r := range recs {
if _, err := r.WriteTo(content); err != nil {
return err
}
}
return ioutil.WriteFile(path, content.Bytes(), 0644)
}
// Delete deletes an arbitrary number of Records already existing in /etc/hosts file
func Delete(path string, recs []Record) error {
old, err := ioutil.ReadFile(path)
if err != nil {
return err
}
regexpStr := fmt.Sprintf("\\S*\\t%s\\n", regexp.QuoteMeta(recs[0].Hosts))
for _, r := range recs[1:] {
regexpStr = regexpStr + "|" + fmt.Sprintf("\\S*\\t%s\\n", regexp.QuoteMeta(r.Hosts))
}
var re = regexp.MustCompile(regexpStr)
return ioutil.WriteFile(path, re.ReplaceAll(old, []byte("")), 0644)
}
// Update all IP addresses where hostname matches.
// path is path to host file
// IP is new IP address

View file

@ -0,0 +1,94 @@
// Package idm manages resevation/release of numerical ids from a configured set of contiguos ids
package idm
import (
"fmt"
"github.com/docker/libnetwork/bitseq"
"github.com/docker/libnetwork/datastore"
"github.com/docker/libnetwork/types"
)
// Idm manages the reservation/release of numerical ids from a contiguos set
type Idm struct {
start uint32
end uint32
handle *bitseq.Handle
}
// New returns an instance of id manager for a set of [start-end] numerical ids
func New(ds datastore.DataStore, id string, start, end uint32) (*Idm, error) {
if id == "" {
return nil, fmt.Errorf("Invalid id")
}
if end <= start {
return nil, fmt.Errorf("Invalid set range: [%d, %d]", start, end)
}
h, err := bitseq.NewHandle("idm", ds, id, uint32(1+end-start))
if err != nil {
return nil, fmt.Errorf("failed to initialize bit sequence handler: %s", err.Error())
}
return &Idm{start: start, end: end, handle: h}, nil
}
// GetID returns the first available id in the set
func (i *Idm) GetID() (uint32, error) {
if i.handle == nil {
return 0, fmt.Errorf("ID set is not initialized")
}
for {
bytePos, bitPos, err := i.handle.GetFirstAvailable()
if err != nil {
return 0, fmt.Errorf("no available ids")
}
id := i.start + uint32(bitPos+bytePos*8)
// for sets which length is non multiple of 32 this check is needed
if i.end < id {
return 0, fmt.Errorf("no available ids")
}
if err := i.handle.PushReservation(bytePos, bitPos, false); err != nil {
if _, ok := err.(types.RetryError); !ok {
return 0, fmt.Errorf("internal failure while reserving the id: %s", err.Error())
}
continue
}
return id, nil
}
}
// GetSpecificID tries to reserve the specified id
func (i *Idm) GetSpecificID(id uint32) error {
if i.handle == nil {
return fmt.Errorf("ID set is not initialized")
}
if id < i.start || id > i.end {
return fmt.Errorf("Requested id does not belong to the set")
}
for {
bytePos, bitPos, err := i.handle.CheckIfAvailable(int(id - i.start))
if err != nil {
return fmt.Errorf("requested id is not available")
}
if err := i.handle.PushReservation(bytePos, bitPos, false); err != nil {
if _, ok := err.(types.RetryError); !ok {
return fmt.Errorf("internal failure while reserving the id: %s", err.Error())
}
continue
}
return nil
}
}
// Release releases the specified id
func (i *Idm) Release(id uint32) {
ordinal := id - i.start
i.handle.PushReservation(int(ordinal/8), int(ordinal%8), true)
}

View file

@ -99,7 +99,8 @@ func NewChain(name, bridge string, table Table, hairpinMode bool) (*Chain, error
case Nat:
preroute := []string{
"-m", "addrtype",
"--dst-type", "LOCAL"}
"--dst-type", "LOCAL",
"-j", c.Name}
if !Exists(Nat, "PREROUTING", preroute...) {
if err := c.Prerouting(Append, preroute...); err != nil {
return nil, fmt.Errorf("Failed to inject docker in PREROUTING chain: %s", err)
@ -107,7 +108,8 @@ func NewChain(name, bridge string, table Table, hairpinMode bool) (*Chain, error
}
output := []string{
"-m", "addrtype",
"--dst-type", "LOCAL"}
"--dst-type", "LOCAL",
"-j", c.Name}
if !hairpinMode {
output = append(output, "!", "--dst", "127.0.0.0/8")
}
@ -228,7 +230,7 @@ func (c *Chain) Prerouting(action Action, args ...string) error {
if len(args) > 0 {
a = append(a, args...)
}
if output, err := Raw(append(a, "-j", c.Name)...); err != nil {
if output, err := Raw(a...); err != nil {
return err
} else if len(output) != 0 {
return ChainError{Chain: "PREROUTING", Output: output}
@ -242,7 +244,7 @@ func (c *Chain) Output(action Action, args ...string) error {
if len(args) > 0 {
a = append(a, args...)
}
if output, err := Raw(append(a, "-j", c.Name)...); err != nil {
if output, err := Raw(a...); err != nil {
return err
} else if len(output) != 0 {
return ChainError{Chain: "OUTPUT", Output: output}
@ -254,9 +256,9 @@ func (c *Chain) Output(action Action, args ...string) error {
func (c *Chain) Remove() error {
// Ignore errors - This could mean the chains were never set up
if c.Table == Nat {
c.Prerouting(Delete, "-m", "addrtype", "--dst-type", "LOCAL")
c.Output(Delete, "-m", "addrtype", "--dst-type", "LOCAL", "!", "--dst", "127.0.0.0/8")
c.Output(Delete, "-m", "addrtype", "--dst-type", "LOCAL") // Created in versions <= 0.1.6
c.Prerouting(Delete, "-m", "addrtype", "--dst-type", "LOCAL", "-j", c.Name)
c.Output(Delete, "-m", "addrtype", "--dst-type", "LOCAL", "!", "--dst", "127.0.0.0/8", "-j", c.Name)
c.Output(Delete, "-m", "addrtype", "--dst-type", "LOCAL", "-j", c.Name) // Created in versions <= 0.1.6
c.Prerouting(Delete)
c.Output(Delete)

View file

@ -1,5 +1,7 @@
package netlabel
import "strings"
const (
// Prefix constant marks the reserved label space for libnetwork
Prefix = "com.docker.network"
@ -21,4 +23,30 @@ const (
//EnableIPv6 constant represents enabling IPV6 at network level
EnableIPv6 = Prefix + ".enable_ipv6"
// KVProvider constant represents the KV provider backend
KVProvider = DriverPrefix + ".kv_provider"
// KVProviderURL constant represents the KV provider URL
KVProviderURL = DriverPrefix + ".kv_provider_url"
// OverlayBindInterface constant represents overlay driver bind interface
OverlayBindInterface = DriverPrefix + ".overlay.bind_interface"
// OverlayNeighborIP constant represents overlay driver neighbor IP
OverlayNeighborIP = DriverPrefix + ".overlay.neighbor_ip"
)
// Key extracts the key portion of the label
func Key(label string) string {
kv := strings.SplitN(label, "=", 2)
return kv[0]
}
// Value extracts the value portion of the label
func Value(label string) string {
kv := strings.SplitN(label, "=", 2)
return kv[1]
}

View file

@ -168,3 +168,53 @@ func GenerateIfaceName(prefix string, len int) (string, error) {
}
return "", types.InternalErrorf("could not generate interface name")
}
func byteArrayToInt(array []byte, numBytes int) uint64 {
if numBytes <= 0 || numBytes > 8 {
panic("Invalid argument")
}
num := 0
for i := 0; i <= len(array)-1; i++ {
num += int(array[len(array)-1-i]) << uint(i*8)
}
return uint64(num)
}
// ATo64 converts a byte array into a uint32
func ATo64(array []byte) uint64 {
return byteArrayToInt(array, 8)
}
// ATo32 converts a byte array into a uint32
func ATo32(array []byte) uint32 {
return uint32(byteArrayToInt(array, 4))
}
// ATo16 converts a byte array into a uint16
func ATo16(array []byte) uint16 {
return uint16(byteArrayToInt(array, 2))
}
func intToByteArray(val uint64, numBytes int) []byte {
array := make([]byte, numBytes)
for i := numBytes - 1; i >= 0; i-- {
array[i] = byte(val & 0xff)
val = val >> 8
}
return array
}
// U64ToA converts a uint64 to a byte array
func U64ToA(val uint64) []byte {
return intToByteArray(uint64(val), 8)
}
// U32ToA converts a uint64 to a byte array
func U32ToA(val uint32) []byte {
return intToByteArray(uint64(val), 4)
}
// U16ToA converts a uint64 to a byte array
func U16ToA(val uint16) []byte {
return intToByteArray(uint64(val), 2)
}

View file

@ -2,6 +2,7 @@ package libnetwork
import (
"encoding/json"
"net"
"sync"
log "github.com/Sirupsen/logrus"
@ -9,6 +10,7 @@ import (
"github.com/docker/libnetwork/config"
"github.com/docker/libnetwork/datastore"
"github.com/docker/libnetwork/driverapi"
"github.com/docker/libnetwork/etchosts"
"github.com/docker/libnetwork/netlabel"
"github.com/docker/libnetwork/options"
"github.com/docker/libnetwork/types"
@ -51,6 +53,8 @@ type Network interface {
// When the function returns true, the walk will stop.
type EndpointWalker func(ep Endpoint) bool
type svcMap map[string]net.IP
type network struct {
ctrlr *controller
name string
@ -62,6 +66,8 @@ type network struct {
endpoints endpointTable
generic options.Generic
dbIndex uint64
svcRecords svcMap
stopWatchCh chan struct{}
sync.Mutex
}
@ -248,6 +254,7 @@ func (n *network) deleteNetwork() error {
}
log.Warnf("driver error deleting network %s : %v", n.name, err)
}
n.stopWatch()
return nil
}
@ -270,6 +277,8 @@ func (n *network) addEndpoint(ep *endpoint) error {
if err != nil {
return err
}
n.updateSvcRecord(ep, true)
return nil
}
@ -382,3 +391,62 @@ func (n *network) isGlobalScoped() (bool, error) {
n.Unlock()
return c.isDriverGlobalScoped(n.networkType)
}
func (n *network) updateSvcRecord(ep *endpoint, isAdd bool) {
n.Lock()
var recs []etchosts.Record
for _, iface := range ep.InterfaceList() {
if isAdd {
n.svcRecords[ep.Name()] = iface.Address().IP
n.svcRecords[ep.Name()+"."+n.name] = iface.Address().IP
} else {
delete(n.svcRecords, ep.Name())
delete(n.svcRecords, ep.Name()+"."+n.name)
}
recs = append(recs, etchosts.Record{
Hosts: ep.Name(),
IP: iface.Address().IP.String(),
})
recs = append(recs, etchosts.Record{
Hosts: ep.Name() + "." + n.name,
IP: iface.Address().IP.String(),
})
}
n.Unlock()
var epList []*endpoint
n.WalkEndpoints(func(e Endpoint) bool {
cEp := e.(*endpoint)
cEp.Lock()
if cEp.container != nil {
epList = append(epList, cEp)
}
cEp.Unlock()
return false
})
for _, cEp := range epList {
if isAdd {
cEp.addHostEntries(recs)
} else {
cEp.deleteHostEntries(recs)
}
}
}
func (n *network) getSvcRecords() []etchosts.Record {
n.Lock()
defer n.Unlock()
var recs []etchosts.Record
for h, ip := range n.svcRecords {
recs = append(recs, etchosts.Record{
Hosts: h,
IP: ip.String(),
})
}
return recs
}

View file

@ -153,14 +153,14 @@ func (i *nwIface) Remove() error {
})
}
func (n *networkNamespace) findDstMaster(srcName string) string {
func (n *networkNamespace) findDst(srcName string, isBridge bool) string {
n.Lock()
defer n.Unlock()
for _, i := range n.iFaces {
// The master should match the srcname of the interface and the
// master interface should be of type bridge.
if i.SrcName() == srcName && i.Bridge() {
// master interface should be of type bridge, if searching for a bridge type
if i.SrcName() == srcName && (!isBridge || i.Bridge()) {
return i.DstName()
}
}
@ -173,7 +173,7 @@ func (n *networkNamespace) AddInterface(srcName, dstPrefix string, options ...If
i.processInterfaceOptions(options...)
if i.master != "" {
i.dstMaster = n.findDstMaster(i.master)
i.dstMaster = n.findDst(i.master, true)
if i.dstMaster == "" {
return fmt.Errorf("could not find an appropriate master %q for %q",
i.master, i.srcName)

View file

@ -37,6 +37,7 @@ type networkNamespace struct {
gw net.IP
gwv6 net.IP
staticRoutes []*types.StaticRoute
neighbors []*neigh
nextIfIndex int
sync.Mutex
}
@ -150,6 +151,10 @@ func (n *networkNamespace) InterfaceOptions() IfaceOptionSetter {
return n
}
func (n *networkNamespace) NeighborOptions() NeighborOptionSetter {
return n
}
func reexecCreateNamespace() {
if len(os.Args) < 2 {
log.Fatal("no namespace path provided")
@ -230,6 +235,13 @@ func loopbackUp() error {
return netlink.LinkSetUp(iface)
}
func (n *networkNamespace) InvokeFunc(f func()) error {
return nsInvoke(n.nsPath(), func(nsFD int) error { return nil }, func(callerFD int) error {
f()
return nil
})
}
func nsInvoke(path string, prefunc func(nsFD int) error, postfunc func(callerFD int) error) error {
runtime.LockOSThread()
defer runtime.UnlockOSThread()

View file

@ -0,0 +1,138 @@
package sandbox
import (
"bytes"
"fmt"
"net"
"github.com/vishvananda/netlink"
)
// NeighOption is a function option type to set interface options
type NeighOption func(nh *neigh)
type neigh struct {
dstIP net.IP
dstMac net.HardwareAddr
linkName string
linkDst string
family int
}
func (n *networkNamespace) findNeighbor(dstIP net.IP, dstMac net.HardwareAddr) *neigh {
n.Lock()
defer n.Unlock()
for _, nh := range n.neighbors {
if nh.dstIP.Equal(dstIP) && bytes.Equal(nh.dstMac, dstMac) {
return nh
}
}
return nil
}
func (n *networkNamespace) DeleteNeighbor(dstIP net.IP, dstMac net.HardwareAddr) error {
nh := n.findNeighbor(dstIP, dstMac)
if nh == nil {
return fmt.Errorf("could not find the neighbor entry to delete")
}
return nsInvoke(n.nsPath(), func(nsFD int) error { return nil }, func(callerFD int) error {
var iface netlink.Link
if nh.linkDst != "" {
var err error
iface, err = netlink.LinkByName(nh.linkDst)
if err != nil {
return fmt.Errorf("could not find interface with destination name %s: %v",
nh.linkDst, err)
}
}
nlnh := &netlink.Neigh{
IP: dstIP,
State: netlink.NUD_PERMANENT,
Family: nh.family,
}
if nlnh.Family > 0 {
nlnh.HardwareAddr = dstMac
nlnh.Flags = netlink.NTF_SELF
}
if nh.linkDst != "" {
nlnh.LinkIndex = iface.Attrs().Index
}
if err := netlink.NeighDel(nlnh); err != nil {
return fmt.Errorf("could not delete neighbor entry: %v", err)
}
for i, nh := range n.neighbors {
if nh.dstIP.Equal(dstIP) && bytes.Equal(nh.dstMac, dstMac) {
n.neighbors = append(n.neighbors[:i], n.neighbors[i+1:]...)
}
}
return nil
})
}
func (n *networkNamespace) AddNeighbor(dstIP net.IP, dstMac net.HardwareAddr, options ...NeighOption) error {
nh := n.findNeighbor(dstIP, dstMac)
if nh != nil {
// If it exists silently return
return nil
}
nh = &neigh{
dstIP: dstIP,
dstMac: dstMac,
}
nh.processNeighOptions(options...)
if nh.linkName != "" {
nh.linkDst = n.findDst(nh.linkName, false)
if nh.linkDst == "" {
return fmt.Errorf("could not find the interface with name %s", nh.linkName)
}
}
return nsInvoke(n.nsPath(), func(nsFD int) error { return nil }, func(callerFD int) error {
var iface netlink.Link
if nh.linkDst != "" {
var err error
iface, err = netlink.LinkByName(nh.linkDst)
if err != nil {
return fmt.Errorf("could not find interface with destination name %s: %v",
nh.linkDst, err)
}
}
nlnh := &netlink.Neigh{
IP: dstIP,
HardwareAddr: dstMac,
State: netlink.NUD_PERMANENT,
Family: nh.family,
}
if nlnh.Family > 0 {
nlnh.Flags = netlink.NTF_SELF
}
if nh.linkDst != "" {
nlnh.LinkIndex = iface.Attrs().Index
}
if err := netlink.NeighSet(nlnh); err != nil {
return fmt.Errorf("could not add neighbor entry: %v", err)
}
n.neighbors = append(n.neighbors, nh)
return nil
})
}

View file

@ -0,0 +1,4 @@
package sandbox
// NeighOption is a function option type to set neighbor options
type NeighOption func()

View file

@ -2,6 +2,26 @@ package sandbox
import "net"
func (nh *neigh) processNeighOptions(options ...NeighOption) {
for _, opt := range options {
if opt != nil {
opt(nh)
}
}
}
func (n *networkNamespace) LinkName(name string) NeighOption {
return func(nh *neigh) {
nh.linkName = name
}
}
func (n *networkNamespace) Family(family int) NeighOption {
return func(nh *neigh) {
nh.family = family
}
}
func (i *nwIface) processInterfaceOptions(options ...IfaceOption) {
for _, opt := range options {
if opt != nil {

View file

@ -81,7 +81,7 @@ func programGateway(path string, gw net.IP, isAdd bool) error {
return nsInvoke(path, func(nsFD int) error { return nil }, func(callerFD int) error {
gwRoutes, err := netlink.RouteGet(gw)
if err != nil {
return fmt.Errorf("route for the gateway could not be found: %v", err)
return fmt.Errorf("route for the gateway %s could not be found: %v", gw, err)
}
if isAdd {
@ -105,7 +105,7 @@ func programRoute(path string, dest *net.IPNet, nh net.IP) error {
return nsInvoke(path, func(nsFD int) error { return nil }, func(callerFD int) error {
gwRoutes, err := netlink.RouteGet(nh)
if err != nil {
return fmt.Errorf("route for the next hop could not be found: %v", err)
return fmt.Errorf("route for the next hop %s could not be found: %v", nh, err)
}
return netlink.RouteAdd(&netlink.Route{

View file

@ -37,9 +37,21 @@ type Sandbox interface {
// Remove a static route from the sandbox.
RemoveStaticRoute(*types.StaticRoute) error
// AddNeighbor adds a neighbor entry into the sandbox.
AddNeighbor(dstIP net.IP, dstMac net.HardwareAddr, option ...NeighOption) error
// DeleteNeighbor deletes neighbor entry from the sandbox.
DeleteNeighbor(dstIP net.IP, dstMac net.HardwareAddr) error
// Returns an interface with methods to set neighbor options.
NeighborOptions() NeighborOptionSetter
// Returns an interface with methods to set interface options.
InterfaceOptions() IfaceOptionSetter
//Invoke
InvokeFunc(func()) error
// Returns an interface with methods to get sandbox state.
Info() Info
@ -47,6 +59,17 @@ type Sandbox interface {
Destroy() error
}
// NeighborOptionSetter interfaces defines the option setter methods for interface options
type NeighborOptionSetter interface {
// LinkName returns an option setter to set the srcName of the link that should
// be used in the neighbor entry
LinkName(string) NeighOption
// Family returns an option setter to set the address family for the neighbor
// entry. eg. AF_BRIDGE
Family(int) NeighOption
}
// IfaceOptionSetter interface defines the option setter methods for interface options.
type IfaceOptionSetter interface {
// Bridge returns an option setter to set if the interface is a bridge.

View file

@ -2,6 +2,7 @@ package libnetwork
import (
"container/heap"
"fmt"
"sync"
"github.com/Sirupsen/logrus"
@ -48,13 +49,9 @@ func (eh *epHeap) Pop() interface{} {
func (s *sandboxData) updateGateway(ep *endpoint) error {
sb := s.sandbox()
if err := sb.UnsetGateway(); err != nil {
return err
}
if err := sb.UnsetGatewayIPv6(); err != nil {
return err
}
sb.UnsetGateway()
sb.UnsetGatewayIPv6()
if ep == nil {
return nil
@ -65,11 +62,11 @@ func (s *sandboxData) updateGateway(ep *endpoint) error {
ep.Unlock()
if err := sb.SetGateway(joinInfo.gw); err != nil {
return err
return fmt.Errorf("failed to set gateway while updating gateway: %v", err)
}
if err := sb.SetGatewayIPv6(joinInfo.gw6); err != nil {
return err
return fmt.Errorf("failed to set IPv6 gateway while updating gateway: %v", err)
}
return nil
@ -93,7 +90,7 @@ func (s *sandboxData) addEndpoint(ep *endpoint) error {
}
if err := sb.AddInterface(i.srcName, i.dstPrefix, ifaceOptions...); err != nil {
return err
return fmt.Errorf("failed to add interface %s to sandbox: %v", i.srcName, err)
}
}
@ -101,7 +98,7 @@ func (s *sandboxData) addEndpoint(ep *endpoint) error {
// Set up non-interface routes.
for _, r := range ep.joinInfo.StaticRoutes {
if err := sb.AddStaticRoute(r); err != nil {
return err
return fmt.Errorf("failed to add static route %s: %v", r.Destination.String(), err)
}
}
}
@ -117,14 +114,10 @@ func (s *sandboxData) addEndpoint(ep *endpoint) error {
}
}
s.Lock()
s.refCnt++
s.Unlock()
return nil
}
func (s *sandboxData) rmEndpoint(ep *endpoint) int {
func (s *sandboxData) rmEndpoint(ep *endpoint) {
ep.Lock()
joinInfo := ep.joinInfo
ep.Unlock()
@ -171,17 +164,6 @@ func (s *sandboxData) rmEndpoint(ep *endpoint) int {
if highEpBefore != highEpAfter {
s.updateGateway(highEpAfter)
}
s.Lock()
s.refCnt--
refCnt := s.refCnt
s.Unlock()
if refCnt == 0 {
s.sandbox().Destroy()
}
return refCnt
}
func (s *sandboxData) sandbox() sandbox.Sandbox {
@ -199,7 +181,7 @@ func (c *controller) sandboxAdd(key string, create bool, ep *endpoint) (sandbox.
if !ok {
sb, err := sandbox.NewSandbox(key, create)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create new sandbox: %v", err)
}
sData = &sandboxData{
@ -225,11 +207,7 @@ func (c *controller) sandboxRm(key string, ep *endpoint) {
sData := c.sandboxes[key]
c.Unlock()
if sData.rmEndpoint(ep) == 0 {
c.Lock()
delete(c.sandboxes, key)
c.Unlock()
}
sData.rmEndpoint(ep)
}
func (c *controller) sandboxGet(key string) sandbox.Sandbox {
@ -243,3 +221,31 @@ func (c *controller) sandboxGet(key string) sandbox.Sandbox {
return sData.sandbox()
}
func (c *controller) LeaveAll(id string) error {
c.Lock()
sData, ok := c.sandboxes[sandbox.GenerateKey(id)]
c.Unlock()
if !ok {
return fmt.Errorf("could not find sandbox for container id %s", id)
}
sData.Lock()
eps := make([]*endpoint, len(sData.endpoints))
for i, ep := range sData.endpoints {
eps[i] = ep
}
sData.Unlock()
for _, ep := range eps {
if err := ep.Leave(id); err != nil {
logrus.Warnf("Failed leaving endpoint id %s: %v\n", ep.ID(), err)
}
}
sData.sandbox().Destroy()
delete(c.sandboxes, sandbox.GenerateKey(id))
return nil
}

View file

@ -5,6 +5,7 @@ import (
"fmt"
log "github.com/Sirupsen/logrus"
"github.com/docker/libkv/store"
"github.com/docker/libnetwork/datastore"
"github.com/docker/libnetwork/types"
)
@ -31,7 +32,21 @@ func (c *controller) initDataStore() error {
c.Lock()
c.store = store
c.Unlock()
return c.watchStore()
nws, err := c.getNetworksFromStore()
if err == nil {
c.processNetworkUpdate(nws, nil)
} else if err != datastore.ErrKeyNotFound {
log.Warnf("failed to read networks from datastore during init : %v", err)
}
return c.watchNetworks()
}
func (c *controller) getNetworksFromStore() ([]*store.KVPair, error) {
c.Lock()
cs := c.store
c.Unlock()
return cs.KVStore().List(datastore.Key(datastore.NetworkKeyPrefix))
}
func (c *controller) newNetworkFromStore(n *network) error {
@ -92,22 +107,6 @@ func (c *controller) newEndpointFromStore(key string, ep *endpoint) error {
n := ep.network
id := ep.id
ep.Unlock()
if n == nil {
// Possibly the watch event for the network has not shown up yet
// Try to get network from the store
nid, err := networkIDFromEndpointKey(key, ep)
if err != nil {
return err
}
n, err = c.getNetworkFromStore(nid)
if err != nil {
return err
}
if err := c.newNetworkFromStore(n); err != nil {
return err
}
n = c.networks[nid]
}
_, err := n.EndpointByID(string(id))
if err != nil {
@ -170,7 +169,11 @@ func (c *controller) deleteEndpointFromStore(ep *endpoint) error {
return nil
}
func (c *controller) watchStore() error {
func (c *controller) watchNetworks() error {
if !c.validateDatastoreConfig() {
return nil
}
c.Lock()
cs := c.store
c.Unlock()
@ -179,64 +182,35 @@ func (c *controller) watchStore() error {
if err != nil {
return err
}
epPairs, err := cs.KVStore().WatchTree(datastore.Key(datastore.EndpointKeyPrefix), nil)
if err != nil {
return err
}
go func() {
for {
select {
case nws := <-nwPairs:
for _, kve := range nws {
var n network
err := json.Unmarshal(kve.Value, &n)
if err != nil {
log.Error(err)
continue
}
n.dbIndex = kve.LastIndex
c.Lock()
existing, ok := c.networks[n.id]
c.Unlock()
if ok {
existing.Lock()
// Skip existing network update
if existing.dbIndex != n.dbIndex {
existing.dbIndex = n.dbIndex
existing.endpointCnt = n.endpointCnt
}
existing.Unlock()
continue
}
if err = c.newNetworkFromStore(&n); err != nil {
log.Error(err)
c.Lock()
tmpview := networkTable{}
lview := c.networks
c.Unlock()
for k, v := range lview {
global, _ := v.isGlobalScoped()
if global {
tmpview[k] = v
}
}
case eps := <-epPairs:
for _, epe := range eps {
var ep endpoint
err := json.Unmarshal(epe.Value, &ep)
if err != nil {
log.Error(err)
c.processNetworkUpdate(nws, &tmpview)
// Delete processing
for k := range tmpview {
c.Lock()
existing, ok := c.networks[k]
c.Unlock()
if !ok {
continue
}
ep.dbIndex = epe.LastIndex
n, err := c.networkFromEndpointKey(epe.Key, &ep)
if err != nil {
if _, ok := err.(ErrNoSuchNetwork); !ok {
log.Error(err)
continue
}
tmp := network{}
if err := c.store.GetObject(datastore.Key(existing.Key()...), &tmp); err != datastore.ErrKeyNotFound {
continue
}
if n != nil {
ep.network = n.(*network)
}
if c.processEndpointUpdate(&ep) {
err = c.newEndpointFromStore(epe.Key, &ep)
if err != nil {
log.Error(err)
}
if err := existing.deleteNetwork(); err != nil {
log.Debugf("Delete failed %s: %s", existing.name, err)
}
}
}
@ -245,20 +219,116 @@ func (c *controller) watchStore() error {
return nil
}
func (c *controller) networkFromEndpointKey(key string, ep *endpoint) (Network, error) {
nid, err := networkIDFromEndpointKey(key, ep)
if err != nil {
return nil, err
func (n *network) watchEndpoints() error {
if !n.ctrlr.validateDatastoreConfig() {
return nil
}
return c.NetworkByID(string(nid))
n.Lock()
cs := n.ctrlr.store
tmp := endpoint{network: n}
n.stopWatchCh = make(chan struct{})
stopCh := n.stopWatchCh
n.Unlock()
epPairs, err := cs.KVStore().WatchTree(datastore.Key(tmp.KeyPrefix()...), stopCh)
if err != nil {
return err
}
go func() {
for {
select {
case <-stopCh:
return
case eps := <-epPairs:
n.Lock()
tmpview := endpointTable{}
lview := n.endpoints
n.Unlock()
for k, v := range lview {
global, _ := v.network.isGlobalScoped()
if global {
tmpview[k] = v
}
}
for _, epe := range eps {
var ep endpoint
err := json.Unmarshal(epe.Value, &ep)
if err != nil {
log.Error(err)
continue
}
delete(tmpview, ep.id)
ep.dbIndex = epe.LastIndex
ep.network = n
if n.ctrlr.processEndpointUpdate(&ep) {
err = n.ctrlr.newEndpointFromStore(epe.Key, &ep)
if err != nil {
log.Error(err)
}
}
}
// Delete processing
for k := range tmpview {
n.Lock()
existing, ok := n.endpoints[k]
n.Unlock()
if !ok {
continue
}
tmp := endpoint{}
if err := cs.GetObject(datastore.Key(existing.Key()...), &tmp); err != datastore.ErrKeyNotFound {
continue
}
if err := existing.deleteEndpoint(); err != nil {
log.Debugf("Delete failed %s: %s", existing.name, err)
}
}
}
}
}()
return nil
}
func networkIDFromEndpointKey(key string, ep *endpoint) (types.UUID, error) {
eKey, err := datastore.ParseKey(key)
if err != nil {
return types.UUID(""), err
func (n *network) stopWatch() {
n.Lock()
if n.stopWatchCh != nil {
close(n.stopWatchCh)
n.stopWatchCh = nil
}
n.Unlock()
}
func (c *controller) processNetworkUpdate(nws []*store.KVPair, prune *networkTable) {
for _, kve := range nws {
var n network
err := json.Unmarshal(kve.Value, &n)
if err != nil {
log.Error(err)
continue
}
if prune != nil {
delete(*prune, n.id)
}
n.dbIndex = kve.LastIndex
c.Lock()
existing, ok := c.networks[n.id]
c.Unlock()
if ok {
existing.Lock()
// Skip existing network update
if existing.dbIndex != n.dbIndex {
existing.dbIndex = n.dbIndex
existing.endpointCnt = n.endpointCnt
}
existing.Unlock()
continue
}
if err = c.newNetworkFromStore(&n); err != nil {
log.Error(err)
}
}
return ep.networkIDFromKey(eKey)
}
func (c *controller) processEndpointUpdate(ep *endpoint) bool {

View file

@ -228,6 +228,12 @@ type MaskableError interface {
Maskable()
}
// RetryError is an interface for errors which might get resolved through retry
type RetryError interface {
// Retry makes implementer into RetryError type
Retry()
}
// BadRequestError is an interface for errors originated by a bad request
type BadRequestError interface {
// BadRequest makes implementer into BadRequestError type
@ -271,7 +277,7 @@ type InternalError interface {
}
/******************************
* Weel-known Error Formatters
* Well-known Error Formatters
******************************/
// BadRequestErrorf creates an instance of BadRequestError
@ -314,6 +320,11 @@ func InternalMaskableErrorf(format string, params ...interface{}) error {
return maskInternal(fmt.Sprintf(format, params...))
}
// RetryErrorf creates an instance of RetryError
func RetryErrorf(format string, params ...interface{}) error {
return retry(fmt.Sprintf(format, params...))
}
/***********************
* Internal Error Types
***********************/
@ -377,3 +388,10 @@ func (mnt maskInternal) Error() string {
}
func (mnt maskInternal) Internal() {}
func (mnt maskInternal) Maskable() {}
type retry string
func (r retry) Error() string {
return string(r)
}
func (r retry) Retry() {}

View file

@ -0,0 +1,143 @@
// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved.
// Use of this source code is governed by a BSD-style license found in the LICENSE file.
/*
High Performance, Feature-Rich Idiomatic Go encoding library for msgpack and binc .
Supported Serialization formats are:
- msgpack: [https://github.com/msgpack/msgpack]
- binc: [http://github.com/ugorji/binc]
To install:
go get github.com/ugorji/go/codec
The idiomatic Go support is as seen in other encoding packages in
the standard library (ie json, xml, gob, etc).
Rich Feature Set includes:
- Simple but extremely powerful and feature-rich API
- Very High Performance.
Our extensive benchmarks show us outperforming Gob, Json and Bson by 2-4X.
This was achieved by taking extreme care on:
- managing allocation
- function frame size (important due to Go's use of split stacks),
- reflection use (and by-passing reflection for common types)
- recursion implications
- zero-copy mode (encoding/decoding to byte slice without using temp buffers)
- Correct.
Care was taken to precisely handle corner cases like:
overflows, nil maps and slices, nil value in stream, etc.
- Efficient zero-copying into temporary byte buffers
when encoding into or decoding from a byte slice.
- Standard field renaming via tags
- Encoding from any value
(struct, slice, map, primitives, pointers, interface{}, etc)
- Decoding into pointer to any non-nil typed value
(struct, slice, map, int, float32, bool, string, reflect.Value, etc)
- Supports extension functions to handle the encode/decode of custom types
- Support Go 1.2 encoding.BinaryMarshaler/BinaryUnmarshaler
- Schema-less decoding
(decode into a pointer to a nil interface{} as opposed to a typed non-nil value).
Includes Options to configure what specific map or slice type to use
when decoding an encoded list or map into a nil interface{}
- Provides a RPC Server and Client Codec for net/rpc communication protocol.
- Msgpack Specific:
- Provides extension functions to handle spec-defined extensions (binary, timestamp)
- Options to resolve ambiguities in handling raw bytes (as string or []byte)
during schema-less decoding (decoding into a nil interface{})
- RPC Server/Client Codec for msgpack-rpc protocol defined at:
https://github.com/msgpack-rpc/msgpack-rpc/blob/master/spec.md
- Fast Paths for some container types:
For some container types, we circumvent reflection and its associated overhead
and allocation costs, and encode/decode directly. These types are:
[]interface{}
[]int
[]string
map[interface{}]interface{}
map[int]interface{}
map[string]interface{}
Extension Support
Users can register a function to handle the encoding or decoding of
their custom types.
There are no restrictions on what the custom type can be. Some examples:
type BisSet []int
type BitSet64 uint64
type UUID string
type MyStructWithUnexportedFields struct { a int; b bool; c []int; }
type GifImage struct { ... }
As an illustration, MyStructWithUnexportedFields would normally be
encoded as an empty map because it has no exported fields, while UUID
would be encoded as a string. However, with extension support, you can
encode any of these however you like.
RPC
RPC Client and Server Codecs are implemented, so the codecs can be used
with the standard net/rpc package.
Usage
Typical usage model:
// create and configure Handle
var (
bh codec.BincHandle
mh codec.MsgpackHandle
)
mh.MapType = reflect.TypeOf(map[string]interface{}(nil))
// configure extensions
// e.g. for msgpack, define functions and enable Time support for tag 1
// mh.AddExt(reflect.TypeOf(time.Time{}), 1, myMsgpackTimeEncodeExtFn, myMsgpackTimeDecodeExtFn)
// create and use decoder/encoder
var (
r io.Reader
w io.Writer
b []byte
h = &bh // or mh to use msgpack
)
dec = codec.NewDecoder(r, h)
dec = codec.NewDecoderBytes(b, h)
err = dec.Decode(&v)
enc = codec.NewEncoder(w, h)
enc = codec.NewEncoderBytes(&b, h)
err = enc.Encode(v)
//RPC Server
go func() {
for {
conn, err := listener.Accept()
rpcCodec := codec.GoRpc.ServerCodec(conn, h)
//OR rpcCodec := codec.MsgpackSpecRpc.ServerCodec(conn, h)
rpc.ServeCodec(rpcCodec)
}
}()
//RPC Communication (client side)
conn, err = net.Dial("tcp", "localhost:5555")
rpcCodec := codec.GoRpc.ClientCodec(conn, h)
//OR rpcCodec := codec.MsgpackSpecRpc.ClientCodec(conn, h)
client := rpc.NewClientWithCodec(rpcCodec)
Representative Benchmark Results
Run the benchmark suite using:
go test -bi -bench=. -benchmem
To run full benchmark suite (including against vmsgpack and bson),
see notes in ext_dep_test.go
*/
package codec

View file

@ -0,0 +1,174 @@
# Codec
High Performance and Feature-Rich Idiomatic Go Library providing
encode/decode support for different serialization formats.
Supported Serialization formats are:
- msgpack: [https://github.com/msgpack/msgpack]
- binc: [http://github.com/ugorji/binc]
To install:
go get github.com/ugorji/go/codec
Online documentation: [http://godoc.org/github.com/ugorji/go/codec]
The idiomatic Go support is as seen in other encoding packages in
the standard library (ie json, xml, gob, etc).
Rich Feature Set includes:
- Simple but extremely powerful and feature-rich API
- Very High Performance.
Our extensive benchmarks show us outperforming Gob, Json and Bson by 2-4X.
This was achieved by taking extreme care on:
- managing allocation
- function frame size (important due to Go's use of split stacks),
- reflection use (and by-passing reflection for common types)
- recursion implications
- zero-copy mode (encoding/decoding to byte slice without using temp buffers)
- Correct.
Care was taken to precisely handle corner cases like:
overflows, nil maps and slices, nil value in stream, etc.
- Efficient zero-copying into temporary byte buffers
when encoding into or decoding from a byte slice.
- Standard field renaming via tags
- Encoding from any value
(struct, slice, map, primitives, pointers, interface{}, etc)
- Decoding into pointer to any non-nil typed value
(struct, slice, map, int, float32, bool, string, reflect.Value, etc)
- Supports extension functions to handle the encode/decode of custom types
- Support Go 1.2 encoding.BinaryMarshaler/BinaryUnmarshaler
- Schema-less decoding
(decode into a pointer to a nil interface{} as opposed to a typed non-nil value).
Includes Options to configure what specific map or slice type to use
when decoding an encoded list or map into a nil interface{}
- Provides a RPC Server and Client Codec for net/rpc communication protocol.
- Msgpack Specific:
- Provides extension functions to handle spec-defined extensions (binary, timestamp)
- Options to resolve ambiguities in handling raw bytes (as string or []byte)
during schema-less decoding (decoding into a nil interface{})
- RPC Server/Client Codec for msgpack-rpc protocol defined at:
https://github.com/msgpack-rpc/msgpack-rpc/blob/master/spec.md
- Fast Paths for some container types:
For some container types, we circumvent reflection and its associated overhead
and allocation costs, and encode/decode directly. These types are:
[]interface{}
[]int
[]string
map[interface{}]interface{}
map[int]interface{}
map[string]interface{}
## Extension Support
Users can register a function to handle the encoding or decoding of
their custom types.
There are no restrictions on what the custom type can be. Some examples:
type BisSet []int
type BitSet64 uint64
type UUID string
type MyStructWithUnexportedFields struct { a int; b bool; c []int; }
type GifImage struct { ... }
As an illustration, MyStructWithUnexportedFields would normally be
encoded as an empty map because it has no exported fields, while UUID
would be encoded as a string. However, with extension support, you can
encode any of these however you like.
## RPC
RPC Client and Server Codecs are implemented, so the codecs can be used
with the standard net/rpc package.
## Usage
Typical usage model:
// create and configure Handle
var (
bh codec.BincHandle
mh codec.MsgpackHandle
)
mh.MapType = reflect.TypeOf(map[string]interface{}(nil))
// configure extensions
// e.g. for msgpack, define functions and enable Time support for tag 1
// mh.AddExt(reflect.TypeOf(time.Time{}), 1, myMsgpackTimeEncodeExtFn, myMsgpackTimeDecodeExtFn)
// create and use decoder/encoder
var (
r io.Reader
w io.Writer
b []byte
h = &bh // or mh to use msgpack
)
dec = codec.NewDecoder(r, h)
dec = codec.NewDecoderBytes(b, h)
err = dec.Decode(&v)
enc = codec.NewEncoder(w, h)
enc = codec.NewEncoderBytes(&b, h)
err = enc.Encode(v)
//RPC Server
go func() {
for {
conn, err := listener.Accept()
rpcCodec := codec.GoRpc.ServerCodec(conn, h)
//OR rpcCodec := codec.MsgpackSpecRpc.ServerCodec(conn, h)
rpc.ServeCodec(rpcCodec)
}
}()
//RPC Communication (client side)
conn, err = net.Dial("tcp", "localhost:5555")
rpcCodec := codec.GoRpc.ClientCodec(conn, h)
//OR rpcCodec := codec.MsgpackSpecRpc.ClientCodec(conn, h)
client := rpc.NewClientWithCodec(rpcCodec)
## Representative Benchmark Results
A sample run of benchmark using "go test -bi -bench=. -benchmem":
/proc/cpuinfo: Intel(R) Core(TM) i7-2630QM CPU @ 2.00GHz (HT)
..............................................
BENCHMARK INIT: 2013-10-16 11:02:50.345970786 -0400 EDT
To run full benchmark comparing encodings (MsgPack, Binc, JSON, GOB, etc), use: "go test -bench=."
Benchmark:
Struct recursive Depth: 1
ApproxDeepSize Of benchmark Struct: 4694 bytes
Benchmark One-Pass Run:
v-msgpack: len: 1600 bytes
bson: len: 3025 bytes
msgpack: len: 1560 bytes
binc: len: 1187 bytes
gob: len: 1972 bytes
json: len: 2538 bytes
..............................................
PASS
Benchmark__Msgpack____Encode 50000 54359 ns/op 14953 B/op 83 allocs/op
Benchmark__Msgpack____Decode 10000 106531 ns/op 14990 B/op 410 allocs/op
Benchmark__Binc_NoSym_Encode 50000 53956 ns/op 14966 B/op 83 allocs/op
Benchmark__Binc_NoSym_Decode 10000 103751 ns/op 14529 B/op 386 allocs/op
Benchmark__Binc_Sym___Encode 50000 65961 ns/op 17130 B/op 88 allocs/op
Benchmark__Binc_Sym___Decode 10000 106310 ns/op 15857 B/op 287 allocs/op
Benchmark__Gob________Encode 10000 135944 ns/op 21189 B/op 237 allocs/op
Benchmark__Gob________Decode 5000 405390 ns/op 83460 B/op 1841 allocs/op
Benchmark__Json_______Encode 20000 79412 ns/op 13874 B/op 102 allocs/op
Benchmark__Json_______Decode 10000 247979 ns/op 14202 B/op 493 allocs/op
Benchmark__Bson_______Encode 10000 121762 ns/op 27814 B/op 514 allocs/op
Benchmark__Bson_______Decode 10000 162126 ns/op 16514 B/op 789 allocs/op
Benchmark__VMsgpack___Encode 50000 69155 ns/op 12370 B/op 344 allocs/op
Benchmark__VMsgpack___Decode 10000 151609 ns/op 20307 B/op 571 allocs/op
ok ugorji.net/codec 30.827s
To run full benchmark suite (including against vmsgpack and bson),
see notes in ext\_dep\_test.go

View file

@ -0,0 +1,786 @@
// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved.
// Use of this source code is governed by a BSD-style license found in the LICENSE file.
package codec
import (
"math"
// "reflect"
// "sync/atomic"
"time"
//"fmt"
)
const bincDoPrune = true // No longer needed. Needed before as C lib did not support pruning.
//var _ = fmt.Printf
// vd as low 4 bits (there are 16 slots)
const (
bincVdSpecial byte = iota
bincVdPosInt
bincVdNegInt
bincVdFloat
bincVdString
bincVdByteArray
bincVdArray
bincVdMap
bincVdTimestamp
bincVdSmallInt
bincVdUnicodeOther
bincVdSymbol
bincVdDecimal
_ // open slot
_ // open slot
bincVdCustomExt = 0x0f
)
const (
bincSpNil byte = iota
bincSpFalse
bincSpTrue
bincSpNan
bincSpPosInf
bincSpNegInf
bincSpZeroFloat
bincSpZero
bincSpNegOne
)
const (
bincFlBin16 byte = iota
bincFlBin32
_ // bincFlBin32e
bincFlBin64
_ // bincFlBin64e
// others not currently supported
)
type bincEncDriver struct {
w encWriter
m map[string]uint16 // symbols
s uint32 // symbols sequencer
b [8]byte
}
func (e *bincEncDriver) isBuiltinType(rt uintptr) bool {
return rt == timeTypId
}
func (e *bincEncDriver) encodeBuiltin(rt uintptr, v interface{}) {
switch rt {
case timeTypId:
bs := encodeTime(v.(time.Time))
e.w.writen1(bincVdTimestamp<<4 | uint8(len(bs)))
e.w.writeb(bs)
}
}
func (e *bincEncDriver) encodeNil() {
e.w.writen1(bincVdSpecial<<4 | bincSpNil)
}
func (e *bincEncDriver) encodeBool(b bool) {
if b {
e.w.writen1(bincVdSpecial<<4 | bincSpTrue)
} else {
e.w.writen1(bincVdSpecial<<4 | bincSpFalse)
}
}
func (e *bincEncDriver) encodeFloat32(f float32) {
if f == 0 {
e.w.writen1(bincVdSpecial<<4 | bincSpZeroFloat)
return
}
e.w.writen1(bincVdFloat<<4 | bincFlBin32)
e.w.writeUint32(math.Float32bits(f))
}
func (e *bincEncDriver) encodeFloat64(f float64) {
if f == 0 {
e.w.writen1(bincVdSpecial<<4 | bincSpZeroFloat)
return
}
bigen.PutUint64(e.b[:], math.Float64bits(f))
if bincDoPrune {
i := 7
for ; i >= 0 && (e.b[i] == 0); i-- {
}
i++
if i <= 6 {
e.w.writen1(bincVdFloat<<4 | 0x8 | bincFlBin64)
e.w.writen1(byte(i))
e.w.writeb(e.b[:i])
return
}
}
e.w.writen1(bincVdFloat<<4 | bincFlBin64)
e.w.writeb(e.b[:])
}
func (e *bincEncDriver) encIntegerPrune(bd byte, pos bool, v uint64, lim uint8) {
if lim == 4 {
bigen.PutUint32(e.b[:lim], uint32(v))
} else {
bigen.PutUint64(e.b[:lim], v)
}
if bincDoPrune {
i := pruneSignExt(e.b[:lim], pos)
e.w.writen1(bd | lim - 1 - byte(i))
e.w.writeb(e.b[i:lim])
} else {
e.w.writen1(bd | lim - 1)
e.w.writeb(e.b[:lim])
}
}
func (e *bincEncDriver) encodeInt(v int64) {
const nbd byte = bincVdNegInt << 4
switch {
case v >= 0:
e.encUint(bincVdPosInt<<4, true, uint64(v))
case v == -1:
e.w.writen1(bincVdSpecial<<4 | bincSpNegOne)
default:
e.encUint(bincVdNegInt<<4, false, uint64(-v))
}
}
func (e *bincEncDriver) encodeUint(v uint64) {
e.encUint(bincVdPosInt<<4, true, v)
}
func (e *bincEncDriver) encUint(bd byte, pos bool, v uint64) {
switch {
case v == 0:
e.w.writen1(bincVdSpecial<<4 | bincSpZero)
case pos && v >= 1 && v <= 16:
e.w.writen1(bincVdSmallInt<<4 | byte(v-1))
case v <= math.MaxUint8:
e.w.writen2(bd|0x0, byte(v))
case v <= math.MaxUint16:
e.w.writen1(bd | 0x01)
e.w.writeUint16(uint16(v))
case v <= math.MaxUint32:
e.encIntegerPrune(bd, pos, v, 4)
default:
e.encIntegerPrune(bd, pos, v, 8)
}
}
func (e *bincEncDriver) encodeExtPreamble(xtag byte, length int) {
e.encLen(bincVdCustomExt<<4, uint64(length))
e.w.writen1(xtag)
}
func (e *bincEncDriver) encodeArrayPreamble(length int) {
e.encLen(bincVdArray<<4, uint64(length))
}
func (e *bincEncDriver) encodeMapPreamble(length int) {
e.encLen(bincVdMap<<4, uint64(length))
}
func (e *bincEncDriver) encodeString(c charEncoding, v string) {
l := uint64(len(v))
e.encBytesLen(c, l)
if l > 0 {
e.w.writestr(v)
}
}
func (e *bincEncDriver) encodeSymbol(v string) {
// if WriteSymbolsNoRefs {
// e.encodeString(c_UTF8, v)
// return
// }
//symbols only offer benefit when string length > 1.
//This is because strings with length 1 take only 2 bytes to store
//(bd with embedded length, and single byte for string val).
l := len(v)
switch l {
case 0:
e.encBytesLen(c_UTF8, 0)
return
case 1:
e.encBytesLen(c_UTF8, 1)
e.w.writen1(v[0])
return
}
if e.m == nil {
e.m = make(map[string]uint16, 16)
}
ui, ok := e.m[v]
if ok {
if ui <= math.MaxUint8 {
e.w.writen2(bincVdSymbol<<4, byte(ui))
} else {
e.w.writen1(bincVdSymbol<<4 | 0x8)
e.w.writeUint16(ui)
}
} else {
e.s++
ui = uint16(e.s)
//ui = uint16(atomic.AddUint32(&e.s, 1))
e.m[v] = ui
var lenprec uint8
switch {
case l <= math.MaxUint8:
// lenprec = 0
case l <= math.MaxUint16:
lenprec = 1
case int64(l) <= math.MaxUint32:
lenprec = 2
default:
lenprec = 3
}
if ui <= math.MaxUint8 {
e.w.writen2(bincVdSymbol<<4|0x0|0x4|lenprec, byte(ui))
} else {
e.w.writen1(bincVdSymbol<<4 | 0x8 | 0x4 | lenprec)
e.w.writeUint16(ui)
}
switch lenprec {
case 0:
e.w.writen1(byte(l))
case 1:
e.w.writeUint16(uint16(l))
case 2:
e.w.writeUint32(uint32(l))
default:
e.w.writeUint64(uint64(l))
}
e.w.writestr(v)
}
}
func (e *bincEncDriver) encodeStringBytes(c charEncoding, v []byte) {
l := uint64(len(v))
e.encBytesLen(c, l)
if l > 0 {
e.w.writeb(v)
}
}
func (e *bincEncDriver) encBytesLen(c charEncoding, length uint64) {
//TODO: support bincUnicodeOther (for now, just use string or bytearray)
if c == c_RAW {
e.encLen(bincVdByteArray<<4, length)
} else {
e.encLen(bincVdString<<4, length)
}
}
func (e *bincEncDriver) encLen(bd byte, l uint64) {
if l < 12 {
e.w.writen1(bd | uint8(l+4))
} else {
e.encLenNumber(bd, l)
}
}
func (e *bincEncDriver) encLenNumber(bd byte, v uint64) {
switch {
case v <= math.MaxUint8:
e.w.writen2(bd, byte(v))
case v <= math.MaxUint16:
e.w.writen1(bd | 0x01)
e.w.writeUint16(uint16(v))
case v <= math.MaxUint32:
e.w.writen1(bd | 0x02)
e.w.writeUint32(uint32(v))
default:
e.w.writen1(bd | 0x03)
e.w.writeUint64(uint64(v))
}
}
//------------------------------------
type bincDecDriver struct {
r decReader
bdRead bool
bdType valueType
bd byte
vd byte
vs byte
b [8]byte
m map[uint32]string // symbols (use uint32 as key, as map optimizes for it)
}
func (d *bincDecDriver) initReadNext() {
if d.bdRead {
return
}
d.bd = d.r.readn1()
d.vd = d.bd >> 4
d.vs = d.bd & 0x0f
d.bdRead = true
d.bdType = valueTypeUnset
}
func (d *bincDecDriver) currentEncodedType() valueType {
if d.bdType == valueTypeUnset {
switch d.vd {
case bincVdSpecial:
switch d.vs {
case bincSpNil:
d.bdType = valueTypeNil
case bincSpFalse, bincSpTrue:
d.bdType = valueTypeBool
case bincSpNan, bincSpNegInf, bincSpPosInf, bincSpZeroFloat:
d.bdType = valueTypeFloat
case bincSpZero:
d.bdType = valueTypeUint
case bincSpNegOne:
d.bdType = valueTypeInt
default:
decErr("currentEncodedType: Unrecognized special value 0x%x", d.vs)
}
case bincVdSmallInt:
d.bdType = valueTypeUint
case bincVdPosInt:
d.bdType = valueTypeUint
case bincVdNegInt:
d.bdType = valueTypeInt
case bincVdFloat:
d.bdType = valueTypeFloat
case bincVdString:
d.bdType = valueTypeString
case bincVdSymbol:
d.bdType = valueTypeSymbol
case bincVdByteArray:
d.bdType = valueTypeBytes
case bincVdTimestamp:
d.bdType = valueTypeTimestamp
case bincVdCustomExt:
d.bdType = valueTypeExt
case bincVdArray:
d.bdType = valueTypeArray
case bincVdMap:
d.bdType = valueTypeMap
default:
decErr("currentEncodedType: Unrecognized d.vd: 0x%x", d.vd)
}
}
return d.bdType
}
func (d *bincDecDriver) tryDecodeAsNil() bool {
if d.bd == bincVdSpecial<<4|bincSpNil {
d.bdRead = false
return true
}
return false
}
func (d *bincDecDriver) isBuiltinType(rt uintptr) bool {
return rt == timeTypId
}
func (d *bincDecDriver) decodeBuiltin(rt uintptr, v interface{}) {
switch rt {
case timeTypId:
if d.vd != bincVdTimestamp {
decErr("Invalid d.vd. Expecting 0x%x. Received: 0x%x", bincVdTimestamp, d.vd)
}
tt, err := decodeTime(d.r.readn(int(d.vs)))
if err != nil {
panic(err)
}
var vt *time.Time = v.(*time.Time)
*vt = tt
d.bdRead = false
}
}
func (d *bincDecDriver) decFloatPre(vs, defaultLen byte) {
if vs&0x8 == 0 {
d.r.readb(d.b[0:defaultLen])
} else {
l := d.r.readn1()
if l > 8 {
decErr("At most 8 bytes used to represent float. Received: %v bytes", l)
}
for i := l; i < 8; i++ {
d.b[i] = 0
}
d.r.readb(d.b[0:l])
}
}
func (d *bincDecDriver) decFloat() (f float64) {
//if true { f = math.Float64frombits(d.r.readUint64()); break; }
switch vs := d.vs; vs & 0x7 {
case bincFlBin32:
d.decFloatPre(vs, 4)
f = float64(math.Float32frombits(bigen.Uint32(d.b[0:4])))
case bincFlBin64:
d.decFloatPre(vs, 8)
f = math.Float64frombits(bigen.Uint64(d.b[0:8]))
default:
decErr("only float32 and float64 are supported. d.vd: 0x%x, d.vs: 0x%x", d.vd, d.vs)
}
return
}
func (d *bincDecDriver) decUint() (v uint64) {
// need to inline the code (interface conversion and type assertion expensive)
switch d.vs {
case 0:
v = uint64(d.r.readn1())
case 1:
d.r.readb(d.b[6:])
v = uint64(bigen.Uint16(d.b[6:]))
case 2:
d.b[4] = 0
d.r.readb(d.b[5:])
v = uint64(bigen.Uint32(d.b[4:]))
case 3:
d.r.readb(d.b[4:])
v = uint64(bigen.Uint32(d.b[4:]))
case 4, 5, 6:
lim := int(7 - d.vs)
d.r.readb(d.b[lim:])
for i := 0; i < lim; i++ {
d.b[i] = 0
}
v = uint64(bigen.Uint64(d.b[:]))
case 7:
d.r.readb(d.b[:])
v = uint64(bigen.Uint64(d.b[:]))
default:
decErr("unsigned integers with greater than 64 bits of precision not supported")
}
return
}
func (d *bincDecDriver) decIntAny() (ui uint64, i int64, neg bool) {
switch d.vd {
case bincVdPosInt:
ui = d.decUint()
i = int64(ui)
case bincVdNegInt:
ui = d.decUint()
i = -(int64(ui))
neg = true
case bincVdSmallInt:
i = int64(d.vs) + 1
ui = uint64(d.vs) + 1
case bincVdSpecial:
switch d.vs {
case bincSpZero:
//i = 0
case bincSpNegOne:
neg = true
ui = 1
i = -1
default:
decErr("numeric decode fails for special value: d.vs: 0x%x", d.vs)
}
default:
decErr("number can only be decoded from uint or int values. d.bd: 0x%x, d.vd: 0x%x", d.bd, d.vd)
}
return
}
func (d *bincDecDriver) decodeInt(bitsize uint8) (i int64) {
_, i, _ = d.decIntAny()
checkOverflow(0, i, bitsize)
d.bdRead = false
return
}
func (d *bincDecDriver) decodeUint(bitsize uint8) (ui uint64) {
ui, i, neg := d.decIntAny()
if neg {
decErr("Assigning negative signed value: %v, to unsigned type", i)
}
checkOverflow(ui, 0, bitsize)
d.bdRead = false
return
}
func (d *bincDecDriver) decodeFloat(chkOverflow32 bool) (f float64) {
switch d.vd {
case bincVdSpecial:
d.bdRead = false
switch d.vs {
case bincSpNan:
return math.NaN()
case bincSpPosInf:
return math.Inf(1)
case bincSpZeroFloat, bincSpZero:
return
case bincSpNegInf:
return math.Inf(-1)
default:
decErr("Invalid d.vs decoding float where d.vd=bincVdSpecial: %v", d.vs)
}
case bincVdFloat:
f = d.decFloat()
default:
_, i, _ := d.decIntAny()
f = float64(i)
}
checkOverflowFloat32(f, chkOverflow32)
d.bdRead = false
return
}
// bool can be decoded from bool only (single byte).
func (d *bincDecDriver) decodeBool() (b bool) {
switch d.bd {
case (bincVdSpecial | bincSpFalse):
// b = false
case (bincVdSpecial | bincSpTrue):
b = true
default:
decErr("Invalid single-byte value for bool: %s: %x", msgBadDesc, d.bd)
}
d.bdRead = false
return
}
func (d *bincDecDriver) readMapLen() (length int) {
if d.vd != bincVdMap {
decErr("Invalid d.vd for map. Expecting 0x%x. Got: 0x%x", bincVdMap, d.vd)
}
length = d.decLen()
d.bdRead = false
return
}
func (d *bincDecDriver) readArrayLen() (length int) {
if d.vd != bincVdArray {
decErr("Invalid d.vd for array. Expecting 0x%x. Got: 0x%x", bincVdArray, d.vd)
}
length = d.decLen()
d.bdRead = false
return
}
func (d *bincDecDriver) decLen() int {
if d.vs <= 3 {
return int(d.decUint())
}
return int(d.vs - 4)
}
func (d *bincDecDriver) decodeString() (s string) {
switch d.vd {
case bincVdString, bincVdByteArray:
if length := d.decLen(); length > 0 {
s = string(d.r.readn(length))
}
case bincVdSymbol:
//from vs: extract numSymbolBytes, containsStringVal, strLenPrecision,
//extract symbol
//if containsStringVal, read it and put in map
//else look in map for string value
var symbol uint32
vs := d.vs
//fmt.Printf(">>>> d.vs: 0b%b, & 0x8: %v, & 0x4: %v\n", d.vs, vs & 0x8, vs & 0x4)
if vs&0x8 == 0 {
symbol = uint32(d.r.readn1())
} else {
symbol = uint32(d.r.readUint16())
}
if d.m == nil {
d.m = make(map[uint32]string, 16)
}
if vs&0x4 == 0 {
s = d.m[symbol]
} else {
var slen int
switch vs & 0x3 {
case 0:
slen = int(d.r.readn1())
case 1:
slen = int(d.r.readUint16())
case 2:
slen = int(d.r.readUint32())
case 3:
slen = int(d.r.readUint64())
}
s = string(d.r.readn(slen))
d.m[symbol] = s
}
default:
decErr("Invalid d.vd for string. Expecting string:0x%x, bytearray:0x%x or symbol: 0x%x. Got: 0x%x",
bincVdString, bincVdByteArray, bincVdSymbol, d.vd)
}
d.bdRead = false
return
}
func (d *bincDecDriver) decodeBytes(bs []byte) (bsOut []byte, changed bool) {
var clen int
switch d.vd {
case bincVdString, bincVdByteArray:
clen = d.decLen()
default:
decErr("Invalid d.vd for bytes. Expecting string:0x%x or bytearray:0x%x. Got: 0x%x",
bincVdString, bincVdByteArray, d.vd)
}
if clen > 0 {
// if no contents in stream, don't update the passed byteslice
if len(bs) != clen {
if len(bs) > clen {
bs = bs[:clen]
} else {
bs = make([]byte, clen)
}
bsOut = bs
changed = true
}
d.r.readb(bs)
}
d.bdRead = false
return
}
func (d *bincDecDriver) decodeExt(verifyTag bool, tag byte) (xtag byte, xbs []byte) {
switch d.vd {
case bincVdCustomExt:
l := d.decLen()
xtag = d.r.readn1()
if verifyTag && xtag != tag {
decErr("Wrong extension tag. Got %b. Expecting: %v", xtag, tag)
}
xbs = d.r.readn(l)
case bincVdByteArray:
xbs, _ = d.decodeBytes(nil)
default:
decErr("Invalid d.vd for extensions (Expecting extensions or byte array). Got: 0x%x", d.vd)
}
d.bdRead = false
return
}
func (d *bincDecDriver) decodeNaked() (v interface{}, vt valueType, decodeFurther bool) {
d.initReadNext()
switch d.vd {
case bincVdSpecial:
switch d.vs {
case bincSpNil:
vt = valueTypeNil
case bincSpFalse:
vt = valueTypeBool
v = false
case bincSpTrue:
vt = valueTypeBool
v = true
case bincSpNan:
vt = valueTypeFloat
v = math.NaN()
case bincSpPosInf:
vt = valueTypeFloat
v = math.Inf(1)
case bincSpNegInf:
vt = valueTypeFloat
v = math.Inf(-1)
case bincSpZeroFloat:
vt = valueTypeFloat
v = float64(0)
case bincSpZero:
vt = valueTypeUint
v = int64(0) // int8(0)
case bincSpNegOne:
vt = valueTypeInt
v = int64(-1) // int8(-1)
default:
decErr("decodeNaked: Unrecognized special value 0x%x", d.vs)
}
case bincVdSmallInt:
vt = valueTypeUint
v = uint64(int8(d.vs)) + 1 // int8(d.vs) + 1
case bincVdPosInt:
vt = valueTypeUint
v = d.decUint()
case bincVdNegInt:
vt = valueTypeInt
v = -(int64(d.decUint()))
case bincVdFloat:
vt = valueTypeFloat
v = d.decFloat()
case bincVdSymbol:
vt = valueTypeSymbol
v = d.decodeString()
case bincVdString:
vt = valueTypeString
v = d.decodeString()
case bincVdByteArray:
vt = valueTypeBytes
v, _ = d.decodeBytes(nil)
case bincVdTimestamp:
vt = valueTypeTimestamp
tt, err := decodeTime(d.r.readn(int(d.vs)))
if err != nil {
panic(err)
}
v = tt
case bincVdCustomExt:
vt = valueTypeExt
l := d.decLen()
var re RawExt
re.Tag = d.r.readn1()
re.Data = d.r.readn(l)
v = &re
vt = valueTypeExt
case bincVdArray:
vt = valueTypeArray
decodeFurther = true
case bincVdMap:
vt = valueTypeMap
decodeFurther = true
default:
decErr("decodeNaked: Unrecognized d.vd: 0x%x", d.vd)
}
if !decodeFurther {
d.bdRead = false
}
return
}
//------------------------------------
//BincHandle is a Handle for the Binc Schema-Free Encoding Format
//defined at https://github.com/ugorji/binc .
//
//BincHandle currently supports all Binc features with the following EXCEPTIONS:
// - only integers up to 64 bits of precision are supported.
// big integers are unsupported.
// - Only IEEE 754 binary32 and binary64 floats are supported (ie Go float32 and float64 types).
// extended precision and decimal IEEE 754 floats are unsupported.
// - Only UTF-8 strings supported.
// Unicode_Other Binc types (UTF16, UTF32) are currently unsupported.
//Note that these EXCEPTIONS are temporary and full support is possible and may happen soon.
type BincHandle struct {
BasicHandle
}
func (h *BincHandle) newEncDriver(w encWriter) encDriver {
return &bincEncDriver{w: w}
}
func (h *BincHandle) newDecDriver(r decReader) decDriver {
return &bincDecDriver{r: r}
}
func (_ *BincHandle) writeExt() bool {
return true
}
func (h *BincHandle) getBasicHandle() *BasicHandle {
return &h.BasicHandle
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,589 @@
// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved.
// Use of this source code is governed by a BSD-style license found in the LICENSE file.
package codec
// Contains code shared by both encode and decode.
import (
"encoding/binary"
"fmt"
"math"
"reflect"
"sort"
"strings"
"sync"
"time"
"unicode"
"unicode/utf8"
)
const (
structTagName = "codec"
// Support
// encoding.BinaryMarshaler: MarshalBinary() (data []byte, err error)
// encoding.BinaryUnmarshaler: UnmarshalBinary(data []byte) error
// This constant flag will enable or disable it.
supportBinaryMarshal = true
// Each Encoder or Decoder uses a cache of functions based on conditionals,
// so that the conditionals are not run every time.
//
// Either a map or a slice is used to keep track of the functions.
// The map is more natural, but has a higher cost than a slice/array.
// This flag (useMapForCodecCache) controls which is used.
useMapForCodecCache = false
// For some common container types, we can short-circuit an elaborate
// reflection dance and call encode/decode directly.
// The currently supported types are:
// - slices of strings, or id's (int64,uint64) or interfaces.
// - maps of str->str, str->intf, id(int64,uint64)->intf, intf->intf
shortCircuitReflectToFastPath = true
// for debugging, set this to false, to catch panic traces.
// Note that this will always cause rpc tests to fail, since they need io.EOF sent via panic.
recoverPanicToErr = true
)
type charEncoding uint8
const (
c_RAW charEncoding = iota
c_UTF8
c_UTF16LE
c_UTF16BE
c_UTF32LE
c_UTF32BE
)
// valueType is the stream type
type valueType uint8
const (
valueTypeUnset valueType = iota
valueTypeNil
valueTypeInt
valueTypeUint
valueTypeFloat
valueTypeBool
valueTypeString
valueTypeSymbol
valueTypeBytes
valueTypeMap
valueTypeArray
valueTypeTimestamp
valueTypeExt
valueTypeInvalid = 0xff
)
var (
bigen = binary.BigEndian
structInfoFieldName = "_struct"
cachedTypeInfo = make(map[uintptr]*typeInfo, 4)
cachedTypeInfoMutex sync.RWMutex
intfSliceTyp = reflect.TypeOf([]interface{}(nil))
intfTyp = intfSliceTyp.Elem()
strSliceTyp = reflect.TypeOf([]string(nil))
boolSliceTyp = reflect.TypeOf([]bool(nil))
uintSliceTyp = reflect.TypeOf([]uint(nil))
uint8SliceTyp = reflect.TypeOf([]uint8(nil))
uint16SliceTyp = reflect.TypeOf([]uint16(nil))
uint32SliceTyp = reflect.TypeOf([]uint32(nil))
uint64SliceTyp = reflect.TypeOf([]uint64(nil))
intSliceTyp = reflect.TypeOf([]int(nil))
int8SliceTyp = reflect.TypeOf([]int8(nil))
int16SliceTyp = reflect.TypeOf([]int16(nil))
int32SliceTyp = reflect.TypeOf([]int32(nil))
int64SliceTyp = reflect.TypeOf([]int64(nil))
float32SliceTyp = reflect.TypeOf([]float32(nil))
float64SliceTyp = reflect.TypeOf([]float64(nil))
mapIntfIntfTyp = reflect.TypeOf(map[interface{}]interface{}(nil))
mapStrIntfTyp = reflect.TypeOf(map[string]interface{}(nil))
mapStrStrTyp = reflect.TypeOf(map[string]string(nil))
mapIntIntfTyp = reflect.TypeOf(map[int]interface{}(nil))
mapInt64IntfTyp = reflect.TypeOf(map[int64]interface{}(nil))
mapUintIntfTyp = reflect.TypeOf(map[uint]interface{}(nil))
mapUint64IntfTyp = reflect.TypeOf(map[uint64]interface{}(nil))
stringTyp = reflect.TypeOf("")
timeTyp = reflect.TypeOf(time.Time{})
rawExtTyp = reflect.TypeOf(RawExt{})
mapBySliceTyp = reflect.TypeOf((*MapBySlice)(nil)).Elem()
binaryMarshalerTyp = reflect.TypeOf((*binaryMarshaler)(nil)).Elem()
binaryUnmarshalerTyp = reflect.TypeOf((*binaryUnmarshaler)(nil)).Elem()
rawExtTypId = reflect.ValueOf(rawExtTyp).Pointer()
intfTypId = reflect.ValueOf(intfTyp).Pointer()
timeTypId = reflect.ValueOf(timeTyp).Pointer()
intfSliceTypId = reflect.ValueOf(intfSliceTyp).Pointer()
strSliceTypId = reflect.ValueOf(strSliceTyp).Pointer()
boolSliceTypId = reflect.ValueOf(boolSliceTyp).Pointer()
uintSliceTypId = reflect.ValueOf(uintSliceTyp).Pointer()
uint8SliceTypId = reflect.ValueOf(uint8SliceTyp).Pointer()
uint16SliceTypId = reflect.ValueOf(uint16SliceTyp).Pointer()
uint32SliceTypId = reflect.ValueOf(uint32SliceTyp).Pointer()
uint64SliceTypId = reflect.ValueOf(uint64SliceTyp).Pointer()
intSliceTypId = reflect.ValueOf(intSliceTyp).Pointer()
int8SliceTypId = reflect.ValueOf(int8SliceTyp).Pointer()
int16SliceTypId = reflect.ValueOf(int16SliceTyp).Pointer()
int32SliceTypId = reflect.ValueOf(int32SliceTyp).Pointer()
int64SliceTypId = reflect.ValueOf(int64SliceTyp).Pointer()
float32SliceTypId = reflect.ValueOf(float32SliceTyp).Pointer()
float64SliceTypId = reflect.ValueOf(float64SliceTyp).Pointer()
mapStrStrTypId = reflect.ValueOf(mapStrStrTyp).Pointer()
mapIntfIntfTypId = reflect.ValueOf(mapIntfIntfTyp).Pointer()
mapStrIntfTypId = reflect.ValueOf(mapStrIntfTyp).Pointer()
mapIntIntfTypId = reflect.ValueOf(mapIntIntfTyp).Pointer()
mapInt64IntfTypId = reflect.ValueOf(mapInt64IntfTyp).Pointer()
mapUintIntfTypId = reflect.ValueOf(mapUintIntfTyp).Pointer()
mapUint64IntfTypId = reflect.ValueOf(mapUint64IntfTyp).Pointer()
// Id = reflect.ValueOf().Pointer()
// mapBySliceTypId = reflect.ValueOf(mapBySliceTyp).Pointer()
binaryMarshalerTypId = reflect.ValueOf(binaryMarshalerTyp).Pointer()
binaryUnmarshalerTypId = reflect.ValueOf(binaryUnmarshalerTyp).Pointer()
intBitsize uint8 = uint8(reflect.TypeOf(int(0)).Bits())
uintBitsize uint8 = uint8(reflect.TypeOf(uint(0)).Bits())
bsAll0x00 = []byte{0, 0, 0, 0, 0, 0, 0, 0}
bsAll0xff = []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
)
type binaryUnmarshaler interface {
UnmarshalBinary(data []byte) error
}
type binaryMarshaler interface {
MarshalBinary() (data []byte, err error)
}
// MapBySlice represents a slice which should be encoded as a map in the stream.
// The slice contains a sequence of key-value pairs.
type MapBySlice interface {
MapBySlice()
}
// WARNING: DO NOT USE DIRECTLY. EXPORTED FOR GODOC BENEFIT. WILL BE REMOVED.
//
// BasicHandle encapsulates the common options and extension functions.
type BasicHandle struct {
extHandle
EncodeOptions
DecodeOptions
}
// Handle is the interface for a specific encoding format.
//
// Typically, a Handle is pre-configured before first time use,
// and not modified while in use. Such a pre-configured Handle
// is safe for concurrent access.
type Handle interface {
writeExt() bool
getBasicHandle() *BasicHandle
newEncDriver(w encWriter) encDriver
newDecDriver(r decReader) decDriver
}
// RawExt represents raw unprocessed extension data.
type RawExt struct {
Tag byte
Data []byte
}
type extTypeTagFn struct {
rtid uintptr
rt reflect.Type
tag byte
encFn func(reflect.Value) ([]byte, error)
decFn func(reflect.Value, []byte) error
}
type extHandle []*extTypeTagFn
// AddExt registers an encode and decode function for a reflect.Type.
// Note that the type must be a named type, and specifically not
// a pointer or Interface. An error is returned if that is not honored.
//
// To Deregister an ext, call AddExt with 0 tag, nil encfn and nil decfn.
func (o *extHandle) AddExt(
rt reflect.Type,
tag byte,
encfn func(reflect.Value) ([]byte, error),
decfn func(reflect.Value, []byte) error,
) (err error) {
// o is a pointer, because we may need to initialize it
if rt.PkgPath() == "" || rt.Kind() == reflect.Interface {
err = fmt.Errorf("codec.Handle.AddExt: Takes named type, especially not a pointer or interface: %T",
reflect.Zero(rt).Interface())
return
}
// o cannot be nil, since it is always embedded in a Handle.
// if nil, let it panic.
// if o == nil {
// err = errors.New("codec.Handle.AddExt: extHandle cannot be a nil pointer.")
// return
// }
rtid := reflect.ValueOf(rt).Pointer()
for _, v := range *o {
if v.rtid == rtid {
v.tag, v.encFn, v.decFn = tag, encfn, decfn
return
}
}
*o = append(*o, &extTypeTagFn{rtid, rt, tag, encfn, decfn})
return
}
func (o extHandle) getExt(rtid uintptr) *extTypeTagFn {
for _, v := range o {
if v.rtid == rtid {
return v
}
}
return nil
}
func (o extHandle) getExtForTag(tag byte) *extTypeTagFn {
for _, v := range o {
if v.tag == tag {
return v
}
}
return nil
}
func (o extHandle) getDecodeExtForTag(tag byte) (
rv reflect.Value, fn func(reflect.Value, []byte) error) {
if x := o.getExtForTag(tag); x != nil {
// ext is only registered for base
rv = reflect.New(x.rt).Elem()
fn = x.decFn
}
return
}
func (o extHandle) getDecodeExt(rtid uintptr) (tag byte, fn func(reflect.Value, []byte) error) {
if x := o.getExt(rtid); x != nil {
tag = x.tag
fn = x.decFn
}
return
}
func (o extHandle) getEncodeExt(rtid uintptr) (tag byte, fn func(reflect.Value) ([]byte, error)) {
if x := o.getExt(rtid); x != nil {
tag = x.tag
fn = x.encFn
}
return
}
type structFieldInfo struct {
encName string // encode name
// only one of 'i' or 'is' can be set. If 'i' is -1, then 'is' has been set.
is []int // (recursive/embedded) field index in struct
i int16 // field index in struct
omitEmpty bool
toArray bool // if field is _struct, is the toArray set?
// tag string // tag
// name string // field name
// encNameBs []byte // encoded name as byte stream
// ikind int // kind of the field as an int i.e. int(reflect.Kind)
}
func parseStructFieldInfo(fname string, stag string) *structFieldInfo {
if fname == "" {
panic("parseStructFieldInfo: No Field Name")
}
si := structFieldInfo{
// name: fname,
encName: fname,
// tag: stag,
}
if stag != "" {
for i, s := range strings.Split(stag, ",") {
if i == 0 {
if s != "" {
si.encName = s
}
} else {
switch s {
case "omitempty":
si.omitEmpty = true
case "toarray":
si.toArray = true
}
}
}
}
// si.encNameBs = []byte(si.encName)
return &si
}
type sfiSortedByEncName []*structFieldInfo
func (p sfiSortedByEncName) Len() int {
return len(p)
}
func (p sfiSortedByEncName) Less(i, j int) bool {
return p[i].encName < p[j].encName
}
func (p sfiSortedByEncName) Swap(i, j int) {
p[i], p[j] = p[j], p[i]
}
// typeInfo keeps information about each type referenced in the encode/decode sequence.
//
// During an encode/decode sequence, we work as below:
// - If base is a built in type, en/decode base value
// - If base is registered as an extension, en/decode base value
// - If type is binary(M/Unm)arshaler, call Binary(M/Unm)arshal method
// - Else decode appropriately based on the reflect.Kind
type typeInfo struct {
sfi []*structFieldInfo // sorted. Used when enc/dec struct to map.
sfip []*structFieldInfo // unsorted. Used when enc/dec struct to array.
rt reflect.Type
rtid uintptr
// baseId gives pointer to the base reflect.Type, after deferencing
// the pointers. E.g. base type of ***time.Time is time.Time.
base reflect.Type
baseId uintptr
baseIndir int8 // number of indirections to get to base
mbs bool // base type (T or *T) is a MapBySlice
m bool // base type (T or *T) is a binaryMarshaler
unm bool // base type (T or *T) is a binaryUnmarshaler
mIndir int8 // number of indirections to get to binaryMarshaler type
unmIndir int8 // number of indirections to get to binaryUnmarshaler type
toArray bool // whether this (struct) type should be encoded as an array
}
func (ti *typeInfo) indexForEncName(name string) int {
//tisfi := ti.sfi
const binarySearchThreshold = 16
if sfilen := len(ti.sfi); sfilen < binarySearchThreshold {
// linear search. faster than binary search in my testing up to 16-field structs.
for i, si := range ti.sfi {
if si.encName == name {
return i
}
}
} else {
// binary search. adapted from sort/search.go.
h, i, j := 0, 0, sfilen
for i < j {
h = i + (j-i)/2
if ti.sfi[h].encName < name {
i = h + 1
} else {
j = h
}
}
if i < sfilen && ti.sfi[i].encName == name {
return i
}
}
return -1
}
func getTypeInfo(rtid uintptr, rt reflect.Type) (pti *typeInfo) {
var ok bool
cachedTypeInfoMutex.RLock()
pti, ok = cachedTypeInfo[rtid]
cachedTypeInfoMutex.RUnlock()
if ok {
return
}
cachedTypeInfoMutex.Lock()
defer cachedTypeInfoMutex.Unlock()
if pti, ok = cachedTypeInfo[rtid]; ok {
return
}
ti := typeInfo{rt: rt, rtid: rtid}
pti = &ti
var indir int8
if ok, indir = implementsIntf(rt, binaryMarshalerTyp); ok {
ti.m, ti.mIndir = true, indir
}
if ok, indir = implementsIntf(rt, binaryUnmarshalerTyp); ok {
ti.unm, ti.unmIndir = true, indir
}
if ok, _ = implementsIntf(rt, mapBySliceTyp); ok {
ti.mbs = true
}
pt := rt
var ptIndir int8
// for ; pt.Kind() == reflect.Ptr; pt, ptIndir = pt.Elem(), ptIndir+1 { }
for pt.Kind() == reflect.Ptr {
pt = pt.Elem()
ptIndir++
}
if ptIndir == 0 {
ti.base = rt
ti.baseId = rtid
} else {
ti.base = pt
ti.baseId = reflect.ValueOf(pt).Pointer()
ti.baseIndir = ptIndir
}
if rt.Kind() == reflect.Struct {
var siInfo *structFieldInfo
if f, ok := rt.FieldByName(structInfoFieldName); ok {
siInfo = parseStructFieldInfo(structInfoFieldName, f.Tag.Get(structTagName))
ti.toArray = siInfo.toArray
}
sfip := make([]*structFieldInfo, 0, rt.NumField())
rgetTypeInfo(rt, nil, make(map[string]bool), &sfip, siInfo)
// // try to put all si close together
// const tryToPutAllStructFieldInfoTogether = true
// if tryToPutAllStructFieldInfoTogether {
// sfip2 := make([]structFieldInfo, len(sfip))
// for i, si := range sfip {
// sfip2[i] = *si
// }
// for i := range sfip {
// sfip[i] = &sfip2[i]
// }
// }
ti.sfip = make([]*structFieldInfo, len(sfip))
ti.sfi = make([]*structFieldInfo, len(sfip))
copy(ti.sfip, sfip)
sort.Sort(sfiSortedByEncName(sfip))
copy(ti.sfi, sfip)
}
// sfi = sfip
cachedTypeInfo[rtid] = pti
return
}
func rgetTypeInfo(rt reflect.Type, indexstack []int, fnameToHastag map[string]bool,
sfi *[]*structFieldInfo, siInfo *structFieldInfo,
) {
// for rt.Kind() == reflect.Ptr {
// // indexstack = append(indexstack, 0)
// rt = rt.Elem()
// }
for j := 0; j < rt.NumField(); j++ {
f := rt.Field(j)
stag := f.Tag.Get(structTagName)
if stag == "-" {
continue
}
if r1, _ := utf8.DecodeRuneInString(f.Name); r1 == utf8.RuneError || !unicode.IsUpper(r1) {
continue
}
// if anonymous and there is no struct tag and its a struct (or pointer to struct), inline it.
if f.Anonymous && stag == "" {
ft := f.Type
for ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
if ft.Kind() == reflect.Struct {
indexstack2 := append(append(make([]int, 0, len(indexstack)+4), indexstack...), j)
rgetTypeInfo(ft, indexstack2, fnameToHastag, sfi, siInfo)
continue
}
}
// do not let fields with same name in embedded structs override field at higher level.
// this must be done after anonymous check, to allow anonymous field
// still include their child fields
if _, ok := fnameToHastag[f.Name]; ok {
continue
}
si := parseStructFieldInfo(f.Name, stag)
// si.ikind = int(f.Type.Kind())
if len(indexstack) == 0 {
si.i = int16(j)
} else {
si.i = -1
si.is = append(append(make([]int, 0, len(indexstack)+4), indexstack...), j)
}
if siInfo != nil {
if siInfo.omitEmpty {
si.omitEmpty = true
}
}
*sfi = append(*sfi, si)
fnameToHastag[f.Name] = stag != ""
}
}
func panicToErr(err *error) {
if recoverPanicToErr {
if x := recover(); x != nil {
//debug.PrintStack()
panicValToErr(x, err)
}
}
}
func doPanic(tag string, format string, params ...interface{}) {
params2 := make([]interface{}, len(params)+1)
params2[0] = tag
copy(params2[1:], params)
panic(fmt.Errorf("%s: "+format, params2...))
}
func checkOverflowFloat32(f float64, doCheck bool) {
if !doCheck {
return
}
// check overflow (logic adapted from std pkg reflect/value.go OverflowFloat()
f2 := f
if f2 < 0 {
f2 = -f
}
if math.MaxFloat32 < f2 && f2 <= math.MaxFloat64 {
decErr("Overflow float32 value: %v", f2)
}
}
func checkOverflow(ui uint64, i int64, bitsize uint8) {
// check overflow (logic adapted from std pkg reflect/value.go OverflowUint()
if bitsize == 0 {
return
}
if i != 0 {
if trunc := (i << (64 - bitsize)) >> (64 - bitsize); i != trunc {
decErr("Overflow int value: %v", i)
}
}
if ui != 0 {
if trunc := (ui << (64 - bitsize)) >> (64 - bitsize); ui != trunc {
decErr("Overflow uint value: %v", ui)
}
}
}

View file

@ -0,0 +1,127 @@
// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved.
// Use of this source code is governed by a BSD-style license found in the LICENSE file.
package codec
// All non-std package dependencies live in this file,
// so porting to different environment is easy (just update functions).
import (
"errors"
"fmt"
"math"
"reflect"
)
var (
raisePanicAfterRecover = false
debugging = true
)
func panicValToErr(panicVal interface{}, err *error) {
switch xerr := panicVal.(type) {
case error:
*err = xerr
case string:
*err = errors.New(xerr)
default:
*err = fmt.Errorf("%v", panicVal)
}
if raisePanicAfterRecover {
panic(panicVal)
}
return
}
func isEmptyValueDeref(v reflect.Value, deref bool) bool {
switch v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Interface, reflect.Ptr:
if deref {
if v.IsNil() {
return true
}
return isEmptyValueDeref(v.Elem(), deref)
} else {
return v.IsNil()
}
case reflect.Struct:
// return true if all fields are empty. else return false.
// we cannot use equality check, because some fields may be maps/slices/etc
// and consequently the structs are not comparable.
// return v.Interface() == reflect.Zero(v.Type()).Interface()
for i, n := 0, v.NumField(); i < n; i++ {
if !isEmptyValueDeref(v.Field(i), deref) {
return false
}
}
return true
}
return false
}
func isEmptyValue(v reflect.Value) bool {
return isEmptyValueDeref(v, true)
}
func debugf(format string, args ...interface{}) {
if debugging {
if len(format) == 0 || format[len(format)-1] != '\n' {
format = format + "\n"
}
fmt.Printf(format, args...)
}
}
func pruneSignExt(v []byte, pos bool) (n int) {
if len(v) < 2 {
} else if pos && v[0] == 0 {
for ; v[n] == 0 && n+1 < len(v) && (v[n+1]&(1<<7) == 0); n++ {
}
} else if !pos && v[0] == 0xff {
for ; v[n] == 0xff && n+1 < len(v) && (v[n+1]&(1<<7) != 0); n++ {
}
}
return
}
func implementsIntf(typ, iTyp reflect.Type) (success bool, indir int8) {
if typ == nil {
return
}
rt := typ
// The type might be a pointer and we need to keep
// dereferencing to the base type until we find an implementation.
for {
if rt.Implements(iTyp) {
return true, indir
}
if p := rt; p.Kind() == reflect.Ptr {
indir++
if indir >= math.MaxInt8 { // insane number of indirections
return false, 0
}
rt = p.Elem()
continue
}
break
}
// No luck yet, but if this is a base type (non-pointer), the pointer might satisfy.
if typ.Kind() != reflect.Ptr {
// Not a pointer, but does the pointer work?
if reflect.PtrTo(typ).Implements(iTyp) {
return true, -1
}
}
return false, 0
}

View file

@ -0,0 +1,816 @@
// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved.
// Use of this source code is governed by a BSD-style license found in the LICENSE file.
/*
MSGPACK
Msgpack-c implementation powers the c, c++, python, ruby, etc libraries.
We need to maintain compatibility with it and how it encodes integer values
without caring about the type.
For compatibility with behaviour of msgpack-c reference implementation:
- Go intX (>0) and uintX
IS ENCODED AS
msgpack +ve fixnum, unsigned
- Go intX (<0)
IS ENCODED AS
msgpack -ve fixnum, signed
*/
package codec
import (
"fmt"
"io"
"math"
"net/rpc"
)
const (
mpPosFixNumMin byte = 0x00
mpPosFixNumMax = 0x7f
mpFixMapMin = 0x80
mpFixMapMax = 0x8f
mpFixArrayMin = 0x90
mpFixArrayMax = 0x9f
mpFixStrMin = 0xa0
mpFixStrMax = 0xbf
mpNil = 0xc0
_ = 0xc1
mpFalse = 0xc2
mpTrue = 0xc3
mpFloat = 0xca
mpDouble = 0xcb
mpUint8 = 0xcc
mpUint16 = 0xcd
mpUint32 = 0xce
mpUint64 = 0xcf
mpInt8 = 0xd0
mpInt16 = 0xd1
mpInt32 = 0xd2
mpInt64 = 0xd3
// extensions below
mpBin8 = 0xc4
mpBin16 = 0xc5
mpBin32 = 0xc6
mpExt8 = 0xc7
mpExt16 = 0xc8
mpExt32 = 0xc9
mpFixExt1 = 0xd4
mpFixExt2 = 0xd5
mpFixExt4 = 0xd6
mpFixExt8 = 0xd7
mpFixExt16 = 0xd8
mpStr8 = 0xd9 // new
mpStr16 = 0xda
mpStr32 = 0xdb
mpArray16 = 0xdc
mpArray32 = 0xdd
mpMap16 = 0xde
mpMap32 = 0xdf
mpNegFixNumMin = 0xe0
mpNegFixNumMax = 0xff
)
// MsgpackSpecRpcMultiArgs is a special type which signifies to the MsgpackSpecRpcCodec
// that the backend RPC service takes multiple arguments, which have been arranged
// in sequence in the slice.
//
// The Codec then passes it AS-IS to the rpc service (without wrapping it in an
// array of 1 element).
type MsgpackSpecRpcMultiArgs []interface{}
// A MsgpackContainer type specifies the different types of msgpackContainers.
type msgpackContainerType struct {
fixCutoff int
bFixMin, b8, b16, b32 byte
hasFixMin, has8, has8Always bool
}
var (
msgpackContainerStr = msgpackContainerType{32, mpFixStrMin, mpStr8, mpStr16, mpStr32, true, true, false}
msgpackContainerBin = msgpackContainerType{0, 0, mpBin8, mpBin16, mpBin32, false, true, true}
msgpackContainerList = msgpackContainerType{16, mpFixArrayMin, 0, mpArray16, mpArray32, true, false, false}
msgpackContainerMap = msgpackContainerType{16, mpFixMapMin, 0, mpMap16, mpMap32, true, false, false}
)
//---------------------------------------------
type msgpackEncDriver struct {
w encWriter
h *MsgpackHandle
}
func (e *msgpackEncDriver) isBuiltinType(rt uintptr) bool {
//no builtin types. All encodings are based on kinds. Types supported as extensions.
return false
}
func (e *msgpackEncDriver) encodeBuiltin(rt uintptr, v interface{}) {}
func (e *msgpackEncDriver) encodeNil() {
e.w.writen1(mpNil)
}
func (e *msgpackEncDriver) encodeInt(i int64) {
switch {
case i >= 0:
e.encodeUint(uint64(i))
case i >= -32:
e.w.writen1(byte(i))
case i >= math.MinInt8:
e.w.writen2(mpInt8, byte(i))
case i >= math.MinInt16:
e.w.writen1(mpInt16)
e.w.writeUint16(uint16(i))
case i >= math.MinInt32:
e.w.writen1(mpInt32)
e.w.writeUint32(uint32(i))
default:
e.w.writen1(mpInt64)
e.w.writeUint64(uint64(i))
}
}
func (e *msgpackEncDriver) encodeUint(i uint64) {
switch {
case i <= math.MaxInt8:
e.w.writen1(byte(i))
case i <= math.MaxUint8:
e.w.writen2(mpUint8, byte(i))
case i <= math.MaxUint16:
e.w.writen1(mpUint16)
e.w.writeUint16(uint16(i))
case i <= math.MaxUint32:
e.w.writen1(mpUint32)
e.w.writeUint32(uint32(i))
default:
e.w.writen1(mpUint64)
e.w.writeUint64(uint64(i))
}
}
func (e *msgpackEncDriver) encodeBool(b bool) {
if b {
e.w.writen1(mpTrue)
} else {
e.w.writen1(mpFalse)
}
}
func (e *msgpackEncDriver) encodeFloat32(f float32) {
e.w.writen1(mpFloat)
e.w.writeUint32(math.Float32bits(f))
}
func (e *msgpackEncDriver) encodeFloat64(f float64) {
e.w.writen1(mpDouble)
e.w.writeUint64(math.Float64bits(f))
}
func (e *msgpackEncDriver) encodeExtPreamble(xtag byte, l int) {
switch {
case l == 1:
e.w.writen2(mpFixExt1, xtag)
case l == 2:
e.w.writen2(mpFixExt2, xtag)
case l == 4:
e.w.writen2(mpFixExt4, xtag)
case l == 8:
e.w.writen2(mpFixExt8, xtag)
case l == 16:
e.w.writen2(mpFixExt16, xtag)
case l < 256:
e.w.writen2(mpExt8, byte(l))
e.w.writen1(xtag)
case l < 65536:
e.w.writen1(mpExt16)
e.w.writeUint16(uint16(l))
e.w.writen1(xtag)
default:
e.w.writen1(mpExt32)
e.w.writeUint32(uint32(l))
e.w.writen1(xtag)
}
}
func (e *msgpackEncDriver) encodeArrayPreamble(length int) {
e.writeContainerLen(msgpackContainerList, length)
}
func (e *msgpackEncDriver) encodeMapPreamble(length int) {
e.writeContainerLen(msgpackContainerMap, length)
}
func (e *msgpackEncDriver) encodeString(c charEncoding, s string) {
if c == c_RAW && e.h.WriteExt {
e.writeContainerLen(msgpackContainerBin, len(s))
} else {
e.writeContainerLen(msgpackContainerStr, len(s))
}
if len(s) > 0 {
e.w.writestr(s)
}
}
func (e *msgpackEncDriver) encodeSymbol(v string) {
e.encodeString(c_UTF8, v)
}
func (e *msgpackEncDriver) encodeStringBytes(c charEncoding, bs []byte) {
if c == c_RAW && e.h.WriteExt {
e.writeContainerLen(msgpackContainerBin, len(bs))
} else {
e.writeContainerLen(msgpackContainerStr, len(bs))
}
if len(bs) > 0 {
e.w.writeb(bs)
}
}
func (e *msgpackEncDriver) writeContainerLen(ct msgpackContainerType, l int) {
switch {
case ct.hasFixMin && l < ct.fixCutoff:
e.w.writen1(ct.bFixMin | byte(l))
case ct.has8 && l < 256 && (ct.has8Always || e.h.WriteExt):
e.w.writen2(ct.b8, uint8(l))
case l < 65536:
e.w.writen1(ct.b16)
e.w.writeUint16(uint16(l))
default:
e.w.writen1(ct.b32)
e.w.writeUint32(uint32(l))
}
}
//---------------------------------------------
type msgpackDecDriver struct {
r decReader
h *MsgpackHandle
bd byte
bdRead bool
bdType valueType
}
func (d *msgpackDecDriver) isBuiltinType(rt uintptr) bool {
//no builtin types. All encodings are based on kinds. Types supported as extensions.
return false
}
func (d *msgpackDecDriver) decodeBuiltin(rt uintptr, v interface{}) {}
// Note: This returns either a primitive (int, bool, etc) for non-containers,
// or a containerType, or a specific type denoting nil or extension.
// It is called when a nil interface{} is passed, leaving it up to the DecDriver
// to introspect the stream and decide how best to decode.
// It deciphers the value by looking at the stream first.
func (d *msgpackDecDriver) decodeNaked() (v interface{}, vt valueType, decodeFurther bool) {
d.initReadNext()
bd := d.bd
switch bd {
case mpNil:
vt = valueTypeNil
d.bdRead = false
case mpFalse:
vt = valueTypeBool
v = false
case mpTrue:
vt = valueTypeBool
v = true
case mpFloat:
vt = valueTypeFloat
v = float64(math.Float32frombits(d.r.readUint32()))
case mpDouble:
vt = valueTypeFloat
v = math.Float64frombits(d.r.readUint64())
case mpUint8:
vt = valueTypeUint
v = uint64(d.r.readn1())
case mpUint16:
vt = valueTypeUint
v = uint64(d.r.readUint16())
case mpUint32:
vt = valueTypeUint
v = uint64(d.r.readUint32())
case mpUint64:
vt = valueTypeUint
v = uint64(d.r.readUint64())
case mpInt8:
vt = valueTypeInt
v = int64(int8(d.r.readn1()))
case mpInt16:
vt = valueTypeInt
v = int64(int16(d.r.readUint16()))
case mpInt32:
vt = valueTypeInt
v = int64(int32(d.r.readUint32()))
case mpInt64:
vt = valueTypeInt
v = int64(int64(d.r.readUint64()))
default:
switch {
case bd >= mpPosFixNumMin && bd <= mpPosFixNumMax:
// positive fixnum (always signed)
vt = valueTypeInt
v = int64(int8(bd))
case bd >= mpNegFixNumMin && bd <= mpNegFixNumMax:
// negative fixnum
vt = valueTypeInt
v = int64(int8(bd))
case bd == mpStr8, bd == mpStr16, bd == mpStr32, bd >= mpFixStrMin && bd <= mpFixStrMax:
if d.h.RawToString {
var rvm string
vt = valueTypeString
v = &rvm
} else {
var rvm = []byte{}
vt = valueTypeBytes
v = &rvm
}
decodeFurther = true
case bd == mpBin8, bd == mpBin16, bd == mpBin32:
var rvm = []byte{}
vt = valueTypeBytes
v = &rvm
decodeFurther = true
case bd == mpArray16, bd == mpArray32, bd >= mpFixArrayMin && bd <= mpFixArrayMax:
vt = valueTypeArray
decodeFurther = true
case bd == mpMap16, bd == mpMap32, bd >= mpFixMapMin && bd <= mpFixMapMax:
vt = valueTypeMap
decodeFurther = true
case bd >= mpFixExt1 && bd <= mpFixExt16, bd >= mpExt8 && bd <= mpExt32:
clen := d.readExtLen()
var re RawExt
re.Tag = d.r.readn1()
re.Data = d.r.readn(clen)
v = &re
vt = valueTypeExt
default:
decErr("Nil-Deciphered DecodeValue: %s: hex: %x, dec: %d", msgBadDesc, bd, bd)
}
}
if !decodeFurther {
d.bdRead = false
}
return
}
// int can be decoded from msgpack type: intXXX or uintXXX
func (d *msgpackDecDriver) decodeInt(bitsize uint8) (i int64) {
switch d.bd {
case mpUint8:
i = int64(uint64(d.r.readn1()))
case mpUint16:
i = int64(uint64(d.r.readUint16()))
case mpUint32:
i = int64(uint64(d.r.readUint32()))
case mpUint64:
i = int64(d.r.readUint64())
case mpInt8:
i = int64(int8(d.r.readn1()))
case mpInt16:
i = int64(int16(d.r.readUint16()))
case mpInt32:
i = int64(int32(d.r.readUint32()))
case mpInt64:
i = int64(d.r.readUint64())
default:
switch {
case d.bd >= mpPosFixNumMin && d.bd <= mpPosFixNumMax:
i = int64(int8(d.bd))
case d.bd >= mpNegFixNumMin && d.bd <= mpNegFixNumMax:
i = int64(int8(d.bd))
default:
decErr("Unhandled single-byte unsigned integer value: %s: %x", msgBadDesc, d.bd)
}
}
// check overflow (logic adapted from std pkg reflect/value.go OverflowUint()
if bitsize > 0 {
if trunc := (i << (64 - bitsize)) >> (64 - bitsize); i != trunc {
decErr("Overflow int value: %v", i)
}
}
d.bdRead = false
return
}
// uint can be decoded from msgpack type: intXXX or uintXXX
func (d *msgpackDecDriver) decodeUint(bitsize uint8) (ui uint64) {
switch d.bd {
case mpUint8:
ui = uint64(d.r.readn1())
case mpUint16:
ui = uint64(d.r.readUint16())
case mpUint32:
ui = uint64(d.r.readUint32())
case mpUint64:
ui = d.r.readUint64()
case mpInt8:
if i := int64(int8(d.r.readn1())); i >= 0 {
ui = uint64(i)
} else {
decErr("Assigning negative signed value: %v, to unsigned type", i)
}
case mpInt16:
if i := int64(int16(d.r.readUint16())); i >= 0 {
ui = uint64(i)
} else {
decErr("Assigning negative signed value: %v, to unsigned type", i)
}
case mpInt32:
if i := int64(int32(d.r.readUint32())); i >= 0 {
ui = uint64(i)
} else {
decErr("Assigning negative signed value: %v, to unsigned type", i)
}
case mpInt64:
if i := int64(d.r.readUint64()); i >= 0 {
ui = uint64(i)
} else {
decErr("Assigning negative signed value: %v, to unsigned type", i)
}
default:
switch {
case d.bd >= mpPosFixNumMin && d.bd <= mpPosFixNumMax:
ui = uint64(d.bd)
case d.bd >= mpNegFixNumMin && d.bd <= mpNegFixNumMax:
decErr("Assigning negative signed value: %v, to unsigned type", int(d.bd))
default:
decErr("Unhandled single-byte unsigned integer value: %s: %x", msgBadDesc, d.bd)
}
}
// check overflow (logic adapted from std pkg reflect/value.go OverflowUint()
if bitsize > 0 {
if trunc := (ui << (64 - bitsize)) >> (64 - bitsize); ui != trunc {
decErr("Overflow uint value: %v", ui)
}
}
d.bdRead = false
return
}
// float can either be decoded from msgpack type: float, double or intX
func (d *msgpackDecDriver) decodeFloat(chkOverflow32 bool) (f float64) {
switch d.bd {
case mpFloat:
f = float64(math.Float32frombits(d.r.readUint32()))
case mpDouble:
f = math.Float64frombits(d.r.readUint64())
default:
f = float64(d.decodeInt(0))
}
checkOverflowFloat32(f, chkOverflow32)
d.bdRead = false
return
}
// bool can be decoded from bool, fixnum 0 or 1.
func (d *msgpackDecDriver) decodeBool() (b bool) {
switch d.bd {
case mpFalse, 0:
// b = false
case mpTrue, 1:
b = true
default:
decErr("Invalid single-byte value for bool: %s: %x", msgBadDesc, d.bd)
}
d.bdRead = false
return
}
func (d *msgpackDecDriver) decodeString() (s string) {
clen := d.readContainerLen(msgpackContainerStr)
if clen > 0 {
s = string(d.r.readn(clen))
}
d.bdRead = false
return
}
// Callers must check if changed=true (to decide whether to replace the one they have)
func (d *msgpackDecDriver) decodeBytes(bs []byte) (bsOut []byte, changed bool) {
// bytes can be decoded from msgpackContainerStr or msgpackContainerBin
var clen int
switch d.bd {
case mpBin8, mpBin16, mpBin32:
clen = d.readContainerLen(msgpackContainerBin)
default:
clen = d.readContainerLen(msgpackContainerStr)
}
// if clen < 0 {
// changed = true
// panic("length cannot be zero. this cannot be nil.")
// }
if clen > 0 {
// if no contents in stream, don't update the passed byteslice
if len(bs) != clen {
// Return changed=true if length of passed slice diff from length of bytes in stream
if len(bs) > clen {
bs = bs[:clen]
} else {
bs = make([]byte, clen)
}
bsOut = bs
changed = true
}
d.r.readb(bs)
}
d.bdRead = false
return
}
// Every top-level decode funcs (i.e. decodeValue, decode) must call this first.
func (d *msgpackDecDriver) initReadNext() {
if d.bdRead {
return
}
d.bd = d.r.readn1()
d.bdRead = true
d.bdType = valueTypeUnset
}
func (d *msgpackDecDriver) currentEncodedType() valueType {
if d.bdType == valueTypeUnset {
bd := d.bd
switch bd {
case mpNil:
d.bdType = valueTypeNil
case mpFalse, mpTrue:
d.bdType = valueTypeBool
case mpFloat, mpDouble:
d.bdType = valueTypeFloat
case mpUint8, mpUint16, mpUint32, mpUint64:
d.bdType = valueTypeUint
case mpInt8, mpInt16, mpInt32, mpInt64:
d.bdType = valueTypeInt
default:
switch {
case bd >= mpPosFixNumMin && bd <= mpPosFixNumMax:
d.bdType = valueTypeInt
case bd >= mpNegFixNumMin && bd <= mpNegFixNumMax:
d.bdType = valueTypeInt
case bd == mpStr8, bd == mpStr16, bd == mpStr32, bd >= mpFixStrMin && bd <= mpFixStrMax:
if d.h.RawToString {
d.bdType = valueTypeString
} else {
d.bdType = valueTypeBytes
}
case bd == mpBin8, bd == mpBin16, bd == mpBin32:
d.bdType = valueTypeBytes
case bd == mpArray16, bd == mpArray32, bd >= mpFixArrayMin && bd <= mpFixArrayMax:
d.bdType = valueTypeArray
case bd == mpMap16, bd == mpMap32, bd >= mpFixMapMin && bd <= mpFixMapMax:
d.bdType = valueTypeMap
case bd >= mpFixExt1 && bd <= mpFixExt16, bd >= mpExt8 && bd <= mpExt32:
d.bdType = valueTypeExt
default:
decErr("currentEncodedType: Undeciphered descriptor: %s: hex: %x, dec: %d", msgBadDesc, bd, bd)
}
}
}
return d.bdType
}
func (d *msgpackDecDriver) tryDecodeAsNil() bool {
if d.bd == mpNil {
d.bdRead = false
return true
}
return false
}
func (d *msgpackDecDriver) readContainerLen(ct msgpackContainerType) (clen int) {
bd := d.bd
switch {
case bd == mpNil:
clen = -1 // to represent nil
case bd == ct.b8:
clen = int(d.r.readn1())
case bd == ct.b16:
clen = int(d.r.readUint16())
case bd == ct.b32:
clen = int(d.r.readUint32())
case (ct.bFixMin & bd) == ct.bFixMin:
clen = int(ct.bFixMin ^ bd)
default:
decErr("readContainerLen: %s: hex: %x, dec: %d", msgBadDesc, bd, bd)
}
d.bdRead = false
return
}
func (d *msgpackDecDriver) readMapLen() int {
return d.readContainerLen(msgpackContainerMap)
}
func (d *msgpackDecDriver) readArrayLen() int {
return d.readContainerLen(msgpackContainerList)
}
func (d *msgpackDecDriver) readExtLen() (clen int) {
switch d.bd {
case mpNil:
clen = -1 // to represent nil
case mpFixExt1:
clen = 1
case mpFixExt2:
clen = 2
case mpFixExt4:
clen = 4
case mpFixExt8:
clen = 8
case mpFixExt16:
clen = 16
case mpExt8:
clen = int(d.r.readn1())
case mpExt16:
clen = int(d.r.readUint16())
case mpExt32:
clen = int(d.r.readUint32())
default:
decErr("decoding ext bytes: found unexpected byte: %x", d.bd)
}
return
}
func (d *msgpackDecDriver) decodeExt(verifyTag bool, tag byte) (xtag byte, xbs []byte) {
xbd := d.bd
switch {
case xbd == mpBin8, xbd == mpBin16, xbd == mpBin32:
xbs, _ = d.decodeBytes(nil)
case xbd == mpStr8, xbd == mpStr16, xbd == mpStr32,
xbd >= mpFixStrMin && xbd <= mpFixStrMax:
xbs = []byte(d.decodeString())
default:
clen := d.readExtLen()
xtag = d.r.readn1()
if verifyTag && xtag != tag {
decErr("Wrong extension tag. Got %b. Expecting: %v", xtag, tag)
}
xbs = d.r.readn(clen)
}
d.bdRead = false
return
}
//--------------------------------------------------
//MsgpackHandle is a Handle for the Msgpack Schema-Free Encoding Format.
type MsgpackHandle struct {
BasicHandle
// RawToString controls how raw bytes are decoded into a nil interface{}.
RawToString bool
// WriteExt flag supports encoding configured extensions with extension tags.
// It also controls whether other elements of the new spec are encoded (ie Str8).
//
// With WriteExt=false, configured extensions are serialized as raw bytes
// and Str8 is not encoded.
//
// A stream can still be decoded into a typed value, provided an appropriate value
// is provided, but the type cannot be inferred from the stream. If no appropriate
// type is provided (e.g. decoding into a nil interface{}), you get back
// a []byte or string based on the setting of RawToString.
WriteExt bool
}
func (h *MsgpackHandle) newEncDriver(w encWriter) encDriver {
return &msgpackEncDriver{w: w, h: h}
}
func (h *MsgpackHandle) newDecDriver(r decReader) decDriver {
return &msgpackDecDriver{r: r, h: h}
}
func (h *MsgpackHandle) writeExt() bool {
return h.WriteExt
}
func (h *MsgpackHandle) getBasicHandle() *BasicHandle {
return &h.BasicHandle
}
//--------------------------------------------------
type msgpackSpecRpcCodec struct {
rpcCodec
}
// /////////////// Spec RPC Codec ///////////////////
func (c *msgpackSpecRpcCodec) WriteRequest(r *rpc.Request, body interface{}) error {
// WriteRequest can write to both a Go service, and other services that do
// not abide by the 1 argument rule of a Go service.
// We discriminate based on if the body is a MsgpackSpecRpcMultiArgs
var bodyArr []interface{}
if m, ok := body.(MsgpackSpecRpcMultiArgs); ok {
bodyArr = ([]interface{})(m)
} else {
bodyArr = []interface{}{body}
}
r2 := []interface{}{0, uint32(r.Seq), r.ServiceMethod, bodyArr}
return c.write(r2, nil, false, true)
}
func (c *msgpackSpecRpcCodec) WriteResponse(r *rpc.Response, body interface{}) error {
var moe interface{}
if r.Error != "" {
moe = r.Error
}
if moe != nil && body != nil {
body = nil
}
r2 := []interface{}{1, uint32(r.Seq), moe, body}
return c.write(r2, nil, false, true)
}
func (c *msgpackSpecRpcCodec) ReadResponseHeader(r *rpc.Response) error {
return c.parseCustomHeader(1, &r.Seq, &r.Error)
}
func (c *msgpackSpecRpcCodec) ReadRequestHeader(r *rpc.Request) error {
return c.parseCustomHeader(0, &r.Seq, &r.ServiceMethod)
}
func (c *msgpackSpecRpcCodec) ReadRequestBody(body interface{}) error {
if body == nil { // read and discard
return c.read(nil)
}
bodyArr := []interface{}{body}
return c.read(&bodyArr)
}
func (c *msgpackSpecRpcCodec) parseCustomHeader(expectTypeByte byte, msgid *uint64, methodOrError *string) (err error) {
if c.cls {
return io.EOF
}
// We read the response header by hand
// so that the body can be decoded on its own from the stream at a later time.
const fia byte = 0x94 //four item array descriptor value
// Not sure why the panic of EOF is swallowed above.
// if bs1 := c.dec.r.readn1(); bs1 != fia {
// err = fmt.Errorf("Unexpected value for array descriptor: Expecting %v. Received %v", fia, bs1)
// return
// }
var b byte
b, err = c.br.ReadByte()
if err != nil {
return
}
if b != fia {
err = fmt.Errorf("Unexpected value for array descriptor: Expecting %v. Received %v", fia, b)
return
}
if err = c.read(&b); err != nil {
return
}
if b != expectTypeByte {
err = fmt.Errorf("Unexpected byte descriptor in header. Expecting %v. Received %v", expectTypeByte, b)
return
}
if err = c.read(msgid); err != nil {
return
}
if err = c.read(methodOrError); err != nil {
return
}
return
}
//--------------------------------------------------
// msgpackSpecRpc is the implementation of Rpc that uses custom communication protocol
// as defined in the msgpack spec at https://github.com/msgpack-rpc/msgpack-rpc/blob/master/spec.md
type msgpackSpecRpc struct{}
// MsgpackSpecRpc implements Rpc using the communication protocol defined in
// the msgpack spec at https://github.com/msgpack-rpc/msgpack-rpc/blob/master/spec.md .
// Its methods (ServerCodec and ClientCodec) return values that implement RpcCodecBuffered.
var MsgpackSpecRpc msgpackSpecRpc
func (x msgpackSpecRpc) ServerCodec(conn io.ReadWriteCloser, h Handle) rpc.ServerCodec {
return &msgpackSpecRpcCodec{newRPCCodec(conn, h)}
}
func (x msgpackSpecRpc) ClientCodec(conn io.ReadWriteCloser, h Handle) rpc.ClientCodec {
return &msgpackSpecRpcCodec{newRPCCodec(conn, h)}
}
var _ decDriver = (*msgpackDecDriver)(nil)
var _ encDriver = (*msgpackEncDriver)(nil)

View file

@ -0,0 +1,110 @@
#!/usr/bin/env python
# This will create golden files in a directory passed to it.
# A Test calls this internally to create the golden files
# So it can process them (so we don't have to checkin the files).
import msgpack, msgpackrpc, sys, os, threading
def get_test_data_list():
# get list with all primitive types, and a combo type
l0 = [
-8,
-1616,
-32323232,
-6464646464646464,
192,
1616,
32323232,
6464646464646464,
192,
-3232.0,
-6464646464.0,
3232.0,
6464646464.0,
False,
True,
None,
"someday",
"",
"bytestring",
1328176922000002000,
-2206187877999998000,
0,
-6795364578871345152
]
l1 = [
{ "true": True,
"false": False },
{ "true": "True",
"false": False,
"uint16(1616)": 1616 },
{ "list": [1616, 32323232, True, -3232.0, {"TRUE":True, "FALSE":False}, [True, False] ],
"int32":32323232, "bool": True,
"LONG STRING": "123456789012345678901234567890123456789012345678901234567890",
"SHORT STRING": "1234567890" },
{ True: "true", 8: False, "false": 0 }
]
l = []
l.extend(l0)
l.append(l0)
l.extend(l1)
return l
def build_test_data(destdir):
l = get_test_data_list()
for i in range(len(l)):
packer = msgpack.Packer()
serialized = packer.pack(l[i])
f = open(os.path.join(destdir, str(i) + '.golden'), 'wb')
f.write(serialized)
f.close()
def doRpcServer(port, stopTimeSec):
class EchoHandler(object):
def Echo123(self, msg1, msg2, msg3):
return ("1:%s 2:%s 3:%s" % (msg1, msg2, msg3))
def EchoStruct(self, msg):
return ("%s" % msg)
addr = msgpackrpc.Address('localhost', port)
server = msgpackrpc.Server(EchoHandler())
server.listen(addr)
# run thread to stop it after stopTimeSec seconds if > 0
if stopTimeSec > 0:
def myStopRpcServer():
server.stop()
t = threading.Timer(stopTimeSec, myStopRpcServer)
t.start()
server.start()
def doRpcClientToPythonSvc(port):
address = msgpackrpc.Address('localhost', port)
client = msgpackrpc.Client(address, unpack_encoding='utf-8')
print client.call("Echo123", "A1", "B2", "C3")
print client.call("EchoStruct", {"A" :"Aa", "B":"Bb", "C":"Cc"})
def doRpcClientToGoSvc(port):
# print ">>>> port: ", port, " <<<<<"
address = msgpackrpc.Address('localhost', port)
client = msgpackrpc.Client(address, unpack_encoding='utf-8')
print client.call("TestRpcInt.Echo123", ["A1", "B2", "C3"])
print client.call("TestRpcInt.EchoStruct", {"A" :"Aa", "B":"Bb", "C":"Cc"})
def doMain(args):
if len(args) == 2 and args[0] == "testdata":
build_test_data(args[1])
elif len(args) == 3 and args[0] == "rpc-server":
doRpcServer(int(args[1]), int(args[2]))
elif len(args) == 2 and args[0] == "rpc-client-python-service":
doRpcClientToPythonSvc(int(args[1]))
elif len(args) == 2 and args[0] == "rpc-client-go-service":
doRpcClientToGoSvc(int(args[1]))
else:
print("Usage: msgpack_test.py " +
"[testdata|rpc-server|rpc-client-python-service|rpc-client-go-service] ...")
if __name__ == "__main__":
doMain(sys.argv[1:])

View file

@ -0,0 +1,152 @@
// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved.
// Use of this source code is governed by a BSD-style license found in the LICENSE file.
package codec
import (
"bufio"
"io"
"net/rpc"
"sync"
)
// Rpc provides a rpc Server or Client Codec for rpc communication.
type Rpc interface {
ServerCodec(conn io.ReadWriteCloser, h Handle) rpc.ServerCodec
ClientCodec(conn io.ReadWriteCloser, h Handle) rpc.ClientCodec
}
// RpcCodecBuffered allows access to the underlying bufio.Reader/Writer
// used by the rpc connection. It accomodates use-cases where the connection
// should be used by rpc and non-rpc functions, e.g. streaming a file after
// sending an rpc response.
type RpcCodecBuffered interface {
BufferedReader() *bufio.Reader
BufferedWriter() *bufio.Writer
}
// -------------------------------------
// rpcCodec defines the struct members and common methods.
type rpcCodec struct {
rwc io.ReadWriteCloser
dec *Decoder
enc *Encoder
bw *bufio.Writer
br *bufio.Reader
mu sync.Mutex
cls bool
}
func newRPCCodec(conn io.ReadWriteCloser, h Handle) rpcCodec {
bw := bufio.NewWriter(conn)
br := bufio.NewReader(conn)
return rpcCodec{
rwc: conn,
bw: bw,
br: br,
enc: NewEncoder(bw, h),
dec: NewDecoder(br, h),
}
}
func (c *rpcCodec) BufferedReader() *bufio.Reader {
return c.br
}
func (c *rpcCodec) BufferedWriter() *bufio.Writer {
return c.bw
}
func (c *rpcCodec) write(obj1, obj2 interface{}, writeObj2, doFlush bool) (err error) {
if c.cls {
return io.EOF
}
if err = c.enc.Encode(obj1); err != nil {
return
}
if writeObj2 {
if err = c.enc.Encode(obj2); err != nil {
return
}
}
if doFlush && c.bw != nil {
return c.bw.Flush()
}
return
}
func (c *rpcCodec) read(obj interface{}) (err error) {
if c.cls {
return io.EOF
}
//If nil is passed in, we should still attempt to read content to nowhere.
if obj == nil {
var obj2 interface{}
return c.dec.Decode(&obj2)
}
return c.dec.Decode(obj)
}
func (c *rpcCodec) Close() error {
if c.cls {
return io.EOF
}
c.cls = true
return c.rwc.Close()
}
func (c *rpcCodec) ReadResponseBody(body interface{}) error {
return c.read(body)
}
// -------------------------------------
type goRpcCodec struct {
rpcCodec
}
func (c *goRpcCodec) WriteRequest(r *rpc.Request, body interface{}) error {
// Must protect for concurrent access as per API
c.mu.Lock()
defer c.mu.Unlock()
return c.write(r, body, true, true)
}
func (c *goRpcCodec) WriteResponse(r *rpc.Response, body interface{}) error {
c.mu.Lock()
defer c.mu.Unlock()
return c.write(r, body, true, true)
}
func (c *goRpcCodec) ReadResponseHeader(r *rpc.Response) error {
return c.read(r)
}
func (c *goRpcCodec) ReadRequestHeader(r *rpc.Request) error {
return c.read(r)
}
func (c *goRpcCodec) ReadRequestBody(body interface{}) error {
return c.read(body)
}
// -------------------------------------
// goRpc is the implementation of Rpc that uses the communication protocol
// as defined in net/rpc package.
type goRpc struct{}
// GoRpc implements Rpc using the communication protocol defined in net/rpc package.
// Its methods (ServerCodec and ClientCodec) return values that implement RpcCodecBuffered.
var GoRpc goRpc
func (x goRpc) ServerCodec(conn io.ReadWriteCloser, h Handle) rpc.ServerCodec {
return &goRpcCodec{newRPCCodec(conn, h)}
}
func (x goRpc) ClientCodec(conn io.ReadWriteCloser, h Handle) rpc.ClientCodec {
return &goRpcCodec{newRPCCodec(conn, h)}
}
var _ RpcCodecBuffered = (*rpcCodec)(nil) // ensure *rpcCodec implements RpcCodecBuffered

View file

@ -0,0 +1,461 @@
// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved.
// Use of this source code is governed by a BSD-style license found in the LICENSE file.
package codec
import "math"
const (
_ uint8 = iota
simpleVdNil = 1
simpleVdFalse = 2
simpleVdTrue = 3
simpleVdFloat32 = 4
simpleVdFloat64 = 5
// each lasts for 4 (ie n, n+1, n+2, n+3)
simpleVdPosInt = 8
simpleVdNegInt = 12
// containers: each lasts for 4 (ie n, n+1, n+2, ... n+7)
simpleVdString = 216
simpleVdByteArray = 224
simpleVdArray = 232
simpleVdMap = 240
simpleVdExt = 248
)
type simpleEncDriver struct {
h *SimpleHandle
w encWriter
//b [8]byte
}
func (e *simpleEncDriver) isBuiltinType(rt uintptr) bool {
return false
}
func (e *simpleEncDriver) encodeBuiltin(rt uintptr, v interface{}) {
}
func (e *simpleEncDriver) encodeNil() {
e.w.writen1(simpleVdNil)
}
func (e *simpleEncDriver) encodeBool(b bool) {
if b {
e.w.writen1(simpleVdTrue)
} else {
e.w.writen1(simpleVdFalse)
}
}
func (e *simpleEncDriver) encodeFloat32(f float32) {
e.w.writen1(simpleVdFloat32)
e.w.writeUint32(math.Float32bits(f))
}
func (e *simpleEncDriver) encodeFloat64(f float64) {
e.w.writen1(simpleVdFloat64)
e.w.writeUint64(math.Float64bits(f))
}
func (e *simpleEncDriver) encodeInt(v int64) {
if v < 0 {
e.encUint(uint64(-v), simpleVdNegInt)
} else {
e.encUint(uint64(v), simpleVdPosInt)
}
}
func (e *simpleEncDriver) encodeUint(v uint64) {
e.encUint(v, simpleVdPosInt)
}
func (e *simpleEncDriver) encUint(v uint64, bd uint8) {
switch {
case v <= math.MaxUint8:
e.w.writen2(bd, uint8(v))
case v <= math.MaxUint16:
e.w.writen1(bd + 1)
e.w.writeUint16(uint16(v))
case v <= math.MaxUint32:
e.w.writen1(bd + 2)
e.w.writeUint32(uint32(v))
case v <= math.MaxUint64:
e.w.writen1(bd + 3)
e.w.writeUint64(v)
}
}
func (e *simpleEncDriver) encLen(bd byte, length int) {
switch {
case length == 0:
e.w.writen1(bd)
case length <= math.MaxUint8:
e.w.writen1(bd + 1)
e.w.writen1(uint8(length))
case length <= math.MaxUint16:
e.w.writen1(bd + 2)
e.w.writeUint16(uint16(length))
case int64(length) <= math.MaxUint32:
e.w.writen1(bd + 3)
e.w.writeUint32(uint32(length))
default:
e.w.writen1(bd + 4)
e.w.writeUint64(uint64(length))
}
}
func (e *simpleEncDriver) encodeExtPreamble(xtag byte, length int) {
e.encLen(simpleVdExt, length)
e.w.writen1(xtag)
}
func (e *simpleEncDriver) encodeArrayPreamble(length int) {
e.encLen(simpleVdArray, length)
}
func (e *simpleEncDriver) encodeMapPreamble(length int) {
e.encLen(simpleVdMap, length)
}
func (e *simpleEncDriver) encodeString(c charEncoding, v string) {
e.encLen(simpleVdString, len(v))
e.w.writestr(v)
}
func (e *simpleEncDriver) encodeSymbol(v string) {
e.encodeString(c_UTF8, v)
}
func (e *simpleEncDriver) encodeStringBytes(c charEncoding, v []byte) {
e.encLen(simpleVdByteArray, len(v))
e.w.writeb(v)
}
//------------------------------------
type simpleDecDriver struct {
h *SimpleHandle
r decReader
bdRead bool
bdType valueType
bd byte
//b [8]byte
}
func (d *simpleDecDriver) initReadNext() {
if d.bdRead {
return
}
d.bd = d.r.readn1()
d.bdRead = true
d.bdType = valueTypeUnset
}
func (d *simpleDecDriver) currentEncodedType() valueType {
if d.bdType == valueTypeUnset {
switch d.bd {
case simpleVdNil:
d.bdType = valueTypeNil
case simpleVdTrue, simpleVdFalse:
d.bdType = valueTypeBool
case simpleVdPosInt, simpleVdPosInt + 1, simpleVdPosInt + 2, simpleVdPosInt + 3:
d.bdType = valueTypeUint
case simpleVdNegInt, simpleVdNegInt + 1, simpleVdNegInt + 2, simpleVdNegInt + 3:
d.bdType = valueTypeInt
case simpleVdFloat32, simpleVdFloat64:
d.bdType = valueTypeFloat
case simpleVdString, simpleVdString + 1, simpleVdString + 2, simpleVdString + 3, simpleVdString + 4:
d.bdType = valueTypeString
case simpleVdByteArray, simpleVdByteArray + 1, simpleVdByteArray + 2, simpleVdByteArray + 3, simpleVdByteArray + 4:
d.bdType = valueTypeBytes
case simpleVdExt, simpleVdExt + 1, simpleVdExt + 2, simpleVdExt + 3, simpleVdExt + 4:
d.bdType = valueTypeExt
case simpleVdArray, simpleVdArray + 1, simpleVdArray + 2, simpleVdArray + 3, simpleVdArray + 4:
d.bdType = valueTypeArray
case simpleVdMap, simpleVdMap + 1, simpleVdMap + 2, simpleVdMap + 3, simpleVdMap + 4:
d.bdType = valueTypeMap
default:
decErr("currentEncodedType: Unrecognized d.vd: 0x%x", d.bd)
}
}
return d.bdType
}
func (d *simpleDecDriver) tryDecodeAsNil() bool {
if d.bd == simpleVdNil {
d.bdRead = false
return true
}
return false
}
func (d *simpleDecDriver) isBuiltinType(rt uintptr) bool {
return false
}
func (d *simpleDecDriver) decodeBuiltin(rt uintptr, v interface{}) {
}
func (d *simpleDecDriver) decIntAny() (ui uint64, i int64, neg bool) {
switch d.bd {
case simpleVdPosInt:
ui = uint64(d.r.readn1())
i = int64(ui)
case simpleVdPosInt + 1:
ui = uint64(d.r.readUint16())
i = int64(ui)
case simpleVdPosInt + 2:
ui = uint64(d.r.readUint32())
i = int64(ui)
case simpleVdPosInt + 3:
ui = uint64(d.r.readUint64())
i = int64(ui)
case simpleVdNegInt:
ui = uint64(d.r.readn1())
i = -(int64(ui))
neg = true
case simpleVdNegInt + 1:
ui = uint64(d.r.readUint16())
i = -(int64(ui))
neg = true
case simpleVdNegInt + 2:
ui = uint64(d.r.readUint32())
i = -(int64(ui))
neg = true
case simpleVdNegInt + 3:
ui = uint64(d.r.readUint64())
i = -(int64(ui))
neg = true
default:
decErr("decIntAny: Integer only valid from pos/neg integer1..8. Invalid descriptor: %v", d.bd)
}
// don't do this check, because callers may only want the unsigned value.
// if ui > math.MaxInt64 {
// decErr("decIntAny: Integer out of range for signed int64: %v", ui)
// }
return
}
func (d *simpleDecDriver) decodeInt(bitsize uint8) (i int64) {
_, i, _ = d.decIntAny()
checkOverflow(0, i, bitsize)
d.bdRead = false
return
}
func (d *simpleDecDriver) decodeUint(bitsize uint8) (ui uint64) {
ui, i, neg := d.decIntAny()
if neg {
decErr("Assigning negative signed value: %v, to unsigned type", i)
}
checkOverflow(ui, 0, bitsize)
d.bdRead = false
return
}
func (d *simpleDecDriver) decodeFloat(chkOverflow32 bool) (f float64) {
switch d.bd {
case simpleVdFloat32:
f = float64(math.Float32frombits(d.r.readUint32()))
case simpleVdFloat64:
f = math.Float64frombits(d.r.readUint64())
default:
if d.bd >= simpleVdPosInt && d.bd <= simpleVdNegInt+3 {
_, i, _ := d.decIntAny()
f = float64(i)
} else {
decErr("Float only valid from float32/64: Invalid descriptor: %v", d.bd)
}
}
checkOverflowFloat32(f, chkOverflow32)
d.bdRead = false
return
}
// bool can be decoded from bool only (single byte).
func (d *simpleDecDriver) decodeBool() (b bool) {
switch d.bd {
case simpleVdTrue:
b = true
case simpleVdFalse:
default:
decErr("Invalid single-byte value for bool: %s: %x", msgBadDesc, d.bd)
}
d.bdRead = false
return
}
func (d *simpleDecDriver) readMapLen() (length int) {
d.bdRead = false
return d.decLen()
}
func (d *simpleDecDriver) readArrayLen() (length int) {
d.bdRead = false
return d.decLen()
}
func (d *simpleDecDriver) decLen() int {
switch d.bd % 8 {
case 0:
return 0
case 1:
return int(d.r.readn1())
case 2:
return int(d.r.readUint16())
case 3:
ui := uint64(d.r.readUint32())
checkOverflow(ui, 0, intBitsize)
return int(ui)
case 4:
ui := d.r.readUint64()
checkOverflow(ui, 0, intBitsize)
return int(ui)
}
decErr("decLen: Cannot read length: bd%8 must be in range 0..4. Got: %d", d.bd%8)
return -1
}
func (d *simpleDecDriver) decodeString() (s string) {
s = string(d.r.readn(d.decLen()))
d.bdRead = false
return
}
func (d *simpleDecDriver) decodeBytes(bs []byte) (bsOut []byte, changed bool) {
if clen := d.decLen(); clen > 0 {
// if no contents in stream, don't update the passed byteslice
if len(bs) != clen {
if len(bs) > clen {
bs = bs[:clen]
} else {
bs = make([]byte, clen)
}
bsOut = bs
changed = true
}
d.r.readb(bs)
}
d.bdRead = false
return
}
func (d *simpleDecDriver) decodeExt(verifyTag bool, tag byte) (xtag byte, xbs []byte) {
switch d.bd {
case simpleVdExt, simpleVdExt + 1, simpleVdExt + 2, simpleVdExt + 3, simpleVdExt + 4:
l := d.decLen()
xtag = d.r.readn1()
if verifyTag && xtag != tag {
decErr("Wrong extension tag. Got %b. Expecting: %v", xtag, tag)
}
xbs = d.r.readn(l)
case simpleVdByteArray, simpleVdByteArray + 1, simpleVdByteArray + 2, simpleVdByteArray + 3, simpleVdByteArray + 4:
xbs, _ = d.decodeBytes(nil)
default:
decErr("Invalid d.vd for extensions (Expecting extensions or byte array). Got: 0x%x", d.bd)
}
d.bdRead = false
return
}
func (d *simpleDecDriver) decodeNaked() (v interface{}, vt valueType, decodeFurther bool) {
d.initReadNext()
switch d.bd {
case simpleVdNil:
vt = valueTypeNil
case simpleVdFalse:
vt = valueTypeBool
v = false
case simpleVdTrue:
vt = valueTypeBool
v = true
case simpleVdPosInt, simpleVdPosInt + 1, simpleVdPosInt + 2, simpleVdPosInt + 3:
vt = valueTypeUint
ui, _, _ := d.decIntAny()
v = ui
case simpleVdNegInt, simpleVdNegInt + 1, simpleVdNegInt + 2, simpleVdNegInt + 3:
vt = valueTypeInt
_, i, _ := d.decIntAny()
v = i
case simpleVdFloat32:
vt = valueTypeFloat
v = d.decodeFloat(true)
case simpleVdFloat64:
vt = valueTypeFloat
v = d.decodeFloat(false)
case simpleVdString, simpleVdString + 1, simpleVdString + 2, simpleVdString + 3, simpleVdString + 4:
vt = valueTypeString
v = d.decodeString()
case simpleVdByteArray, simpleVdByteArray + 1, simpleVdByteArray + 2, simpleVdByteArray + 3, simpleVdByteArray + 4:
vt = valueTypeBytes
v, _ = d.decodeBytes(nil)
case simpleVdExt, simpleVdExt + 1, simpleVdExt + 2, simpleVdExt + 3, simpleVdExt + 4:
vt = valueTypeExt
l := d.decLen()
var re RawExt
re.Tag = d.r.readn1()
re.Data = d.r.readn(l)
v = &re
vt = valueTypeExt
case simpleVdArray, simpleVdArray + 1, simpleVdArray + 2, simpleVdArray + 3, simpleVdArray + 4:
vt = valueTypeArray
decodeFurther = true
case simpleVdMap, simpleVdMap + 1, simpleVdMap + 2, simpleVdMap + 3, simpleVdMap + 4:
vt = valueTypeMap
decodeFurther = true
default:
decErr("decodeNaked: Unrecognized d.vd: 0x%x", d.bd)
}
if !decodeFurther {
d.bdRead = false
}
return
}
//------------------------------------
// SimpleHandle is a Handle for a very simple encoding format.
//
// simple is a simplistic codec similar to binc, but not as compact.
// - Encoding of a value is always preceeded by the descriptor byte (bd)
// - True, false, nil are encoded fully in 1 byte (the descriptor)
// - Integers (intXXX, uintXXX) are encoded in 1, 2, 4 or 8 bytes (plus a descriptor byte).
// There are positive (uintXXX and intXXX >= 0) and negative (intXXX < 0) integers.
// - Floats are encoded in 4 or 8 bytes (plus a descriptor byte)
// - Lenght of containers (strings, bytes, array, map, extensions)
// are encoded in 0, 1, 2, 4 or 8 bytes.
// Zero-length containers have no length encoded.
// For others, the number of bytes is given by pow(2, bd%3)
// - maps are encoded as [bd] [length] [[key][value]]...
// - arrays are encoded as [bd] [length] [value]...
// - extensions are encoded as [bd] [length] [tag] [byte]...
// - strings/bytearrays are encoded as [bd] [length] [byte]...
//
// The full spec will be published soon.
type SimpleHandle struct {
BasicHandle
}
func (h *SimpleHandle) newEncDriver(w encWriter) encDriver {
return &simpleEncDriver{w: w, h: h}
}
func (h *SimpleHandle) newDecDriver(r decReader) decDriver {
return &simpleDecDriver{r: r, h: h}
}
func (_ *SimpleHandle) writeExt() bool {
return true
}
func (h *SimpleHandle) getBasicHandle() *BasicHandle {
return &h.BasicHandle
}
var _ decDriver = (*simpleDecDriver)(nil)
var _ encDriver = (*simpleEncDriver)(nil)

View file

@ -0,0 +1,193 @@
// Copyright (c) 2012, 2013 Ugorji Nwoke. All rights reserved.
// Use of this source code is governed by a BSD-style license found in the LICENSE file.
package codec
import (
"time"
)
var (
timeDigits = [...]byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}
)
// EncodeTime encodes a time.Time as a []byte, including
// information on the instant in time and UTC offset.
//
// Format Description
//
// A timestamp is composed of 3 components:
//
// - secs: signed integer representing seconds since unix epoch
// - nsces: unsigned integer representing fractional seconds as a
// nanosecond offset within secs, in the range 0 <= nsecs < 1e9
// - tz: signed integer representing timezone offset in minutes east of UTC,
// and a dst (daylight savings time) flag
//
// When encoding a timestamp, the first byte is the descriptor, which
// defines which components are encoded and how many bytes are used to
// encode secs and nsecs components. *If secs/nsecs is 0 or tz is UTC, it
// is not encoded in the byte array explicitly*.
//
// Descriptor 8 bits are of the form `A B C DDD EE`:
// A: Is secs component encoded? 1 = true
// B: Is nsecs component encoded? 1 = true
// C: Is tz component encoded? 1 = true
// DDD: Number of extra bytes for secs (range 0-7).
// If A = 1, secs encoded in DDD+1 bytes.
// If A = 0, secs is not encoded, and is assumed to be 0.
// If A = 1, then we need at least 1 byte to encode secs.
// DDD says the number of extra bytes beyond that 1.
// E.g. if DDD=0, then secs is represented in 1 byte.
// if DDD=2, then secs is represented in 3 bytes.
// EE: Number of extra bytes for nsecs (range 0-3).
// If B = 1, nsecs encoded in EE+1 bytes (similar to secs/DDD above)
//
// Following the descriptor bytes, subsequent bytes are:
//
// secs component encoded in `DDD + 1` bytes (if A == 1)
// nsecs component encoded in `EE + 1` bytes (if B == 1)
// tz component encoded in 2 bytes (if C == 1)
//
// secs and nsecs components are integers encoded in a BigEndian
// 2-complement encoding format.
//
// tz component is encoded as 2 bytes (16 bits). Most significant bit 15 to
// Least significant bit 0 are described below:
//
// Timezone offset has a range of -12:00 to +14:00 (ie -720 to +840 minutes).
// Bit 15 = have\_dst: set to 1 if we set the dst flag.
// Bit 14 = dst\_on: set to 1 if dst is in effect at the time, or 0 if not.
// Bits 13..0 = timezone offset in minutes. It is a signed integer in Big Endian format.
//
func encodeTime(t time.Time) []byte {
//t := rv.Interface().(time.Time)
tsecs, tnsecs := t.Unix(), t.Nanosecond()
var (
bd byte
btmp [8]byte
bs [16]byte
i int = 1
)
l := t.Location()
if l == time.UTC {
l = nil
}
if tsecs != 0 {
bd = bd | 0x80
bigen.PutUint64(btmp[:], uint64(tsecs))
f := pruneSignExt(btmp[:], tsecs >= 0)
bd = bd | (byte(7-f) << 2)
copy(bs[i:], btmp[f:])
i = i + (8 - f)
}
if tnsecs != 0 {
bd = bd | 0x40
bigen.PutUint32(btmp[:4], uint32(tnsecs))
f := pruneSignExt(btmp[:4], true)
bd = bd | byte(3-f)
copy(bs[i:], btmp[f:4])
i = i + (4 - f)
}
if l != nil {
bd = bd | 0x20
// Note that Go Libs do not give access to dst flag.
_, zoneOffset := t.Zone()
//zoneName, zoneOffset := t.Zone()
zoneOffset /= 60
z := uint16(zoneOffset)
bigen.PutUint16(btmp[:2], z)
// clear dst flags
bs[i] = btmp[0] & 0x3f
bs[i+1] = btmp[1]
i = i + 2
}
bs[0] = bd
return bs[0:i]
}
// DecodeTime decodes a []byte into a time.Time.
func decodeTime(bs []byte) (tt time.Time, err error) {
bd := bs[0]
var (
tsec int64
tnsec uint32
tz uint16
i byte = 1
i2 byte
n byte
)
if bd&(1<<7) != 0 {
var btmp [8]byte
n = ((bd >> 2) & 0x7) + 1
i2 = i + n
copy(btmp[8-n:], bs[i:i2])
//if first bit of bs[i] is set, then fill btmp[0..8-n] with 0xff (ie sign extend it)
if bs[i]&(1<<7) != 0 {
copy(btmp[0:8-n], bsAll0xff)
//for j,k := byte(0), 8-n; j < k; j++ { btmp[j] = 0xff }
}
i = i2
tsec = int64(bigen.Uint64(btmp[:]))
}
if bd&(1<<6) != 0 {
var btmp [4]byte
n = (bd & 0x3) + 1
i2 = i + n
copy(btmp[4-n:], bs[i:i2])
i = i2
tnsec = bigen.Uint32(btmp[:])
}
if bd&(1<<5) == 0 {
tt = time.Unix(tsec, int64(tnsec)).UTC()
return
}
// In stdlib time.Parse, when a date is parsed without a zone name, it uses "" as zone name.
// However, we need name here, so it can be shown when time is printed.
// Zone name is in form: UTC-08:00.
// Note that Go Libs do not give access to dst flag, so we ignore dst bits
i2 = i + 2
tz = bigen.Uint16(bs[i:i2])
i = i2
// sign extend sign bit into top 2 MSB (which were dst bits):
if tz&(1<<13) == 0 { // positive
tz = tz & 0x3fff //clear 2 MSBs: dst bits
} else { // negative
tz = tz | 0xc000 //set 2 MSBs: dst bits
//tzname[3] = '-' (TODO: verify. this works here)
}
tzint := int16(tz)
if tzint == 0 {
tt = time.Unix(tsec, int64(tnsec)).UTC()
} else {
// For Go Time, do not use a descriptive timezone.
// It's unnecessary, and makes it harder to do a reflect.DeepEqual.
// The Offset already tells what the offset should be, if not on UTC and unknown zone name.
// var zoneName = timeLocUTCName(tzint)
tt = time.Unix(tsec, int64(tnsec)).In(time.FixedZone("", int(tzint)*60))
}
return
}
func timeLocUTCName(tzint int16) string {
if tzint == 0 {
return "UTC"
}
var tzname = []byte("UTC+00:00")
//tzname := fmt.Sprintf("UTC%s%02d:%02d", tzsign, tz/60, tz%60) //perf issue using Sprintf. inline below.
//tzhr, tzmin := tz/60, tz%60 //faster if u convert to int first
var tzhr, tzmin int16
if tzint < 0 {
tzname[3] = '-' // (TODO: verify. this works here)
tzhr, tzmin = -tzint/60, (-tzint)%60
} else {
tzhr, tzmin = tzint/60, tzint%60
}
tzname[4] = timeDigits[tzhr/10]
tzname[5] = timeDigits[tzhr%10]
tzname[7] = timeDigits[tzmin/10]
tzname[8] = timeDigits[tzmin%10]
return string(tzname)
//return time.FixedZone(string(tzname), int(tzint)*60)
}

View file

@ -0,0 +1,25 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
.vagrant/

View file

@ -0,0 +1,354 @@
Mozilla Public License, version 2.0
1. Definitions
1.1. “Contributor”
means each individual or legal entity that creates, contributes to the
creation of, or owns Covered Software.
1.2. “Contributor Version”
means the combination of the Contributions of others (if any) used by a
Contributor and that particular Contributors Contribution.
1.3. “Contribution”
means Covered Software of a particular Contributor.
1.4. “Covered Software”
means Source Code Form to which the initial Contributor has attached the
notice in Exhibit A, the Executable Form of such Source Code Form, and
Modifications of such Source Code Form, in each case including portions
thereof.
1.5. “Incompatible With Secondary Licenses”
means
a. that the initial Contributor has attached the notice described in
Exhibit B to the Covered Software; or
b. that the Covered Software was made available under the terms of version
1.1 or earlier of the License, but not also under the terms of a
Secondary License.
1.6. “Executable Form”
means any form of the work other than Source Code Form.
1.7. “Larger Work”
means a work that combines Covered Software with other material, in a separate
file or files, that is not Covered Software.
1.8. “License”
means this document.
1.9. “Licensable”
means having the right to grant, to the maximum extent possible, whether at the
time of the initial grant or subsequently, any and all of the rights conveyed by
this License.
1.10. “Modifications”
means any of the following:
a. any file in Source Code Form that results from an addition to, deletion
from, or modification of the contents of Covered Software; or
b. any new file in Source Code Form that contains any Covered Software.
1.11. “Patent Claims” of a Contributor
means any patent claim(s), including without limitation, method, process,
and apparatus claims, in any patent Licensable by such Contributor that
would be infringed, but for the grant of the License, by the making,
using, selling, offering for sale, having made, import, or transfer of
either its Contributions or its Contributor Version.
1.12. “Secondary License”
means either the GNU General Public License, Version 2.0, the GNU Lesser
General Public License, Version 2.1, the GNU Affero General Public
License, Version 3.0, or any later versions of those licenses.
1.13. “Source Code Form”
means the form of the work preferred for making modifications.
1.14. “You” (or “Your”)
means an individual or a legal entity exercising rights under this
License. For legal entities, “You” includes any entity that controls, is
controlled by, or is under common control with You. For purposes of this
definition, “control” means (a) the power, direct or indirect, to cause
the direction or management of such entity, whether by contract or
otherwise, or (b) ownership of more than fifty percent (50%) of the
outstanding shares or beneficial ownership of such entity.
2. License Grants and Conditions
2.1. Grants
Each Contributor hereby grants You a world-wide, royalty-free,
non-exclusive license:
a. under intellectual property rights (other than patent or trademark)
Licensable by such Contributor to use, reproduce, make available,
modify, display, perform, distribute, and otherwise exploit its
Contributions, either on an unmodified basis, with Modifications, or as
part of a Larger Work; and
b. under Patent Claims of such Contributor to make, use, sell, offer for
sale, have made, import, and otherwise transfer either its Contributions
or its Contributor Version.
2.2. Effective Date
The licenses granted in Section 2.1 with respect to any Contribution become
effective for each Contribution on the date the Contributor first distributes
such Contribution.
2.3. Limitations on Grant Scope
The licenses granted in this Section 2 are the only rights granted under this
License. No additional rights or licenses will be implied from the distribution
or licensing of Covered Software under this License. Notwithstanding Section
2.1(b) above, no patent license is granted by a Contributor:
a. for any code that a Contributor has removed from Covered Software; or
b. for infringements caused by: (i) Your and any other third partys
modifications of Covered Software, or (ii) the combination of its
Contributions with other software (except as part of its Contributor
Version); or
c. under Patent Claims infringed by Covered Software in the absence of its
Contributions.
This License does not grant any rights in the trademarks, service marks, or
logos of any Contributor (except as may be necessary to comply with the
notice requirements in Section 3.4).
2.4. Subsequent Licenses
No Contributor makes additional grants as a result of Your choice to
distribute the Covered Software under a subsequent version of this License
(see Section 10.2) or under the terms of a Secondary License (if permitted
under the terms of Section 3.3).
2.5. Representation
Each Contributor represents that the Contributor believes its Contributions
are its original creation(s) or it has sufficient rights to grant the
rights to its Contributions conveyed by this License.
2.6. Fair Use
This License is not intended to limit any rights You have under applicable
copyright doctrines of fair use, fair dealing, or other equivalents.
2.7. Conditions
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in
Section 2.1.
3. Responsibilities
3.1. Distribution of Source Form
All distribution of Covered Software in Source Code Form, including any
Modifications that You create or to which You contribute, must be under the
terms of this License. You must inform recipients that the Source Code Form
of the Covered Software is governed by the terms of this License, and how
they can obtain a copy of this License. You may not attempt to alter or
restrict the recipients rights in the Source Code Form.
3.2. Distribution of Executable Form
If You distribute Covered Software in Executable Form then:
a. such Covered Software must also be made available in Source Code Form,
as described in Section 3.1, and You must inform recipients of the
Executable Form how they can obtain a copy of such Source Code Form by
reasonable means in a timely manner, at a charge no more than the cost
of distribution to the recipient; and
b. You may distribute such Executable Form under the terms of this License,
or sublicense it under different terms, provided that the license for
the Executable Form does not attempt to limit or alter the recipients
rights in the Source Code Form under this License.
3.3. Distribution of a Larger Work
You may create and distribute a Larger Work under terms of Your choice,
provided that You also comply with the requirements of this License for the
Covered Software. If the Larger Work is a combination of Covered Software
with a work governed by one or more Secondary Licenses, and the Covered
Software is not Incompatible With Secondary Licenses, this License permits
You to additionally distribute such Covered Software under the terms of
such Secondary License(s), so that the recipient of the Larger Work may, at
their option, further distribute the Covered Software under the terms of
either this License or such Secondary License(s).
3.4. Notices
You may not remove or alter the substance of any license notices (including
copyright notices, patent notices, disclaimers of warranty, or limitations
of liability) contained within the Source Code Form of the Covered
Software, except that You may alter any license notices to the extent
required to remedy known factual inaccuracies.
3.5. Application of Additional Terms
You may choose to offer, and to charge a fee for, warranty, support,
indemnity or liability obligations to one or more recipients of Covered
Software. However, You may do so only on Your own behalf, and not on behalf
of any Contributor. You must make it absolutely clear that any such
warranty, support, indemnity, or liability obligation is offered by You
alone, and You hereby agree to indemnify every Contributor for any
liability incurred by such Contributor as a result of warranty, support,
indemnity or liability terms You offer. You may include additional
disclaimers of warranty and limitations of liability specific to any
jurisdiction.
4. Inability to Comply Due to Statute or Regulation
If it is impossible for You to comply with any of the terms of this License
with respect to some or all of the Covered Software due to statute, judicial
order, or regulation then You must: (a) comply with the terms of this License
to the maximum extent possible; and (b) describe the limitations and the code
they affect. Such description must be placed in a text file included with all
distributions of the Covered Software under this License. Except to the
extent prohibited by statute or regulation, such description must be
sufficiently detailed for a recipient of ordinary skill to be able to
understand it.
5. Termination
5.1. The rights granted under this License will terminate automatically if You
fail to comply with any of its terms. However, if You become compliant,
then the rights granted under this License from a particular Contributor
are reinstated (a) provisionally, unless and until such Contributor
explicitly and finally terminates Your grants, and (b) on an ongoing basis,
if such Contributor fails to notify You of the non-compliance by some
reasonable means prior to 60 days after You have come back into compliance.
Moreover, Your grants from a particular Contributor are reinstated on an
ongoing basis if such Contributor notifies You of the non-compliance by
some reasonable means, this is the first time You have received notice of
non-compliance with this License from such Contributor, and You become
compliant prior to 30 days after Your receipt of the notice.
5.2. If You initiate litigation against any entity by asserting a patent
infringement claim (excluding declaratory judgment actions, counter-claims,
and cross-claims) alleging that a Contributor Version directly or
indirectly infringes any patent, then the rights granted to You by any and
all Contributors for the Covered Software under Section 2.1 of this License
shall terminate.
5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user
license agreements (excluding distributors and resellers) which have been
validly granted by You or Your distributors under this License prior to
termination shall survive termination.
6. Disclaimer of Warranty
Covered Software is provided under this License on an “as is” basis, without
warranty of any kind, either expressed, implied, or statutory, including,
without limitation, warranties that the Covered Software is free of defects,
merchantable, fit for a particular purpose or non-infringing. The entire
risk as to the quality and performance of the Covered Software is with You.
Should any Covered Software prove defective in any respect, You (not any
Contributor) assume the cost of any necessary servicing, repair, or
correction. This disclaimer of warranty constitutes an essential part of this
License. No use of any Covered Software is authorized under this License
except under this disclaimer.
7. Limitation of Liability
Under no circumstances and under no legal theory, whether tort (including
negligence), contract, or otherwise, shall any Contributor, or anyone who
distributes Covered Software as permitted above, be liable to You for any
direct, indirect, special, incidental, or consequential damages of any
character including, without limitation, damages for lost profits, loss of
goodwill, work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses, even if such party shall have been
informed of the possibility of such damages. This limitation of liability
shall not apply to liability for death or personal injury resulting from such
partys negligence to the extent applicable law prohibits such limitation.
Some jurisdictions do not allow the exclusion or limitation of incidental or
consequential damages, so this exclusion and limitation may not apply to You.
8. Litigation
Any litigation relating to this License may be brought only in the courts of
a jurisdiction where the defendant maintains its principal place of business
and such litigation shall be governed by laws of that jurisdiction, without
reference to its conflict-of-law provisions. Nothing in this Section shall
prevent a partys ability to bring cross-claims or counter-claims.
9. Miscellaneous
This License represents the complete agreement concerning the subject matter
hereof. If any provision of this License is held to be unenforceable, such
provision shall be reformed only to the extent necessary to make it
enforceable. Any law or regulation which provides that the language of a
contract shall be construed against the drafter shall not be used to construe
this License against a Contributor.
10. Versions of the License
10.1. New Versions
Mozilla Foundation is the license steward. Except as provided in Section
10.3, no one other than the license steward has the right to modify or
publish new versions of this License. Each version will be given a
distinguishing version number.
10.2. Effect of New Versions
You may distribute the Covered Software under the terms of the version of
the License under which You originally received the Covered Software, or
under the terms of any subsequent version published by the license
steward.
10.3. Modified Versions
If you create software not governed by this License, and you want to
create a new license for such software, you may create and use a modified
version of this License if you rename the license and remove any
references to the name of the license steward (except to note that such
modified license differs from this License).
10.4. Distributing Source Code Form that is Incompatible With Secondary Licenses
If You choose to distribute Source Code Form that is Incompatible With
Secondary Licenses under the terms of this version of the License, the
notice described in Exhibit B of this License must be attached.
Exhibit A - Source Code Form License Notice
This Source Code Form is subject to the
terms of the Mozilla Public License, v.
2.0. If a copy of the MPL was not
distributed with this file, You can
obtain one at
http://mozilla.org/MPL/2.0/.
If it is not possible or desirable to put the notice in a particular file, then
You may include the notice in a location (such as a LICENSE file in a relevant
directory) where a recipient would be likely to look for such a notice.
You may add additional accurate notices of copyright ownership.
Exhibit B - “Incompatible With Secondary Licenses” Notice
This Source Code Form is “Incompatible
With Secondary Licenses”, as defined by
the Mozilla Public License, v. 2.0.

View file

@ -0,0 +1,14 @@
test: subnet
go test ./...
integ: subnet
INTEG_TESTS=yes go test ./...
subnet:
./test/setup_subnet.sh
cov:
gocov test github.com/hashicorp/memberlist | gocov-html > /tmp/coverage.html
open /tmp/coverage.html
.PNONY: test cov integ

View file

@ -0,0 +1,137 @@
# memberlist
memberlist is a [Go](http://www.golang.org) library that manages cluster
membership and member failure detection using a gossip based protocol.
The use cases for such a library are far-reaching: all distributed systems
require membership, and memberlist is a re-usable solution to managing
cluster membership and node failure detection.
memberlist is eventually consistent but converges quickly on average.
The speed at which it converges can be heavily tuned via various knobs
on the protocol. Node failures are detected and network partitions are partially
tolerated by attempting to communicate to potentially dead nodes through
multiple routes.
## Building
If you wish to build memberlist you'll need Go version 1.2+ installed.
Please check your installation with:
```
go version
```
## Usage
Memberlist is surprisingly simple to use. An example is shown below:
```go
/* Create the initial memberlist from a safe configuration.
Please reference the godoc for other default config types.
http://godoc.org/github.com/hashicorp/memberlist#Config
*/
list, err := memberlist.Create(memberlist.DefaultLocalConfig())
if err != nil {
panic("Failed to create memberlist: " + err.Error())
}
// Join an existing cluster by specifying at least one known member.
n, err := list.Join([]string{"1.2.3.4"})
if err != nil {
panic("Failed to join cluster: " + err.Error())
}
// Ask for members of the cluster
for _, member := range list.Members() {
fmt.Printf("Member: %s %s\n", member.Name, member.Addr)
}
// Continue doing whatever you need, memberlist will maintain membership
// information in the background. Delegates can be used for receiving
// events when members join or leave.
```
The most difficult part of memberlist is configuring it since it has many
available knobs in order to tune state propagation delay and convergence times.
Memberlist provides a default configuration that offers a good starting point,
but errs on the side of caution, choosing values that are optimized for
higher convergence at the cost of higher bandwidth usage.
For complete documentation, see the associated [Godoc](http://godoc.org/github.com/hashicorp/memberlist).
## Protocol
memberlist is based on ["SWIM: Scalable Weakly-consistent Infection-style Process Group Membership Protocol"](http://www.cs.cornell.edu/~asdas/research/dsn02-swim.pdf),
with a few minor adaptations, mostly to increase propogation speed and
convergence rate.
A high level overview of the memberlist protocol (based on SWIM) is
described below, but for details please read the full
[SWIM paper](http://www.cs.cornell.edu/~asdas/research/dsn02-swim.pdf)
followed by the memberlist source. We welcome any questions related
to the protocol on our issue tracker.
### Protocol Description
memberlist begins by joining an existing cluster or starting a new
cluster. If starting a new cluster, additional nodes are expected to join
it. New nodes in an existing cluster must be given the address of at
least one existing member in order to join the cluster. The new member
does a full state sync with the existing member over TCP and begins gossiping its
existence to the cluster.
Gossip is done over UDP to a with a configurable but fixed fanout and interval.
This ensures that network usage is constant with regards to number of nodes, as opposed to
exponential growth that can occur with traditional heartbeat mechanisms.
Complete state exchanges with a random node are done periodically over
TCP, but much less often than gossip messages. This increases the likelihood
that the membership list converges properly since the full state is exchanged
and merged. The interval between full state exchanges is configurable or can
be disabled entirely.
Failure detection is done by periodic random probing using a configurable interval.
If the node fails to ack within a reasonable time (typically some multiple
of RTT), then an indirect probe is attempted. An indirect probe asks a
configurable number of random nodes to probe the same node, in case there
are network issues causing our own node to fail the probe. If both our
probe and the indirect probes fail within a reasonable time, then the
node is marked "suspicious" and this knowledge is gossiped to the cluster.
A suspicious node is still considered a member of cluster. If the suspect member
of the cluster does not disputes the suspicion within a configurable period of
time, the node is finally considered dead, and this state is then gossiped
to the cluster.
This is a brief and incomplete description of the protocol. For a better idea,
please read the
[SWIM paper](http://www.cs.cornell.edu/~asdas/research/dsn02-swim.pdf)
in its entirety, along with the memberlist source code.
### Changes from SWIM
As mentioned earlier, the memberlist protocol is based on SWIM but includes
minor changes, mostly to increase propogation speed and convergence rates.
The changes from SWIM are noted here:
* memberlist does a full state sync over TCP periodically. SWIM only propagates
changes over gossip. While both eventually reach convergence, the full state
sync increases the likelihood that nodes are fully converged more quickly,
at the expense of more bandwidth usage. This feature can be totally disabled
if you wish.
* memberlist has a dedicated gossip layer separate from the failure detection
protocol. SWIM only piggybacks gossip messages on top of probe/ack messages.
memberlist also piggybacks gossip messages on top of probe/ack messages, but
also will periodically send out dedicated gossip messages on their own. This
feature lets you have a higher gossip rate (for example once per 200ms)
and a slower failure detection rate (such as once per second), resulting
in overall faster convergence rates and data propogation speeds. This feature
can be totally disabed as well, if you wish.
* memberlist stores around the state of dead nodes for a set amount of time,
so that when full syncs are requested, the requester also receives information
about dead nodes. Because SWIM doesn't do full syncs, SWIM deletes dead node
state immediately upon learning that the node is dead. This change again helps
the cluster converge more quickly.

View file

@ -0,0 +1,100 @@
package memberlist
/*
The broadcast mechanism works by maintaining a sorted list of messages to be
sent out. When a message is to be broadcast, the retransmit count
is set to zero and appended to the queue. The retransmit count serves
as the "priority", ensuring that newer messages get sent first. Once
a message hits the retransmit limit, it is removed from the queue.
Additionally, older entries can be invalidated by new messages that
are contradictory. For example, if we send "{suspect M1 inc: 1},
then a following {alive M1 inc: 2} will invalidate that message
*/
type memberlistBroadcast struct {
node string
msg []byte
notify chan struct{}
}
func (b *memberlistBroadcast) Invalidates(other Broadcast) bool {
// Check if that broadcast is a memberlist type
mb, ok := other.(*memberlistBroadcast)
if !ok {
return false
}
// Invalidates any message about the same node
return b.node == mb.node
}
func (b *memberlistBroadcast) Message() []byte {
return b.msg
}
func (b *memberlistBroadcast) Finished() {
select {
case b.notify <- struct{}{}:
default:
}
}
// encodeAndBroadcast encodes a message and enqueues it for broadcast. Fails
// silently if there is an encoding error.
func (m *Memberlist) encodeAndBroadcast(node string, msgType messageType, msg interface{}) {
m.encodeBroadcastNotify(node, msgType, msg, nil)
}
// encodeBroadcastNotify encodes a message and enqueues it for broadcast
// and notifies the given channel when transmission is finished. Fails
// silently if there is an encoding error.
func (m *Memberlist) encodeBroadcastNotify(node string, msgType messageType, msg interface{}, notify chan struct{}) {
buf, err := encode(msgType, msg)
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to encode message for broadcast: %s", err)
} else {
m.queueBroadcast(node, buf.Bytes(), notify)
}
}
// queueBroadcast is used to start dissemination of a message. It will be
// sent up to a configured number of times. The message could potentially
// be invalidated by a future message about the same node
func (m *Memberlist) queueBroadcast(node string, msg []byte, notify chan struct{}) {
b := &memberlistBroadcast{node, msg, notify}
m.broadcasts.QueueBroadcast(b)
}
// getBroadcasts is used to return a slice of broadcasts to send up to
// a maximum byte size, while imposing a per-broadcast overhead. This is used
// to fill a UDP packet with piggybacked data
func (m *Memberlist) getBroadcasts(overhead, limit int) [][]byte {
// Get memberlist messages first
toSend := m.broadcasts.GetBroadcasts(overhead, limit)
// Check if the user has anything to broadcast
d := m.config.Delegate
if d != nil {
// Determine the bytes used already
bytesUsed := 0
for _, msg := range toSend {
bytesUsed += len(msg) + overhead
}
// Check space remaining for user messages
avail := limit - bytesUsed
if avail > overhead+userMsgOverhead {
userMsgs := d.GetBroadcasts(overhead+userMsgOverhead, avail)
// Frame each user message
for _, msg := range userMsgs {
buf := make([]byte, 1, len(msg)+1)
buf[0] = byte(userMsg)
buf = append(buf, msg...)
toSend = append(toSend, buf)
}
}
}
return toSend
}

View file

@ -0,0 +1,209 @@
package memberlist
import (
"io"
"os"
"time"
)
type Config struct {
// The name of this node. This must be unique in the cluster.
Name string
// Configuration related to what address to bind to and ports to
// listen on. The port is used for both UDP and TCP gossip.
// It is assumed other nodes are running on this port, but they
// do not need to.
BindAddr string
BindPort int
// Configuration related to what address to advertise to other
// cluster members. Used for nat traversal.
AdvertiseAddr string
AdvertisePort int
// ProtocolVersion is the configured protocol version that we
// will _speak_. This must be between ProtocolVersionMin and
// ProtocolVersionMax.
ProtocolVersion uint8
// TCPTimeout is the timeout for establishing a TCP connection with
// a remote node for a full state sync.
TCPTimeout time.Duration
// IndirectChecks is the number of nodes that will be asked to perform
// an indirect probe of a node in the case a direct probe fails. Memberlist
// waits for an ack from any single indirect node, so increasing this
// number will increase the likelihood that an indirect probe will succeed
// at the expense of bandwidth.
IndirectChecks int
// RetransmitMult is the multiplier for the number of retransmissions
// that are attempted for messages broadcasted over gossip. The actual
// count of retransmissions is calculated using the formula:
//
// Retransmits = RetransmitMult * log(N+1)
//
// This allows the retransmits to scale properly with cluster size. The
// higher the multiplier, the more likely a failed broadcast is to converge
// at the expense of increased bandwidth.
RetransmitMult int
// SuspicionMult is the multiplier for determining the time an
// inaccessible node is considered suspect before declaring it dead.
// The actual timeout is calculated using the formula:
//
// SuspicionTimeout = SuspicionMult * log(N+1) * ProbeInterval
//
// This allows the timeout to scale properly with expected propagation
// delay with a larger cluster size. The higher the multiplier, the longer
// an inaccessible node is considered part of the cluster before declaring
// it dead, giving that suspect node more time to refute if it is indeed
// still alive.
SuspicionMult int
// PushPullInterval is the interval between complete state syncs.
// Complete state syncs are done with a single node over TCP and are
// quite expensive relative to standard gossiped messages. Setting this
// to zero will disable state push/pull syncs completely.
//
// Setting this interval lower (more frequent) will increase convergence
// speeds across larger clusters at the expense of increased bandwidth
// usage.
PushPullInterval time.Duration
// ProbeInterval and ProbeTimeout are used to configure probing
// behavior for memberlist.
//
// ProbeInterval is the interval between random node probes. Setting
// this lower (more frequent) will cause the memberlist cluster to detect
// failed nodes more quickly at the expense of increased bandwidth usage.
//
// ProbeTimeout is the timeout to wait for an ack from a probed node
// before assuming it is unhealthy. This should be set to 99-percentile
// of RTT (round-trip time) on your network.
ProbeInterval time.Duration
ProbeTimeout time.Duration
// GossipInterval and GossipNodes are used to configure the gossip
// behavior of memberlist.
//
// GossipInterval is the interval between sending messages that need
// to be gossiped that haven't been able to piggyback on probing messages.
// If this is set to zero, non-piggyback gossip is disabled. By lowering
// this value (more frequent) gossip messages are propagated across
// the cluster more quickly at the expense of increased bandwidth.
//
// GossipNodes is the number of random nodes to send gossip messages to
// per GossipInterval. Increasing this number causes the gossip messages
// to propagate across the cluster more quickly at the expense of
// increased bandwidth.
GossipInterval time.Duration
GossipNodes int
// EnableCompression is used to control message compression. This can
// be used to reduce bandwidth usage at the cost of slightly more CPU
// utilization. This is only available starting at protocol version 1.
EnableCompression bool
// SecretKey is used to initialize the primary encryption key in a keyring.
// The primary encryption key is the only key used to encrypt messages and
// the first key used while attempting to decrypt messages. Providing a
// value for this primary key will enable message-level encryption and
// verification, and automatically install the key onto the keyring.
SecretKey []byte
// The keyring holds all of the encryption keys used internally. It is
// automatically initialized using the SecretKey and SecretKeys values.
Keyring *Keyring
// Delegate and Events are delegates for receiving and providing
// data to memberlist via callback mechanisms. For Delegate, see
// the Delegate interface. For Events, see the EventDelegate interface.
//
// The DelegateProtocolMin/Max are used to guarantee protocol-compatibility
// for any custom messages that the delegate might do (broadcasts,
// local/remote state, etc.). If you don't set these, then the protocol
// versions will just be zero, and version compliance won't be done.
Delegate Delegate
DelegateProtocolVersion uint8
DelegateProtocolMin uint8
DelegateProtocolMax uint8
Events EventDelegate
Conflict ConflictDelegate
Merge MergeDelegate
// LogOutput is the writer where logs should be sent. If this is not
// set, logging will go to stderr by default.
LogOutput io.Writer
}
// DefaultLANConfig returns a sane set of configurations for Memberlist.
// It uses the hostname as the node name, and otherwise sets very conservative
// values that are sane for most LAN environments. The default configuration
// errs on the side on the side of caution, choosing values that are optimized
// for higher convergence at the cost of higher bandwidth usage. Regardless,
// these values are a good starting point when getting started with memberlist.
func DefaultLANConfig() *Config {
hostname, _ := os.Hostname()
return &Config{
Name: hostname,
BindAddr: "0.0.0.0",
BindPort: 7946,
AdvertiseAddr: "",
AdvertisePort: 7946,
ProtocolVersion: ProtocolVersionMax,
TCPTimeout: 10 * time.Second, // Timeout after 10 seconds
IndirectChecks: 3, // Use 3 nodes for the indirect ping
RetransmitMult: 4, // Retransmit a message 4 * log(N+1) nodes
SuspicionMult: 5, // Suspect a node for 5 * log(N+1) * Interval
PushPullInterval: 30 * time.Second, // Low frequency
ProbeTimeout: 500 * time.Millisecond, // Reasonable RTT time for LAN
ProbeInterval: 1 * time.Second, // Failure check every second
GossipNodes: 3, // Gossip to 3 nodes
GossipInterval: 200 * time.Millisecond, // Gossip more rapidly
EnableCompression: true, // Enable compression by default
SecretKey: nil,
Keyring: nil,
}
}
// DefaultWANConfig works like DefaultConfig, however it returns a configuration
// that is optimized for most WAN environments. The default configuration is
// still very conservative and errs on the side of caution.
func DefaultWANConfig() *Config {
conf := DefaultLANConfig()
conf.TCPTimeout = 30 * time.Second
conf.SuspicionMult = 6
conf.PushPullInterval = 60 * time.Second
conf.ProbeTimeout = 3 * time.Second
conf.ProbeInterval = 5 * time.Second
conf.GossipNodes = 4 // Gossip less frequently, but to an additional node
conf.GossipInterval = 500 * time.Millisecond
return conf
}
// DefaultLocalConfig works like DefaultConfig, however it returns a configuration
// that is optimized for a local loopback environments. The default configuration is
// still very conservative and errs on the side of caution.
func DefaultLocalConfig() *Config {
conf := DefaultLANConfig()
conf.TCPTimeout = time.Second
conf.IndirectChecks = 1
conf.RetransmitMult = 2
conf.SuspicionMult = 3
conf.PushPullInterval = 15 * time.Second
conf.ProbeTimeout = 200 * time.Millisecond
conf.ProbeInterval = time.Second
conf.GossipInterval = 100 * time.Millisecond
return conf
}
// Returns whether or not encryption is enabled
func (c *Config) EncryptionEnabled() bool {
return c.Keyring != nil && len(c.Keyring.GetKeys()) > 0
}

View file

@ -0,0 +1,10 @@
package memberlist
// ConflictDelegate is a used to inform a client that
// a node has attempted to join which would result in a
// name conflict. This happens if two clients are configured
// with the same name but different addresses.
type ConflictDelegate interface {
// NotifyConflict is invoked when a name conflict is detected
NotifyConflict(existing, other *Node)
}

View file

@ -0,0 +1,36 @@
package memberlist
// Delegate is the interface that clients must implement if they want to hook
// into the gossip layer of Memberlist. All the methods must be thread-safe,
// as they can and generally will be called concurrently.
type Delegate interface {
// NodeMeta is used to retrieve meta-data about the current node
// when broadcasting an alive message. It's length is limited to
// the given byte size. This metadata is available in the Node structure.
NodeMeta(limit int) []byte
// NotifyMsg is called when a user-data message is received.
// Care should be taken that this method does not block, since doing
// so would block the entire UDP packet receive loop. Additionally, the byte
// slice may be modified after the call returns, so it should be copied if needed.
NotifyMsg([]byte)
// GetBroadcasts is called when user data messages can be broadcast.
// It can return a list of buffers to send. Each buffer should assume an
// overhead as provided with a limit on the total byte size allowed.
// The total byte size of the resulting data to send must not exceed
// the limit.
GetBroadcasts(overhead, limit int) [][]byte
// LocalState is used for a TCP Push/Pull. This is sent to
// the remote side in addition to the membership information. Any
// data can be sent here. See MergeRemoteState as well. The `join`
// boolean indicates this is for a join instead of a push/pull.
LocalState(join bool) []byte
// MergeRemoteState is invoked after a TCP Push/Pull. This is the
// state received from the remote side and is the result of the
// remote side's LocalState call. The 'join'
// boolean indicates this is for a join instead of a push/pull.
MergeRemoteState(buf []byte, join bool)
}

View file

@ -0,0 +1,61 @@
package memberlist
// EventDelegate is a simpler delegate that is used only to receive
// notifications about members joining and leaving. The methods in this
// delegate may be called by multiple goroutines, but never concurrently.
// This allows you to reason about ordering.
type EventDelegate interface {
// NotifyJoin is invoked when a node is detected to have joined.
// The Node argument must not be modified.
NotifyJoin(*Node)
// NotifyLeave is invoked when a node is detected to have left.
// The Node argument must not be modified.
NotifyLeave(*Node)
// NotifyUpdate is invoked when a node is detected to have
// updated, usually involving the meta data. The Node argument
// must not be modified.
NotifyUpdate(*Node)
}
// ChannelEventDelegate is used to enable an application to receive
// events about joins and leaves over a channel instead of a direct
// function call.
//
// Care must be taken that events are processed in a timely manner from
// the channel, since this delegate will block until an event can be sent.
type ChannelEventDelegate struct {
Ch chan<- NodeEvent
}
// NodeEventType are the types of events that can be sent from the
// ChannelEventDelegate.
type NodeEventType int
const (
NodeJoin NodeEventType = iota
NodeLeave
NodeUpdate
)
// NodeEvent is a single event related to node activity in the memberlist.
// The Node member of this struct must not be directly modified. It is passed
// as a pointer to avoid unnecessary copies. If you wish to modify the node,
// make a copy first.
type NodeEvent struct {
Event NodeEventType
Node *Node
}
func (c *ChannelEventDelegate) NotifyJoin(n *Node) {
c.Ch <- NodeEvent{NodeJoin, n}
}
func (c *ChannelEventDelegate) NotifyLeave(n *Node) {
c.Ch <- NodeEvent{NodeLeave, n}
}
func (c *ChannelEventDelegate) NotifyUpdate(n *Node) {
c.Ch <- NodeEvent{NodeUpdate, n}
}

View file

@ -0,0 +1,144 @@
package memberlist
import (
"bytes"
"fmt"
"sync"
)
type Keyring struct {
// Keys stores the key data used during encryption and decryption. It is
// ordered in such a way where the first key (index 0) is the primary key,
// which is used for encrypting messages, and is the first key tried during
// message decryption.
keys [][]byte
// The keyring lock is used while performing IO operations on the keyring.
l sync.Mutex
}
// Init allocates substructures
func (k *Keyring) init() {
k.keys = make([][]byte, 0)
}
// NewKeyring constructs a new container for a set of encryption keys. The
// keyring contains all key data used internally by memberlist.
//
// While creating a new keyring, you must do one of:
// - Omit keys and primary key, effectively disabling encryption
// - Pass a set of keys plus the primary key
// - Pass only a primary key
//
// If only a primary key is passed, then it will be automatically added to the
// keyring. If creating a keyring with multiple keys, one key must be designated
// primary by passing it as the primaryKey. If the primaryKey does not exist in
// the list of secondary keys, it will be automatically added at position 0.
func NewKeyring(keys [][]byte, primaryKey []byte) (*Keyring, error) {
keyring := &Keyring{}
keyring.init()
if len(keys) > 0 || len(primaryKey) > 0 {
if len(primaryKey) == 0 {
return nil, fmt.Errorf("Empty primary key not allowed")
}
if err := keyring.AddKey(primaryKey); err != nil {
return nil, err
}
for _, key := range keys {
if err := keyring.AddKey(key); err != nil {
return nil, err
}
}
}
return keyring, nil
}
// AddKey will install a new key on the ring. Adding a key to the ring will make
// it available for use in decryption. If the key already exists on the ring,
// this function will just return noop.
func (k *Keyring) AddKey(key []byte) error {
// Encorce 16-byte key size
if len(key) != 16 {
return fmt.Errorf("key size must be 16 bytes")
}
// No-op if key is already installed
for _, installedKey := range k.keys {
if bytes.Equal(installedKey, key) {
return nil
}
}
keys := append(k.keys, key)
primaryKey := k.GetPrimaryKey()
if primaryKey == nil {
primaryKey = key
}
k.installKeys(keys, primaryKey)
return nil
}
// UseKey changes the key used to encrypt messages. This is the only key used to
// encrypt messages, so peers should know this key before this method is called.
func (k *Keyring) UseKey(key []byte) error {
for _, installedKey := range k.keys {
if bytes.Equal(key, installedKey) {
k.installKeys(k.keys, key)
return nil
}
}
return fmt.Errorf("Requested key is not in the keyring")
}
// RemoveKey drops a key from the keyring. This will return an error if the key
// requested for removal is currently at position 0 (primary key).
func (k *Keyring) RemoveKey(key []byte) error {
if bytes.Equal(key, k.keys[0]) {
return fmt.Errorf("Removing the primary key is not allowed")
}
for i, installedKey := range k.keys {
if bytes.Equal(key, installedKey) {
keys := append(k.keys[:i], k.keys[i+1:]...)
k.installKeys(keys, k.keys[0])
}
}
return nil
}
// installKeys will take out a lock on the keyring, and replace the keys with a
// new set of keys. The key indicated by primaryKey will be installed as the new
// primary key.
func (k *Keyring) installKeys(keys [][]byte, primaryKey []byte) {
k.l.Lock()
defer k.l.Unlock()
newKeys := [][]byte{primaryKey}
for _, key := range keys {
if !bytes.Equal(key, primaryKey) {
newKeys = append(newKeys, key)
}
}
k.keys = newKeys
}
// GetKeys returns the current set of keys on the ring.
func (k *Keyring) GetKeys() [][]byte {
k.l.Lock()
defer k.l.Unlock()
return k.keys
}
// GetPrimaryKey returns the key on the ring at position 0. This is the key used
// for encrypting messages, and is the first key tried for decrypting messages.
func (k *Keyring) GetPrimaryKey() (key []byte) {
k.l.Lock()
defer k.l.Unlock()
if len(k.keys) > 0 {
key = k.keys[0]
}
return
}

View file

@ -0,0 +1,525 @@
/*
memberlist is a library that manages cluster
membership and member failure detection using a gossip based protocol.
The use cases for such a library are far-reaching: all distributed systems
require membership, and memberlist is a re-usable solution to managing
cluster membership and node failure detection.
memberlist is eventually consistent but converges quickly on average.
The speed at which it converges can be heavily tuned via various knobs
on the protocol. Node failures are detected and network partitions are partially
tolerated by attempting to communicate to potentially dead nodes through
multiple routes.
*/
package memberlist
import (
"fmt"
"log"
"net"
"os"
"strconv"
"sync"
"time"
)
type Memberlist struct {
config *Config
shutdown bool
shutdownCh chan struct{}
leave bool
leaveBroadcast chan struct{}
udpListener *net.UDPConn
tcpListener *net.TCPListener
handoff chan msgHandoff
sequenceNum uint32 // Local sequence number
incarnation uint32 // Local incarnation number
nodeLock sync.RWMutex
nodes []*nodeState // Known nodes
nodeMap map[string]*nodeState // Maps Addr.String() -> NodeState
tickerLock sync.Mutex
tickers []*time.Ticker
stopTick chan struct{}
probeIndex int
ackLock sync.Mutex
ackHandlers map[uint32]*ackHandler
broadcasts *TransmitLimitedQueue
startStopLock sync.Mutex
logger *log.Logger
}
// newMemberlist creates the network listeners.
// Does not schedule execution of background maintenence.
func newMemberlist(conf *Config) (*Memberlist, error) {
if conf.ProtocolVersion < ProtocolVersionMin {
return nil, fmt.Errorf("Protocol version '%d' too low. Must be in range: [%d, %d]",
conf.ProtocolVersion, ProtocolVersionMin, ProtocolVersionMax)
} else if conf.ProtocolVersion > ProtocolVersionMax {
return nil, fmt.Errorf("Protocol version '%d' too high. Must be in range: [%d, %d]",
conf.ProtocolVersion, ProtocolVersionMin, ProtocolVersionMax)
}
if len(conf.SecretKey) > 0 {
if conf.Keyring == nil {
keyring, err := NewKeyring(nil, conf.SecretKey)
if err != nil {
return nil, err
}
conf.Keyring = keyring
} else {
if err := conf.Keyring.AddKey(conf.SecretKey); err != nil {
return nil, err
}
if err := conf.Keyring.UseKey(conf.SecretKey); err != nil {
return nil, err
}
}
}
tcpAddr := &net.TCPAddr{IP: net.ParseIP(conf.BindAddr), Port: conf.BindPort}
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
return nil, fmt.Errorf("Failed to start TCP listener. Err: %s", err)
}
udpAddr := &net.UDPAddr{IP: net.ParseIP(conf.BindAddr), Port: conf.BindPort}
udpLn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
tcpLn.Close()
return nil, fmt.Errorf("Failed to start UDP listener. Err: %s", err)
}
// Set the UDP receive window size
setUDPRecvBuf(udpLn)
if conf.LogOutput == nil {
conf.LogOutput = os.Stderr
}
logger := log.New(conf.LogOutput, "", log.LstdFlags)
m := &Memberlist{
config: conf,
shutdownCh: make(chan struct{}),
leaveBroadcast: make(chan struct{}, 1),
udpListener: udpLn,
tcpListener: tcpLn,
handoff: make(chan msgHandoff, 1024),
nodeMap: make(map[string]*nodeState),
ackHandlers: make(map[uint32]*ackHandler),
broadcasts: &TransmitLimitedQueue{RetransmitMult: conf.RetransmitMult},
logger: logger,
}
m.broadcasts.NumNodes = func() int { return len(m.nodes) }
go m.tcpListen()
go m.udpListen()
go m.udpHandler()
return m, nil
}
// Create will create a new Memberlist using the given configuration.
// This will not connect to any other node (see Join) yet, but will start
// all the listeners to allow other nodes to join this memberlist.
// After creating a Memberlist, the configuration given should not be
// modified by the user anymore.
func Create(conf *Config) (*Memberlist, error) {
m, err := newMemberlist(conf)
if err != nil {
return nil, err
}
if err := m.setAlive(); err != nil {
m.Shutdown()
return nil, err
}
m.schedule()
return m, nil
}
// Join is used to take an existing Memberlist and attempt to join a cluster
// by contacting all the given hosts and performing a state sync. Initially,
// the Memberlist only contains our own state, so doing this will cause
// remote nodes to become aware of the existence of this node, effectively
// joining the cluster.
//
// This returns the number of hosts successfully contacted and an error if
// none could be reached. If an error is returned, the node did not successfully
// join the cluster.
func (m *Memberlist) Join(existing []string) (int, error) {
// Attempt to join any of them
numSuccess := 0
var retErr error
for _, exist := range existing {
addrs, port, err := m.resolveAddr(exist)
if err != nil {
m.logger.Printf("[WARN] memberlist: Failed to resolve %s: %v", exist, err)
retErr = err
continue
}
for _, addr := range addrs {
if err := m.pushPullNode(addr, port, true); err != nil {
retErr = err
continue
}
numSuccess++
}
}
if numSuccess > 0 {
retErr = nil
}
return numSuccess, retErr
}
// resolveAddr is used to resolve the address into an address,
// port, and error. If no port is given, use the default
func (m *Memberlist) resolveAddr(hostStr string) ([][]byte, uint16, error) {
ips := make([][]byte, 0)
port := uint16(0)
host, sport, err := net.SplitHostPort(hostStr)
if ae, ok := err.(*net.AddrError); ok && ae.Err == "missing port in address" {
// error, port missing - we can solve this
port = uint16(m.config.BindPort)
host = hostStr
} else if err != nil {
// error, but not missing port
return ips, port, err
} else if lport, err := strconv.ParseUint(sport, 10, 16); err != nil {
// error, when parsing port
return ips, port, err
} else {
// no error
port = uint16(lport)
}
// Get the addresses that hostPort might resolve to
// ResolveTcpAddr requres ipv6 brackets to separate
// port numbers whereas ParseIP doesn't, but luckily
// SplitHostPort takes care of the brackets
if ip := net.ParseIP(host); ip == nil {
if pre, err := net.LookupIP(host); err == nil {
for _, ip := range pre {
ips = append(ips, ip)
}
} else {
return ips, port, err
}
} else {
ips = append(ips, ip)
}
return ips, port, nil
}
// setAlive is used to mark this node as being alive. This is the same
// as if we received an alive notification our own network channel for
// ourself.
func (m *Memberlist) setAlive() error {
var advertiseAddr []byte
var advertisePort int
if m.config.AdvertiseAddr != "" {
// If AdvertiseAddr is not empty, then advertise
// the given address and port.
ip := net.ParseIP(m.config.AdvertiseAddr)
if ip == nil {
return fmt.Errorf("Failed to parse advertise address!")
}
// Ensure IPv4 conversion if necessary
if ip4 := ip.To4(); ip4 != nil {
ip = ip4
}
advertiseAddr = ip
advertisePort = m.config.AdvertisePort
} else {
if m.config.BindAddr == "0.0.0.0" {
// Otherwise, if we're not bound to a specific IP,
//let's list the interfaces on this machine and use
// the first private IP we find.
addresses, err := net.InterfaceAddrs()
if err != nil {
return fmt.Errorf("Failed to get interface addresses! Err: %v", err)
}
// Find private IPv4 address
for _, rawAddr := range addresses {
var ip net.IP
switch addr := rawAddr.(type) {
case *net.IPAddr:
ip = addr.IP
case *net.IPNet:
ip = addr.IP
default:
continue
}
if ip.To4() == nil {
continue
}
if !isPrivateIP(ip.String()) {
continue
}
advertiseAddr = ip
break
}
// Failed to find private IP, error
if advertiseAddr == nil {
return fmt.Errorf("No private IP address found, and explicit IP not provided")
}
} else {
// Use the IP that we're bound to.
addr := m.tcpListener.Addr().(*net.TCPAddr)
advertiseAddr = addr.IP
}
advertisePort = m.config.BindPort
}
// Check if this is a public address without encryption
addrStr := net.IP(advertiseAddr).String()
if !isPrivateIP(addrStr) && !isLoopbackIP(addrStr) && !m.config.EncryptionEnabled() {
m.logger.Printf("[WARN] memberlist: Binding to public address without encryption!")
}
// Get the node meta data
var meta []byte
if m.config.Delegate != nil {
meta = m.config.Delegate.NodeMeta(MetaMaxSize)
if len(meta) > MetaMaxSize {
panic("Node meta data provided is longer than the limit")
}
}
a := alive{
Incarnation: m.nextIncarnation(),
Node: m.config.Name,
Addr: advertiseAddr,
Port: uint16(advertisePort),
Meta: meta,
Vsn: []uint8{
ProtocolVersionMin, ProtocolVersionMax, m.config.ProtocolVersion,
m.config.DelegateProtocolMin, m.config.DelegateProtocolMax,
m.config.DelegateProtocolVersion,
},
}
m.aliveNode(&a, nil, true)
return nil
}
// LocalNode is used to return the local Node
func (m *Memberlist) LocalNode() *Node {
m.nodeLock.RLock()
defer m.nodeLock.RUnlock()
state := m.nodeMap[m.config.Name]
return &state.Node
}
// UpdateNode is used to trigger re-advertising the local node. This is
// primarily used with a Delegate to support dynamic updates to the local
// meta data. This will block until the update message is successfully
// broadcasted to a member of the cluster, if any exist or until a specified
// timeout is reached.
func (m *Memberlist) UpdateNode(timeout time.Duration) error {
// Get the node meta data
var meta []byte
if m.config.Delegate != nil {
meta = m.config.Delegate.NodeMeta(MetaMaxSize)
if len(meta) > MetaMaxSize {
panic("Node meta data provided is longer than the limit")
}
}
// Get the existing node
m.nodeLock.RLock()
state := m.nodeMap[m.config.Name]
m.nodeLock.RUnlock()
// Format a new alive message
a := alive{
Incarnation: m.nextIncarnation(),
Node: m.config.Name,
Addr: state.Addr,
Port: state.Port,
Meta: meta,
Vsn: []uint8{
ProtocolVersionMin, ProtocolVersionMax, m.config.ProtocolVersion,
m.config.DelegateProtocolMin, m.config.DelegateProtocolMax,
m.config.DelegateProtocolVersion,
},
}
notifyCh := make(chan struct{})
m.aliveNode(&a, notifyCh, true)
// Wait for the broadcast or a timeout
if m.anyAlive() {
var timeoutCh <-chan time.Time
if timeout > 0 {
timeoutCh = time.After(timeout)
}
select {
case <-notifyCh:
case <-timeoutCh:
return fmt.Errorf("timeout waiting for update broadcast")
}
}
return nil
}
// SendTo is used to directly send a message to another node, without
// the use of the gossip mechanism. This will encode the message as a
// user-data message, which a delegate will receive through NotifyMsg
// The actual data is transmitted over UDP, which means this is a
// best-effort transmission mechanism, and the maximum size of the
// message is the size of a single UDP datagram, after compression
func (m *Memberlist) SendTo(to net.Addr, msg []byte) error {
// Encode as a user message
buf := make([]byte, 1, len(msg)+1)
buf[0] = byte(userMsg)
buf = append(buf, msg...)
// Send the message
return m.rawSendMsg(to, buf)
}
// Members returns a list of all known live nodes. The node structures
// returned must not be modified. If you wish to modify a Node, make a
// copy first.
func (m *Memberlist) Members() []*Node {
m.nodeLock.RLock()
defer m.nodeLock.RUnlock()
nodes := make([]*Node, 0, len(m.nodes))
for _, n := range m.nodes {
if n.State != stateDead {
nodes = append(nodes, &n.Node)
}
}
return nodes
}
// NumMembers returns the number of alive nodes currently known. Between
// the time of calling this and calling Members, the number of alive nodes
// may have changed, so this shouldn't be used to determine how many
// members will be returned by Members.
func (m *Memberlist) NumMembers() (alive int) {
m.nodeLock.RLock()
defer m.nodeLock.RUnlock()
for _, n := range m.nodes {
if n.State != stateDead {
alive++
}
}
return
}
// Leave will broadcast a leave message but will not shutdown the background
// listeners, meaning the node will continue participating in gossip and state
// updates.
//
// This will block until the leave message is successfully broadcasted to
// a member of the cluster, if any exist or until a specified timeout
// is reached.
//
// This method is safe to call multiple times, but must not be called
// after the cluster is already shut down.
func (m *Memberlist) Leave(timeout time.Duration) error {
m.startStopLock.Lock()
defer m.startStopLock.Unlock()
if m.shutdown {
panic("leave after shutdown")
}
if !m.leave {
m.leave = true
state, ok := m.nodeMap[m.config.Name]
if !ok {
m.logger.Printf("[WARN] memberlist: Leave but we're not in the node map.")
return nil
}
d := dead{
Incarnation: state.Incarnation,
Node: state.Name,
}
m.deadNode(&d)
// Block until the broadcast goes out
if m.anyAlive() {
var timeoutCh <-chan time.Time
if timeout > 0 {
timeoutCh = time.After(timeout)
}
select {
case <-m.leaveBroadcast:
case <-timeoutCh:
return fmt.Errorf("timeout waiting for leave broadcast")
}
}
}
return nil
}
// Check for any other alive node.
func (m *Memberlist) anyAlive() bool {
m.nodeLock.RLock()
defer m.nodeLock.RUnlock()
for _, n := range m.nodes {
if n.State != stateDead && n.Name != m.config.Name {
return true
}
}
return false
}
// ProtocolVersion returns the protocol version currently in use by
// this memberlist.
func (m *Memberlist) ProtocolVersion() uint8 {
// NOTE: This method exists so that in the future we can control
// any locking if necessary, if we change the protocol version at
// runtime, etc.
return m.config.ProtocolVersion
}
// Shutdown will stop any background maintanence of network activity
// for this memberlist, causing it to appear "dead". A leave message
// will not be broadcasted prior, so the cluster being left will have
// to detect this node's shutdown using probing. If you wish to more
// gracefully exit the cluster, call Leave prior to shutting down.
//
// This method is safe to call multiple times.
func (m *Memberlist) Shutdown() error {
m.startStopLock.Lock()
defer m.startStopLock.Unlock()
if m.shutdown {
return nil
}
m.shutdown = true
close(m.shutdownCh)
m.deschedule()
m.udpListener.Close()
m.tcpListener.Close()
return nil
}

View file

@ -0,0 +1,13 @@
package memberlist
// MergeDelegate is used to involve a client in
// a potential cluster merge operation. Namely, when
// a node does a TCP push/pull (as part of a join),
// the delegate is involved and allowed to cancel the join
// based on custom logic. The merge delegate is NOT invoked
// as part of the push-pull anti-entropy.
type MergeDelegate interface {
// NotifyMerge is invoked when a merge could take place.
// Provides a list of the nodes known by the peer.
NotifyMerge(peers []*Node) (cancel bool)
}

View file

@ -0,0 +1,852 @@
package memberlist
import (
"bufio"
"bytes"
"encoding/binary"
"fmt"
"io"
"net"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/go-msgpack/codec"
)
// This is the minimum and maximum protocol version that we can
// _understand_. We're allowed to speak at any version within this
// range. This range is inclusive.
const (
ProtocolVersionMin uint8 = 1
ProtocolVersionMax = 2
)
// messageType is an integer ID of a type of message that can be received
// on network channels from other members.
type messageType uint8
// The list of available message types.
const (
pingMsg messageType = iota
indirectPingMsg
ackRespMsg
suspectMsg
aliveMsg
deadMsg
pushPullMsg
compoundMsg
userMsg // User mesg, not handled by us
compressMsg
encryptMsg
)
// compressionType is used to specify the compression algorithm
type compressionType uint8
const (
lzwAlgo compressionType = iota
)
const (
MetaMaxSize = 512 // Maximum size for node meta data
compoundHeaderOverhead = 2 // Assumed header overhead
compoundOverhead = 2 // Assumed overhead per entry in compoundHeader
udpBufSize = 65536
udpRecvBuf = 2 * 1024 * 1024
udpSendBuf = 1400
userMsgOverhead = 1
blockingWarning = 10 * time.Millisecond // Warn if a UDP packet takes this long to process
maxPushStateBytes = 10 * 1024 * 1024
)
// ping request sent directly to node
type ping struct {
SeqNo uint32
// Node is sent so the target can verify they are
// the intended recipient. This is to protect again an agent
// restart with a new name.
Node string
}
// indirect ping sent to an indirect ndoe
type indirectPingReq struct {
SeqNo uint32
Target []byte
Port uint16
Node string
}
// ack response is sent for a ping
type ackResp struct {
SeqNo uint32
}
// suspect is broadcast when we suspect a node is dead
type suspect struct {
Incarnation uint32
Node string
From string // Include who is suspecting
}
// alive is broadcast when we know a node is alive.
// Overloaded for nodes joining
type alive struct {
Incarnation uint32
Node string
Addr []byte
Port uint16
Meta []byte
// The versions of the protocol/delegate that are being spoken, order:
// pmin, pmax, pcur, dmin, dmax, dcur
Vsn []uint8
}
// dead is broadcast when we confirm a node is dead
// Overloaded for nodes leaving
type dead struct {
Incarnation uint32
Node string
From string // Include who is suspecting
}
// pushPullHeader is used to inform the
// otherside how many states we are transfering
type pushPullHeader struct {
Nodes int
UserStateLen int // Encodes the byte lengh of user state
Join bool // Is this a join request or a anti-entropy run
}
// pushNodeState is used for pushPullReq when we are
// transfering out node states
type pushNodeState struct {
Name string
Addr []byte
Port uint16
Meta []byte
Incarnation uint32
State nodeStateType
Vsn []uint8 // Protocol versions
}
// compress is used to wrap an underlying payload
// using a specified compression algorithm
type compress struct {
Algo compressionType
Buf []byte
}
// msgHandoff is used to transfer a message between goroutines
type msgHandoff struct {
msgType messageType
buf []byte
from net.Addr
}
// encryptionVersion returns the encryption version to use
func (m *Memberlist) encryptionVersion() encryptionVersion {
switch m.ProtocolVersion() {
case 1:
return 0
default:
return 1
}
}
// setUDPRecvBuf is used to resize the UDP receive window. The function
// attempts to set the read buffer to `udpRecvBuf` but backs off until
// the read buffer can be set.
func setUDPRecvBuf(c *net.UDPConn) {
size := udpRecvBuf
for {
if err := c.SetReadBuffer(size); err == nil {
break
}
size = size / 2
}
}
// tcpListen listens for and handles incoming connections
func (m *Memberlist) tcpListen() {
for {
conn, err := m.tcpListener.AcceptTCP()
if err != nil {
if m.shutdown {
break
}
m.logger.Printf("[ERR] memberlist: Error accepting TCP connection: %s", err)
continue
}
go m.handleConn(conn)
}
}
// handleConn handles a single incoming TCP connection
func (m *Memberlist) handleConn(conn *net.TCPConn) {
m.logger.Printf("[DEBUG] memberlist: Responding to push/pull sync with: %s", conn.RemoteAddr())
defer conn.Close()
metrics.IncrCounter([]string{"memberlist", "tcp", "accept"}, 1)
join, remoteNodes, userState, err := m.readRemoteState(conn)
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to receive remote state: %s", err)
return
}
if err := m.sendLocalState(conn, join); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to push local state: %s", err)
}
if err := m.verifyProtocol(remoteNodes); err != nil {
m.logger.Printf("[ERR] memberlist: Push/pull verification failed: %s", err)
return
}
// Invoke the merge delegate if any
if join && m.config.Merge != nil {
nodes := make([]*Node, len(remoteNodes))
for idx, n := range remoteNodes {
nodes[idx] = &Node{
Name: n.Name,
Addr: n.Addr,
Port: n.Port,
Meta: n.Meta,
PMin: n.Vsn[0],
PMax: n.Vsn[1],
PCur: n.Vsn[2],
DMin: n.Vsn[3],
DMax: n.Vsn[4],
DCur: n.Vsn[5],
}
}
if m.config.Merge.NotifyMerge(nodes) {
m.logger.Printf("[WARN] memberlist: Cluster merge canceled")
return
}
}
// Merge the membership state
m.mergeState(remoteNodes)
// Invoke the delegate for user state
if m.config.Delegate != nil {
m.config.Delegate.MergeRemoteState(userState, join)
}
}
// udpListen listens for and handles incoming UDP packets
func (m *Memberlist) udpListen() {
var n int
var addr net.Addr
var err error
var lastPacket time.Time
for {
// Do a check for potentially blocking operations
if !lastPacket.IsZero() && time.Now().Sub(lastPacket) > blockingWarning {
diff := time.Now().Sub(lastPacket)
m.logger.Printf(
"[DEBUG] memberlist: Potential blocking operation. Last command took %v",
diff)
}
// Create a new buffer
// TODO: Use Sync.Pool eventually
buf := make([]byte, udpBufSize)
// Read a packet
n, addr, err = m.udpListener.ReadFrom(buf)
if err != nil {
if m.shutdown {
break
}
m.logger.Printf("[ERR] memberlist: Error reading UDP packet: %s", err)
continue
}
// Check the length
if n < 1 {
m.logger.Printf("[ERR] memberlist: UDP packet too short (%d bytes). From: %s",
len(buf), addr)
continue
}
// Capture the current time
lastPacket = time.Now()
// Ingest this packet
metrics.IncrCounter([]string{"memberlist", "udp", "received"}, float32(n))
m.ingestPacket(buf[:n], addr)
}
}
func (m *Memberlist) ingestPacket(buf []byte, from net.Addr) {
// Check if encryption is enabled
if m.config.EncryptionEnabled() {
// Decrypt the payload
plain, err := decryptPayload(m.config.Keyring.GetKeys(), buf, nil)
if err != nil {
m.logger.Printf("[ERR] memberlist: Decrypt packet failed: %v", err)
return
}
// Continue processing the plaintext buffer
buf = plain
}
// Handle the command
m.handleCommand(buf, from)
}
func (m *Memberlist) handleCommand(buf []byte, from net.Addr) {
// Decode the message type
msgType := messageType(buf[0])
buf = buf[1:]
// Switch on the msgType
switch msgType {
case compoundMsg:
m.handleCompound(buf, from)
case compressMsg:
m.handleCompressed(buf, from)
case pingMsg:
m.handlePing(buf, from)
case indirectPingMsg:
m.handleIndirectPing(buf, from)
case ackRespMsg:
m.handleAck(buf, from)
case suspectMsg:
fallthrough
case aliveMsg:
fallthrough
case deadMsg:
fallthrough
case userMsg:
select {
case m.handoff <- msgHandoff{msgType, buf, from}:
default:
m.logger.Printf("[WARN] memberlist: UDP handler queue full, dropping message (%d)", msgType)
}
default:
m.logger.Printf("[ERR] memberlist: UDP msg type (%d) not supported. From: %s", msgType, from)
}
}
// udpHandler processes messages received over UDP, but is decoupled
// from the listener to avoid blocking the listener which may cause
// ping/ack messages to be delayed.
func (m *Memberlist) udpHandler() {
for {
select {
case msg := <-m.handoff:
msgType := msg.msgType
buf := msg.buf
from := msg.from
switch msgType {
case suspectMsg:
m.handleSuspect(buf, from)
case aliveMsg:
m.handleAlive(buf, from)
case deadMsg:
m.handleDead(buf, from)
case userMsg:
m.handleUser(buf, from)
default:
m.logger.Printf("[ERR] memberlist: UDP msg type (%d) not supported. From: %s (handler)", msgType, from)
}
case <-m.shutdownCh:
return
}
}
}
func (m *Memberlist) handleCompound(buf []byte, from net.Addr) {
// Decode the parts
trunc, parts, err := decodeCompoundMessage(buf)
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode compound request: %s", err)
return
}
// Log any truncation
if trunc > 0 {
m.logger.Printf("[WARN] memberlist: Compound request had %d truncated messages", trunc)
}
// Handle each message
for _, part := range parts {
m.handleCommand(part, from)
}
}
func (m *Memberlist) handlePing(buf []byte, from net.Addr) {
var p ping
if err := decode(buf, &p); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode ping request: %s", err)
return
}
// If node is provided, verify that it is for us
if p.Node != "" && p.Node != m.config.Name {
m.logger.Printf("[WARN] memberlist: Got ping for unexpected node '%s'", p.Node)
return
}
ack := ackResp{p.SeqNo}
if err := m.encodeAndSendMsg(from, ackRespMsg, &ack); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send ack: %s", err)
}
}
func (m *Memberlist) handleIndirectPing(buf []byte, from net.Addr) {
var ind indirectPingReq
if err := decode(buf, &ind); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode indirect ping request: %s", err)
return
}
// For proto versions < 2, there is no port provided. Mask old
// behavior by using the configured port
if m.ProtocolVersion() < 2 || ind.Port == 0 {
ind.Port = uint16(m.config.BindPort)
}
// Send a ping to the correct host
localSeqNo := m.nextSeqNo()
ping := ping{SeqNo: localSeqNo, Node: ind.Node}
destAddr := &net.UDPAddr{IP: ind.Target, Port: int(ind.Port)}
// Setup a response handler to relay the ack
respHandler := func() {
ack := ackResp{ind.SeqNo}
if err := m.encodeAndSendMsg(from, ackRespMsg, &ack); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to forward ack: %s", err)
}
}
m.setAckHandler(localSeqNo, respHandler, m.config.ProbeTimeout)
// Send the ping
if err := m.encodeAndSendMsg(destAddr, pingMsg, &ping); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send ping: %s", err)
}
}
func (m *Memberlist) handleAck(buf []byte, from net.Addr) {
var ack ackResp
if err := decode(buf, &ack); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode ack response: %s", err)
return
}
m.invokeAckHandler(ack.SeqNo)
}
func (m *Memberlist) handleSuspect(buf []byte, from net.Addr) {
var sus suspect
if err := decode(buf, &sus); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode suspect message: %s", err)
return
}
m.suspectNode(&sus)
}
func (m *Memberlist) handleAlive(buf []byte, from net.Addr) {
var live alive
if err := decode(buf, &live); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode alive message: %s", err)
return
}
// For proto versions < 2, there is no port provided. Mask old
// behavior by using the configured port
if m.ProtocolVersion() < 2 || live.Port == 0 {
live.Port = uint16(m.config.BindPort)
}
m.aliveNode(&live, nil, false)
}
func (m *Memberlist) handleDead(buf []byte, from net.Addr) {
var d dead
if err := decode(buf, &d); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode dead message: %s", err)
return
}
m.deadNode(&d)
}
// handleUser is used to notify channels of incoming user data
func (m *Memberlist) handleUser(buf []byte, from net.Addr) {
d := m.config.Delegate
if d != nil {
d.NotifyMsg(buf)
}
}
// handleCompressed is used to unpack a compressed message
func (m *Memberlist) handleCompressed(buf []byte, from net.Addr) {
// Try to decode the payload
payload, err := decompressPayload(buf)
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decompress payload: %v", err)
return
}
// Recursively handle the payload
m.handleCommand(payload, from)
}
// encodeAndSendMsg is used to combine the encoding and sending steps
func (m *Memberlist) encodeAndSendMsg(to net.Addr, msgType messageType, msg interface{}) error {
out, err := encode(msgType, msg)
if err != nil {
return err
}
if err := m.sendMsg(to, out.Bytes()); err != nil {
return err
}
return nil
}
// sendMsg is used to send a UDP message to another host. It will opportunistically
// create a compoundMsg and piggy back other broadcasts
func (m *Memberlist) sendMsg(to net.Addr, msg []byte) error {
// Check if we can piggy back any messages
bytesAvail := udpSendBuf - len(msg) - compoundHeaderOverhead
if m.config.EncryptionEnabled() {
bytesAvail -= encryptOverhead(m.encryptionVersion())
}
extra := m.getBroadcasts(compoundOverhead, bytesAvail)
// Fast path if nothing to piggypack
if len(extra) == 0 {
return m.rawSendMsg(to, msg)
}
// Join all the messages
msgs := make([][]byte, 0, 1+len(extra))
msgs = append(msgs, msg)
msgs = append(msgs, extra...)
// Create a compound message
compound := makeCompoundMessage(msgs)
// Send the message
return m.rawSendMsg(to, compound.Bytes())
}
// rawSendMsg is used to send a UDP message to another host without modification
func (m *Memberlist) rawSendMsg(to net.Addr, msg []byte) error {
// Check if we have compression enabled
if m.config.EnableCompression {
buf, err := compressPayload(msg)
if err != nil {
m.logger.Printf("[WARN] memberlist: Failed to compress payload: %v", err)
} else {
// Only use compression if it reduced the size
if buf.Len() < len(msg) {
msg = buf.Bytes()
}
}
}
// Check if we have encryption enabled
if m.config.EncryptionEnabled() {
// Encrypt the payload
var buf bytes.Buffer
primaryKey := m.config.Keyring.GetPrimaryKey()
err := encryptPayload(m.encryptionVersion(), primaryKey, msg, nil, &buf)
if err != nil {
m.logger.Printf("[ERR] memberlist: Encryption of message failed: %v", err)
return err
}
msg = buf.Bytes()
}
metrics.IncrCounter([]string{"memberlist", "udp", "sent"}, float32(len(msg)))
_, err := m.udpListener.WriteTo(msg, to)
return err
}
// sendState is used to initiate a push/pull over TCP with a remote node
func (m *Memberlist) sendAndReceiveState(addr []byte, port uint16, join bool) ([]pushNodeState, []byte, error) {
// Attempt to connect
dialer := net.Dialer{Timeout: m.config.TCPTimeout}
dest := net.TCPAddr{IP: addr, Port: int(port)}
conn, err := dialer.Dial("tcp", dest.String())
if err != nil {
return nil, nil, err
}
defer conn.Close()
m.logger.Printf("[DEBUG] memberlist: Initiating push/pull sync with: %s", conn.RemoteAddr())
metrics.IncrCounter([]string{"memberlist", "tcp", "connect"}, 1)
// Send our state
if err := m.sendLocalState(conn, join); err != nil {
return nil, nil, err
}
// Read remote state
_, remote, userState, err := m.readRemoteState(conn)
if err != nil {
err := fmt.Errorf("Reading remote state failed: %v", err)
return nil, nil, err
}
// Return the remote state
return remote, userState, nil
}
// sendLocalState is invoked to send our local state over a tcp connection
func (m *Memberlist) sendLocalState(conn net.Conn, join bool) error {
// Setup a deadline
conn.SetDeadline(time.Now().Add(m.config.TCPTimeout))
// Prepare the local node state
m.nodeLock.RLock()
localNodes := make([]pushNodeState, len(m.nodes))
for idx, n := range m.nodes {
localNodes[idx].Name = n.Name
localNodes[idx].Addr = n.Addr
localNodes[idx].Port = n.Port
localNodes[idx].Incarnation = n.Incarnation
localNodes[idx].State = n.State
localNodes[idx].Meta = n.Meta
localNodes[idx].Vsn = []uint8{
n.PMin, n.PMax, n.PCur,
n.DMin, n.DMax, n.DCur,
}
}
m.nodeLock.RUnlock()
// Get the delegate state
var userData []byte
if m.config.Delegate != nil {
userData = m.config.Delegate.LocalState(join)
}
// Create a bytes buffer writer
bufConn := bytes.NewBuffer(nil)
// Send our node state
header := pushPullHeader{Nodes: len(localNodes), UserStateLen: len(userData), Join: join}
hd := codec.MsgpackHandle{}
enc := codec.NewEncoder(bufConn, &hd)
// Begin state push
if _, err := bufConn.Write([]byte{byte(pushPullMsg)}); err != nil {
return err
}
if err := enc.Encode(&header); err != nil {
return err
}
for i := 0; i < header.Nodes; i++ {
if err := enc.Encode(&localNodes[i]); err != nil {
return err
}
}
// Write the user state as well
if userData != nil {
if _, err := bufConn.Write(userData); err != nil {
return err
}
}
// Get the send buffer
sendBuf := bufConn.Bytes()
// Check if compresion is enabled
if m.config.EnableCompression {
compBuf, err := compressPayload(bufConn.Bytes())
if err != nil {
m.logger.Printf("[ERROR] memberlist: Failed to compress local state: %v", err)
} else {
sendBuf = compBuf.Bytes()
}
}
// Check if encryption is enabled
if m.config.EncryptionEnabled() {
crypt, err := m.encryptLocalState(sendBuf)
if err != nil {
m.logger.Printf("[ERROR] memberlist: Failed to encrypt local state: %v", err)
return err
}
sendBuf = crypt
}
// Write out the entire send buffer
metrics.IncrCounter([]string{"memberlist", "tcp", "sent"}, float32(len(sendBuf)))
if _, err := conn.Write(sendBuf); err != nil {
return err
}
return nil
}
// encryptLocalState is used to help encrypt local state before sending
func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) {
var buf bytes.Buffer
// Write the encryptMsg byte
buf.WriteByte(byte(encryptMsg))
// Write the size of the message
sizeBuf := make([]byte, 4)
encVsn := m.encryptionVersion()
encLen := encryptedLength(encVsn, len(sendBuf))
binary.BigEndian.PutUint32(sizeBuf, uint32(encLen))
buf.Write(sizeBuf)
// Write the encrypted cipher text to the buffer
key := m.config.Keyring.GetPrimaryKey()
err := encryptPayload(encVsn, key, sendBuf, buf.Bytes()[:5], &buf)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// decryptRemoteState is used to help decrypt the remote state
func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) {
// Read in enough to determine message length
cipherText := bytes.NewBuffer(nil)
cipherText.WriteByte(byte(encryptMsg))
_, err := io.CopyN(cipherText, bufConn, 4)
if err != nil {
return nil, err
}
// Ensure we aren't asked to download too much. This is to guard against
// an attack vector where a huge amount of state is sent
moreBytes := binary.BigEndian.Uint32(cipherText.Bytes()[1:5])
if moreBytes > maxPushStateBytes {
return nil, fmt.Errorf("Remote node state is larger than limit (%d)", moreBytes)
}
// Read in the rest of the payload
_, err = io.CopyN(cipherText, bufConn, int64(moreBytes))
if err != nil {
return nil, err
}
// Decrypt the cipherText
dataBytes := cipherText.Bytes()[:5]
cipherBytes := cipherText.Bytes()[5:]
// Decrypt the payload
keys := m.config.Keyring.GetKeys()
return decryptPayload(keys, cipherBytes, dataBytes)
}
// recvRemoteState is used to read the remote state from a connection
func (m *Memberlist) readRemoteState(conn net.Conn) (bool, []pushNodeState, []byte, error) {
// Setup a deadline
conn.SetDeadline(time.Now().Add(m.config.TCPTimeout))
// Created a buffered reader
var bufConn io.Reader = bufio.NewReader(conn)
// Read the message type
buf := [1]byte{0}
if _, err := bufConn.Read(buf[:]); err != nil {
return false, nil, nil, err
}
msgType := messageType(buf[0])
// Check if the message is encrypted
if msgType == encryptMsg {
if !m.config.EncryptionEnabled() {
return false, nil, nil,
fmt.Errorf("Remote state is encrypted and encryption is not configured")
}
plain, err := m.decryptRemoteState(bufConn)
if err != nil {
return false, nil, nil, err
}
// Reset message type and bufConn
msgType = messageType(plain[0])
bufConn = bytes.NewReader(plain[1:])
} else if m.config.EncryptionEnabled() {
return false, nil, nil,
fmt.Errorf("Encryption is configured but remote state is not encrypted")
}
// Get the msgPack decoders
hd := codec.MsgpackHandle{}
dec := codec.NewDecoder(bufConn, &hd)
// Check if we have a compressed message
if msgType == compressMsg {
var c compress
if err := dec.Decode(&c); err != nil {
return false, nil, nil, err
}
decomp, err := decompressBuffer(&c)
if err != nil {
return false, nil, nil, err
}
// Reset the message type
msgType = messageType(decomp[0])
// Create a new bufConn
bufConn = bytes.NewReader(decomp[1:])
// Create a new decoder
dec = codec.NewDecoder(bufConn, &hd)
}
// Quit if not push/pull
if msgType != pushPullMsg {
err := fmt.Errorf("received invalid msgType (%d)", msgType)
return false, nil, nil, err
}
// Read the push/pull header
var header pushPullHeader
if err := dec.Decode(&header); err != nil {
return false, nil, nil, err
}
// Allocate space for the transfer
remoteNodes := make([]pushNodeState, header.Nodes)
// Try to decode all the states
for i := 0; i < header.Nodes; i++ {
if err := dec.Decode(&remoteNodes[i]); err != nil {
return false, remoteNodes, nil, err
}
}
// Read the remote user state into a buffer
var userBuf []byte
if header.UserStateLen > 0 {
userBuf = make([]byte, header.UserStateLen)
bytes, err := io.ReadAtLeast(bufConn, userBuf, header.UserStateLen)
if err == nil && bytes != header.UserStateLen {
err = fmt.Errorf(
"Failed to read full user state (%d / %d)",
bytes, header.UserStateLen)
}
if err != nil {
return false, remoteNodes, nil, err
}
}
// For proto versions < 2, there is no port provided. Mask old
// behavior by using the configured port
for idx := range remoteNodes {
if m.ProtocolVersion() < 2 || remoteNodes[idx].Port == 0 {
remoteNodes[idx].Port = uint16(m.config.BindPort)
}
}
return header.Join, remoteNodes, userBuf, nil
}

View file

@ -0,0 +1,167 @@
package memberlist
import (
"sort"
"sync"
)
// TransmitLimitedQueue is used to queue messages to broadcast to
// the cluster (via gossip) but limits the number of transmits per
// message. It also prioritizes messages with lower transmit counts
// (hence newer messages).
type TransmitLimitedQueue struct {
// NumNodes returns the number of nodes in the cluster. This is
// used to determine the retransmit count, which is calculated
// based on the log of this.
NumNodes func() int
// RetransmitMult is the multiplier used to determine the maximum
// number of retransmissions attempted.
RetransmitMult int
sync.Mutex
bcQueue limitedBroadcasts
}
type limitedBroadcast struct {
transmits int // Number of transmissions attempted.
b Broadcast
}
type limitedBroadcasts []*limitedBroadcast
// Broadcast is something that can be broadcasted via gossip to
// the memberlist cluster.
type Broadcast interface {
// Invalidates checks if enqueuing the current broadcast
// invalidates a previous broadcast
Invalidates(b Broadcast) bool
// Returns a byte form of the message
Message() []byte
// Finished is invoked when the message will no longer
// be broadcast, either due to invalidation or to the
// transmit limit being reached
Finished()
}
// QueueBroadcast is used to enqueue a broadcast
func (q *TransmitLimitedQueue) QueueBroadcast(b Broadcast) {
q.Lock()
defer q.Unlock()
// Check if this message invalidates another
n := len(q.bcQueue)
for i := 0; i < n; i++ {
if b.Invalidates(q.bcQueue[i].b) {
q.bcQueue[i].b.Finished()
copy(q.bcQueue[i:], q.bcQueue[i+1:])
q.bcQueue[n-1] = nil
q.bcQueue = q.bcQueue[:n-1]
n--
}
}
// Append to the queue
q.bcQueue = append(q.bcQueue, &limitedBroadcast{0, b})
}
// GetBroadcasts is used to get a number of broadcasts, up to a byte limit
// and applying a per-message overhead as provided.
func (q *TransmitLimitedQueue) GetBroadcasts(overhead, limit int) [][]byte {
q.Lock()
defer q.Unlock()
// Fast path the default case
if len(q.bcQueue) == 0 {
return nil
}
transmitLimit := retransmitLimit(q.RetransmitMult, q.NumNodes())
bytesUsed := 0
var toSend [][]byte
for i := len(q.bcQueue) - 1; i >= 0; i-- {
// Check if this is within our limits
b := q.bcQueue[i]
msg := b.b.Message()
if bytesUsed+overhead+len(msg) > limit {
continue
}
// Add to slice to send
bytesUsed += overhead + len(msg)
toSend = append(toSend, msg)
// Check if we should stop transmission
b.transmits++
if b.transmits >= transmitLimit {
b.b.Finished()
n := len(q.bcQueue)
q.bcQueue[i], q.bcQueue[n-1] = q.bcQueue[n-1], nil
q.bcQueue = q.bcQueue[:n-1]
}
}
// If we are sending anything, we need to re-sort to deal
// with adjusted transmit counts
if len(toSend) > 0 {
q.bcQueue.Sort()
}
return toSend
}
// NumQueued returns the number of queued messages
func (q *TransmitLimitedQueue) NumQueued() int {
q.Lock()
defer q.Unlock()
return len(q.bcQueue)
}
// Reset clears all the queued messages
func (q *TransmitLimitedQueue) Reset() {
q.Lock()
defer q.Unlock()
for _, b := range q.bcQueue {
b.b.Finished()
}
q.bcQueue = nil
}
// Prune will retain the maxRetain latest messages, and the rest
// will be discarded. This can be used to prevent unbounded queue sizes
func (q *TransmitLimitedQueue) Prune(maxRetain int) {
q.Lock()
defer q.Unlock()
// Do nothing if queue size is less than the limit
n := len(q.bcQueue)
if n < maxRetain {
return
}
// Invalidate the messages we will be removing
for i := 0; i < n-maxRetain; i++ {
q.bcQueue[i].b.Finished()
}
// Move the messages, and retain only the last maxRetain
copy(q.bcQueue[0:], q.bcQueue[n-maxRetain:])
q.bcQueue = q.bcQueue[:maxRetain]
}
func (b limitedBroadcasts) Len() int {
return len(b)
}
func (b limitedBroadcasts) Less(i, j int) bool {
return b[i].transmits < b[j].transmits
}
func (b limitedBroadcasts) Swap(i, j int) {
b[i], b[j] = b[j], b[i]
}
func (b limitedBroadcasts) Sort() {
sort.Sort(sort.Reverse(b))
}

View file

@ -0,0 +1,198 @@
package memberlist
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"fmt"
"io"
)
/*
Encrypted messages are prefixed with an encryptionVersion byte
that is used for us to be able to properly encode/decode. We
currently support the following versions:
0 - AES-GCM 128, using PKCS7 padding
1 - AES-GCM 128, no padding. Padding not needed, caused bloat.
*/
type encryptionVersion uint8
const (
minEncryptionVersion encryptionVersion = 0
maxEncryptionVersion encryptionVersion = 1
)
const (
versionSize = 1
nonceSize = 12
tagSize = 16
maxPadOverhead = 16
blockSize = aes.BlockSize
)
// pkcs7encode is used to pad a byte buffer to a specific block size using
// the PKCS7 algorithm. "Ignores" some bytes to compensate for IV
func pkcs7encode(buf *bytes.Buffer, ignore, blockSize int) {
n := buf.Len() - ignore
more := blockSize - (n % blockSize)
for i := 0; i < more; i++ {
buf.WriteByte(byte(more))
}
}
// pkcs7decode is used to decode a buffer that has been padded
func pkcs7decode(buf []byte, blockSize int) []byte {
if len(buf) == 0 {
panic("Cannot decode a PKCS7 buffer of zero length")
}
n := len(buf)
last := buf[n-1]
n -= int(last)
return buf[:n]
}
// encryptOverhead returns the maximum possible overhead of encryption by version
func encryptOverhead(vsn encryptionVersion) int {
switch vsn {
case 0:
return 45 // Version: 1, IV: 12, Padding: 16, Tag: 16
case 1:
return 29 // Version: 1, IV: 12, Tag: 16
default:
panic("unsupported version")
}
}
// encryptedLength is used to compute the buffer size needed
// for a message of given length
func encryptedLength(vsn encryptionVersion, inp int) int {
// If we are on version 1, there is no padding
if vsn >= 1 {
return versionSize + nonceSize + inp + tagSize
}
// Determine the padding size
padding := blockSize - (inp % blockSize)
// Sum the extra parts to get total size
return versionSize + nonceSize + inp + padding + tagSize
}
// encryptPayload is used to encrypt a message with a given key.
// We make use of AES-128 in GCM mode. New byte buffer is the version,
// nonce, ciphertext and tag
func encryptPayload(vsn encryptionVersion, key []byte, msg []byte, data []byte, dst *bytes.Buffer) error {
// Get the AES block cipher
aesBlock, err := aes.NewCipher(key)
if err != nil {
return err
}
// Get the GCM cipher mode
gcm, err := cipher.NewGCM(aesBlock)
if err != nil {
return err
}
// Grow the buffer to make room for everything
offset := dst.Len()
dst.Grow(encryptedLength(vsn, len(msg)))
// Write the encryption version
dst.WriteByte(byte(vsn))
// Add a random nonce
io.CopyN(dst, rand.Reader, nonceSize)
afterNonce := dst.Len()
// Ensure we are correctly padded (only version 0)
if vsn == 0 {
io.Copy(dst, bytes.NewReader(msg))
pkcs7encode(dst, offset+versionSize+nonceSize, aes.BlockSize)
}
// Encrypt message using GCM
slice := dst.Bytes()[offset:]
nonce := slice[versionSize : versionSize+nonceSize]
// Message source depends on the encryption version.
// Version 0 uses padding, version 1 does not
var src []byte
if vsn == 0 {
src = slice[versionSize+nonceSize:]
} else {
src = msg
}
out := gcm.Seal(nil, nonce, src, data)
// Truncate the plaintext, and write the cipher text
dst.Truncate(afterNonce)
dst.Write(out)
return nil
}
// decryptMessage performs the actual decryption of ciphertext. This is in its
// own function to allow it to be called on all keys easily.
func decryptMessage(key, msg []byte, data []byte) ([]byte, error) {
// Get the AES block cipher
aesBlock, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
// Get the GCM cipher mode
gcm, err := cipher.NewGCM(aesBlock)
if err != nil {
return nil, err
}
// Decrypt the message
nonce := msg[versionSize : versionSize+nonceSize]
ciphertext := msg[versionSize+nonceSize:]
plain, err := gcm.Open(nil, nonce, ciphertext, data)
if err != nil {
return nil, err
}
// Success!
return plain, nil
}
// decryptPayload is used to decrypt a message with a given key,
// and verify it's contents. Any padding will be removed, and a
// slice to the plaintext is returned. Decryption is done IN PLACE!
func decryptPayload(keys [][]byte, msg []byte, data []byte) ([]byte, error) {
// Ensure we have at least one byte
if len(msg) == 0 {
return nil, fmt.Errorf("Cannot decrypt empty payload")
}
// Verify the version
vsn := encryptionVersion(msg[0])
if vsn > maxEncryptionVersion {
return nil, fmt.Errorf("Unsupported encryption version %d", msg[0])
}
// Ensure the length is sane
if len(msg) < encryptedLength(vsn, 0) {
return nil, fmt.Errorf("Payload is too small to decrypt: %d", len(msg))
}
for _, key := range keys {
plain, err := decryptMessage(key, msg, data)
if err == nil {
// Remove the PKCS7 padding for vsn 0
if vsn == 0 {
return pkcs7decode(plain, aes.BlockSize), nil
} else {
return plain, nil
}
}
}
return nil, fmt.Errorf("No installed keys could decrypt the message")
}

View file

@ -0,0 +1,916 @@
package memberlist
import (
"bytes"
"fmt"
"math"
"math/rand"
"net"
"sync/atomic"
"time"
"github.com/armon/go-metrics"
)
type nodeStateType int
const (
stateAlive nodeStateType = iota
stateSuspect
stateDead
)
// Node represents a node in the cluster.
type Node struct {
Name string
Addr net.IP
Port uint16
Meta []byte // Metadata from the delegate for this node.
PMin uint8 // Minimum protocol version this understands
PMax uint8 // Maximum protocol version this understands
PCur uint8 // Current version node is speaking
DMin uint8 // Min protocol version for the delegate to understand
DMax uint8 // Max protocol version for the delegate to understand
DCur uint8 // Current version delegate is speaking
}
// NodeState is used to manage our state view of another node
type nodeState struct {
Node
Incarnation uint32 // Last known incarnation number
State nodeStateType // Current state
StateChange time.Time // Time last state change happened
}
// ackHandler is used to register handlers for incoming acks
type ackHandler struct {
handler func()
timer *time.Timer
}
// Schedule is used to ensure the Tick is performed periodically. This
// function is safe to call multiple times. If the memberlist is already
// scheduled, then it won't do anything.
func (m *Memberlist) schedule() {
m.tickerLock.Lock()
defer m.tickerLock.Unlock()
// If we already have tickers, then don't do anything, since we're
// scheduled
if len(m.tickers) > 0 {
return
}
// Create the stop tick channel, a blocking channel. We close this
// when we should stop the tickers.
stopCh := make(chan struct{})
// Create a new probeTicker
if m.config.ProbeInterval > 0 {
t := time.NewTicker(m.config.ProbeInterval)
go m.triggerFunc(m.config.ProbeInterval, t.C, stopCh, m.probe)
m.tickers = append(m.tickers, t)
}
// Create a push pull ticker if needed
if m.config.PushPullInterval > 0 {
go m.pushPullTrigger(stopCh)
}
// Create a gossip ticker if needed
if m.config.GossipInterval > 0 && m.config.GossipNodes > 0 {
t := time.NewTicker(m.config.GossipInterval)
go m.triggerFunc(m.config.GossipInterval, t.C, stopCh, m.gossip)
m.tickers = append(m.tickers, t)
}
// If we made any tickers, then record the stopTick channel for
// later.
if len(m.tickers) > 0 {
m.stopTick = stopCh
}
}
// triggerFunc is used to trigger a function call each time a
// message is received until a stop tick arrives.
func (m *Memberlist) triggerFunc(stagger time.Duration, C <-chan time.Time, stop <-chan struct{}, f func()) {
// Use a random stagger to avoid syncronizing
randStagger := time.Duration(uint64(rand.Int63()) % uint64(stagger))
select {
case <-time.After(randStagger):
case <-stop:
return
}
for {
select {
case <-C:
f()
case <-stop:
return
}
}
}
// pushPullTrigger is used to periodically trigger a push/pull until
// a stop tick arrives. We don't use triggerFunc since the push/pull
// timer is dynamically scaled based on cluster size to avoid network
// saturation
func (m *Memberlist) pushPullTrigger(stop <-chan struct{}) {
interval := m.config.PushPullInterval
// Use a random stagger to avoid syncronizing
randStagger := time.Duration(uint64(rand.Int63()) % uint64(interval))
select {
case <-time.After(randStagger):
case <-stop:
return
}
// Tick using a dynamic timer
for {
m.nodeLock.RLock()
tickTime := pushPullScale(interval, len(m.nodes))
m.nodeLock.RUnlock()
select {
case <-time.After(tickTime):
m.pushPull()
case <-stop:
return
}
}
}
// Deschedule is used to stop the background maintenence. This is safe
// to call multiple times.
func (m *Memberlist) deschedule() {
m.tickerLock.Lock()
defer m.tickerLock.Unlock()
// If we have no tickers, then we aren't scheduled.
if len(m.tickers) == 0 {
return
}
// Close the stop channel so all the ticker listeners stop.
close(m.stopTick)
// Explicitly stop all the tickers themselves so they don't take
// up any more resources, and get rid of the list.
for _, t := range m.tickers {
t.Stop()
}
m.tickers = nil
}
// Tick is used to perform a single round of failure detection and gossip
func (m *Memberlist) probe() {
// Track the number of indexes we've considered probing
numCheck := 0
START:
m.nodeLock.RLock()
// Make sure we don't wrap around infinitely
if numCheck >= len(m.nodes) {
m.nodeLock.RUnlock()
return
}
// Handle the wrap around case
if m.probeIndex >= len(m.nodes) {
m.nodeLock.RUnlock()
m.resetNodes()
m.probeIndex = 0
numCheck++
goto START
}
// Determine if we should probe this node
skip := false
var node nodeState
node = *m.nodes[m.probeIndex]
if node.Name == m.config.Name {
skip = true
} else if node.State == stateDead {
skip = true
}
// Potentially skip
m.nodeLock.RUnlock()
m.probeIndex++
if skip {
numCheck++
goto START
}
// Probe the specific node
m.probeNode(&node)
}
// probeNode handles a single round of failure checking on a node
func (m *Memberlist) probeNode(node *nodeState) {
defer metrics.MeasureSince([]string{"memberlist", "probeNode"}, time.Now())
// Send a ping to the node
ping := ping{SeqNo: m.nextSeqNo(), Node: node.Name}
destAddr := &net.UDPAddr{IP: node.Addr, Port: int(node.Port)}
// Setup an ack handler
ackCh := make(chan bool, m.config.IndirectChecks+1)
m.setAckChannel(ping.SeqNo, ackCh, m.config.ProbeInterval)
// Send the ping message
if err := m.encodeAndSendMsg(destAddr, pingMsg, &ping); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send ping: %s", err)
return
}
// Wait for response or round-trip-time
select {
case v := <-ackCh:
if v == true {
return
}
// As an edge case, if we get a timeout, we need to re-enqueue it
// here to break out of the select below
if v == false {
ackCh <- v
}
case <-time.After(m.config.ProbeTimeout):
}
// Get some random live nodes
m.nodeLock.RLock()
excludes := []string{m.config.Name, node.Name}
kNodes := kRandomNodes(m.config.IndirectChecks, excludes, m.nodes)
m.nodeLock.RUnlock()
// Attempt an indirect ping
ind := indirectPingReq{SeqNo: ping.SeqNo, Target: node.Addr, Port: node.Port, Node: node.Name}
for _, peer := range kNodes {
destAddr := &net.UDPAddr{IP: peer.Addr, Port: int(peer.Port)}
if err := m.encodeAndSendMsg(destAddr, indirectPingMsg, &ind); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send indirect ping: %s", err)
}
}
// Wait for the acks or timeout
select {
case v := <-ackCh:
if v == true {
return
}
}
// No acks received from target, suspect
m.logger.Printf("[INFO] memberlist: Suspect %s has failed, no acks received", node.Name)
s := suspect{Incarnation: node.Incarnation, Node: node.Name, From: m.config.Name}
m.suspectNode(&s)
}
// resetNodes is used when the tick wraps around. It will reap the
// dead nodes and shuffle the node list.
func (m *Memberlist) resetNodes() {
m.nodeLock.Lock()
defer m.nodeLock.Unlock()
// Move the dead nodes
deadIdx := moveDeadNodes(m.nodes)
// Deregister the dead nodes
for i := deadIdx; i < len(m.nodes); i++ {
delete(m.nodeMap, m.nodes[i].Name)
m.nodes[i] = nil
}
// Trim the nodes to exclude the dead nodes
m.nodes = m.nodes[0:deadIdx]
// Shuffle live nodes
shuffleNodes(m.nodes)
}
// gossip is invoked every GossipInterval period to broadcast our gossip
// messages to a few random nodes.
func (m *Memberlist) gossip() {
defer metrics.MeasureSince([]string{"memberlist", "gossip"}, time.Now())
// Get some random live nodes
m.nodeLock.RLock()
excludes := []string{m.config.Name}
kNodes := kRandomNodes(m.config.GossipNodes, excludes, m.nodes)
m.nodeLock.RUnlock()
// Compute the bytes available
bytesAvail := udpSendBuf - compoundHeaderOverhead
if m.config.EncryptionEnabled() {
bytesAvail -= encryptOverhead(m.encryptionVersion())
}
for _, node := range kNodes {
// Get any pending broadcasts
msgs := m.getBroadcasts(compoundOverhead, bytesAvail)
if len(msgs) == 0 {
return
}
// Create a compound message
compound := makeCompoundMessage(msgs)
// Send the compound message
destAddr := &net.UDPAddr{IP: node.Addr, Port: int(node.Port)}
if err := m.rawSendMsg(destAddr, compound.Bytes()); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send gossip to %s: %s", destAddr, err)
}
}
}
// pushPull is invoked periodically to randomly perform a complete state
// exchange. Used to ensure a high level of convergence, but is also
// reasonably expensive as the entire state of this node is exchanged
// with the other node.
func (m *Memberlist) pushPull() {
// Get a random live node
m.nodeLock.RLock()
excludes := []string{m.config.Name}
nodes := kRandomNodes(1, excludes, m.nodes)
m.nodeLock.RUnlock()
// If no nodes, bail
if len(nodes) == 0 {
return
}
node := nodes[0]
// Attempt a push pull
if err := m.pushPullNode(node.Addr, node.Port, false); err != nil {
m.logger.Printf("[ERR] memberlist: Push/Pull with %s failed: %s", node.Name, err)
}
}
// pushPullNode does a complete state exchange with a specific node.
func (m *Memberlist) pushPullNode(addr []byte, port uint16, join bool) error {
defer metrics.MeasureSince([]string{"memberlist", "pushPullNode"}, time.Now())
// Attempt to send and receive with the node
remote, userState, err := m.sendAndReceiveState(addr, port, join)
if err != nil {
return err
}
if err := m.verifyProtocol(remote); err != nil {
return err
}
// Invoke the merge delegate if any
if join && m.config.Merge != nil {
nodes := make([]*Node, len(remote))
for idx, n := range remote {
nodes[idx] = &Node{
Name: n.Name,
Addr: n.Addr,
Port: n.Port,
Meta: n.Meta,
PMin: n.Vsn[0],
PMax: n.Vsn[1],
PCur: n.Vsn[2],
DMin: n.Vsn[3],
DMax: n.Vsn[4],
DCur: n.Vsn[5],
}
}
if m.config.Merge.NotifyMerge(nodes) {
m.logger.Printf("[WARN] memberlist: Cluster merge canceled")
return fmt.Errorf("Merge canceled")
}
}
// Merge the state
m.mergeState(remote)
// Invoke the delegate
if m.config.Delegate != nil {
m.config.Delegate.MergeRemoteState(userState, join)
}
return nil
}
// verifyProtocol verifies that all the remote nodes can speak with our
// nodes and vice versa on both the core protocol as well as the
// delegate protocol level.
//
// The verification works by finding the maximum minimum and
// minimum maximum understood protocol and delegate versions. In other words,
// it finds the common denominator of protocol and delegate version ranges
// for the entire cluster.
//
// After this, it goes through the entire cluster (local and remote) and
// verifies that everyone's speaking protocol versions satisfy this range.
// If this passes, it means that every node can understand each other.
func (m *Memberlist) verifyProtocol(remote []pushNodeState) error {
m.nodeLock.RLock()
defer m.nodeLock.RUnlock()
// Maximum minimum understood and minimum maximum understood for both
// the protocol and delegate versions. We use this to verify everyone
// can be understood.
var maxpmin, minpmax uint8
var maxdmin, mindmax uint8
minpmax = math.MaxUint8
mindmax = math.MaxUint8
for _, rn := range remote {
// If the node isn't alive, then skip it
if rn.State != stateAlive {
continue
}
// Skip nodes that don't have versions set, it just means
// their version is zero.
if len(rn.Vsn) == 0 {
continue
}
if rn.Vsn[0] > maxpmin {
maxpmin = rn.Vsn[0]
}
if rn.Vsn[1] < minpmax {
minpmax = rn.Vsn[1]
}
if rn.Vsn[3] > maxdmin {
maxdmin = rn.Vsn[3]
}
if rn.Vsn[4] < mindmax {
mindmax = rn.Vsn[4]
}
}
for _, n := range m.nodes {
// Ignore non-alive nodes
if n.State != stateAlive {
continue
}
if n.PMin > maxpmin {
maxpmin = n.PMin
}
if n.PMax < minpmax {
minpmax = n.PMax
}
if n.DMin > maxdmin {
maxdmin = n.DMin
}
if n.DMax < mindmax {
mindmax = n.DMax
}
}
// Now that we definitively know the minimum and maximum understood
// version that satisfies the whole cluster, we verify that every
// node in the cluster satisifies this.
for _, n := range remote {
var nPCur, nDCur uint8
if len(n.Vsn) > 0 {
nPCur = n.Vsn[2]
nDCur = n.Vsn[5]
}
if nPCur < maxpmin || nPCur > minpmax {
return fmt.Errorf(
"Node '%s' protocol version (%d) is incompatible: [%d, %d]",
n.Name, nPCur, maxpmin, minpmax)
}
if nDCur < maxdmin || nDCur > mindmax {
return fmt.Errorf(
"Node '%s' delegate protocol version (%d) is incompatible: [%d, %d]",
n.Name, nDCur, maxdmin, mindmax)
}
}
for _, n := range m.nodes {
nPCur := n.PCur
nDCur := n.DCur
if nPCur < maxpmin || nPCur > minpmax {
return fmt.Errorf(
"Node '%s' protocol version (%d) is incompatible: [%d, %d]",
n.Name, nPCur, maxpmin, minpmax)
}
if nDCur < maxdmin || nDCur > mindmax {
return fmt.Errorf(
"Node '%s' delegate protocol version (%d) is incompatible: [%d, %d]",
n.Name, nDCur, maxdmin, mindmax)
}
}
return nil
}
// nextSeqNo returns a usable sequence number in a thread safe way
func (m *Memberlist) nextSeqNo() uint32 {
return atomic.AddUint32(&m.sequenceNum, 1)
}
// nextIncarnation returns the next incarnation number in a thread safe way
func (m *Memberlist) nextIncarnation() uint32 {
return atomic.AddUint32(&m.incarnation, 1)
}
// setAckChannel is used to attach a channel to receive a message when
// an ack with a given sequence number is received. The channel gets sent
// false on timeout
func (m *Memberlist) setAckChannel(seqNo uint32, ch chan bool, timeout time.Duration) {
// Create a handler function
handler := func() {
select {
case ch <- true:
default:
}
}
// Add the handler
ah := &ackHandler{handler, nil}
m.ackLock.Lock()
m.ackHandlers[seqNo] = ah
m.ackLock.Unlock()
// Setup a reaping routing
ah.timer = time.AfterFunc(timeout, func() {
m.ackLock.Lock()
delete(m.ackHandlers, seqNo)
m.ackLock.Unlock()
select {
case ch <- false:
default:
}
})
}
// setAckHandler is used to attach a handler to be invoked when an
// ack with a given sequence number is received. If a timeout is reached,
// the handler is deleted
func (m *Memberlist) setAckHandler(seqNo uint32, handler func(), timeout time.Duration) {
// Add the handler
ah := &ackHandler{handler, nil}
m.ackLock.Lock()
m.ackHandlers[seqNo] = ah
m.ackLock.Unlock()
// Setup a reaping routing
ah.timer = time.AfterFunc(timeout, func() {
m.ackLock.Lock()
delete(m.ackHandlers, seqNo)
m.ackLock.Unlock()
})
}
// Invokes an Ack handler if any is associated, and reaps the handler immediately
func (m *Memberlist) invokeAckHandler(seqNo uint32) {
m.ackLock.Lock()
ah, ok := m.ackHandlers[seqNo]
delete(m.ackHandlers, seqNo)
m.ackLock.Unlock()
if !ok {
return
}
ah.timer.Stop()
ah.handler()
}
// aliveNode is invoked by the network layer when we get a message about a
// live node.
func (m *Memberlist) aliveNode(a *alive, notify chan struct{}, bootstrap bool) {
m.nodeLock.Lock()
defer m.nodeLock.Unlock()
state, ok := m.nodeMap[a.Node]
// It is possible that during a Leave(), there is already an aliveMsg
// in-queue to be processed but blocked by the locks above. If we let
// that aliveMsg process, it'll cause us to re-join the cluster. This
// ensures that we don't.
if m.leave && a.Node == m.config.Name {
return
}
// Check if we've never seen this node before, and if not, then
// store this node in our node map.
if !ok {
state = &nodeState{
Node: Node{
Name: a.Node,
Addr: a.Addr,
Port: a.Port,
Meta: a.Meta,
},
State: stateDead,
}
// Add to map
m.nodeMap[a.Node] = state
// Get a random offset. This is important to ensure
// the failure detection bound is low on average. If all
// nodes did an append, failure detection bound would be
// very high.
n := len(m.nodes)
offset := randomOffset(n)
// Add at the end and swap with the node at the offset
m.nodes = append(m.nodes, state)
m.nodes[offset], m.nodes[n] = m.nodes[n], m.nodes[offset]
}
// Check if this address is different than the existing node
if !bytes.Equal([]byte(state.Addr), a.Addr) || state.Port != a.Port {
m.logger.Printf("[ERR] memberlist: Conflicting address for %s. Mine: %v:%d Theirs: %v:%d",
state.Name, state.Addr, state.Port, net.IP(a.Addr), a.Port)
// Inform the conflict delegate if provided
if m.config.Conflict != nil {
other := Node{
Name: a.Node,
Addr: a.Addr,
Port: a.Port,
Meta: a.Meta,
}
m.config.Conflict.NotifyConflict(&state.Node, &other)
}
return
}
// Bail if the incarnation number is older, and this is not about us
isLocalNode := state.Name == m.config.Name
if a.Incarnation <= state.Incarnation && !isLocalNode {
return
}
// Bail if strictly less and this is about us
if a.Incarnation < state.Incarnation && isLocalNode {
return
}
// Update metrics
metrics.IncrCounter([]string{"memberlist", "msg", "alive"}, 1)
// Store the old state and meta data
oldState := state.State
oldMeta := state.Meta
// If this is us we need to refute, otherwise re-broadcast
if !bootstrap && isLocalNode {
// Compute the version vector
versions := []uint8{
state.PMin, state.PMax, state.PCur,
state.DMin, state.DMax, state.DCur,
}
// If the Incarnation is the same, we need special handling, since it
// possible for the following situation to happen:
// 1) Start with configuration C, join cluster
// 2) Hard fail / Kill / Shutdown
// 3) Restart with configuration C', join cluster
//
// In this case, other nodes and the local node see the same incarnation,
// but the values may not be the same. For this reason, we always
// need to do an equality check for this Incarnation. In most cases,
// we just ignore, but we may need to refute.
//
if a.Incarnation == state.Incarnation &&
bytes.Equal(a.Meta, state.Meta) &&
bytes.Equal(a.Vsn, versions) {
return
}
inc := m.nextIncarnation()
for a.Incarnation >= inc {
inc = m.nextIncarnation()
}
state.Incarnation = inc
a := alive{
Incarnation: inc,
Node: state.Name,
Addr: state.Addr,
Port: state.Port,
Meta: state.Meta,
Vsn: versions,
}
m.encodeBroadcastNotify(a.Node, aliveMsg, a, notify)
m.logger.Printf("[WARN] memberlist: Refuting an alive message")
} else {
m.encodeBroadcastNotify(a.Node, aliveMsg, a, notify)
// Update protocol versions if it arrived
if len(a.Vsn) > 0 {
state.PMin = a.Vsn[0]
state.PMax = a.Vsn[1]
state.PCur = a.Vsn[2]
state.DMin = a.Vsn[3]
state.DMax = a.Vsn[4]
state.DCur = a.Vsn[5]
}
// Update the state and incarnation number
state.Incarnation = a.Incarnation
state.Meta = a.Meta
if state.State != stateAlive {
state.State = stateAlive
state.StateChange = time.Now()
}
}
// Notify the delegate of any relevant updates
if m.config.Events != nil {
if oldState == stateDead {
// if Dead -> Alive, notify of join
m.config.Events.NotifyJoin(&state.Node)
} else if !bytes.Equal(oldMeta, state.Meta) {
// if Meta changed, trigger an update notification
m.config.Events.NotifyUpdate(&state.Node)
}
}
}
// suspectNode is invoked by the network layer when we get a message
// about a suspect node
func (m *Memberlist) suspectNode(s *suspect) {
m.nodeLock.Lock()
defer m.nodeLock.Unlock()
state, ok := m.nodeMap[s.Node]
// If we've never heard about this node before, ignore it
if !ok {
return
}
// Ignore old incarnation numbers
if s.Incarnation < state.Incarnation {
return
}
// Ignore non-alive nodes
if state.State != stateAlive {
return
}
// If this is us we need to refute, otherwise re-broadcast
if state.Name == m.config.Name {
inc := m.nextIncarnation()
for s.Incarnation >= inc {
inc = m.nextIncarnation()
}
state.Incarnation = inc
a := alive{
Incarnation: inc,
Node: state.Name,
Addr: state.Addr,
Port: state.Port,
Meta: state.Meta,
Vsn: []uint8{
state.PMin, state.PMax, state.PCur,
state.DMin, state.DMax, state.DCur,
},
}
m.encodeAndBroadcast(s.Node, aliveMsg, a)
m.logger.Printf("[WARN] memberlist: Refuting a suspect message (from: %s)", s.From)
return // Do not mark ourself suspect
} else {
m.encodeAndBroadcast(s.Node, suspectMsg, s)
}
// Update metrics
metrics.IncrCounter([]string{"memberlist", "msg", "suspect"}, 1)
// Update the state
state.Incarnation = s.Incarnation
state.State = stateSuspect
changeTime := time.Now()
state.StateChange = changeTime
// Setup a timeout for this
timeout := suspicionTimeout(m.config.SuspicionMult, len(m.nodes), m.config.ProbeInterval)
time.AfterFunc(timeout, func() {
m.nodeLock.Lock()
state, ok := m.nodeMap[s.Node]
timeout := ok && state.State == stateSuspect && state.StateChange == changeTime
m.nodeLock.Unlock()
if timeout {
m.suspectTimeout(state)
}
})
}
// suspectTimeout is invoked when a suspect timeout has occurred
func (m *Memberlist) suspectTimeout(n *nodeState) {
// Construct a dead message
m.logger.Printf("[INFO] memberlist: Marking %s as failed, suspect timeout reached", n.Name)
d := dead{Incarnation: n.Incarnation, Node: n.Name, From: m.config.Name}
m.deadNode(&d)
}
// deadNode is invoked by the network layer when we get a message
// about a dead node
func (m *Memberlist) deadNode(d *dead) {
m.nodeLock.Lock()
defer m.nodeLock.Unlock()
state, ok := m.nodeMap[d.Node]
// If we've never heard about this node before, ignore it
if !ok {
return
}
// Ignore old incarnation numbers
if d.Incarnation < state.Incarnation {
return
}
// Ignore if node is already dead
if state.State == stateDead {
return
}
// Check if this is us
if state.Name == m.config.Name {
// If we are not leaving we need to refute
if !m.leave {
inc := m.nextIncarnation()
for d.Incarnation >= inc {
inc = m.nextIncarnation()
}
state.Incarnation = inc
a := alive{
Incarnation: inc,
Node: state.Name,
Addr: state.Addr,
Port: state.Port,
Meta: state.Meta,
Vsn: []uint8{
state.PMin, state.PMax, state.PCur,
state.DMin, state.DMax, state.DCur,
},
}
m.encodeAndBroadcast(d.Node, aliveMsg, a)
m.logger.Printf("[WARN] memberlist: Refuting a dead message (from: %s)", d.From)
return // Do not mark ourself dead
}
// If we are leaving, we broadcast and wait
m.encodeBroadcastNotify(d.Node, deadMsg, d, m.leaveBroadcast)
} else {
m.encodeAndBroadcast(d.Node, deadMsg, d)
}
// Update metrics
metrics.IncrCounter([]string{"memberlist", "msg", "dead"}, 1)
// Update the state
state.Incarnation = d.Incarnation
state.State = stateDead
state.StateChange = time.Now()
// Notify of death
if m.config.Events != nil {
m.config.Events.NotifyLeave(&state.Node)
}
}
// mergeState is invoked by the network layer when we get a Push/Pull
// state transfer
func (m *Memberlist) mergeState(remote []pushNodeState) {
for _, r := range remote {
switch r.State {
case stateAlive:
a := alive{
Incarnation: r.Incarnation,
Node: r.Name,
Addr: r.Addr,
Port: r.Port,
Meta: r.Meta,
Vsn: r.Vsn,
}
m.aliveNode(&a, nil, false)
case stateDead:
// If the remote node belives a node is dead, we prefer to
// suspect that node instead of declaring it dead instantly
fallthrough
case stateSuspect:
s := suspect{Incarnation: r.Incarnation, Node: r.Name, From: m.config.Name}
m.suspectNode(&s)
}
}
}

View file

@ -0,0 +1,6 @@
# TODO
* Dynamic RTT discovery
* Compute 99th percentile for ping/ack
* Better lower bound for ping/ack, faster failure detection
* Dynamic MTU discovery
* Prevent lost updates, increases efficiency

View file

@ -0,0 +1,325 @@
package memberlist
import (
"bytes"
"compress/lzw"
"encoding/binary"
"fmt"
"github.com/hashicorp/go-msgpack/codec"
"io"
"math"
"math/rand"
"net"
"time"
)
// pushPullScale is the minimum number of nodes
// before we start scaling the push/pull timing. The scale
// effect is the log2(Nodes) - log2(pushPullScale). This means
// that the 33rd node will cause us to double the interval,
// while the 65th will triple it.
const pushPullScaleThreshold = 32
/*
* Contains an entry for each private block:
* 10.0.0.0/8
* 172.16.0.0/12
* 192.168/16
*/
var privateBlocks []*net.IPNet
var loopbackBlock *net.IPNet
const (
// Constant litWidth 2-8
lzwLitWidth = 8
)
func init() {
// Seed the random number generator
rand.Seed(time.Now().UnixNano())
// Add each private block
privateBlocks = make([]*net.IPNet, 3)
_, block, err := net.ParseCIDR("10.0.0.0/8")
if err != nil {
panic(fmt.Sprintf("Bad cidr. Got %v", err))
}
privateBlocks[0] = block
_, block, err = net.ParseCIDR("172.16.0.0/12")
if err != nil {
panic(fmt.Sprintf("Bad cidr. Got %v", err))
}
privateBlocks[1] = block
_, block, err = net.ParseCIDR("192.168.0.0/16")
if err != nil {
panic(fmt.Sprintf("Bad cidr. Got %v", err))
}
privateBlocks[2] = block
_, block, err = net.ParseCIDR("127.0.0.0/8")
if err != nil {
panic(fmt.Sprintf("Bad cidr. Got %v", err))
}
loopbackBlock = block
}
// Decode reverses the encode operation on a byte slice input
func decode(buf []byte, out interface{}) error {
r := bytes.NewReader(buf)
hd := codec.MsgpackHandle{}
dec := codec.NewDecoder(r, &hd)
return dec.Decode(out)
}
// Encode writes an encoded object to a new bytes buffer
func encode(msgType messageType, in interface{}) (*bytes.Buffer, error) {
buf := bytes.NewBuffer(nil)
buf.WriteByte(uint8(msgType))
hd := codec.MsgpackHandle{}
enc := codec.NewEncoder(buf, &hd)
err := enc.Encode(in)
return buf, err
}
// Returns a random offset between 0 and n
func randomOffset(n int) int {
if n == 0 {
return 0
}
return int(rand.Uint32() % uint32(n))
}
// suspicionTimeout computes the timeout that should be used when
// a node is suspected
func suspicionTimeout(suspicionMult, n int, interval time.Duration) time.Duration {
nodeScale := math.Ceil(math.Log10(float64(n + 1)))
timeout := time.Duration(suspicionMult) * time.Duration(nodeScale) * interval
return timeout
}
// retransmitLimit computes the limit of retransmissions
func retransmitLimit(retransmitMult, n int) int {
nodeScale := math.Ceil(math.Log10(float64(n + 1)))
limit := retransmitMult * int(nodeScale)
return limit
}
// shuffleNodes randomly shuffles the input nodes
func shuffleNodes(nodes []*nodeState) {
for i := range nodes {
j := rand.Intn(i + 1)
nodes[i], nodes[j] = nodes[j], nodes[i]
}
}
// pushPushScale is used to scale the time interval at which push/pull
// syncs take place. It is used to prevent network saturation as the
// cluster size grows
func pushPullScale(interval time.Duration, n int) time.Duration {
// Don't scale until we cross the threshold
if n <= pushPullScaleThreshold {
return interval
}
multiplier := math.Ceil(math.Log2(float64(n))-math.Log2(pushPullScaleThreshold)) + 1.0
return time.Duration(multiplier) * interval
}
// moveDeadNodes moves all the nodes in the dead state
// to the end of the slice and returns the index of the first dead node.
func moveDeadNodes(nodes []*nodeState) int {
numDead := 0
n := len(nodes)
for i := 0; i < n-numDead; i++ {
if nodes[i].State != stateDead {
continue
}
// Move this node to the end
nodes[i], nodes[n-numDead-1] = nodes[n-numDead-1], nodes[i]
numDead++
i--
}
return n - numDead
}
// kRandomNodes is used to select up to k random nodes, excluding a given
// node and any non-alive nodes. It is possible that less than k nodes are returned.
func kRandomNodes(k int, excludes []string, nodes []*nodeState) []*nodeState {
n := len(nodes)
kNodes := make([]*nodeState, 0, k)
OUTER:
// Probe up to 3*n times, with large n this is not necessary
// since k << n, but with small n we want search to be
// exhaustive
for i := 0; i < 3*n && len(kNodes) < k; i++ {
// Get random node
idx := randomOffset(n)
node := nodes[idx]
// Exclude node if match
for _, exclude := range excludes {
if node.Name == exclude {
continue OUTER
}
}
// Exclude if not alive
if node.State != stateAlive {
continue
}
// Check if we have this node already
for j := 0; j < len(kNodes); j++ {
if node == kNodes[j] {
continue OUTER
}
}
// Append the node
kNodes = append(kNodes, node)
}
return kNodes
}
// makeCompoundMessage takes a list of messages and generates
// a single compound message containing all of them
func makeCompoundMessage(msgs [][]byte) *bytes.Buffer {
// Create a local buffer
buf := bytes.NewBuffer(nil)
// Write out the type
buf.WriteByte(uint8(compoundMsg))
// Write out the number of message
buf.WriteByte(uint8(len(msgs)))
// Add the message lengths
for _, m := range msgs {
binary.Write(buf, binary.BigEndian, uint16(len(m)))
}
// Append the messages
for _, m := range msgs {
buf.Write(m)
}
return buf
}
// decodeCompoundMessage splits a compound message and returns
// the slices of individual messages. Also returns the number
// of truncated messages and any potential error
func decodeCompoundMessage(buf []byte) (trunc int, parts [][]byte, err error) {
if len(buf) < 1 {
err = fmt.Errorf("missing compound length byte")
return
}
numParts := uint8(buf[0])
buf = buf[1:]
// Check we have enough bytes
if len(buf) < int(numParts*2) {
err = fmt.Errorf("truncated len slice")
return
}
// Decode the lengths
lengths := make([]uint16, numParts)
for i := 0; i < int(numParts); i++ {
lengths[i] = binary.BigEndian.Uint16(buf[i*2 : i*2+2])
}
buf = buf[numParts*2:]
// Split each message
for idx, msgLen := range lengths {
if len(buf) < int(msgLen) {
trunc = int(numParts) - idx
return
}
// Extract the slice, seek past on the buffer
slice := buf[:msgLen]
buf = buf[msgLen:]
parts = append(parts, slice)
}
return
}
// Returns if the given IP is in a private block
func isPrivateIP(ip_str string) bool {
ip := net.ParseIP(ip_str)
for _, priv := range privateBlocks {
if priv.Contains(ip) {
return true
}
}
return false
}
// Returns if the given IP is in a loopback block
func isLoopbackIP(ip_str string) bool {
ip := net.ParseIP(ip_str)
return loopbackBlock.Contains(ip)
}
// compressPayload takes an opaque input buffer, compresses it
// and wraps it in a compress{} message that is encoded.
func compressPayload(inp []byte) (*bytes.Buffer, error) {
var buf bytes.Buffer
compressor := lzw.NewWriter(&buf, lzw.LSB, lzwLitWidth)
_, err := compressor.Write(inp)
if err != nil {
return nil, err
}
// Ensure we flush everything out
if err := compressor.Close(); err != nil {
return nil, err
}
// Create a compressed message
c := compress{
Algo: lzwAlgo,
Buf: buf.Bytes(),
}
return encode(compressMsg, &c)
}
// decompressPayload is used to unpack an encoded compress{}
// message and return its payload uncompressed
func decompressPayload(msg []byte) ([]byte, error) {
// Decode the message
var c compress
if err := decode(msg, &c); err != nil {
return nil, err
}
return decompressBuffer(&c)
}
// decompressBuffer is used to decompress the buffer of
// a single compress message, handling multiple algorithms
func decompressBuffer(c *compress) ([]byte, error) {
// Verify the algorithm
if c.Algo != lzwAlgo {
return nil, fmt.Errorf("Cannot decompress unknown algorithm %d", c.Algo)
}
// Create a uncompressor
uncomp := lzw.NewReader(bytes.NewReader(c.Buf), lzw.LSB, lzwLitWidth)
defer uncomp.Close()
// Read all the data
var b bytes.Buffer
_, err := io.Copy(&b, uncomp)
if err != nil {
return nil, err
}
// Return the uncompressed bytes
return b.Bytes(), nil
}

View file

@ -0,0 +1,27 @@
package serf
import (
"github.com/hashicorp/memberlist"
)
// broadcast is an implementation of memberlist.Broadcast and is used
// to manage broadcasts across the memberlist channel that are related
// only to Serf.
type broadcast struct {
msg []byte
notify chan<- struct{}
}
func (b *broadcast) Invalidates(other memberlist.Broadcast) bool {
return false
}
func (b *broadcast) Message() []byte {
return b.msg
}
func (b *broadcast) Finished() {
if b.notify != nil {
close(b.notify)
}
}

View file

@ -0,0 +1,80 @@
package serf
import (
"time"
)
// coalescer is a simple interface that must be implemented to be
// used inside of a coalesceLoop
type coalescer interface {
// Can the coalescer handle this event, if not it is
// directly passed through to the destination channel
Handle(Event) bool
// Invoked to coalesce the given event
Coalesce(Event)
// Invoked to flush the coalesced events
Flush(outChan chan<- Event)
}
// coalescedEventCh returns an event channel where the events are coalesced
// using the given coalescer.
func coalescedEventCh(outCh chan<- Event, shutdownCh <-chan struct{},
cPeriod time.Duration, qPeriod time.Duration, c coalescer) chan<- Event {
inCh := make(chan Event, 1024)
go coalesceLoop(inCh, outCh, shutdownCh, cPeriod, qPeriod, c)
return inCh
}
// coalesceLoop is a simple long-running routine that manages the high-level
// flow of coalescing based on quiescence and a maximum quantum period.
func coalesceLoop(inCh <-chan Event, outCh chan<- Event, shutdownCh <-chan struct{},
coalescePeriod time.Duration, quiescentPeriod time.Duration, c coalescer) {
var quiescent <-chan time.Time
var quantum <-chan time.Time
shutdown := false
INGEST:
// Reset the timers
quantum = nil
quiescent = nil
for {
select {
case e := <-inCh:
// Ignore any non handled events
if !c.Handle(e) {
outCh <- e
continue
}
// Start a new quantum if we need to
// and restart the quiescent timer
if quantum == nil {
quantum = time.After(coalescePeriod)
}
quiescent = time.After(quiescentPeriod)
// Coalesce the event
c.Coalesce(e)
case <-quantum:
goto FLUSH
case <-quiescent:
goto FLUSH
case <-shutdownCh:
shutdown = true
goto FLUSH
}
}
FLUSH:
// Flush the coalesced events
c.Flush(outCh)
// Restart ingestion if we are not done
if !shutdown {
goto INGEST
}
}

View file

@ -0,0 +1,68 @@
package serf
type coalesceEvent struct {
Type EventType
Member *Member
}
type memberEventCoalescer struct {
lastEvents map[string]EventType
latestEvents map[string]coalesceEvent
}
func (c *memberEventCoalescer) Handle(e Event) bool {
switch e.EventType() {
case EventMemberJoin:
return true
case EventMemberLeave:
return true
case EventMemberFailed:
return true
case EventMemberUpdate:
return true
case EventMemberReap:
return true
default:
return false
}
}
func (c *memberEventCoalescer) Coalesce(raw Event) {
e := raw.(MemberEvent)
for _, m := range e.Members {
c.latestEvents[m.Name] = coalesceEvent{
Type: e.Type,
Member: &m,
}
}
}
func (c *memberEventCoalescer) Flush(outCh chan<- Event) {
// Coalesce the various events we got into a single set of events.
events := make(map[EventType]*MemberEvent)
for name, cevent := range c.latestEvents {
previous, ok := c.lastEvents[name]
// If we sent the same event before, then ignore
// unless it is a MemberUpdate
if ok && previous == cevent.Type && cevent.Type != EventMemberUpdate {
continue
}
// Update our last event
c.lastEvents[name] = cevent.Type
// Add it to our event
newEvent, ok := events[cevent.Type]
if !ok {
newEvent = &MemberEvent{Type: cevent.Type}
events[cevent.Type] = newEvent
}
newEvent.Members = append(newEvent.Members, *cevent.Member)
}
// Send out those events
for _, event := range events {
outCh <- *event
}
}

View file

@ -0,0 +1,52 @@
package serf
type latestUserEvents struct {
LTime LamportTime
Events []Event
}
type userEventCoalescer struct {
// Maps an event name into the latest versions
events map[string]*latestUserEvents
}
func (c *userEventCoalescer) Handle(e Event) bool {
// Only handle EventUser messages
if e.EventType() != EventUser {
return false
}
// Check if coalescing is enabled
user := e.(UserEvent)
return user.Coalesce
}
func (c *userEventCoalescer) Coalesce(e Event) {
user := e.(UserEvent)
latest, ok := c.events[user.Name]
// Create a new entry if there are none, or
// if this message has the newest LTime
if !ok || latest.LTime < user.LTime {
latest = &latestUserEvents{
LTime: user.LTime,
Events: []Event{e},
}
c.events[user.Name] = latest
return
}
// If the the same age, save it
if latest.LTime == user.LTime {
latest.Events = append(latest.Events, e)
}
}
func (c *userEventCoalescer) Flush(outChan chan<- Event) {
for _, latest := range c.events {
for _, e := range latest.Events {
outChan <- e
}
}
c.events = make(map[string]*latestUserEvents)
}

View file

@ -0,0 +1,234 @@
package serf
import (
"io"
"os"
"time"
"github.com/hashicorp/memberlist"
)
// ProtocolVersionMap is the mapping of Serf delegate protocol versions
// to memberlist protocol versions. We mask the memberlist protocols using
// our own protocol version.
var ProtocolVersionMap map[uint8]uint8
func init() {
ProtocolVersionMap = map[uint8]uint8{
4: 2,
3: 2,
2: 2,
}
}
// Config is the configuration for creating a Serf instance.
type Config struct {
// The name of this node. This must be unique in the cluster. If this
// is not set, Serf will set it to the hostname of the running machine.
NodeName string
// The tags for this role, if any. This is used to provide arbitrary
// key/value metadata per-node. For example, a "role" tag may be used to
// differentiate "load-balancer" from a "web" role as parts of the same cluster.
// Tags are deprecating 'Role', and instead it acts as a special key in this
// map.
Tags map[string]string
// EventCh is a channel that receives all the Serf events. The events
// are sent on this channel in proper ordering. Care must be taken that
// this channel doesn't block, either by processing the events quick
// enough or buffering the channel, otherwise it can block state updates
// within Serf itself. If no EventCh is specified, no events will be fired,
// but point-in-time snapshots of members can still be retrieved by
// calling Members on Serf.
EventCh chan<- Event
// ProtocolVersion is the protocol version to speak. This must be between
// ProtocolVersionMin and ProtocolVersionMax.
ProtocolVersion uint8
// BroadcastTimeout is the amount of time to wait for a broadcast
// message to be sent to the cluster. Broadcast messages are used for
// things like leave messages and force remove messages. If this is not
// set, a timeout of 5 seconds will be set.
BroadcastTimeout time.Duration
// The settings below relate to Serf's event coalescence feature. Serf
// is able to coalesce multiple events into single events in order to
// reduce the amount of noise that is sent along the EventCh. For example
// if five nodes quickly join, the EventCh will be sent one EventMemberJoin
// containing the five nodes rather than five individual EventMemberJoin
// events. Coalescence can mitigate potential flapping behavior.
//
// Coalescence is disabled by default and can be enabled by setting
// CoalescePeriod.
//
// CoalescePeriod specifies the time duration to coalesce events.
// For example, if this is set to 5 seconds, then all events received
// within 5 seconds that can be coalesced will be.
//
// QuiescentPeriod specifies the duration of time where if no events
// are received, coalescence immediately happens. For example, if
// CoalscePeriod is set to 10 seconds but QuiscentPeriod is set to 2
// seconds, then the events will be coalesced and dispatched if no
// new events are received within 2 seconds of the last event. Otherwise,
// every event will always be delayed by at least 10 seconds.
CoalescePeriod time.Duration
QuiescentPeriod time.Duration
// The settings below relate to Serf's user event coalescing feature.
// The settings operate like above but only affect user messages and
// not the Member* messages that Serf generates.
UserCoalescePeriod time.Duration
UserQuiescentPeriod time.Duration
// The settings below relate to Serf keeping track of recently
// failed/left nodes and attempting reconnects.
//
// ReapInterval is the interval when the reaper runs. If this is not
// set (it is zero), it will be set to a reasonable default.
//
// ReconnectInterval is the interval when we attempt to reconnect
// to failed nodes. If this is not set (it is zero), it will be set
// to a reasonable default.
//
// ReconnectTimeout is the amount of time to attempt to reconnect to
// a failed node before giving up and considering it completely gone.
//
// TombstoneTimeout is the amount of time to keep around nodes
// that gracefully left as tombstones for syncing state with other
// Serf nodes.
ReapInterval time.Duration
ReconnectInterval time.Duration
ReconnectTimeout time.Duration
TombstoneTimeout time.Duration
// QueueDepthWarning is used to generate warning message if the
// number of queued messages to broadcast exceeds this number. This
// is to provide the user feedback if events are being triggered
// faster than they can be disseminated
QueueDepthWarning int
// MaxQueueDepth is used to start dropping messages if the number
// of queued messages to broadcast exceeds this number. This is to
// prevent an unbounded growth of memory utilization
MaxQueueDepth int
// RecentIntentBuffer is used to set the size of recent join and leave intent
// messages that will be buffered. This is used to guard against
// the case where Serf broadcasts an intent that arrives before the
// Memberlist event. It is important that this not be too small to avoid
// continuous rebroadcasting of dead events.
RecentIntentBuffer int
// EventBuffer is used to control how many events are buffered.
// This is used to prevent re-delivery of events to a client. The buffer
// must be large enough to handle all "recent" events, since Serf will
// not deliver messages that are older than the oldest entry in the buffer.
// Thus if a client is generating too many events, it's possible that the
// buffer gets overrun and messages are not delivered.
EventBuffer int
// QueryBuffer is used to control how many queries are buffered.
// This is used to prevent re-delivery of queries to a client. The buffer
// must be large enough to handle all "recent" events, since Serf will not
// deliver queries older than the oldest entry in the buffer.
// Thus if a client is generating too many queries, it's possible that the
// buffer gets overrun and messages are not delivered.
QueryBuffer int
// QueryTimeoutMult configures the default timeout multipler for a query to run if no
// specific value is provided. Queries are real-time by nature, where the
// reply is time sensitive. As a result, results are collected in an async
// fashion, however the query must have a bounded duration. We want the timeout
// to be long enough that all nodes have time to receive the message, run a handler,
// and generate a reply. Once the timeout is exceeded, any further replies are ignored.
// The default value is
//
// Timeout = GossipInterval * QueryTimeoutMult * log(N+1)
//
QueryTimeoutMult int
// MemberlistConfig is the memberlist configuration that Serf will
// use to do the underlying membership management and gossip. Some
// fields in the MemberlistConfig will be overwritten by Serf no
// matter what:
//
// * Name - This will always be set to the same as the NodeName
// in this configuration.
//
// * Events - Serf uses a custom event delegate.
//
// * Delegate - Serf uses a custom delegate.
//
MemberlistConfig *memberlist.Config
// LogOutput is the location to write logs to. If this is not set,
// logs will go to stderr.
LogOutput io.Writer
// SnapshotPath if provided is used to snapshot live nodes as well
// as lamport clock values. When Serf is started with a snapshot,
// it will attempt to join all the previously known nodes until one
// succeeds and will also avoid replaying old user events.
SnapshotPath string
// RejoinAfterLeave controls our interaction with the snapshot file.
// When set to false (default), a leave causes a Serf to not rejoin
// the cluster until an explicit join is received. If this is set to
// true, we ignore the leave, and rejoin the cluster on start.
RejoinAfterLeave bool
// EnableNameConflictResolution controls if Serf will actively attempt
// to resolve a name conflict. Since each Serf member must have a unique
// name, a cluster can run into issues if multiple nodes claim the same
// name. Without automatic resolution, Serf merely logs some warnings, but
// otherwise does not take any action. Automatic resolution detects the
// conflict and issues a special query which asks the cluster for the
// Name -> IP:Port mapping. If there is a simple majority of votes, that
// node stays while the other node will leave the cluster and exit.
EnableNameConflictResolution bool
// KeyringFile provides the location of a writable file where Serf can
// persist changes to the encryption keyring.
KeyringFile string
// Merge can be optionally provided to intercept a cluster merge
// and conditionally abort the merge.
Merge MergeDelegate
}
// Init allocates the subdata structures
func (c *Config) Init() {
if c.Tags == nil {
c.Tags = make(map[string]string)
}
}
// DefaultConfig returns a Config struct that contains reasonable defaults
// for most of the configurations.
func DefaultConfig() *Config {
hostname, err := os.Hostname()
if err != nil {
panic(err)
}
return &Config{
NodeName: hostname,
BroadcastTimeout: 5 * time.Second,
EventBuffer: 512,
QueryBuffer: 512,
LogOutput: os.Stderr,
ProtocolVersion: ProtocolVersionMax,
ReapInterval: 15 * time.Second,
RecentIntentBuffer: 128,
ReconnectInterval: 30 * time.Second,
ReconnectTimeout: 24 * time.Hour,
QueueDepthWarning: 128,
MaxQueueDepth: 4096,
TombstoneTimeout: 24 * time.Hour,
MemberlistConfig: memberlist.DefaultLANConfig(),
QueryTimeoutMult: 16,
EnableNameConflictResolution: true,
}
}

View file

@ -0,0 +1,13 @@
package serf
import (
"github.com/hashicorp/memberlist"
)
type conflictDelegate struct {
serf *Serf
}
func (c *conflictDelegate) NotifyConflict(existing, other *memberlist.Node) {
c.serf.handleNodeConflict(existing, other)
}

View file

@ -0,0 +1,247 @@
package serf
import (
"fmt"
"github.com/armon/go-metrics"
)
// delegate is the memberlist.Delegate implementation that Serf uses.
type delegate struct {
serf *Serf
}
func (d *delegate) NodeMeta(limit int) []byte {
roleBytes := d.serf.encodeTags(d.serf.config.Tags)
if len(roleBytes) > limit {
panic(fmt.Errorf("Node tags '%v' exceeds length limit of %d bytes", d.serf.config.Tags, limit))
}
return roleBytes
}
func (d *delegate) NotifyMsg(buf []byte) {
// If we didn't actually receive any data, then ignore it.
if len(buf) == 0 {
return
}
metrics.AddSample([]string{"serf", "msgs", "received"}, float32(len(buf)))
rebroadcast := false
rebroadcastQueue := d.serf.broadcasts
t := messageType(buf[0])
switch t {
case messageLeaveType:
var leave messageLeave
if err := decodeMessage(buf[1:], &leave); err != nil {
d.serf.logger.Printf("[ERR] serf: Error decoding leave message: %s", err)
break
}
d.serf.logger.Printf("[DEBUG] serf: messageLeaveType: %s", leave.Node)
rebroadcast = d.serf.handleNodeLeaveIntent(&leave)
case messageJoinType:
var join messageJoin
if err := decodeMessage(buf[1:], &join); err != nil {
d.serf.logger.Printf("[ERR] serf: Error decoding join message: %s", err)
break
}
d.serf.logger.Printf("[DEBUG] serf: messageJoinType: %s", join.Node)
rebroadcast = d.serf.handleNodeJoinIntent(&join)
case messageUserEventType:
var event messageUserEvent
if err := decodeMessage(buf[1:], &event); err != nil {
d.serf.logger.Printf("[ERR] serf: Error decoding user event message: %s", err)
break
}
d.serf.logger.Printf("[DEBUG] serf: messageUserEventType: %s", event.Name)
rebroadcast = d.serf.handleUserEvent(&event)
rebroadcastQueue = d.serf.eventBroadcasts
case messageQueryType:
var query messageQuery
if err := decodeMessage(buf[1:], &query); err != nil {
d.serf.logger.Printf("[ERR] serf: Error decoding query message: %s", err)
break
}
d.serf.logger.Printf("[DEBUG] serf: messageQueryType: %s", query.Name)
rebroadcast = d.serf.handleQuery(&query)
rebroadcastQueue = d.serf.queryBroadcasts
case messageQueryResponseType:
var resp messageQueryResponse
if err := decodeMessage(buf[1:], &resp); err != nil {
d.serf.logger.Printf("[ERR] serf: Error decoding query response message: %s", err)
break
}
d.serf.logger.Printf("[DEBUG] serf: messageQueryResponseType: %v", resp.From)
d.serf.handleQueryResponse(&resp)
default:
d.serf.logger.Printf("[WARN] serf: Received message of unknown type: %d", t)
}
if rebroadcast {
// Copy the buffer since it we cannot rely on the slice not changing
newBuf := make([]byte, len(buf))
copy(newBuf, buf)
rebroadcastQueue.QueueBroadcast(&broadcast{
msg: newBuf,
notify: nil,
})
}
}
func (d *delegate) GetBroadcasts(overhead, limit int) [][]byte {
msgs := d.serf.broadcasts.GetBroadcasts(overhead, limit)
// Determine the bytes used already
bytesUsed := 0
for _, msg := range msgs {
lm := len(msg)
bytesUsed += lm + overhead
metrics.AddSample([]string{"serf", "msgs", "sent"}, float32(lm))
}
// Get any additional query broadcasts
queryMsgs := d.serf.queryBroadcasts.GetBroadcasts(overhead, limit-bytesUsed)
if queryMsgs != nil {
for _, m := range queryMsgs {
lm := len(m)
bytesUsed += lm + overhead
metrics.AddSample([]string{"serf", "msgs", "sent"}, float32(lm))
}
msgs = append(msgs, queryMsgs...)
}
// Get any additional event broadcasts
eventMsgs := d.serf.eventBroadcasts.GetBroadcasts(overhead, limit-bytesUsed)
if eventMsgs != nil {
for _, m := range eventMsgs {
lm := len(m)
bytesUsed += lm + overhead
metrics.AddSample([]string{"serf", "msgs", "sent"}, float32(lm))
}
msgs = append(msgs, eventMsgs...)
}
return msgs
}
func (d *delegate) LocalState(join bool) []byte {
d.serf.memberLock.RLock()
defer d.serf.memberLock.RUnlock()
d.serf.eventLock.RLock()
defer d.serf.eventLock.RUnlock()
// Create the message to send
pp := messagePushPull{
LTime: d.serf.clock.Time(),
StatusLTimes: make(map[string]LamportTime, len(d.serf.members)),
LeftMembers: make([]string, 0, len(d.serf.leftMembers)),
EventLTime: d.serf.eventClock.Time(),
Events: d.serf.eventBuffer,
QueryLTime: d.serf.queryClock.Time(),
}
// Add all the join LTimes
for name, member := range d.serf.members {
pp.StatusLTimes[name] = member.statusLTime
}
// Add all the left nodes
for _, member := range d.serf.leftMembers {
pp.LeftMembers = append(pp.LeftMembers, member.Name)
}
// Encode the push pull state
buf, err := encodeMessage(messagePushPullType, &pp)
if err != nil {
d.serf.logger.Printf("[ERR] serf: Failed to encode local state: %v", err)
return nil
}
return buf
}
func (d *delegate) MergeRemoteState(buf []byte, isJoin bool) {
// Check the message type
if messageType(buf[0]) != messagePushPullType {
d.serf.logger.Printf("[ERR] serf: Remote state has bad type prefix: %v", buf[0])
return
}
// Attempt a decode
pp := messagePushPull{}
if err := decodeMessage(buf[1:], &pp); err != nil {
d.serf.logger.Printf("[ERR] serf: Failed to decode remote state: %v", err)
return
}
// Witness the Lamport clocks first.
// We subtract 1 since no message with that clock has been sent yet
if pp.LTime > 0 {
d.serf.clock.Witness(pp.LTime - 1)
}
if pp.EventLTime > 0 {
d.serf.eventClock.Witness(pp.EventLTime - 1)
}
if pp.QueryLTime > 0 {
d.serf.queryClock.Witness(pp.QueryLTime - 1)
}
// Process the left nodes first to avoid the LTimes from being increment
// in the wrong order
leftMap := make(map[string]struct{}, len(pp.LeftMembers))
leave := messageLeave{}
for _, name := range pp.LeftMembers {
leftMap[name] = struct{}{}
leave.LTime = pp.StatusLTimes[name]
leave.Node = name
d.serf.handleNodeLeaveIntent(&leave)
}
// Update any other LTimes
join := messageJoin{}
for name, statusLTime := range pp.StatusLTimes {
// Skip the left nodes
if _, ok := leftMap[name]; ok {
continue
}
// Create an artificial join message
join.LTime = statusLTime
join.Node = name
d.serf.handleNodeJoinIntent(&join)
}
// If we are doing a join, and eventJoinIgnore is set
// then we set the eventMinTime to the EventLTime. This
// prevents any of the incoming events from being processed
if isJoin && d.serf.eventJoinIgnore {
d.serf.eventLock.Lock()
if pp.EventLTime > d.serf.eventMinTime {
d.serf.eventMinTime = pp.EventLTime
}
d.serf.eventLock.Unlock()
}
// Process all the events
userEvent := messageUserEvent{}
for _, events := range pp.Events {
if events == nil {
continue
}
userEvent.LTime = events.LTime
for _, e := range events.Events {
userEvent.Name = e.Name
userEvent.Payload = e.Payload
d.serf.handleUserEvent(&userEvent)
}
}
}

View file

@ -0,0 +1,168 @@
package serf
import (
"fmt"
"net"
"sync"
"time"
)
// EventType are all the types of events that may occur and be sent
// along the Serf channel.
type EventType int
const (
EventMemberJoin EventType = iota
EventMemberLeave
EventMemberFailed
EventMemberUpdate
EventMemberReap
EventUser
EventQuery
)
func (t EventType) String() string {
switch t {
case EventMemberJoin:
return "member-join"
case EventMemberLeave:
return "member-leave"
case EventMemberFailed:
return "member-failed"
case EventMemberUpdate:
return "member-update"
case EventMemberReap:
return "member-reap"
case EventUser:
return "user"
case EventQuery:
return "query"
default:
panic(fmt.Sprintf("unknown event type: %d", t))
}
}
// Event is a generic interface for exposing Serf events
// Clients will usually need to use a type switches to get
// to a more useful type
type Event interface {
EventType() EventType
String() string
}
// MemberEvent is the struct used for member related events
// Because Serf coalesces events, an event may contain multiple members.
type MemberEvent struct {
Type EventType
Members []Member
}
func (m MemberEvent) EventType() EventType {
return m.Type
}
func (m MemberEvent) String() string {
switch m.Type {
case EventMemberJoin:
return "member-join"
case EventMemberLeave:
return "member-leave"
case EventMemberFailed:
return "member-failed"
case EventMemberUpdate:
return "member-update"
case EventMemberReap:
return "member-reap"
default:
panic(fmt.Sprintf("unknown event type: %d", m.Type))
}
}
// UserEvent is the struct used for events that are triggered
// by the user and are not related to members
type UserEvent struct {
LTime LamportTime
Name string
Payload []byte
Coalesce bool
}
func (u UserEvent) EventType() EventType {
return EventUser
}
func (u UserEvent) String() string {
return fmt.Sprintf("user-event: %s", u.Name)
}
// Query is the struct used EventQuery type events
type Query struct {
LTime LamportTime
Name string
Payload []byte
serf *Serf
id uint32 // ID is not exported, since it may change
addr []byte // Address to respond to
port uint16 // Port to respond to
deadline time.Time // Must respond by this deadline
respLock sync.Mutex
}
func (q *Query) EventType() EventType {
return EventQuery
}
func (q *Query) String() string {
return fmt.Sprintf("query: %s", q.Name)
}
// Deadline returns the time by which a response must be sent
func (q *Query) Deadline() time.Time {
return q.deadline
}
// Respond is used to send a response to the user query
func (q *Query) Respond(buf []byte) error {
q.respLock.Lock()
defer q.respLock.Unlock()
// Check if we've already responded
if q.deadline.IsZero() {
return fmt.Errorf("Response already sent")
}
// Ensure we aren't past our response deadline
if time.Now().After(q.deadline) {
return fmt.Errorf("Response is past the deadline")
}
// Create response
resp := messageQueryResponse{
LTime: q.LTime,
ID: q.id,
From: q.serf.config.NodeName,
Payload: buf,
}
// Format the response
raw, err := encodeMessage(messageQueryResponseType, &resp)
if err != nil {
return fmt.Errorf("Failed to format response: %v", err)
}
// Check the size limit
if len(raw) > QueryResponseSizeLimit {
return fmt.Errorf("response exceeds limit of %d bytes", QueryResponseSizeLimit)
}
// Send the response
addr := net.UDPAddr{IP: q.addr, Port: int(q.port)}
if err := q.serf.memberlist.SendTo(&addr, raw); err != nil {
return err
}
// Clera the deadline, response sent
q.deadline = time.Time{}
return nil
}

View file

@ -0,0 +1,21 @@
package serf
import (
"github.com/hashicorp/memberlist"
)
type eventDelegate struct {
serf *Serf
}
func (e *eventDelegate) NotifyJoin(n *memberlist.Node) {
e.serf.handleNodeJoin(n)
}
func (e *eventDelegate) NotifyLeave(n *memberlist.Node) {
e.serf.handleNodeLeave(n)
}
func (e *eventDelegate) NotifyUpdate(n *memberlist.Node) {
e.serf.handleNodeUpdate(n)
}

View file

@ -0,0 +1,312 @@
package serf
import (
"encoding/base64"
"log"
"strings"
)
const (
// This is the prefix we use for queries that are internal to Serf.
// They are handled internally, and not forwarded to a client.
InternalQueryPrefix = "_serf_"
// pingQuery is run to check for reachability
pingQuery = "ping"
// conflictQuery is run to resolve a name conflict
conflictQuery = "conflict"
// installKeyQuery is used to install a new key
installKeyQuery = "install-key"
// useKeyQuery is used to change the primary encryption key
useKeyQuery = "use-key"
// removeKeyQuery is used to remove a key from the keyring
removeKeyQuery = "remove-key"
// listKeysQuery is used to list all known keys in the cluster
listKeysQuery = "list-keys"
)
// internalQueryName is used to generate a query name for an internal query
func internalQueryName(name string) string {
return InternalQueryPrefix + name
}
// serfQueries is used to listen for queries that start with
// _serf and respond to them as appropriate.
type serfQueries struct {
inCh chan Event
logger *log.Logger
outCh chan<- Event
serf *Serf
shutdownCh <-chan struct{}
}
// nodeKeyResponse is used to store the result from an individual node while
// replying to key modification queries
type nodeKeyResponse struct {
// Result indicates true/false if there were errors or not
Result bool
// Message contains error messages or other information
Message string
// Keys is used in listing queries to relay a list of installed keys
Keys []string
}
// newSerfQueries is used to create a new serfQueries. We return an event
// channel that is ingested and forwarded to an outCh. Any Queries that
// have the InternalQueryPrefix are handled instead of forwarded.
func newSerfQueries(serf *Serf, logger *log.Logger, outCh chan<- Event, shutdownCh <-chan struct{}) (chan<- Event, error) {
inCh := make(chan Event, 1024)
q := &serfQueries{
inCh: inCh,
logger: logger,
outCh: outCh,
serf: serf,
shutdownCh: shutdownCh,
}
go q.stream()
return inCh, nil
}
// stream is a long running routine to ingest the event stream
func (s *serfQueries) stream() {
for {
select {
case e := <-s.inCh:
// Check if this is a query we should process
if q, ok := e.(*Query); ok && strings.HasPrefix(q.Name, InternalQueryPrefix) {
go s.handleQuery(q)
} else if s.outCh != nil {
s.outCh <- e
}
case <-s.shutdownCh:
return
}
}
}
// handleQuery is invoked when we get an internal query
func (s *serfQueries) handleQuery(q *Query) {
// Get the queryName after the initial prefix
queryName := q.Name[len(InternalQueryPrefix):]
switch queryName {
case pingQuery:
// Nothing to do, we will ack the query
case conflictQuery:
s.handleConflict(q)
case installKeyQuery:
s.handleInstallKey(q)
case useKeyQuery:
s.handleUseKey(q)
case removeKeyQuery:
s.handleRemoveKey(q)
case listKeysQuery:
s.handleListKeys(q)
default:
s.logger.Printf("[WARN] serf: Unhandled internal query '%s'", queryName)
}
}
// handleConflict is invoked when we get a query that is attempting to
// disambiguate a name conflict. They payload is a node name, and the response
// should the address we believe that node is at, if any.
func (s *serfQueries) handleConflict(q *Query) {
// The target node name is the payload
node := string(q.Payload)
// Do not respond to the query if it is about us
if node == s.serf.config.NodeName {
return
}
s.logger.Printf("[DEBUG] serf: Got conflict resolution query for '%s'", node)
// Look for the member info
var out *Member
s.serf.memberLock.Lock()
if member, ok := s.serf.members[node]; ok {
out = &member.Member
}
s.serf.memberLock.Unlock()
// Encode the response
buf, err := encodeMessage(messageConflictResponseType, out)
if err != nil {
s.logger.Printf("[ERR] serf: Failed to encode conflict query response: %v", err)
return
}
// Send our answer
if err := q.Respond(buf); err != nil {
s.logger.Printf("[ERR] serf: Failed to respond to conflict query: %v", err)
}
}
// sendKeyResponse handles responding to key-related queries.
func (s *serfQueries) sendKeyResponse(q *Query, resp *nodeKeyResponse) {
buf, err := encodeMessage(messageKeyResponseType, resp)
if err != nil {
s.logger.Printf("[ERR] serf: Failed to encode key response: %v", err)
return
}
if err := q.Respond(buf); err != nil {
s.logger.Printf("[ERR] serf: Failed to respond to key query: %v", err)
return
}
}
// handleInstallKey is invoked whenever a new encryption key is received from
// another member in the cluster, and handles the process of installing it onto
// the memberlist keyring. This type of query may fail if the provided key does
// not fit the constraints that memberlist enforces. If the query fails, the
// response will contain the error message so that it may be relayed.
func (s *serfQueries) handleInstallKey(q *Query) {
response := nodeKeyResponse{Result: false}
keyring := s.serf.config.MemberlistConfig.Keyring
req := keyRequest{}
err := decodeMessage(q.Payload[1:], &req)
if err != nil {
s.logger.Printf("[ERR] serf: Failed to decode key request: %v", err)
goto SEND
}
if !s.serf.EncryptionEnabled() {
response.Message = "No keyring to modify (encryption not enabled)"
s.logger.Printf("[ERR] serf: No keyring to modify (encryption not enabled)")
goto SEND
}
s.logger.Printf("[INFO] serf: Received install-key query")
if err := keyring.AddKey(req.Key); err != nil {
response.Message = err.Error()
s.logger.Printf("[ERR] serf: Failed to install key: %s", err)
goto SEND
}
if err := s.serf.writeKeyringFile(); err != nil {
response.Message = err.Error()
s.logger.Printf("[ERR] serf: Failed to write keyring file: %s", err)
goto SEND
}
response.Result = true
SEND:
s.sendKeyResponse(q, &response)
}
// handleUseKey is invoked whenever a query is received to mark a different key
// in the internal keyring as the primary key. This type of query may fail due
// to operator error (requested key not in ring), and thus sends error messages
// back in the response.
func (s *serfQueries) handleUseKey(q *Query) {
response := nodeKeyResponse{Result: false}
keyring := s.serf.config.MemberlistConfig.Keyring
req := keyRequest{}
err := decodeMessage(q.Payload[1:], &req)
if err != nil {
s.logger.Printf("[ERR] serf: Failed to decode key request: %v", err)
goto SEND
}
if !s.serf.EncryptionEnabled() {
response.Message = "No keyring to modify (encryption not enabled)"
s.logger.Printf("[ERR] serf: No keyring to modify (encryption not enabled)")
goto SEND
}
s.logger.Printf("[INFO] serf: Received use-key query")
if err := keyring.UseKey(req.Key); err != nil {
response.Message = err.Error()
s.logger.Printf("[ERR] serf: Failed to change primary key: %s", err)
goto SEND
}
if err := s.serf.writeKeyringFile(); err != nil {
response.Message = err.Error()
s.logger.Printf("[ERR] serf: Failed to write keyring file: %s", err)
goto SEND
}
response.Result = true
SEND:
s.sendKeyResponse(q, &response)
}
// handleRemoveKey is invoked when a query is received to remove a particular
// key from the keyring. This type of query can fail if the key requested for
// deletion is currently the primary key in the keyring, so therefore it will
// reply to the query with any relevant errors from the operation.
func (s *serfQueries) handleRemoveKey(q *Query) {
response := nodeKeyResponse{Result: false}
keyring := s.serf.config.MemberlistConfig.Keyring
req := keyRequest{}
err := decodeMessage(q.Payload[1:], &req)
if err != nil {
s.logger.Printf("[ERR] serf: Failed to decode key request: %v", err)
goto SEND
}
if !s.serf.EncryptionEnabled() {
response.Message = "No keyring to modify (encryption not enabled)"
s.logger.Printf("[ERR] serf: No keyring to modify (encryption not enabled)")
goto SEND
}
s.logger.Printf("[INFO] serf: Received remove-key query")
if err := keyring.RemoveKey(req.Key); err != nil {
response.Message = err.Error()
s.logger.Printf("[ERR] serf: Failed to remove key: %s", err)
goto SEND
}
if err := s.serf.writeKeyringFile(); err != nil {
response.Message = err.Error()
s.logger.Printf("[ERR] serf: Failed to write keyring file: %s", err)
goto SEND
}
response.Result = true
SEND:
s.sendKeyResponse(q, &response)
}
// handleListKeys is invoked when a query is received to return a list of all
// installed keys the Serf instance knows of. For performance, the keys are
// encoded to base64 on each of the members to remove this burden from the
// node asking for the results.
func (s *serfQueries) handleListKeys(q *Query) {
response := nodeKeyResponse{Result: false}
keyring := s.serf.config.MemberlistConfig.Keyring
if !s.serf.EncryptionEnabled() {
response.Message = "Keyring is empty (encryption not enabled)"
s.logger.Printf("[ERR] serf: Keyring is empty (encryption not enabled)")
goto SEND
}
s.logger.Printf("[INFO] serf: Received list-keys query")
for _, keyBytes := range keyring.GetKeys() {
// Encode the keys before sending the response. This should help take
// some the burden of doing this off of the asking member.
key := base64.StdEncoding.EncodeToString(keyBytes)
response.Keys = append(response.Keys, key)
}
response.Result = true
SEND:
s.sendKeyResponse(q, &response)
}

View file

@ -0,0 +1,166 @@
package serf
import (
"encoding/base64"
"fmt"
"sync"
)
// KeyManager encapsulates all functionality within Serf for handling
// encryption keyring changes across a cluster.
type KeyManager struct {
serf *Serf
// Lock to protect read and write operations
l sync.RWMutex
}
// keyRequest is used to contain input parameters which get broadcasted to all
// nodes as part of a key query operation.
type keyRequest struct {
Key []byte
}
// KeyResponse is used to relay a query for a list of all keys in use.
type KeyResponse struct {
Messages map[string]string // Map of node name to response message
NumNodes int // Total nodes memberlist knows of
NumResp int // Total responses received
NumErr int // Total errors from request
// Keys is a mapping of the base64-encoded value of the key bytes to the
// number of nodes that have the key installed.
Keys map[string]int
}
// streamKeyResp takes care of reading responses from a channel and composing
// them into a KeyResponse. It will update a KeyResponse *in place* and
// therefore has nothing to return.
func (k *KeyManager) streamKeyResp(resp *KeyResponse, ch <-chan NodeResponse) {
for r := range ch {
var nodeResponse nodeKeyResponse
resp.NumResp++
// Decode the response
if len(r.Payload) < 1 || messageType(r.Payload[0]) != messageKeyResponseType {
resp.Messages[r.From] = fmt.Sprintf(
"Invalid key query response type: %v", r.Payload)
resp.NumErr++
goto NEXT
}
if err := decodeMessage(r.Payload[1:], &nodeResponse); err != nil {
resp.Messages[r.From] = fmt.Sprintf(
"Failed to decode key query response: %v", r.Payload)
resp.NumErr++
goto NEXT
}
if !nodeResponse.Result {
resp.Messages[r.From] = nodeResponse.Message
resp.NumErr++
}
// Currently only used for key list queries, this adds keys to a counter
// and increments them for each node response which contains them.
for _, key := range nodeResponse.Keys {
if _, ok := resp.Keys[key]; !ok {
resp.Keys[key] = 1
} else {
resp.Keys[key]++
}
}
NEXT:
// Return early if all nodes have responded. This allows us to avoid
// waiting for the full timeout when there is nothing left to do.
if resp.NumResp == resp.NumNodes {
return
}
}
}
// handleKeyRequest performs query broadcasting to all members for any type of
// key operation and manages gathering responses and packing them up into a
// KeyResponse for uniform response handling.
func (k *KeyManager) handleKeyRequest(key, query string) (*KeyResponse, error) {
resp := &KeyResponse{
Messages: make(map[string]string),
Keys: make(map[string]int),
}
qName := internalQueryName(query)
// Decode the new key into raw bytes
rawKey, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return resp, err
}
// Encode the query request
req, err := encodeMessage(messageKeyRequestType, keyRequest{Key: rawKey})
if err != nil {
return resp, err
}
qParam := k.serf.DefaultQueryParams()
queryResp, err := k.serf.Query(qName, req, qParam)
if err != nil {
return resp, err
}
// Handle the response stream and populate the KeyResponse
resp.NumNodes = k.serf.memberlist.NumMembers()
k.streamKeyResp(resp, queryResp.respCh)
// Check the response for any reported failure conditions
if resp.NumErr != 0 {
return resp, fmt.Errorf("%d/%d nodes reported failure", resp.NumErr, resp.NumNodes)
}
if resp.NumResp != resp.NumNodes {
return resp, fmt.Errorf("%d/%d nodes reported success", resp.NumResp, resp.NumNodes)
}
return resp, nil
}
// InstallKey handles broadcasting a query to all members and gathering
// responses from each of them, returning a list of messages from each node
// and any applicable error conditions.
func (k *KeyManager) InstallKey(key string) (*KeyResponse, error) {
k.l.Lock()
defer k.l.Unlock()
return k.handleKeyRequest(key, installKeyQuery)
}
// UseKey handles broadcasting a primary key change to all members in the
// cluster, and gathering any response messages. If successful, there should
// be an empty KeyResponse returned.
func (k *KeyManager) UseKey(key string) (*KeyResponse, error) {
k.l.Lock()
defer k.l.Unlock()
return k.handleKeyRequest(key, useKeyQuery)
}
// RemoveKey handles broadcasting a key to the cluster for removal. Each member
// will receive this event, and if they have the key in their keyring, remove
// it. If any errors are encountered, RemoveKey will collect and relay them.
func (k *KeyManager) RemoveKey(key string) (*KeyResponse, error) {
k.l.Lock()
defer k.l.Unlock()
return k.handleKeyRequest(key, removeKeyQuery)
}
// ListKeys is used to collect installed keys from members in a Serf cluster
// and return an aggregated list of all installed keys. This is useful to
// operators to ensure that there are no lingering keys installed on any agents.
// Since having multiple keys installed can cause performance penalties in some
// cases, it's important to verify this information and remove unneeded keys.
func (k *KeyManager) ListKeys() (*KeyResponse, error) {
k.l.RLock()
defer k.l.RUnlock()
return k.handleKeyRequest("", listKeysQuery)
}

View file

@ -0,0 +1,45 @@
package serf
import (
"sync/atomic"
)
// LamportClock is a thread safe implementation of a lamport clock. It
// uses efficient atomic operations for all of its functions, falling back
// to a heavy lock only if there are enough CAS failures.
type LamportClock struct {
counter uint64
}
// LamportTime is the value of a LamportClock.
type LamportTime uint64
// Time is used to return the current value of the lamport clock
func (l *LamportClock) Time() LamportTime {
return LamportTime(atomic.LoadUint64(&l.counter))
}
// Increment is used to increment and return the value of the lamport clock
func (l *LamportClock) Increment() LamportTime {
return LamportTime(atomic.AddUint64(&l.counter, 1))
}
// Witness is called to update our local clock if necessary after
// witnessing a clock value received from another process
func (l *LamportClock) Witness(v LamportTime) {
WITNESS:
// If the other value is old, we do not need to do anything
cur := atomic.LoadUint64(&l.counter)
other := uint64(v)
if other < cur {
return
}
// Ensure that our local clock is at least one ahead.
if !atomic.CompareAndSwapUint64(&l.counter, cur, other+1) {
// The CAS failed, so we just retry. Eventually our CAS should
// succeed or a future witness will pass us by and our witness
// will end.
goto WITNESS
}
}

View file

@ -0,0 +1,35 @@
package serf
import (
"net"
"github.com/hashicorp/memberlist"
)
type MergeDelegate interface {
NotifyMerge([]*Member) (cancel bool)
}
type mergeDelegate struct {
serf *Serf
}
func (m *mergeDelegate) NotifyMerge(nodes []*memberlist.Node) (cancel bool) {
members := make([]*Member, len(nodes))
for idx, n := range nodes {
members[idx] = &Member{
Name: n.Name,
Addr: net.IP(n.Addr),
Port: n.Port,
Tags: m.serf.decodeTags(n.Meta),
Status: StatusNone,
ProtocolMin: n.PMin,
ProtocolMax: n.PMax,
ProtocolCur: n.PCur,
DelegateMin: n.DMin,
DelegateMax: n.DMax,
DelegateCur: n.DCur,
}
}
return m.serf.config.Merge.NotifyMerge(members)
}

View file

@ -0,0 +1,147 @@
package serf
import (
"bytes"
"github.com/hashicorp/go-msgpack/codec"
"time"
)
// messageType are the types of gossip messages Serf will send along
// memberlist.
type messageType uint8
const (
messageLeaveType messageType = iota
messageJoinType
messagePushPullType
messageUserEventType
messageQueryType
messageQueryResponseType
messageConflictResponseType
messageKeyRequestType
messageKeyResponseType
)
const (
// Ack flag is used to force receiver to send an ack back
queryFlagAck uint32 = 1 << iota
// NoBroadcast is used to prevent re-broadcast of a query.
// this can be used to selectively send queries to individual members
queryFlagNoBroadcast
)
// filterType is used with a queryFilter to specify the type of
// filter we are sending
type filterType uint8
const (
filterNodeType filterType = iota
filterTagType
)
// messageJoin is the message broadcasted after we join to
// associated the node with a lamport clock
type messageJoin struct {
LTime LamportTime
Node string
}
// messageLeave is the message broadcasted to signal the intentional to
// leave.
type messageLeave struct {
LTime LamportTime
Node string
}
// messagePushPullType is used when doing a state exchange. This
// is a relatively large message, but is sent infrequently
type messagePushPull struct {
LTime LamportTime // Current node lamport time
StatusLTimes map[string]LamportTime // Maps the node to its status time
LeftMembers []string // List of left nodes
EventLTime LamportTime // Lamport time for event clock
Events []*userEvents // Recent events
QueryLTime LamportTime // Lamport time for query clock
}
// messageUserEvent is used for user-generated events
type messageUserEvent struct {
LTime LamportTime
Name string
Payload []byte
CC bool // "Can Coalesce". Zero value is compatible with Serf 0.1
}
// messageQuery is used for query events
type messageQuery struct {
LTime LamportTime // Event lamport time
ID uint32 // Query ID, randomly generated
Addr []byte // Source address, used for a direct reply
Port uint16 // Source port, used for a direct reply
Filters [][]byte // Potential query filters
Flags uint32 // Used to provide various flags
Timeout time.Duration // Maximum time between delivery and response
Name string // Query name
Payload []byte // Query payload
}
// Ack checks if the ack flag is set
func (m *messageQuery) Ack() bool {
return (m.Flags & queryFlagAck) != 0
}
// NoBroadcast checks if the no broadcast flag is set
func (m *messageQuery) NoBroadcast() bool {
return (m.Flags & queryFlagNoBroadcast) != 0
}
// filterNode is used with the filterNodeType, and is a list
// of node names
type filterNode []string
// filterTag is used with the filterTagType and is a regular
// expression to apply to a tag
type filterTag struct {
Tag string
Expr string
}
// messageQueryResponse is used to respond to a query
type messageQueryResponse struct {
LTime LamportTime // Event lamport time
ID uint32 // Query ID
From string // Node name
Flags uint32 // Used to provide various flags
Payload []byte // Optional response payload
}
// Ack checks if the ack flag is set
func (m *messageQueryResponse) Ack() bool {
return (m.Flags & queryFlagAck) != 0
}
func decodeMessage(buf []byte, out interface{}) error {
var handle codec.MsgpackHandle
return codec.NewDecoder(bytes.NewReader(buf), &handle).Decode(out)
}
func encodeMessage(t messageType, msg interface{}) ([]byte, error) {
buf := bytes.NewBuffer(nil)
buf.WriteByte(uint8(t))
handle := codec.MsgpackHandle{}
encoder := codec.NewEncoder(buf, &handle)
err := encoder.Encode(msg)
return buf.Bytes(), err
}
func encodeFilter(f filterType, filt interface{}) ([]byte, error) {
buf := bytes.NewBuffer(nil)
buf.WriteByte(uint8(f))
handle := codec.MsgpackHandle{}
encoder := codec.NewEncoder(buf, &handle)
err := encoder.Encode(filt)
return buf.Bytes(), err
}

View file

@ -0,0 +1,210 @@
package serf
import (
"math"
"regexp"
"sync"
"time"
)
// QueryParam is provided to Query() to configure the parameters of the
// query. If not provided, sane defaults will be used.
type QueryParam struct {
// If provided, we restrict the nodes that should respond to those
// with names in this list
FilterNodes []string
// FilterTags maps a tag name to a regular expression that is applied
// to restrict the nodes that should respond
FilterTags map[string]string
// If true, we are requesting an delivery acknowledgement from
// every node that meets the filter requirement. This means nodes
// the receive the message but do not pass the filters, will not
// send an ack.
RequestAck bool
// The timeout limits how long the query is left open. If not provided,
// then a default timeout is used based on the configuration of Serf
Timeout time.Duration
}
// DefaultQueryTimeout returns the default timeout value for a query
// Computed as GossipInterval * QueryTimeoutMult * log(N+1)
func (s *Serf) DefaultQueryTimeout() time.Duration {
n := s.memberlist.NumMembers()
timeout := s.config.MemberlistConfig.GossipInterval
timeout *= time.Duration(s.config.QueryTimeoutMult)
timeout *= time.Duration(math.Ceil(math.Log10(float64(n + 1))))
return timeout
}
// DefaultQueryParam is used to return the default query parameters
func (s *Serf) DefaultQueryParams() *QueryParam {
return &QueryParam{
FilterNodes: nil,
FilterTags: nil,
RequestAck: false,
Timeout: s.DefaultQueryTimeout(),
}
}
// encodeFilters is used to convert the filters into the wire format
func (q *QueryParam) encodeFilters() ([][]byte, error) {
var filters [][]byte
// Add the node filter
if len(q.FilterNodes) > 0 {
if buf, err := encodeFilter(filterNodeType, q.FilterNodes); err != nil {
return nil, err
} else {
filters = append(filters, buf)
}
}
// Add the tag filters
for tag, expr := range q.FilterTags {
filt := filterTag{tag, expr}
if buf, err := encodeFilter(filterTagType, &filt); err != nil {
return nil, err
} else {
filters = append(filters, buf)
}
}
return filters, nil
}
// QueryResponse is returned for each new Query. It is used to collect
// Ack's as well as responses and to provide those back to a client.
type QueryResponse struct {
// ackCh is used to send the name of a node for which we've received an ack
ackCh chan string
// deadline is the query end time (start + query timeout)
deadline time.Time
// Query ID
id uint32
// Stores the LTime of the query
lTime LamportTime
// respCh is used to send a response from a node
respCh chan NodeResponse
closed bool
closeLock sync.Mutex
}
// newQueryResponse is used to construct a new query response
func newQueryResponse(n int, q *messageQuery) *QueryResponse {
resp := &QueryResponse{
deadline: time.Now().Add(q.Timeout),
id: q.ID,
lTime: q.LTime,
respCh: make(chan NodeResponse, n),
}
if q.Ack() {
resp.ackCh = make(chan string, n)
}
return resp
}
// Close is used to close the query, which will close the underlying
// channels and prevent further deliveries
func (r *QueryResponse) Close() {
r.closeLock.Lock()
defer r.closeLock.Unlock()
if r.closed {
return
}
r.closed = true
if r.ackCh != nil {
close(r.ackCh)
}
if r.respCh != nil {
close(r.respCh)
}
}
// Deadline returns the ending deadline of the query
func (r *QueryResponse) Deadline() time.Time {
return r.deadline
}
// Finished returns if the query is finished running
func (r *QueryResponse) Finished() bool {
return r.closed || time.Now().After(r.deadline)
}
// AckCh returns a channel that can be used to listen for acks
// Channel will be closed when the query is finished. This is nil,
// if the query did not specify RequestAck.
func (r *QueryResponse) AckCh() <-chan string {
return r.ackCh
}
// ResponseCh returns a channel that can be used to listen for responses.
// Channel will be closed when the query is finished.
func (r *QueryResponse) ResponseCh() <-chan NodeResponse {
return r.respCh
}
// NodeResponse is used to represent a single response from a node
type NodeResponse struct {
From string
Payload []byte
}
// shouldProcessQuery checks if a query should be proceeded given
// a set of filers.
func (s *Serf) shouldProcessQuery(filters [][]byte) bool {
for _, filter := range filters {
switch filterType(filter[0]) {
case filterNodeType:
// Decode the filter
var nodes filterNode
if err := decodeMessage(filter[1:], &nodes); err != nil {
s.logger.Printf("[WARN] serf: failed to decode filterNodeType: %v", err)
return false
}
// Check if we are being targeted
found := false
for _, n := range nodes {
if n == s.config.NodeName {
found = true
break
}
}
if !found {
return false
}
case filterTagType:
// Decode the filter
var filt filterTag
if err := decodeMessage(filter[1:], &filt); err != nil {
s.logger.Printf("[WARN] serf: failed to decode filterTagType: %v", err)
return false
}
// Check if we match this regex
tags := s.config.Tags
matched, err := regexp.MatchString(filt.Expr, tags[filt.Tag])
if err != nil {
s.logger.Printf("[WARN] serf: failed to compile filter regex (%s): %v", filt.Expr, err)
return false
}
if !matched {
return false
}
default:
s.logger.Printf("[WARN] serf: query has unrecognized filter type: %d", filter[0])
return false
}
}
return true
}

Some files were not shown because too many files have changed in this diff Show more