Vendoring in libnetwork to 3be488927db8d719568917203deddd630a194564

This PR brings the vendored libnetwork code to
3be488927db8d719568917203deddd630a194564, which pulls in quite a few
fixes to support kvstore, windows daemon compilation fixes,
multi-network support for Bridge driver, etc...

Signed-off-by: Madhu Venugopal <madhu@docker.com>
This commit is contained in:
Madhu Venugopal 2015-06-12 13:51:57 -07:00
parent 5bddafe169
commit 083300168f
118 changed files with 15213 additions and 736 deletions

View file

@ -652,7 +652,7 @@ func (container *Container) UpdateNetwork() error {
return fmt.Errorf("Update network failed: %v", err)
}
if _, err := ep.Join(container.ID, joinOptions...); err != nil {
if err := ep.Join(container.ID, joinOptions...); err != nil {
return fmt.Errorf("endpoint join failed: %v", err)
}
@ -769,7 +769,7 @@ func (container *Container) AllocateNetwork() error {
return err
}
if _, err := ep.Join(container.ID, joinOptions...); err != nil {
if err := ep.Join(container.ID, joinOptions...); err != nil {
return err
}

View file

@ -17,9 +17,14 @@ clone hg code.google.com/p/go.net 84a4013f96e0
clone hg code.google.com/p/gosqlite 74691fb6f837
#get libnetwork packages
clone git github.com/docker/libnetwork e578e95aa101441481411ff1d620f343895f24fe
clone git github.com/docker/libnetwork 3be488927db8d719568917203deddd630a194564
clone git github.com/docker/libkv e8cde779d58273d240c1eff065352a6cd67027dd
clone git github.com/vishvananda/netns 5478c060110032f972e86a1f844fdb9a2f008f2c
clone git github.com/vishvananda/netlink 8eb64238879fed52fd51c5b30ad20b928fb4c36c
clone git github.com/BurntSushi/toml f706d00e3de6abe700c994cdd545a1a4915af060
clone git github.com/samuel/go-zookeeper d0e0d8e11f318e000a8cc434616d69e329edc374
clone git github.com/coreos/go-etcd v2.0.0
clone git github.com/hashicorp/consul v0.5.2
# get distribution packages
clone git github.com/docker/distribution b9eeb328080d367dbde850ec6e94f1e4ac2b5efe

View file

@ -0,0 +1,5 @@
TAGS
tags
.*.swp
tomlcheck/tomlcheck
toml.test

View file

@ -0,0 +1,12 @@
language: go
go:
- 1.1
- 1.2
- tip
install:
- go install ./...
- go get github.com/BurntSushi/toml-test
script:
- export PATH="$PATH:$HOME/gopath/bin"
- make test

View file

@ -0,0 +1,3 @@
Compatible with TOML version
[v0.2.0](https://github.com/mojombo/toml/blob/master/versions/toml-v0.2.0.md)

View file

@ -0,0 +1,14 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
Version 2, December 2004
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
Everyone is permitted to copy and distribute verbatim or modified
copies of this license document, and changing it is allowed as long
as the name is changed.
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. You just DO WHAT THE FUCK YOU WANT TO.

View file

@ -0,0 +1,19 @@
install:
go install ./...
test: install
go test -v
toml-test toml-test-decoder
toml-test -encoder toml-test-encoder
fmt:
gofmt -w *.go */*.go
colcheck *.go */*.go
tags:
find ./ -name '*.go' -print0 | xargs -0 gotags > TAGS
push:
git push origin master
git push github master

View file

@ -0,0 +1,220 @@
## TOML parser and encoder for Go with reflection
TOML stands for Tom's Obvious, Minimal Language. This Go package provides a
reflection interface similar to Go's standard library `json` and `xml`
packages. This package also supports the `encoding.TextUnmarshaler` and
`encoding.TextMarshaler` interfaces so that you can define custom data
representations. (There is an example of this below.)
Spec: https://github.com/mojombo/toml
Compatible with TOML version
[v0.2.0](https://github.com/mojombo/toml/blob/master/versions/toml-v0.2.0.md)
Documentation: http://godoc.org/github.com/BurntSushi/toml
Installation:
```bash
go get github.com/BurntSushi/toml
```
Try the toml validator:
```bash
go get github.com/BurntSushi/toml/cmd/tomlv
tomlv some-toml-file.toml
```
[![Build status](https://api.travis-ci.org/BurntSushi/toml.png)](https://travis-ci.org/BurntSushi/toml)
### Testing
This package passes all tests in
[toml-test](https://github.com/BurntSushi/toml-test) for both the decoder
and the encoder.
### Examples
This package works similarly to how the Go standard library handles `XML`
and `JSON`. Namely, data is loaded into Go values via reflection.
For the simplest example, consider some TOML file as just a list of keys
and values:
```toml
Age = 25
Cats = [ "Cauchy", "Plato" ]
Pi = 3.14
Perfection = [ 6, 28, 496, 8128 ]
DOB = 1987-07-05T05:45:00Z
```
Which could be defined in Go as:
```go
type Config struct {
Age int
Cats []string
Pi float64
Perfection []int
DOB time.Time // requires `import time`
}
```
And then decoded with:
```go
var conf Config
if _, err := toml.Decode(tomlData, &conf); err != nil {
// handle error
}
```
You can also use struct tags if your struct field name doesn't map to a TOML
key value directly:
```toml
some_key_NAME = "wat"
```
```go
type TOML struct {
ObscureKey string `toml:"some_key_NAME"`
}
```
### Using the `encoding.TextUnmarshaler` interface
Here's an example that automatically parses duration strings into
`time.Duration` values:
```toml
[[song]]
name = "Thunder Road"
duration = "4m49s"
[[song]]
name = "Stairway to Heaven"
duration = "8m03s"
```
Which can be decoded with:
```go
type song struct {
Name string
Duration duration
}
type songs struct {
Song []song
}
var favorites songs
if _, err := toml.Decode(blob, &favorites); err != nil {
log.Fatal(err)
}
for _, s := range favorites.Song {
fmt.Printf("%s (%s)\n", s.Name, s.Duration)
}
```
And you'll also need a `duration` type that satisfies the
`encoding.TextUnmarshaler` interface:
```go
type duration struct {
time.Duration
}
func (d *duration) UnmarshalText(text []byte) error {
var err error
d.Duration, err = time.ParseDuration(string(text))
return err
}
```
### More complex usage
Here's an example of how to load the example from the official spec page:
```toml
# This is a TOML document. Boom.
title = "TOML Example"
[owner]
name = "Tom Preston-Werner"
organization = "GitHub"
bio = "GitHub Cofounder & CEO\nLikes tater tots and beer."
dob = 1979-05-27T07:32:00Z # First class dates? Why not?
[database]
server = "192.168.1.1"
ports = [ 8001, 8001, 8002 ]
connection_max = 5000
enabled = true
[servers]
# You can indent as you please. Tabs or spaces. TOML don't care.
[servers.alpha]
ip = "10.0.0.1"
dc = "eqdc10"
[servers.beta]
ip = "10.0.0.2"
dc = "eqdc10"
[clients]
data = [ ["gamma", "delta"], [1, 2] ] # just an update to make sure parsers support it
# Line breaks are OK when inside arrays
hosts = [
"alpha",
"omega"
]
```
And the corresponding Go types are:
```go
type tomlConfig struct {
Title string
Owner ownerInfo
DB database `toml:"database"`
Servers map[string]server
Clients clients
}
type ownerInfo struct {
Name string
Org string `toml:"organization"`
Bio string
DOB time.Time
}
type database struct {
Server string
Ports []int
ConnMax int `toml:"connection_max"`
Enabled bool
}
type server struct {
IP string
DC string
}
type clients struct {
Data [][]interface{}
Hosts []string
}
```
Note that a case insensitive match will be tried if an exact match can't be
found.
A working example of the above can be found in `_examples/example.{go,toml}`.

View file

@ -0,0 +1,492 @@
package toml
import (
"fmt"
"io"
"io/ioutil"
"math"
"reflect"
"strings"
"time"
)
var e = fmt.Errorf
// Unmarshaler is the interface implemented by objects that can unmarshal a
// TOML description of themselves.
type Unmarshaler interface {
UnmarshalTOML(interface{}) error
}
// Unmarshal decodes the contents of `p` in TOML format into a pointer `v`.
func Unmarshal(p []byte, v interface{}) error {
_, err := Decode(string(p), v)
return err
}
// Primitive is a TOML value that hasn't been decoded into a Go value.
// When using the various `Decode*` functions, the type `Primitive` may
// be given to any value, and its decoding will be delayed.
//
// A `Primitive` value can be decoded using the `PrimitiveDecode` function.
//
// The underlying representation of a `Primitive` value is subject to change.
// Do not rely on it.
//
// N.B. Primitive values are still parsed, so using them will only avoid
// the overhead of reflection. They can be useful when you don't know the
// exact type of TOML data until run time.
type Primitive struct {
undecoded interface{}
context Key
}
// DEPRECATED!
//
// Use MetaData.PrimitiveDecode instead.
func PrimitiveDecode(primValue Primitive, v interface{}) error {
md := MetaData{decoded: make(map[string]bool)}
return md.unify(primValue.undecoded, rvalue(v))
}
// PrimitiveDecode is just like the other `Decode*` functions, except it
// decodes a TOML value that has already been parsed. Valid primitive values
// can *only* be obtained from values filled by the decoder functions,
// including this method. (i.e., `v` may contain more `Primitive`
// values.)
//
// Meta data for primitive values is included in the meta data returned by
// the `Decode*` functions with one exception: keys returned by the Undecoded
// method will only reflect keys that were decoded. Namely, any keys hidden
// behind a Primitive will be considered undecoded. Executing this method will
// update the undecoded keys in the meta data. (See the example.)
func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
md.context = primValue.context
defer func() { md.context = nil }()
return md.unify(primValue.undecoded, rvalue(v))
}
// Decode will decode the contents of `data` in TOML format into a pointer
// `v`.
//
// TOML hashes correspond to Go structs or maps. (Dealer's choice. They can be
// used interchangeably.)
//
// TOML arrays of tables correspond to either a slice of structs or a slice
// of maps.
//
// TOML datetimes correspond to Go `time.Time` values.
//
// All other TOML types (float, string, int, bool and array) correspond
// to the obvious Go types.
//
// An exception to the above rules is if a type implements the
// encoding.TextUnmarshaler interface. In this case, any primitive TOML value
// (floats, strings, integers, booleans and datetimes) will be converted to
// a byte string and given to the value's UnmarshalText method. See the
// Unmarshaler example for a demonstration with time duration strings.
//
// Key mapping
//
// TOML keys can map to either keys in a Go map or field names in a Go
// struct. The special `toml` struct tag may be used to map TOML keys to
// struct fields that don't match the key name exactly. (See the example.)
// A case insensitive match to struct names will be tried if an exact match
// can't be found.
//
// The mapping between TOML values and Go values is loose. That is, there
// may exist TOML values that cannot be placed into your representation, and
// there may be parts of your representation that do not correspond to
// TOML values. This loose mapping can be made stricter by using the IsDefined
// and/or Undecoded methods on the MetaData returned.
//
// This decoder will not handle cyclic types. If a cyclic type is passed,
// `Decode` will not terminate.
func Decode(data string, v interface{}) (MetaData, error) {
p, err := parse(data)
if err != nil {
return MetaData{}, err
}
md := MetaData{
p.mapping, p.types, p.ordered,
make(map[string]bool, len(p.ordered)), nil,
}
return md, md.unify(p.mapping, rvalue(v))
}
// DecodeFile is just like Decode, except it will automatically read the
// contents of the file at `fpath` and decode it for you.
func DecodeFile(fpath string, v interface{}) (MetaData, error) {
bs, err := ioutil.ReadFile(fpath)
if err != nil {
return MetaData{}, err
}
return Decode(string(bs), v)
}
// DecodeReader is just like Decode, except it will consume all bytes
// from the reader and decode it for you.
func DecodeReader(r io.Reader, v interface{}) (MetaData, error) {
bs, err := ioutil.ReadAll(r)
if err != nil {
return MetaData{}, err
}
return Decode(string(bs), v)
}
// unify performs a sort of type unification based on the structure of `rv`,
// which is the client representation.
//
// Any type mismatch produces an error. Finding a type that we don't know
// how to handle produces an unsupported type error.
func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
// Special case. Look for a `Primitive` value.
if rv.Type() == reflect.TypeOf((*Primitive)(nil)).Elem() {
// Save the undecoded data and the key context into the primitive
// value.
context := make(Key, len(md.context))
copy(context, md.context)
rv.Set(reflect.ValueOf(Primitive{
undecoded: data,
context: context,
}))
return nil
}
// Special case. Unmarshaler Interface support.
if rv.CanAddr() {
if v, ok := rv.Addr().Interface().(Unmarshaler); ok {
return v.UnmarshalTOML(data)
}
}
// Special case. Handle time.Time values specifically.
// TODO: Remove this code when we decide to drop support for Go 1.1.
// This isn't necessary in Go 1.2 because time.Time satisfies the encoding
// interfaces.
if rv.Type().AssignableTo(rvalue(time.Time{}).Type()) {
return md.unifyDatetime(data, rv)
}
// Special case. Look for a value satisfying the TextUnmarshaler interface.
if v, ok := rv.Interface().(TextUnmarshaler); ok {
return md.unifyText(data, v)
}
// BUG(burntsushi)
// The behavior here is incorrect whenever a Go type satisfies the
// encoding.TextUnmarshaler interface but also corresponds to a TOML
// hash or array. In particular, the unmarshaler should only be applied
// to primitive TOML values. But at this point, it will be applied to
// all kinds of values and produce an incorrect error whenever those values
// are hashes or arrays (including arrays of tables).
k := rv.Kind()
// laziness
if k >= reflect.Int && k <= reflect.Uint64 {
return md.unifyInt(data, rv)
}
switch k {
case reflect.Ptr:
elem := reflect.New(rv.Type().Elem())
err := md.unify(data, reflect.Indirect(elem))
if err != nil {
return err
}
rv.Set(elem)
return nil
case reflect.Struct:
return md.unifyStruct(data, rv)
case reflect.Map:
return md.unifyMap(data, rv)
case reflect.Array:
return md.unifyArray(data, rv)
case reflect.Slice:
return md.unifySlice(data, rv)
case reflect.String:
return md.unifyString(data, rv)
case reflect.Bool:
return md.unifyBool(data, rv)
case reflect.Interface:
// we only support empty interfaces.
if rv.NumMethod() > 0 {
return e("Unsupported type '%s'.", rv.Kind())
}
return md.unifyAnything(data, rv)
case reflect.Float32:
fallthrough
case reflect.Float64:
return md.unifyFloat64(data, rv)
}
return e("Unsupported type '%s'.", rv.Kind())
}
func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
tmap, ok := mapping.(map[string]interface{})
if !ok {
return mismatch(rv, "map", mapping)
}
for key, datum := range tmap {
var f *field
fields := cachedTypeFields(rv.Type())
for i := range fields {
ff := &fields[i]
if ff.name == key {
f = ff
break
}
if f == nil && strings.EqualFold(ff.name, key) {
f = ff
}
}
if f != nil {
subv := rv
for _, i := range f.index {
subv = indirect(subv.Field(i))
}
if isUnifiable(subv) {
md.decoded[md.context.add(key).String()] = true
md.context = append(md.context, key)
if err := md.unify(datum, subv); err != nil {
return e("Type mismatch for '%s.%s': %s",
rv.Type().String(), f.name, err)
}
md.context = md.context[0 : len(md.context)-1]
} else if f.name != "" {
// Bad user! No soup for you!
return e("Field '%s.%s' is unexported, and therefore cannot "+
"be loaded with reflection.", rv.Type().String(), f.name)
}
}
}
return nil
}
func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
tmap, ok := mapping.(map[string]interface{})
if !ok {
return badtype("map", mapping)
}
if rv.IsNil() {
rv.Set(reflect.MakeMap(rv.Type()))
}
for k, v := range tmap {
md.decoded[md.context.add(k).String()] = true
md.context = append(md.context, k)
rvkey := indirect(reflect.New(rv.Type().Key()))
rvval := reflect.Indirect(reflect.New(rv.Type().Elem()))
if err := md.unify(v, rvval); err != nil {
return err
}
md.context = md.context[0 : len(md.context)-1]
rvkey.SetString(k)
rv.SetMapIndex(rvkey, rvval)
}
return nil
}
func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error {
datav := reflect.ValueOf(data)
if datav.Kind() != reflect.Slice {
return badtype("slice", data)
}
sliceLen := datav.Len()
if sliceLen != rv.Len() {
return e("expected array length %d; got TOML array of length %d",
rv.Len(), sliceLen)
}
return md.unifySliceArray(datav, rv)
}
func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error {
datav := reflect.ValueOf(data)
if datav.Kind() != reflect.Slice {
return badtype("slice", data)
}
sliceLen := datav.Len()
if rv.IsNil() {
rv.Set(reflect.MakeSlice(rv.Type(), sliceLen, sliceLen))
}
return md.unifySliceArray(datav, rv)
}
func (md *MetaData) unifySliceArray(data, rv reflect.Value) error {
sliceLen := data.Len()
for i := 0; i < sliceLen; i++ {
v := data.Index(i).Interface()
sliceval := indirect(rv.Index(i))
if err := md.unify(v, sliceval); err != nil {
return err
}
}
return nil
}
func (md *MetaData) unifyDatetime(data interface{}, rv reflect.Value) error {
if _, ok := data.(time.Time); ok {
rv.Set(reflect.ValueOf(data))
return nil
}
return badtype("time.Time", data)
}
func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
if s, ok := data.(string); ok {
rv.SetString(s)
return nil
}
return badtype("string", data)
}
func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
if num, ok := data.(float64); ok {
switch rv.Kind() {
case reflect.Float32:
fallthrough
case reflect.Float64:
rv.SetFloat(num)
default:
panic("bug")
}
return nil
}
return badtype("float", data)
}
func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
if num, ok := data.(int64); ok {
if rv.Kind() >= reflect.Int && rv.Kind() <= reflect.Int64 {
switch rv.Kind() {
case reflect.Int, reflect.Int64:
// No bounds checking necessary.
case reflect.Int8:
if num < math.MinInt8 || num > math.MaxInt8 {
return e("Value '%d' is out of range for int8.", num)
}
case reflect.Int16:
if num < math.MinInt16 || num > math.MaxInt16 {
return e("Value '%d' is out of range for int16.", num)
}
case reflect.Int32:
if num < math.MinInt32 || num > math.MaxInt32 {
return e("Value '%d' is out of range for int32.", num)
}
}
rv.SetInt(num)
} else if rv.Kind() >= reflect.Uint && rv.Kind() <= reflect.Uint64 {
unum := uint64(num)
switch rv.Kind() {
case reflect.Uint, reflect.Uint64:
// No bounds checking necessary.
case reflect.Uint8:
if num < 0 || unum > math.MaxUint8 {
return e("Value '%d' is out of range for uint8.", num)
}
case reflect.Uint16:
if num < 0 || unum > math.MaxUint16 {
return e("Value '%d' is out of range for uint16.", num)
}
case reflect.Uint32:
if num < 0 || unum > math.MaxUint32 {
return e("Value '%d' is out of range for uint32.", num)
}
}
rv.SetUint(unum)
} else {
panic("unreachable")
}
return nil
}
return badtype("integer", data)
}
func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error {
if b, ok := data.(bool); ok {
rv.SetBool(b)
return nil
}
return badtype("boolean", data)
}
func (md *MetaData) unifyAnything(data interface{}, rv reflect.Value) error {
rv.Set(reflect.ValueOf(data))
return nil
}
func (md *MetaData) unifyText(data interface{}, v TextUnmarshaler) error {
var s string
switch sdata := data.(type) {
case TextMarshaler:
text, err := sdata.MarshalText()
if err != nil {
return err
}
s = string(text)
case fmt.Stringer:
s = sdata.String()
case string:
s = sdata
case bool:
s = fmt.Sprintf("%v", sdata)
case int64:
s = fmt.Sprintf("%d", sdata)
case float64:
s = fmt.Sprintf("%f", sdata)
default:
return badtype("primitive (string-like)", data)
}
if err := v.UnmarshalText([]byte(s)); err != nil {
return err
}
return nil
}
// rvalue returns a reflect.Value of `v`. All pointers are resolved.
func rvalue(v interface{}) reflect.Value {
return indirect(reflect.ValueOf(v))
}
// indirect returns the value pointed to by a pointer.
// Pointers are followed until the value is not a pointer.
// New values are allocated for each nil pointer.
//
// An exception to this rule is if the value satisfies an interface of
// interest to us (like encoding.TextUnmarshaler).
func indirect(v reflect.Value) reflect.Value {
if v.Kind() != reflect.Ptr {
if v.CanAddr() {
pv := v.Addr()
if _, ok := pv.Interface().(TextUnmarshaler); ok {
return pv
}
}
return v
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
return indirect(reflect.Indirect(v))
}
func isUnifiable(rv reflect.Value) bool {
if rv.CanSet() {
return true
}
if _, ok := rv.Interface().(TextUnmarshaler); ok {
return true
}
return false
}
func badtype(expected string, data interface{}) error {
return e("Expected %s but found '%T'.", expected, data)
}
func mismatch(user reflect.Value, expected string, data interface{}) error {
return e("Type mismatch for %s. Expected %s but found '%T'.",
user.Type().String(), expected, data)
}

View file

@ -0,0 +1,122 @@
package toml
import "strings"
// MetaData allows access to meta information about TOML data that may not
// be inferrable via reflection. In particular, whether a key has been defined
// and the TOML type of a key.
type MetaData struct {
mapping map[string]interface{}
types map[string]tomlType
keys []Key
decoded map[string]bool
context Key // Used only during decoding.
}
// IsDefined returns true if the key given exists in the TOML data. The key
// should be specified hierarchially. e.g.,
//
// // access the TOML key 'a.b.c'
// IsDefined("a", "b", "c")
//
// IsDefined will return false if an empty key given. Keys are case sensitive.
func (md *MetaData) IsDefined(key ...string) bool {
if len(key) == 0 {
return false
}
var hash map[string]interface{}
var ok bool
var hashOrVal interface{} = md.mapping
for _, k := range key {
if hash, ok = hashOrVal.(map[string]interface{}); !ok {
return false
}
if hashOrVal, ok = hash[k]; !ok {
return false
}
}
return true
}
// Type returns a string representation of the type of the key specified.
//
// Type will return the empty string if given an empty key or a key that
// does not exist. Keys are case sensitive.
func (md *MetaData) Type(key ...string) string {
fullkey := strings.Join(key, ".")
if typ, ok := md.types[fullkey]; ok {
return typ.typeString()
}
return ""
}
// Key is the type of any TOML key, including key groups. Use (MetaData).Keys
// to get values of this type.
type Key []string
func (k Key) String() string {
return strings.Join(k, ".")
}
func (k Key) maybeQuotedAll() string {
var ss []string
for i := range k {
ss = append(ss, k.maybeQuoted(i))
}
return strings.Join(ss, ".")
}
func (k Key) maybeQuoted(i int) string {
quote := false
for _, c := range k[i] {
if !isBareKeyChar(c) {
quote = true
break
}
}
if quote {
return "\"" + strings.Replace(k[i], "\"", "\\\"", -1) + "\""
} else {
return k[i]
}
}
func (k Key) add(piece string) Key {
newKey := make(Key, len(k)+1)
copy(newKey, k)
newKey[len(k)] = piece
return newKey
}
// Keys returns a slice of every key in the TOML data, including key groups.
// Each key is itself a slice, where the first element is the top of the
// hierarchy and the last is the most specific.
//
// The list will have the same order as the keys appeared in the TOML data.
//
// All keys returned are non-empty.
func (md *MetaData) Keys() []Key {
return md.keys
}
// Undecoded returns all keys that have not been decoded in the order in which
// they appear in the original TOML document.
//
// This includes keys that haven't been decoded because of a Primitive value.
// Once the Primitive value is decoded, the keys will be considered decoded.
//
// Also note that decoding into an empty interface will result in no decoding,
// and so no keys will be considered decoded.
//
// In this sense, the Undecoded keys correspond to keys in the TOML document
// that do not have a concrete type in your representation.
func (md *MetaData) Undecoded() []Key {
undecoded := make([]Key, 0, len(md.keys))
for _, key := range md.keys {
if !md.decoded[key.String()] {
undecoded = append(undecoded, key)
}
}
return undecoded
}

View file

@ -0,0 +1,27 @@
/*
Package toml provides facilities for decoding and encoding TOML configuration
files via reflection. There is also support for delaying decoding with
the Primitive type, and querying the set of keys in a TOML document with the
MetaData type.
The specification implemented: https://github.com/mojombo/toml
The sub-command github.com/BurntSushi/toml/cmd/tomlv can be used to verify
whether a file is a valid TOML document. It can also be used to print the
type of each key in a TOML document.
Testing
There are two important types of tests used for this package. The first is
contained inside '*_test.go' files and uses the standard Go unit testing
framework. These tests are primarily devoted to holistically testing the
decoder and encoder.
The second type of testing is used to verify the implementation's adherence
to the TOML specification. These tests have been factored into their own
project: https://github.com/BurntSushi/toml-test
The reason the tests are in a separate project is so that they can be used by
any implementation of TOML. Namely, it is language agnostic.
*/
package toml

View file

@ -0,0 +1,496 @@
package toml
import (
"bufio"
"errors"
"fmt"
"io"
"reflect"
"sort"
"strconv"
"strings"
"time"
)
type tomlEncodeError struct{ error }
var (
errArrayMixedElementTypes = errors.New(
"can't encode array with mixed element types")
errArrayNilElement = errors.New(
"can't encode array with nil element")
errNonString = errors.New(
"can't encode a map with non-string key type")
errAnonNonStruct = errors.New(
"can't encode an anonymous field that is not a struct")
errArrayNoTable = errors.New(
"TOML array element can't contain a table")
errNoKey = errors.New(
"top-level values must be a Go map or struct")
errAnything = errors.New("") // used in testing
)
var quotedReplacer = strings.NewReplacer(
"\t", "\\t",
"\n", "\\n",
"\r", "\\r",
"\"", "\\\"",
"\\", "\\\\",
)
// Encoder controls the encoding of Go values to a TOML document to some
// io.Writer.
//
// The indentation level can be controlled with the Indent field.
type Encoder struct {
// A single indentation level. By default it is two spaces.
Indent string
// hasWritten is whether we have written any output to w yet.
hasWritten bool
w *bufio.Writer
}
// NewEncoder returns a TOML encoder that encodes Go values to the io.Writer
// given. By default, a single indentation level is 2 spaces.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{
w: bufio.NewWriter(w),
Indent: " ",
}
}
// Encode writes a TOML representation of the Go value to the underlying
// io.Writer. If the value given cannot be encoded to a valid TOML document,
// then an error is returned.
//
// The mapping between Go values and TOML values should be precisely the same
// as for the Decode* functions. Similarly, the TextMarshaler interface is
// supported by encoding the resulting bytes as strings. (If you want to write
// arbitrary binary data then you will need to use something like base64 since
// TOML does not have any binary types.)
//
// When encoding TOML hashes (i.e., Go maps or structs), keys without any
// sub-hashes are encoded first.
//
// If a Go map is encoded, then its keys are sorted alphabetically for
// deterministic output. More control over this behavior may be provided if
// there is demand for it.
//
// Encoding Go values without a corresponding TOML representation---like map
// types with non-string keys---will cause an error to be returned. Similarly
// for mixed arrays/slices, arrays/slices with nil elements, embedded
// non-struct types and nested slices containing maps or structs.
// (e.g., [][]map[string]string is not allowed but []map[string]string is OK
// and so is []map[string][]string.)
func (enc *Encoder) Encode(v interface{}) error {
rv := eindirect(reflect.ValueOf(v))
if err := enc.safeEncode(Key([]string{}), rv); err != nil {
return err
}
return enc.w.Flush()
}
func (enc *Encoder) safeEncode(key Key, rv reflect.Value) (err error) {
defer func() {
if r := recover(); r != nil {
if terr, ok := r.(tomlEncodeError); ok {
err = terr.error
return
}
panic(r)
}
}()
enc.encode(key, rv)
return nil
}
func (enc *Encoder) encode(key Key, rv reflect.Value) {
// Special case. Time needs to be in ISO8601 format.
// Special case. If we can marshal the type to text, then we used that.
// Basically, this prevents the encoder for handling these types as
// generic structs (or whatever the underlying type of a TextMarshaler is).
switch rv.Interface().(type) {
case time.Time, TextMarshaler:
enc.keyEqElement(key, rv)
return
}
k := rv.Kind()
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64,
reflect.Float32, reflect.Float64, reflect.String, reflect.Bool:
enc.keyEqElement(key, rv)
case reflect.Array, reflect.Slice:
if typeEqual(tomlArrayHash, tomlTypeOfGo(rv)) {
enc.eArrayOfTables(key, rv)
} else {
enc.keyEqElement(key, rv)
}
case reflect.Interface:
if rv.IsNil() {
return
}
enc.encode(key, rv.Elem())
case reflect.Map:
if rv.IsNil() {
return
}
enc.eTable(key, rv)
case reflect.Ptr:
if rv.IsNil() {
return
}
enc.encode(key, rv.Elem())
case reflect.Struct:
enc.eTable(key, rv)
default:
panic(e("Unsupported type for key '%s': %s", key, k))
}
}
// eElement encodes any value that can be an array element (primitives and
// arrays).
func (enc *Encoder) eElement(rv reflect.Value) {
switch v := rv.Interface().(type) {
case time.Time:
// Special case time.Time as a primitive. Has to come before
// TextMarshaler below because time.Time implements
// encoding.TextMarshaler, but we need to always use UTC.
enc.wf(v.In(time.FixedZone("UTC", 0)).Format("2006-01-02T15:04:05Z"))
return
case TextMarshaler:
// Special case. Use text marshaler if it's available for this value.
if s, err := v.MarshalText(); err != nil {
encPanic(err)
} else {
enc.writeQuoted(string(s))
}
return
}
switch rv.Kind() {
case reflect.Bool:
enc.wf(strconv.FormatBool(rv.Bool()))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64:
enc.wf(strconv.FormatInt(rv.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16,
reflect.Uint32, reflect.Uint64:
enc.wf(strconv.FormatUint(rv.Uint(), 10))
case reflect.Float32:
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 32)))
case reflect.Float64:
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 64)))
case reflect.Array, reflect.Slice:
enc.eArrayOrSliceElement(rv)
case reflect.Interface:
enc.eElement(rv.Elem())
case reflect.String:
enc.writeQuoted(rv.String())
default:
panic(e("Unexpected primitive type: %s", rv.Kind()))
}
}
// By the TOML spec, all floats must have a decimal with at least one
// number on either side.
func floatAddDecimal(fstr string) string {
if !strings.Contains(fstr, ".") {
return fstr + ".0"
}
return fstr
}
func (enc *Encoder) writeQuoted(s string) {
enc.wf("\"%s\"", quotedReplacer.Replace(s))
}
func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) {
length := rv.Len()
enc.wf("[")
for i := 0; i < length; i++ {
elem := rv.Index(i)
enc.eElement(elem)
if i != length-1 {
enc.wf(", ")
}
}
enc.wf("]")
}
func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) {
if len(key) == 0 {
encPanic(errNoKey)
}
for i := 0; i < rv.Len(); i++ {
trv := rv.Index(i)
if isNil(trv) {
continue
}
panicIfInvalidKey(key)
enc.newline()
enc.wf("%s[[%s]]", enc.indentStr(key), key.maybeQuotedAll())
enc.newline()
enc.eMapOrStruct(key, trv)
}
}
func (enc *Encoder) eTable(key Key, rv reflect.Value) {
panicIfInvalidKey(key)
if len(key) == 1 {
// Output an extra new line between top-level tables.
// (The newline isn't written if nothing else has been written though.)
enc.newline()
}
if len(key) > 0 {
enc.wf("%s[%s]", enc.indentStr(key), key.maybeQuotedAll())
enc.newline()
}
enc.eMapOrStruct(key, rv)
}
func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value) {
switch rv := eindirect(rv); rv.Kind() {
case reflect.Map:
enc.eMap(key, rv)
case reflect.Struct:
enc.eStruct(key, rv)
default:
panic("eTable: unhandled reflect.Value Kind: " + rv.Kind().String())
}
}
func (enc *Encoder) eMap(key Key, rv reflect.Value) {
rt := rv.Type()
if rt.Key().Kind() != reflect.String {
encPanic(errNonString)
}
// Sort keys so that we have deterministic output. And write keys directly
// underneath this key first, before writing sub-structs or sub-maps.
var mapKeysDirect, mapKeysSub []string
for _, mapKey := range rv.MapKeys() {
k := mapKey.String()
if typeIsHash(tomlTypeOfGo(rv.MapIndex(mapKey))) {
mapKeysSub = append(mapKeysSub, k)
} else {
mapKeysDirect = append(mapKeysDirect, k)
}
}
var writeMapKeys = func(mapKeys []string) {
sort.Strings(mapKeys)
for _, mapKey := range mapKeys {
mrv := rv.MapIndex(reflect.ValueOf(mapKey))
if isNil(mrv) {
// Don't write anything for nil fields.
continue
}
enc.encode(key.add(mapKey), mrv)
}
}
writeMapKeys(mapKeysDirect)
writeMapKeys(mapKeysSub)
}
func (enc *Encoder) eStruct(key Key, rv reflect.Value) {
// Write keys for fields directly under this key first, because if we write
// a field that creates a new table, then all keys under it will be in that
// table (not the one we're writing here).
rt := rv.Type()
var fieldsDirect, fieldsSub [][]int
var addFields func(rt reflect.Type, rv reflect.Value, start []int)
addFields = func(rt reflect.Type, rv reflect.Value, start []int) {
for i := 0; i < rt.NumField(); i++ {
f := rt.Field(i)
// skip unexporded fields
if f.PkgPath != "" {
continue
}
frv := rv.Field(i)
if f.Anonymous {
frv := eindirect(frv)
t := frv.Type()
if t.Kind() != reflect.Struct {
encPanic(errAnonNonStruct)
}
addFields(t, frv, f.Index)
} else if typeIsHash(tomlTypeOfGo(frv)) {
fieldsSub = append(fieldsSub, append(start, f.Index...))
} else {
fieldsDirect = append(fieldsDirect, append(start, f.Index...))
}
}
}
addFields(rt, rv, nil)
var writeFields = func(fields [][]int) {
for _, fieldIndex := range fields {
sft := rt.FieldByIndex(fieldIndex)
sf := rv.FieldByIndex(fieldIndex)
if isNil(sf) {
// Don't write anything for nil fields.
continue
}
keyName := sft.Tag.Get("toml")
if keyName == "-" {
continue
}
if keyName == "" {
keyName = sft.Name
}
enc.encode(key.add(keyName), sf)
}
}
writeFields(fieldsDirect)
writeFields(fieldsSub)
}
// tomlTypeName returns the TOML type name of the Go value's type. It is
// used to determine whether the types of array elements are mixed (which is
// forbidden). If the Go value is nil, then it is illegal for it to be an array
// element, and valueIsNil is returned as true.
// Returns the TOML type of a Go value. The type may be `nil`, which means
// no concrete TOML type could be found.
func tomlTypeOfGo(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() {
return nil
}
switch rv.Kind() {
case reflect.Bool:
return tomlBool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64:
return tomlInteger
case reflect.Float32, reflect.Float64:
return tomlFloat
case reflect.Array, reflect.Slice:
if typeEqual(tomlHash, tomlArrayType(rv)) {
return tomlArrayHash
} else {
return tomlArray
}
case reflect.Ptr, reflect.Interface:
return tomlTypeOfGo(rv.Elem())
case reflect.String:
return tomlString
case reflect.Map:
return tomlHash
case reflect.Struct:
switch rv.Interface().(type) {
case time.Time:
return tomlDatetime
case TextMarshaler:
return tomlString
default:
return tomlHash
}
default:
panic("unexpected reflect.Kind: " + rv.Kind().String())
}
}
// tomlArrayType returns the element type of a TOML array. The type returned
// may be nil if it cannot be determined (e.g., a nil slice or a zero length
// slize). This function may also panic if it finds a type that cannot be
// expressed in TOML (such as nil elements, heterogeneous arrays or directly
// nested arrays of tables).
func tomlArrayType(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() || rv.Len() == 0 {
return nil
}
firstType := tomlTypeOfGo(rv.Index(0))
if firstType == nil {
encPanic(errArrayNilElement)
}
rvlen := rv.Len()
for i := 1; i < rvlen; i++ {
elem := rv.Index(i)
switch elemType := tomlTypeOfGo(elem); {
case elemType == nil:
encPanic(errArrayNilElement)
case !typeEqual(firstType, elemType):
encPanic(errArrayMixedElementTypes)
}
}
// If we have a nested array, then we must make sure that the nested
// array contains ONLY primitives.
// This checks arbitrarily nested arrays.
if typeEqual(firstType, tomlArray) || typeEqual(firstType, tomlArrayHash) {
nest := tomlArrayType(eindirect(rv.Index(0)))
if typeEqual(nest, tomlHash) || typeEqual(nest, tomlArrayHash) {
encPanic(errArrayNoTable)
}
}
return firstType
}
func (enc *Encoder) newline() {
if enc.hasWritten {
enc.wf("\n")
}
}
func (enc *Encoder) keyEqElement(key Key, val reflect.Value) {
if len(key) == 0 {
encPanic(errNoKey)
}
panicIfInvalidKey(key)
enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1))
enc.eElement(val)
enc.newline()
}
func (enc *Encoder) wf(format string, v ...interface{}) {
if _, err := fmt.Fprintf(enc.w, format, v...); err != nil {
encPanic(err)
}
enc.hasWritten = true
}
func (enc *Encoder) indentStr(key Key) string {
return strings.Repeat(enc.Indent, len(key)-1)
}
func encPanic(err error) {
panic(tomlEncodeError{err})
}
func eindirect(v reflect.Value) reflect.Value {
switch v.Kind() {
case reflect.Ptr, reflect.Interface:
return eindirect(v.Elem())
default:
return v
}
}
func isNil(rv reflect.Value) bool {
switch rv.Kind() {
case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
return rv.IsNil()
default:
return false
}
}
func panicIfInvalidKey(key Key) {
for _, k := range key {
if len(k) == 0 {
encPanic(e("Key '%s' is not a valid table name. Key names "+
"cannot be empty.", key.maybeQuotedAll()))
}
}
}
func isValidKeyName(s string) bool {
return len(s) != 0
}

View file

@ -0,0 +1,19 @@
// +build go1.2
package toml
// In order to support Go 1.1, we define our own TextMarshaler and
// TextUnmarshaler types. For Go 1.2+, we just alias them with the
// standard library interfaces.
import (
"encoding"
)
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
// so that Go 1.1 can be supported.
type TextMarshaler encoding.TextMarshaler
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
// here so that Go 1.1 can be supported.
type TextUnmarshaler encoding.TextUnmarshaler

View file

@ -0,0 +1,18 @@
// +build !go1.2
package toml
// These interfaces were introduced in Go 1.2, so we add them manually when
// compiling for Go 1.1.
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
// so that Go 1.1 can be supported.
type TextMarshaler interface {
MarshalText() (text []byte, err error)
}
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
// here so that Go 1.1 can be supported.
type TextUnmarshaler interface {
UnmarshalText(text []byte) error
}

View file

@ -0,0 +1,874 @@
package toml
import (
"fmt"
"strings"
"unicode/utf8"
)
type itemType int
const (
itemError itemType = iota
itemNIL // used in the parser to indicate no type
itemEOF
itemText
itemString
itemRawString
itemMultilineString
itemRawMultilineString
itemBool
itemInteger
itemFloat
itemDatetime
itemArray // the start of an array
itemArrayEnd
itemTableStart
itemTableEnd
itemArrayTableStart
itemArrayTableEnd
itemKeyStart
itemCommentStart
)
const (
eof = 0
tableStart = '['
tableEnd = ']'
arrayTableStart = '['
arrayTableEnd = ']'
tableSep = '.'
keySep = '='
arrayStart = '['
arrayEnd = ']'
arrayValTerm = ','
commentStart = '#'
stringStart = '"'
stringEnd = '"'
rawStringStart = '\''
rawStringEnd = '\''
)
type stateFn func(lx *lexer) stateFn
type lexer struct {
input string
start int
pos int
width int
line int
state stateFn
items chan item
// A stack of state functions used to maintain context.
// The idea is to reuse parts of the state machine in various places.
// For example, values can appear at the top level or within arbitrarily
// nested arrays. The last state on the stack is used after a value has
// been lexed. Similarly for comments.
stack []stateFn
}
type item struct {
typ itemType
val string
line int
}
func (lx *lexer) nextItem() item {
for {
select {
case item := <-lx.items:
return item
default:
lx.state = lx.state(lx)
}
}
}
func lex(input string) *lexer {
lx := &lexer{
input: input + "\n",
state: lexTop,
line: 1,
items: make(chan item, 10),
stack: make([]stateFn, 0, 10),
}
return lx
}
func (lx *lexer) push(state stateFn) {
lx.stack = append(lx.stack, state)
}
func (lx *lexer) pop() stateFn {
if len(lx.stack) == 0 {
return lx.errorf("BUG in lexer: no states to pop.")
}
last := lx.stack[len(lx.stack)-1]
lx.stack = lx.stack[0 : len(lx.stack)-1]
return last
}
func (lx *lexer) current() string {
return lx.input[lx.start:lx.pos]
}
func (lx *lexer) emit(typ itemType) {
lx.items <- item{typ, lx.current(), lx.line}
lx.start = lx.pos
}
func (lx *lexer) emitTrim(typ itemType) {
lx.items <- item{typ, strings.TrimSpace(lx.current()), lx.line}
lx.start = lx.pos
}
func (lx *lexer) next() (r rune) {
if lx.pos >= len(lx.input) {
lx.width = 0
return eof
}
if lx.input[lx.pos] == '\n' {
lx.line++
}
r, lx.width = utf8.DecodeRuneInString(lx.input[lx.pos:])
lx.pos += lx.width
return r
}
// ignore skips over the pending input before this point.
func (lx *lexer) ignore() {
lx.start = lx.pos
}
// backup steps back one rune. Can be called only once per call of next.
func (lx *lexer) backup() {
lx.pos -= lx.width
if lx.pos < len(lx.input) && lx.input[lx.pos] == '\n' {
lx.line--
}
}
// accept consumes the next rune if it's equal to `valid`.
func (lx *lexer) accept(valid rune) bool {
if lx.next() == valid {
return true
}
lx.backup()
return false
}
// peek returns but does not consume the next rune in the input.
func (lx *lexer) peek() rune {
r := lx.next()
lx.backup()
return r
}
// errorf stops all lexing by emitting an error and returning `nil`.
// Note that any value that is a character is escaped if it's a special
// character (new lines, tabs, etc.).
func (lx *lexer) errorf(format string, values ...interface{}) stateFn {
lx.items <- item{
itemError,
fmt.Sprintf(format, values...),
lx.line,
}
return nil
}
// lexTop consumes elements at the top level of TOML data.
func lexTop(lx *lexer) stateFn {
r := lx.next()
if isWhitespace(r) || isNL(r) {
return lexSkip(lx, lexTop)
}
switch r {
case commentStart:
lx.push(lexTop)
return lexCommentStart
case tableStart:
return lexTableStart
case eof:
if lx.pos > lx.start {
return lx.errorf("Unexpected EOF.")
}
lx.emit(itemEOF)
return nil
}
// At this point, the only valid item can be a key, so we back up
// and let the key lexer do the rest.
lx.backup()
lx.push(lexTopEnd)
return lexKeyStart
}
// lexTopEnd is entered whenever a top-level item has been consumed. (A value
// or a table.) It must see only whitespace, and will turn back to lexTop
// upon a new line. If it sees EOF, it will quit the lexer successfully.
func lexTopEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case r == commentStart:
// a comment will read to a new line for us.
lx.push(lexTop)
return lexCommentStart
case isWhitespace(r):
return lexTopEnd
case isNL(r):
lx.ignore()
return lexTop
case r == eof:
lx.ignore()
return lexTop
}
return lx.errorf("Expected a top-level item to end with a new line, "+
"comment or EOF, but got %q instead.", r)
}
// lexTable lexes the beginning of a table. Namely, it makes sure that
// it starts with a character other than '.' and ']'.
// It assumes that '[' has already been consumed.
// It also handles the case that this is an item in an array of tables.
// e.g., '[[name]]'.
func lexTableStart(lx *lexer) stateFn {
if lx.peek() == arrayTableStart {
lx.next()
lx.emit(itemArrayTableStart)
lx.push(lexArrayTableEnd)
} else {
lx.emit(itemTableStart)
lx.push(lexTableEnd)
}
return lexTableNameStart
}
func lexTableEnd(lx *lexer) stateFn {
lx.emit(itemTableEnd)
return lexTopEnd
}
func lexArrayTableEnd(lx *lexer) stateFn {
if r := lx.next(); r != arrayTableEnd {
return lx.errorf("Expected end of table array name delimiter %q, "+
"but got %q instead.", arrayTableEnd, r)
}
lx.emit(itemArrayTableEnd)
return lexTopEnd
}
func lexTableNameStart(lx *lexer) stateFn {
switch r := lx.peek(); {
case r == tableEnd || r == eof:
return lx.errorf("Unexpected end of table name. (Table names cannot " +
"be empty.)")
case r == tableSep:
return lx.errorf("Unexpected table separator. (Table names cannot " +
"be empty.)")
case r == stringStart || r == rawStringStart:
lx.ignore()
lx.push(lexTableNameEnd)
return lexValue // reuse string lexing
case isWhitespace(r):
return lexTableNameStart
default:
return lexBareTableName
}
}
// lexTableName lexes the name of a table. It assumes that at least one
// valid character for the table has already been read.
func lexBareTableName(lx *lexer) stateFn {
switch r := lx.next(); {
case isBareKeyChar(r):
return lexBareTableName
case r == tableSep || r == tableEnd:
lx.backup()
lx.emitTrim(itemText)
return lexTableNameEnd
default:
return lx.errorf("Bare keys cannot contain %q.", r)
}
}
// lexTableNameEnd reads the end of a piece of a table name, optionally
// consuming whitespace.
func lexTableNameEnd(lx *lexer) stateFn {
switch r := lx.next(); {
case isWhitespace(r):
return lexTableNameEnd
case r == tableSep:
lx.ignore()
return lexTableNameStart
case r == tableEnd:
return lx.pop()
default:
return lx.errorf("Expected '.' or ']' to end table name, but got %q "+
"instead.", r)
}
}
// lexKeyStart consumes a key name up until the first non-whitespace character.
// lexKeyStart will ignore whitespace.
func lexKeyStart(lx *lexer) stateFn {
r := lx.peek()
switch {
case r == keySep:
return lx.errorf("Unexpected key separator %q.", keySep)
case isWhitespace(r) || isNL(r):
lx.next()
return lexSkip(lx, lexKeyStart)
case r == stringStart || r == rawStringStart:
lx.ignore()
lx.emit(itemKeyStart)
lx.push(lexKeyEnd)
return lexValue // reuse string lexing
default:
lx.ignore()
lx.emit(itemKeyStart)
return lexBareKey
}
}
// lexBareKey consumes the text of a bare key. Assumes that the first character
// (which is not whitespace) has not yet been consumed.
func lexBareKey(lx *lexer) stateFn {
switch r := lx.next(); {
case isBareKeyChar(r):
return lexBareKey
case isWhitespace(r):
lx.emitTrim(itemText)
return lexKeyEnd
case r == keySep:
lx.backup()
lx.emitTrim(itemText)
return lexKeyEnd
default:
return lx.errorf("Bare keys cannot contain %q.", r)
}
}
// lexKeyEnd consumes the end of a key and trims whitespace (up to the key
// separator).
func lexKeyEnd(lx *lexer) stateFn {
switch r := lx.next(); {
case r == keySep:
return lexSkip(lx, lexValue)
case isWhitespace(r):
return lexSkip(lx, lexKeyEnd)
default:
return lx.errorf("Expected key separator %q, but got %q instead.",
keySep, r)
}
}
// lexValue starts the consumption of a value anywhere a value is expected.
// lexValue will ignore whitespace.
// After a value is lexed, the last state on the next is popped and returned.
func lexValue(lx *lexer) stateFn {
// We allow whitespace to precede a value, but NOT new lines.
// In array syntax, the array states are responsible for ignoring new
// lines.
r := lx.next()
if isWhitespace(r) {
return lexSkip(lx, lexValue)
}
switch {
case r == arrayStart:
lx.ignore()
lx.emit(itemArray)
return lexArrayValue
case r == stringStart:
if lx.accept(stringStart) {
if lx.accept(stringStart) {
lx.ignore() // Ignore """
return lexMultilineString
}
lx.backup()
}
lx.ignore() // ignore the '"'
return lexString
case r == rawStringStart:
if lx.accept(rawStringStart) {
if lx.accept(rawStringStart) {
lx.ignore() // Ignore """
return lexMultilineRawString
}
lx.backup()
}
lx.ignore() // ignore the "'"
return lexRawString
case r == 't':
return lexTrue
case r == 'f':
return lexFalse
case r == '-':
return lexNumberStart
case isDigit(r):
lx.backup() // avoid an extra state and use the same as above
return lexNumberOrDateStart
case r == '.': // special error case, be kind to users
return lx.errorf("Floats must start with a digit, not '.'.")
}
return lx.errorf("Expected value but found %q instead.", r)
}
// lexArrayValue consumes one value in an array. It assumes that '[' or ','
// have already been consumed. All whitespace and new lines are ignored.
func lexArrayValue(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r) || isNL(r):
return lexSkip(lx, lexArrayValue)
case r == commentStart:
lx.push(lexArrayValue)
return lexCommentStart
case r == arrayValTerm:
return lx.errorf("Unexpected array value terminator %q.",
arrayValTerm)
case r == arrayEnd:
return lexArrayEnd
}
lx.backup()
lx.push(lexArrayValueEnd)
return lexValue
}
// lexArrayValueEnd consumes the cruft between values of an array. Namely,
// it ignores whitespace and expects either a ',' or a ']'.
func lexArrayValueEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r) || isNL(r):
return lexSkip(lx, lexArrayValueEnd)
case r == commentStart:
lx.push(lexArrayValueEnd)
return lexCommentStart
case r == arrayValTerm:
lx.ignore()
return lexArrayValue // move on to the next value
case r == arrayEnd:
return lexArrayEnd
}
return lx.errorf("Expected an array value terminator %q or an array "+
"terminator %q, but got %q instead.", arrayValTerm, arrayEnd, r)
}
// lexArrayEnd finishes the lexing of an array. It assumes that a ']' has
// just been consumed.
func lexArrayEnd(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemArrayEnd)
return lx.pop()
}
// lexString consumes the inner contents of a string. It assumes that the
// beginning '"' has already been consumed and ignored.
func lexString(lx *lexer) stateFn {
r := lx.next()
switch {
case isNL(r):
return lx.errorf("Strings cannot contain new lines.")
case r == '\\':
lx.push(lexString)
return lexStringEscape
case r == stringEnd:
lx.backup()
lx.emit(itemString)
lx.next()
lx.ignore()
return lx.pop()
}
return lexString
}
// lexMultilineString consumes the inner contents of a string. It assumes that
// the beginning '"""' has already been consumed and ignored.
func lexMultilineString(lx *lexer) stateFn {
r := lx.next()
switch {
case r == '\\':
return lexMultilineStringEscape
case r == stringEnd:
if lx.accept(stringEnd) {
if lx.accept(stringEnd) {
lx.backup()
lx.backup()
lx.backup()
lx.emit(itemMultilineString)
lx.next()
lx.next()
lx.next()
lx.ignore()
return lx.pop()
}
lx.backup()
}
}
return lexMultilineString
}
// lexRawString consumes a raw string. Nothing can be escaped in such a string.
// It assumes that the beginning "'" has already been consumed and ignored.
func lexRawString(lx *lexer) stateFn {
r := lx.next()
switch {
case isNL(r):
return lx.errorf("Strings cannot contain new lines.")
case r == rawStringEnd:
lx.backup()
lx.emit(itemRawString)
lx.next()
lx.ignore()
return lx.pop()
}
return lexRawString
}
// lexMultilineRawString consumes a raw string. Nothing can be escaped in such
// a string. It assumes that the beginning "'" has already been consumed and
// ignored.
func lexMultilineRawString(lx *lexer) stateFn {
r := lx.next()
switch {
case r == rawStringEnd:
if lx.accept(rawStringEnd) {
if lx.accept(rawStringEnd) {
lx.backup()
lx.backup()
lx.backup()
lx.emit(itemRawMultilineString)
lx.next()
lx.next()
lx.next()
lx.ignore()
return lx.pop()
}
lx.backup()
}
}
return lexMultilineRawString
}
// lexMultilineStringEscape consumes an escaped character. It assumes that the
// preceding '\\' has already been consumed.
func lexMultilineStringEscape(lx *lexer) stateFn {
// Handle the special case first:
if isNL(lx.next()) {
lx.next()
return lexMultilineString
} else {
lx.backup()
lx.push(lexMultilineString)
return lexStringEscape(lx)
}
}
func lexStringEscape(lx *lexer) stateFn {
r := lx.next()
switch r {
case 'b':
fallthrough
case 't':
fallthrough
case 'n':
fallthrough
case 'f':
fallthrough
case 'r':
fallthrough
case '"':
fallthrough
case '\\':
return lx.pop()
case 'u':
return lexShortUnicodeEscape
case 'U':
return lexLongUnicodeEscape
}
return lx.errorf("Invalid escape character %q. Only the following "+
"escape characters are allowed: "+
"\\b, \\t, \\n, \\f, \\r, \\\", \\/, \\\\, "+
"\\uXXXX and \\UXXXXXXXX.", r)
}
func lexShortUnicodeEscape(lx *lexer) stateFn {
var r rune
for i := 0; i < 4; i++ {
r = lx.next()
if !isHexadecimal(r) {
return lx.errorf("Expected four hexadecimal digits after '\\u', "+
"but got '%s' instead.", lx.current())
}
}
return lx.pop()
}
func lexLongUnicodeEscape(lx *lexer) stateFn {
var r rune
for i := 0; i < 8; i++ {
r = lx.next()
if !isHexadecimal(r) {
return lx.errorf("Expected eight hexadecimal digits after '\\U', "+
"but got '%s' instead.", lx.current())
}
}
return lx.pop()
}
// lexNumberOrDateStart consumes either a (positive) integer, float or
// datetime. It assumes that NO negative sign has been consumed.
func lexNumberOrDateStart(lx *lexer) stateFn {
r := lx.next()
if !isDigit(r) {
if r == '.' {
return lx.errorf("Floats must start with a digit, not '.'.")
} else {
return lx.errorf("Expected a digit but got %q.", r)
}
}
return lexNumberOrDate
}
// lexNumberOrDate consumes either a (positive) integer, float or datetime.
func lexNumberOrDate(lx *lexer) stateFn {
r := lx.next()
switch {
case r == '-':
if lx.pos-lx.start != 5 {
return lx.errorf("All ISO8601 dates must be in full Zulu form.")
}
return lexDateAfterYear
case isDigit(r):
return lexNumberOrDate
case r == '.':
return lexFloatStart
}
lx.backup()
lx.emit(itemInteger)
return lx.pop()
}
// lexDateAfterYear consumes a full Zulu Datetime in ISO8601 format.
// It assumes that "YYYY-" has already been consumed.
func lexDateAfterYear(lx *lexer) stateFn {
formats := []rune{
// digits are '0'.
// everything else is direct equality.
'0', '0', '-', '0', '0',
'T',
'0', '0', ':', '0', '0', ':', '0', '0',
'Z',
}
for _, f := range formats {
r := lx.next()
if f == '0' {
if !isDigit(r) {
return lx.errorf("Expected digit in ISO8601 datetime, "+
"but found %q instead.", r)
}
} else if f != r {
return lx.errorf("Expected %q in ISO8601 datetime, "+
"but found %q instead.", f, r)
}
}
lx.emit(itemDatetime)
return lx.pop()
}
// lexNumberStart consumes either an integer or a float. It assumes that
// a negative sign has already been read, but that *no* digits have been
// consumed. lexNumberStart will move to the appropriate integer or float
// states.
func lexNumberStart(lx *lexer) stateFn {
// we MUST see a digit. Even floats have to start with a digit.
r := lx.next()
if !isDigit(r) {
if r == '.' {
return lx.errorf("Floats must start with a digit, not '.'.")
} else {
return lx.errorf("Expected a digit but got %q.", r)
}
}
return lexNumber
}
// lexNumber consumes an integer or a float after seeing the first digit.
func lexNumber(lx *lexer) stateFn {
r := lx.next()
switch {
case isDigit(r):
return lexNumber
case r == '.':
return lexFloatStart
}
lx.backup()
lx.emit(itemInteger)
return lx.pop()
}
// lexFloatStart starts the consumption of digits of a float after a '.'.
// Namely, at least one digit is required.
func lexFloatStart(lx *lexer) stateFn {
r := lx.next()
if !isDigit(r) {
return lx.errorf("Floats must have a digit after the '.', but got "+
"%q instead.", r)
}
return lexFloat
}
// lexFloat consumes the digits of a float after a '.'.
// Assumes that one digit has been consumed after a '.' already.
func lexFloat(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexFloat
}
lx.backup()
lx.emit(itemFloat)
return lx.pop()
}
// lexConst consumes the s[1:] in s. It assumes that s[0] has already been
// consumed.
func lexConst(lx *lexer, s string) stateFn {
for i := range s[1:] {
if r := lx.next(); r != rune(s[i+1]) {
return lx.errorf("Expected %q, but found %q instead.", s[:i+1],
s[:i]+string(r))
}
}
return nil
}
// lexTrue consumes the "rue" in "true". It assumes that 't' has already
// been consumed.
func lexTrue(lx *lexer) stateFn {
if fn := lexConst(lx, "true"); fn != nil {
return fn
}
lx.emit(itemBool)
return lx.pop()
}
// lexFalse consumes the "alse" in "false". It assumes that 'f' has already
// been consumed.
func lexFalse(lx *lexer) stateFn {
if fn := lexConst(lx, "false"); fn != nil {
return fn
}
lx.emit(itemBool)
return lx.pop()
}
// lexCommentStart begins the lexing of a comment. It will emit
// itemCommentStart and consume no characters, passing control to lexComment.
func lexCommentStart(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemCommentStart)
return lexComment
}
// lexComment lexes an entire comment. It assumes that '#' has been consumed.
// It will consume *up to* the first new line character, and pass control
// back to the last state on the stack.
func lexComment(lx *lexer) stateFn {
r := lx.peek()
if isNL(r) || r == eof {
lx.emit(itemText)
return lx.pop()
}
lx.next()
return lexComment
}
// lexSkip ignores all slurped input and moves on to the next state.
func lexSkip(lx *lexer, nextState stateFn) stateFn {
return func(lx *lexer) stateFn {
lx.ignore()
return nextState
}
}
// isWhitespace returns true if `r` is a whitespace character according
// to the spec.
func isWhitespace(r rune) bool {
return r == '\t' || r == ' '
}
func isNL(r rune) bool {
return r == '\n' || r == '\r'
}
func isDigit(r rune) bool {
return r >= '0' && r <= '9'
}
func isHexadecimal(r rune) bool {
return (r >= '0' && r <= '9') ||
(r >= 'a' && r <= 'f') ||
(r >= 'A' && r <= 'F')
}
func isBareKeyChar(r rune) bool {
return (r >= 'A' && r <= 'Z') ||
(r >= 'a' && r <= 'z') ||
(r >= '0' && r <= '9') ||
r == '_' ||
r == '-'
}
func (itype itemType) String() string {
switch itype {
case itemError:
return "Error"
case itemNIL:
return "NIL"
case itemEOF:
return "EOF"
case itemText:
return "Text"
case itemString:
return "String"
case itemRawString:
return "String"
case itemMultilineString:
return "String"
case itemRawMultilineString:
return "String"
case itemBool:
return "Bool"
case itemInteger:
return "Integer"
case itemFloat:
return "Float"
case itemDatetime:
return "DateTime"
case itemTableStart:
return "TableStart"
case itemTableEnd:
return "TableEnd"
case itemKeyStart:
return "KeyStart"
case itemArray:
return "Array"
case itemArrayEnd:
return "ArrayEnd"
case itemCommentStart:
return "CommentStart"
}
panic(fmt.Sprintf("BUG: Unknown type '%d'.", int(itype)))
}
func (item item) String() string {
return fmt.Sprintf("(%s, %s)", item.typ.String(), item.val)
}

View file

@ -0,0 +1,498 @@
package toml
import (
"fmt"
"log"
"strconv"
"strings"
"time"
"unicode"
"unicode/utf8"
)
type parser struct {
mapping map[string]interface{}
types map[string]tomlType
lx *lexer
// A list of keys in the order that they appear in the TOML data.
ordered []Key
// the full key for the current hash in scope
context Key
// the base key name for everything except hashes
currentKey string
// rough approximation of line number
approxLine int
// A map of 'key.group.names' to whether they were created implicitly.
implicits map[string]bool
}
type parseError string
func (pe parseError) Error() string {
return string(pe)
}
func parse(data string) (p *parser, err error) {
defer func() {
if r := recover(); r != nil {
var ok bool
if err, ok = r.(parseError); ok {
return
}
panic(r)
}
}()
p = &parser{
mapping: make(map[string]interface{}),
types: make(map[string]tomlType),
lx: lex(data),
ordered: make([]Key, 0),
implicits: make(map[string]bool),
}
for {
item := p.next()
if item.typ == itemEOF {
break
}
p.topLevel(item)
}
return p, nil
}
func (p *parser) panicf(format string, v ...interface{}) {
msg := fmt.Sprintf("Near line %d (last key parsed '%s'): %s",
p.approxLine, p.current(), fmt.Sprintf(format, v...))
panic(parseError(msg))
}
func (p *parser) next() item {
it := p.lx.nextItem()
if it.typ == itemError {
p.panicf("%s", it.val)
}
return it
}
func (p *parser) bug(format string, v ...interface{}) {
log.Fatalf("BUG: %s\n\n", fmt.Sprintf(format, v...))
}
func (p *parser) expect(typ itemType) item {
it := p.next()
p.assertEqual(typ, it.typ)
return it
}
func (p *parser) assertEqual(expected, got itemType) {
if expected != got {
p.bug("Expected '%s' but got '%s'.", expected, got)
}
}
func (p *parser) topLevel(item item) {
switch item.typ {
case itemCommentStart:
p.approxLine = item.line
p.expect(itemText)
case itemTableStart:
kg := p.next()
p.approxLine = kg.line
var key Key
for ; kg.typ != itemTableEnd && kg.typ != itemEOF; kg = p.next() {
key = append(key, p.keyString(kg))
}
p.assertEqual(itemTableEnd, kg.typ)
p.establishContext(key, false)
p.setType("", tomlHash)
p.ordered = append(p.ordered, key)
case itemArrayTableStart:
kg := p.next()
p.approxLine = kg.line
var key Key
for ; kg.typ != itemArrayTableEnd && kg.typ != itemEOF; kg = p.next() {
key = append(key, p.keyString(kg))
}
p.assertEqual(itemArrayTableEnd, kg.typ)
p.establishContext(key, true)
p.setType("", tomlArrayHash)
p.ordered = append(p.ordered, key)
case itemKeyStart:
kname := p.next()
p.approxLine = kname.line
p.currentKey = p.keyString(kname)
val, typ := p.value(p.next())
p.setValue(p.currentKey, val)
p.setType(p.currentKey, typ)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
p.currentKey = ""
default:
p.bug("Unexpected type at top level: %s", item.typ)
}
}
// Gets a string for a key (or part of a key in a table name).
func (p *parser) keyString(it item) string {
switch it.typ {
case itemText:
return it.val
case itemString, itemMultilineString,
itemRawString, itemRawMultilineString:
s, _ := p.value(it)
return s.(string)
default:
p.bug("Unexpected key type: %s", it.typ)
panic("unreachable")
}
}
// value translates an expected value from the lexer into a Go value wrapped
// as an empty interface.
func (p *parser) value(it item) (interface{}, tomlType) {
switch it.typ {
case itemString:
return p.replaceEscapes(it.val), p.typeOfPrimitive(it)
case itemMultilineString:
trimmed := stripFirstNewline(stripEscapedWhitespace(it.val))
return p.replaceEscapes(trimmed), p.typeOfPrimitive(it)
case itemRawString:
return it.val, p.typeOfPrimitive(it)
case itemRawMultilineString:
return stripFirstNewline(it.val), p.typeOfPrimitive(it)
case itemBool:
switch it.val {
case "true":
return true, p.typeOfPrimitive(it)
case "false":
return false, p.typeOfPrimitive(it)
}
p.bug("Expected boolean value, but got '%s'.", it.val)
case itemInteger:
num, err := strconv.ParseInt(it.val, 10, 64)
if err != nil {
// See comment below for floats describing why we make a
// distinction between a bug and a user error.
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
p.panicf("Integer '%s' is out of the range of 64-bit "+
"signed integers.", it.val)
} else {
p.bug("Expected integer value, but got '%s'.", it.val)
}
}
return num, p.typeOfPrimitive(it)
case itemFloat:
num, err := strconv.ParseFloat(it.val, 64)
if err != nil {
// Distinguish float values. Normally, it'd be a bug if the lexer
// provides an invalid float, but it's possible that the float is
// out of range of valid values (which the lexer cannot determine).
// So mark the former as a bug but the latter as a legitimate user
// error.
//
// This is also true for integers.
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
p.panicf("Float '%s' is out of the range of 64-bit "+
"IEEE-754 floating-point numbers.", it.val)
} else {
p.bug("Expected float value, but got '%s'.", it.val)
}
}
return num, p.typeOfPrimitive(it)
case itemDatetime:
t, err := time.Parse("2006-01-02T15:04:05Z", it.val)
if err != nil {
p.bug("Expected Zulu formatted DateTime, but got '%s'.", it.val)
}
return t, p.typeOfPrimitive(it)
case itemArray:
array := make([]interface{}, 0)
types := make([]tomlType, 0)
for it = p.next(); it.typ != itemArrayEnd; it = p.next() {
if it.typ == itemCommentStart {
p.expect(itemText)
continue
}
val, typ := p.value(it)
array = append(array, val)
types = append(types, typ)
}
return array, p.typeOfArray(types)
}
p.bug("Unexpected value type: %s", it.typ)
panic("unreachable")
}
// establishContext sets the current context of the parser,
// where the context is either a hash or an array of hashes. Which one is
// set depends on the value of the `array` parameter.
//
// Establishing the context also makes sure that the key isn't a duplicate, and
// will create implicit hashes automatically.
func (p *parser) establishContext(key Key, array bool) {
var ok bool
// Always start at the top level and drill down for our context.
hashContext := p.mapping
keyContext := make(Key, 0)
// We only need implicit hashes for key[0:-1]
for _, k := range key[0 : len(key)-1] {
_, ok = hashContext[k]
keyContext = append(keyContext, k)
// No key? Make an implicit hash and move on.
if !ok {
p.addImplicit(keyContext)
hashContext[k] = make(map[string]interface{})
}
// If the hash context is actually an array of tables, then set
// the hash context to the last element in that array.
//
// Otherwise, it better be a table, since this MUST be a key group (by
// virtue of it not being the last element in a key).
switch t := hashContext[k].(type) {
case []map[string]interface{}:
hashContext = t[len(t)-1]
case map[string]interface{}:
hashContext = t
default:
p.panicf("Key '%s' was already created as a hash.", keyContext)
}
}
p.context = keyContext
if array {
// If this is the first element for this array, then allocate a new
// list of tables for it.
k := key[len(key)-1]
if _, ok := hashContext[k]; !ok {
hashContext[k] = make([]map[string]interface{}, 0, 5)
}
// Add a new table. But make sure the key hasn't already been used
// for something else.
if hash, ok := hashContext[k].([]map[string]interface{}); ok {
hashContext[k] = append(hash, make(map[string]interface{}))
} else {
p.panicf("Key '%s' was already created and cannot be used as "+
"an array.", keyContext)
}
} else {
p.setValue(key[len(key)-1], make(map[string]interface{}))
}
p.context = append(p.context, key[len(key)-1])
}
// setValue sets the given key to the given value in the current context.
// It will make sure that the key hasn't already been defined, account for
// implicit key groups.
func (p *parser) setValue(key string, value interface{}) {
var tmpHash interface{}
var ok bool
hash := p.mapping
keyContext := make(Key, 0)
for _, k := range p.context {
keyContext = append(keyContext, k)
if tmpHash, ok = hash[k]; !ok {
p.bug("Context for key '%s' has not been established.", keyContext)
}
switch t := tmpHash.(type) {
case []map[string]interface{}:
// The context is a table of hashes. Pick the most recent table
// defined as the current hash.
hash = t[len(t)-1]
case map[string]interface{}:
hash = t
default:
p.bug("Expected hash to have type 'map[string]interface{}', but "+
"it has '%T' instead.", tmpHash)
}
}
keyContext = append(keyContext, key)
if _, ok := hash[key]; ok {
// Typically, if the given key has already been set, then we have
// to raise an error since duplicate keys are disallowed. However,
// it's possible that a key was previously defined implicitly. In this
// case, it is allowed to be redefined concretely. (See the
// `tests/valid/implicit-and-explicit-after.toml` test in `toml-test`.)
//
// But we have to make sure to stop marking it as an implicit. (So that
// another redefinition provokes an error.)
//
// Note that since it has already been defined (as a hash), we don't
// want to overwrite it. So our business is done.
if p.isImplicit(keyContext) {
p.removeImplicit(keyContext)
return
}
// Otherwise, we have a concrete key trying to override a previous
// key, which is *always* wrong.
p.panicf("Key '%s' has already been defined.", keyContext)
}
hash[key] = value
}
// setType sets the type of a particular value at a given key.
// It should be called immediately AFTER setValue.
//
// Note that if `key` is empty, then the type given will be applied to the
// current context (which is either a table or an array of tables).
func (p *parser) setType(key string, typ tomlType) {
keyContext := make(Key, 0, len(p.context)+1)
for _, k := range p.context {
keyContext = append(keyContext, k)
}
if len(key) > 0 { // allow type setting for hashes
keyContext = append(keyContext, key)
}
p.types[keyContext.String()] = typ
}
// addImplicit sets the given Key as having been created implicitly.
func (p *parser) addImplicit(key Key) {
p.implicits[key.String()] = true
}
// removeImplicit stops tagging the given key as having been implicitly
// created.
func (p *parser) removeImplicit(key Key) {
p.implicits[key.String()] = false
}
// isImplicit returns true if the key group pointed to by the key was created
// implicitly.
func (p *parser) isImplicit(key Key) bool {
return p.implicits[key.String()]
}
// current returns the full key name of the current context.
func (p *parser) current() string {
if len(p.currentKey) == 0 {
return p.context.String()
}
if len(p.context) == 0 {
return p.currentKey
}
return fmt.Sprintf("%s.%s", p.context, p.currentKey)
}
func stripFirstNewline(s string) string {
if len(s) == 0 || s[0] != '\n' {
return s
}
return s[1:len(s)]
}
func stripEscapedWhitespace(s string) string {
esc := strings.Split(s, "\\\n")
if len(esc) > 1 {
for i := 1; i < len(esc); i++ {
esc[i] = strings.TrimLeftFunc(esc[i], unicode.IsSpace)
}
}
return strings.Join(esc, "")
}
func (p *parser) replaceEscapes(str string) string {
var replaced []rune
s := []byte(str)
r := 0
for r < len(s) {
if s[r] != '\\' {
c, size := utf8.DecodeRune(s[r:])
r += size
replaced = append(replaced, c)
continue
}
r += 1
if r >= len(s) {
p.bug("Escape sequence at end of string.")
return ""
}
switch s[r] {
default:
p.bug("Expected valid escape code after \\, but got %q.", s[r])
return ""
case 'b':
replaced = append(replaced, rune(0x0008))
r += 1
case 't':
replaced = append(replaced, rune(0x0009))
r += 1
case 'n':
replaced = append(replaced, rune(0x000A))
r += 1
case 'f':
replaced = append(replaced, rune(0x000C))
r += 1
case 'r':
replaced = append(replaced, rune(0x000D))
r += 1
case '"':
replaced = append(replaced, rune(0x0022))
r += 1
case '\\':
replaced = append(replaced, rune(0x005C))
r += 1
case 'u':
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+5). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(s[r+1 : r+5])
replaced = append(replaced, escaped)
r += 5
case 'U':
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+9). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(s[r+1 : r+9])
replaced = append(replaced, escaped)
r += 9
}
}
return string(replaced)
}
func (p *parser) asciiEscapeToUnicode(bs []byte) rune {
s := string(bs)
hex, err := strconv.ParseUint(strings.ToLower(s), 16, 32)
if err != nil {
p.bug("Could not parse '%s' as a hexadecimal number, but the "+
"lexer claims it's OK: %s", s, err)
}
// BUG(burntsushi)
// I honestly don't understand how this works. I can't seem
// to find a way to make this fail. I figured this would fail on invalid
// UTF-8 characters like U+DCFF, but it doesn't.
if !utf8.ValidString(string(rune(hex))) {
p.panicf("Escaped character '\\u%s' is not valid UTF-8.", s)
}
return rune(hex)
}
func isStringType(ty itemType) bool {
return ty == itemString || ty == itemMultilineString ||
ty == itemRawString || ty == itemRawMultilineString
}

View file

@ -0,0 +1 @@
au BufWritePost *.go silent!make tags > /dev/null 2>&1

View file

@ -0,0 +1,91 @@
package toml
// tomlType represents any Go type that corresponds to a TOML type.
// While the first draft of the TOML spec has a simplistic type system that
// probably doesn't need this level of sophistication, we seem to be militating
// toward adding real composite types.
type tomlType interface {
typeString() string
}
// typeEqual accepts any two types and returns true if they are equal.
func typeEqual(t1, t2 tomlType) bool {
if t1 == nil || t2 == nil {
return false
}
return t1.typeString() == t2.typeString()
}
func typeIsHash(t tomlType) bool {
return typeEqual(t, tomlHash) || typeEqual(t, tomlArrayHash)
}
type tomlBaseType string
func (btype tomlBaseType) typeString() string {
return string(btype)
}
func (btype tomlBaseType) String() string {
return btype.typeString()
}
var (
tomlInteger tomlBaseType = "Integer"
tomlFloat tomlBaseType = "Float"
tomlDatetime tomlBaseType = "Datetime"
tomlString tomlBaseType = "String"
tomlBool tomlBaseType = "Bool"
tomlArray tomlBaseType = "Array"
tomlHash tomlBaseType = "Hash"
tomlArrayHash tomlBaseType = "ArrayHash"
)
// typeOfPrimitive returns a tomlType of any primitive value in TOML.
// Primitive values are: Integer, Float, Datetime, String and Bool.
//
// Passing a lexer item other than the following will cause a BUG message
// to occur: itemString, itemBool, itemInteger, itemFloat, itemDatetime.
func (p *parser) typeOfPrimitive(lexItem item) tomlType {
switch lexItem.typ {
case itemInteger:
return tomlInteger
case itemFloat:
return tomlFloat
case itemDatetime:
return tomlDatetime
case itemString:
return tomlString
case itemMultilineString:
return tomlString
case itemRawString:
return tomlString
case itemRawMultilineString:
return tomlString
case itemBool:
return tomlBool
}
p.bug("Cannot infer primitive type of lex item '%s'.", lexItem)
panic("unreachable")
}
// typeOfArray returns a tomlType for an array given a list of types of its
// values.
//
// In the current spec, if an array is homogeneous, then its type is always
// "Array". If the array is not homogeneous, an error is generated.
func (p *parser) typeOfArray(types []tomlType) tomlType {
// Empty arrays are cool.
if len(types) == 0 {
return tomlArray
}
theType := types[0]
for _, t := range types[1:] {
if !typeEqual(theType, t) {
p.panicf("Array contains values of type '%s' and '%s', but "+
"arrays must be homogeneous.", theType, t)
}
}
return tomlArray
}

View file

@ -0,0 +1,241 @@
package toml
// Struct field handling is adapted from code in encoding/json:
//
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the Go distribution.
import (
"reflect"
"sort"
"sync"
)
// A field represents a single field found in a struct.
type field struct {
name string // the name of the field (`toml` tag included)
tag bool // whether field has a `toml` tag
index []int // represents the depth of an anonymous field
typ reflect.Type // the type of the field
}
// byName sorts field by name, breaking ties with depth,
// then breaking ties with "name came from toml tag", then
// breaking ties with index sequence.
type byName []field
func (x byName) Len() int { return len(x) }
func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x byName) Less(i, j int) bool {
if x[i].name != x[j].name {
return x[i].name < x[j].name
}
if len(x[i].index) != len(x[j].index) {
return len(x[i].index) < len(x[j].index)
}
if x[i].tag != x[j].tag {
return x[i].tag
}
return byIndex(x).Less(i, j)
}
// byIndex sorts field by index sequence.
type byIndex []field
func (x byIndex) Len() int { return len(x) }
func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x byIndex) Less(i, j int) bool {
for k, xik := range x[i].index {
if k >= len(x[j].index) {
return false
}
if xik != x[j].index[k] {
return xik < x[j].index[k]
}
}
return len(x[i].index) < len(x[j].index)
}
// typeFields returns a list of fields that TOML should recognize for the given
// type. The algorithm is breadth-first search over the set of structs to
// include - the top struct and then any reachable anonymous structs.
func typeFields(t reflect.Type) []field {
// Anonymous fields to explore at the current level and the next.
current := []field{}
next := []field{{typ: t}}
// Count of queued names for current level and the next.
count := map[reflect.Type]int{}
nextCount := map[reflect.Type]int{}
// Types already visited at an earlier level.
visited := map[reflect.Type]bool{}
// Fields found.
var fields []field
for len(next) > 0 {
current, next = next, current[:0]
count, nextCount = nextCount, map[reflect.Type]int{}
for _, f := range current {
if visited[f.typ] {
continue
}
visited[f.typ] = true
// Scan f.typ for fields to include.
for i := 0; i < f.typ.NumField(); i++ {
sf := f.typ.Field(i)
if sf.PkgPath != "" { // unexported
continue
}
name := sf.Tag.Get("toml")
if name == "-" {
continue
}
index := make([]int, len(f.index)+1)
copy(index, f.index)
index[len(f.index)] = i
ft := sf.Type
if ft.Name() == "" && ft.Kind() == reflect.Ptr {
// Follow pointer.
ft = ft.Elem()
}
// Record found field and index sequence.
if name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct {
tagged := name != ""
if name == "" {
name = sf.Name
}
fields = append(fields, field{name, tagged, index, ft})
if count[f.typ] > 1 {
// If there were multiple instances, add a second,
// so that the annihilation code will see a duplicate.
// It only cares about the distinction between 1 or 2,
// so don't bother generating any more copies.
fields = append(fields, fields[len(fields)-1])
}
continue
}
// Record new anonymous struct to explore in next round.
nextCount[ft]++
if nextCount[ft] == 1 {
f := field{name: ft.Name(), index: index, typ: ft}
next = append(next, f)
}
}
}
}
sort.Sort(byName(fields))
// Delete all fields that are hidden by the Go rules for embedded fields,
// except that fields with TOML tags are promoted.
// The fields are sorted in primary order of name, secondary order
// of field index length. Loop over names; for each name, delete
// hidden fields by choosing the one dominant field that survives.
out := fields[:0]
for advance, i := 0, 0; i < len(fields); i += advance {
// One iteration per name.
// Find the sequence of fields with the name of this first field.
fi := fields[i]
name := fi.name
for advance = 1; i+advance < len(fields); advance++ {
fj := fields[i+advance]
if fj.name != name {
break
}
}
if advance == 1 { // Only one field with this name
out = append(out, fi)
continue
}
dominant, ok := dominantField(fields[i : i+advance])
if ok {
out = append(out, dominant)
}
}
fields = out
sort.Sort(byIndex(fields))
return fields
}
// dominantField looks through the fields, all of which are known to
// have the same name, to find the single field that dominates the
// others using Go's embedding rules, modified by the presence of
// TOML tags. If there are multiple top-level fields, the boolean
// will be false: This condition is an error in Go and we skip all
// the fields.
func dominantField(fields []field) (field, bool) {
// The fields are sorted in increasing index-length order. The winner
// must therefore be one with the shortest index length. Drop all
// longer entries, which is easy: just truncate the slice.
length := len(fields[0].index)
tagged := -1 // Index of first tagged field.
for i, f := range fields {
if len(f.index) > length {
fields = fields[:i]
break
}
if f.tag {
if tagged >= 0 {
// Multiple tagged fields at the same level: conflict.
// Return no field.
return field{}, false
}
tagged = i
}
}
if tagged >= 0 {
return fields[tagged], true
}
// All remaining fields have the same length. If there's more than one,
// we have a conflict (two fields named "X" at the same level) and we
// return no field.
if len(fields) > 1 {
return field{}, false
}
return fields[0], true
}
var fieldCache struct {
sync.RWMutex
m map[reflect.Type][]field
}
// cachedTypeFields is like typeFields but uses a cache to avoid repeated work.
func cachedTypeFields(t reflect.Type) []field {
fieldCache.RLock()
f := fieldCache.m[t]
fieldCache.RUnlock()
if f != nil {
return f
}
// Compute fields without lock.
// Might duplicate effort but won't hold other computations back.
f = typeFields(t)
if f == nil {
f = []field{}
}
fieldCache.Lock()
if fieldCache.m == nil {
fieldCache.m = map[reflect.Type][]field{}
}
fieldCache.m[t] = f
fieldCache.Unlock()
return f
}

View file

@ -0,0 +1,23 @@
package etcd
// Add a new directory with a random etcd-generated key under the given path.
func (c *Client) AddChildDir(key string, ttl uint64) (*Response, error) {
raw, err := c.post(key, "", ttl)
if err != nil {
return nil, err
}
return raw.Unmarshal()
}
// Add a new file with a random etcd-generated key under the given path.
func (c *Client) AddChild(key string, value string, ttl uint64) (*Response, error) {
raw, err := c.post(key, value, ttl)
if err != nil {
return nil, err
}
return raw.Unmarshal()
}

View file

@ -0,0 +1,481 @@
package etcd
import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"io"
"io/ioutil"
"math/rand"
"net"
"net/http"
"net/url"
"os"
"path"
"strings"
"time"
)
// See SetConsistency for how to use these constants.
const (
// Using strings rather than iota because the consistency level
// could be persisted to disk, so it'd be better to use
// human-readable values.
STRONG_CONSISTENCY = "STRONG"
WEAK_CONSISTENCY = "WEAK"
)
const (
defaultBufferSize = 10
)
func init() {
rand.Seed(int64(time.Now().Nanosecond()))
}
type Config struct {
CertFile string `json:"certFile"`
KeyFile string `json:"keyFile"`
CaCertFile []string `json:"caCertFiles"`
DialTimeout time.Duration `json:"timeout"`
Consistency string `json:"consistency"`
}
type credentials struct {
username string
password string
}
type Client struct {
config Config `json:"config"`
cluster *Cluster `json:"cluster"`
httpClient *http.Client
credentials *credentials
transport *http.Transport
persistence io.Writer
cURLch chan string
// CheckRetry can be used to control the policy for failed requests
// and modify the cluster if needed.
// The client calls it before sending requests again, and
// stops retrying if CheckRetry returns some error. The cases that
// this function needs to handle include no response and unexpected
// http status code of response.
// If CheckRetry is nil, client will call the default one
// `DefaultCheckRetry`.
// Argument cluster is the etcd.Cluster object that these requests have been made on.
// Argument numReqs is the number of http.Requests that have been made so far.
// Argument lastResp is the http.Responses from the last request.
// Argument err is the reason of the failure.
CheckRetry func(cluster *Cluster, numReqs int,
lastResp http.Response, err error) error
}
// NewClient create a basic client that is configured to be used
// with the given machine list.
func NewClient(machines []string) *Client {
config := Config{
// default timeout is one second
DialTimeout: time.Second,
Consistency: WEAK_CONSISTENCY,
}
client := &Client{
cluster: NewCluster(machines),
config: config,
}
client.initHTTPClient()
client.saveConfig()
return client
}
// NewTLSClient create a basic client with TLS configuration
func NewTLSClient(machines []string, cert, key, caCert string) (*Client, error) {
// overwrite the default machine to use https
if len(machines) == 0 {
machines = []string{"https://127.0.0.1:4001"}
}
config := Config{
// default timeout is one second
DialTimeout: time.Second,
Consistency: WEAK_CONSISTENCY,
CertFile: cert,
KeyFile: key,
CaCertFile: make([]string, 0),
}
client := &Client{
cluster: NewCluster(machines),
config: config,
}
err := client.initHTTPSClient(cert, key)
if err != nil {
return nil, err
}
err = client.AddRootCA(caCert)
client.saveConfig()
return client, nil
}
// NewClientFromFile creates a client from a given file path.
// The given file is expected to use the JSON format.
func NewClientFromFile(fpath string) (*Client, error) {
fi, err := os.Open(fpath)
if err != nil {
return nil, err
}
defer func() {
if err := fi.Close(); err != nil {
panic(err)
}
}()
return NewClientFromReader(fi)
}
// NewClientFromReader creates a Client configured from a given reader.
// The configuration is expected to use the JSON format.
func NewClientFromReader(reader io.Reader) (*Client, error) {
c := new(Client)
b, err := ioutil.ReadAll(reader)
if err != nil {
return nil, err
}
err = json.Unmarshal(b, c)
if err != nil {
return nil, err
}
if c.config.CertFile == "" {
c.initHTTPClient()
} else {
err = c.initHTTPSClient(c.config.CertFile, c.config.KeyFile)
}
if err != nil {
return nil, err
}
for _, caCert := range c.config.CaCertFile {
if err := c.AddRootCA(caCert); err != nil {
return nil, err
}
}
return c, nil
}
// Override the Client's HTTP Transport object
func (c *Client) SetTransport(tr *http.Transport) {
c.httpClient.Transport = tr
c.transport = tr
}
func (c *Client) SetCredentials(username, password string) {
c.credentials = &credentials{username, password}
}
func (c *Client) Close() {
c.transport.DisableKeepAlives = true
c.transport.CloseIdleConnections()
}
// initHTTPClient initializes a HTTP client for etcd client
func (c *Client) initHTTPClient() {
c.transport = &http.Transport{
Dial: c.dial,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
c.httpClient = &http.Client{Transport: c.transport}
}
// initHTTPClient initializes a HTTPS client for etcd client
func (c *Client) initHTTPSClient(cert, key string) error {
if cert == "" || key == "" {
return errors.New("Require both cert and key path")
}
tlsCert, err := tls.LoadX509KeyPair(cert, key)
if err != nil {
return err
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{tlsCert},
InsecureSkipVerify: true,
}
tr := &http.Transport{
TLSClientConfig: tlsConfig,
Dial: c.dial,
}
c.httpClient = &http.Client{Transport: tr}
return nil
}
// SetPersistence sets a writer to which the config will be
// written every time it's changed.
func (c *Client) SetPersistence(writer io.Writer) {
c.persistence = writer
}
// SetConsistency changes the consistency level of the client.
//
// When consistency is set to STRONG_CONSISTENCY, all requests,
// including GET, are sent to the leader. This means that, assuming
// the absence of leader failures, GET requests are guaranteed to see
// the changes made by previous requests.
//
// When consistency is set to WEAK_CONSISTENCY, other requests
// are still sent to the leader, but GET requests are sent to a
// random server from the server pool. This reduces the read
// load on the leader, but it's not guaranteed that the GET requests
// will see changes made by previous requests (they might have not
// yet been committed on non-leader servers).
func (c *Client) SetConsistency(consistency string) error {
if !(consistency == STRONG_CONSISTENCY || consistency == WEAK_CONSISTENCY) {
return errors.New("The argument must be either STRONG_CONSISTENCY or WEAK_CONSISTENCY.")
}
c.config.Consistency = consistency
return nil
}
// Sets the DialTimeout value
func (c *Client) SetDialTimeout(d time.Duration) {
c.config.DialTimeout = d
}
// AddRootCA adds a root CA cert for the etcd client
func (c *Client) AddRootCA(caCert string) error {
if c.httpClient == nil {
return errors.New("Client has not been initialized yet!")
}
certBytes, err := ioutil.ReadFile(caCert)
if err != nil {
return err
}
tr, ok := c.httpClient.Transport.(*http.Transport)
if !ok {
panic("AddRootCA(): Transport type assert should not fail")
}
if tr.TLSClientConfig.RootCAs == nil {
caCertPool := x509.NewCertPool()
ok = caCertPool.AppendCertsFromPEM(certBytes)
if ok {
tr.TLSClientConfig.RootCAs = caCertPool
}
tr.TLSClientConfig.InsecureSkipVerify = false
} else {
ok = tr.TLSClientConfig.RootCAs.AppendCertsFromPEM(certBytes)
}
if !ok {
err = errors.New("Unable to load caCert")
}
c.config.CaCertFile = append(c.config.CaCertFile, caCert)
c.saveConfig()
return err
}
// SetCluster updates cluster information using the given machine list.
func (c *Client) SetCluster(machines []string) bool {
success := c.internalSyncCluster(machines)
return success
}
func (c *Client) GetCluster() []string {
return c.cluster.Machines
}
// SyncCluster updates the cluster information using the internal machine list.
func (c *Client) SyncCluster() bool {
return c.internalSyncCluster(c.cluster.Machines)
}
// internalSyncCluster syncs cluster information using the given machine list.
func (c *Client) internalSyncCluster(machines []string) bool {
for _, machine := range machines {
httpPath := c.createHttpPath(machine, path.Join(version, "members"))
resp, err := c.httpClient.Get(httpPath)
if err != nil {
// try another machine in the cluster
continue
}
if resp.StatusCode != http.StatusOK { // fall-back to old endpoint
httpPath := c.createHttpPath(machine, path.Join(version, "machines"))
resp, err := c.httpClient.Get(httpPath)
if err != nil {
// try another machine in the cluster
continue
}
b, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
// try another machine in the cluster
continue
}
// update Machines List
c.cluster.updateFromStr(string(b))
} else {
b, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
// try another machine in the cluster
continue
}
var mCollection memberCollection
if err := json.Unmarshal(b, &mCollection); err != nil {
// try another machine
continue
}
urls := make([]string, 0)
for _, m := range mCollection {
urls = append(urls, m.ClientURLs...)
}
// update Machines List
c.cluster.updateFromStr(strings.Join(urls, ","))
}
logger.Debug("sync.machines ", c.cluster.Machines)
c.saveConfig()
return true
}
return false
}
// createHttpPath creates a complete HTTP URL.
// serverName should contain both the host name and a port number, if any.
func (c *Client) createHttpPath(serverName string, _path string) string {
u, err := url.Parse(serverName)
if err != nil {
panic(err)
}
u.Path = path.Join(u.Path, _path)
if u.Scheme == "" {
u.Scheme = "http"
}
return u.String()
}
// dial attempts to open a TCP connection to the provided address, explicitly
// enabling keep-alives with a one-second interval.
func (c *Client) dial(network, addr string) (net.Conn, error) {
conn, err := net.DialTimeout(network, addr, c.config.DialTimeout)
if err != nil {
return nil, err
}
tcpConn, ok := conn.(*net.TCPConn)
if !ok {
return nil, errors.New("Failed type-assertion of net.Conn as *net.TCPConn")
}
// Keep TCP alive to check whether or not the remote machine is down
if err = tcpConn.SetKeepAlive(true); err != nil {
return nil, err
}
if err = tcpConn.SetKeepAlivePeriod(time.Second); err != nil {
return nil, err
}
return tcpConn, nil
}
func (c *Client) OpenCURL() {
c.cURLch = make(chan string, defaultBufferSize)
}
func (c *Client) CloseCURL() {
c.cURLch = nil
}
func (c *Client) sendCURL(command string) {
go func() {
select {
case c.cURLch <- command:
default:
}
}()
}
func (c *Client) RecvCURL() string {
return <-c.cURLch
}
// saveConfig saves the current config using c.persistence.
func (c *Client) saveConfig() error {
if c.persistence != nil {
b, err := json.Marshal(c)
if err != nil {
return err
}
_, err = c.persistence.Write(b)
if err != nil {
return err
}
}
return nil
}
// MarshalJSON implements the Marshaller interface
// as defined by the standard JSON package.
func (c *Client) MarshalJSON() ([]byte, error) {
b, err := json.Marshal(struct {
Config Config `json:"config"`
Cluster *Cluster `json:"cluster"`
}{
Config: c.config,
Cluster: c.cluster,
})
if err != nil {
return nil, err
}
return b, nil
}
// UnmarshalJSON implements the Unmarshaller interface
// as defined by the standard JSON package.
func (c *Client) UnmarshalJSON(b []byte) error {
temp := struct {
Config Config `json:"config"`
Cluster *Cluster `json:"cluster"`
}{}
err := json.Unmarshal(b, &temp)
if err != nil {
return err
}
c.cluster = temp.Cluster
c.config = temp.Config
return nil
}

View file

@ -0,0 +1,37 @@
package etcd
import (
"math/rand"
"strings"
)
type Cluster struct {
Leader string `json:"leader"`
Machines []string `json:"machines"`
picked int
}
func NewCluster(machines []string) *Cluster {
// if an empty slice was sent in then just assume HTTP 4001 on localhost
if len(machines) == 0 {
machines = []string{"http://127.0.0.1:4001"}
}
// default leader and machines
return &Cluster{
Leader: "",
Machines: machines,
picked: rand.Intn(len(machines)),
}
}
func (cl *Cluster) failure() { cl.picked = rand.Intn(len(cl.Machines)) }
func (cl *Cluster) pick() string { return cl.Machines[cl.picked] }
func (cl *Cluster) updateFromStr(machines string) {
cl.Machines = strings.Split(machines, ",")
for i := range cl.Machines {
cl.Machines[i] = strings.TrimSpace(cl.Machines[i])
}
cl.picked = rand.Intn(len(cl.Machines))
}

View file

@ -0,0 +1,34 @@
package etcd
import "fmt"
func (c *Client) CompareAndDelete(key string, prevValue string, prevIndex uint64) (*Response, error) {
raw, err := c.RawCompareAndDelete(key, prevValue, prevIndex)
if err != nil {
return nil, err
}
return raw.Unmarshal()
}
func (c *Client) RawCompareAndDelete(key string, prevValue string, prevIndex uint64) (*RawResponse, error) {
if prevValue == "" && prevIndex == 0 {
return nil, fmt.Errorf("You must give either prevValue or prevIndex.")
}
options := Options{}
if prevValue != "" {
options["prevValue"] = prevValue
}
if prevIndex != 0 {
options["prevIndex"] = prevIndex
}
raw, err := c.delete(key, options)
if err != nil {
return nil, err
}
return raw, err
}

View file

@ -0,0 +1,36 @@
package etcd
import "fmt"
func (c *Client) CompareAndSwap(key string, value string, ttl uint64,
prevValue string, prevIndex uint64) (*Response, error) {
raw, err := c.RawCompareAndSwap(key, value, ttl, prevValue, prevIndex)
if err != nil {
return nil, err
}
return raw.Unmarshal()
}
func (c *Client) RawCompareAndSwap(key string, value string, ttl uint64,
prevValue string, prevIndex uint64) (*RawResponse, error) {
if prevValue == "" && prevIndex == 0 {
return nil, fmt.Errorf("You must give either prevValue or prevIndex.")
}
options := Options{}
if prevValue != "" {
options["prevValue"] = prevValue
}
if prevIndex != 0 {
options["prevIndex"] = prevIndex
}
raw, err := c.put(key, value, ttl, options)
if err != nil {
return nil, err
}
return raw, err
}

View file

@ -0,0 +1,55 @@
package etcd
import (
"fmt"
"io/ioutil"
"log"
"strings"
)
var logger *etcdLogger
func SetLogger(l *log.Logger) {
logger = &etcdLogger{l}
}
func GetLogger() *log.Logger {
return logger.log
}
type etcdLogger struct {
log *log.Logger
}
func (p *etcdLogger) Debug(args ...interface{}) {
msg := "DEBUG: " + fmt.Sprint(args...)
p.log.Println(msg)
}
func (p *etcdLogger) Debugf(f string, args ...interface{}) {
msg := "DEBUG: " + fmt.Sprintf(f, args...)
// Append newline if necessary
if !strings.HasSuffix(msg, "\n") {
msg = msg + "\n"
}
p.log.Print(msg)
}
func (p *etcdLogger) Warning(args ...interface{}) {
msg := "WARNING: " + fmt.Sprint(args...)
p.log.Println(msg)
}
func (p *etcdLogger) Warningf(f string, args ...interface{}) {
msg := "WARNING: " + fmt.Sprintf(f, args...)
// Append newline if necessary
if !strings.HasSuffix(msg, "\n") {
msg = msg + "\n"
}
p.log.Print(msg)
}
func init() {
// Default logger uses the go default log.
SetLogger(log.New(ioutil.Discard, "go-etcd", log.LstdFlags))
}

View file

@ -0,0 +1,40 @@
package etcd
// Delete deletes the given key.
//
// When recursive set to false, if the key points to a
// directory the method will fail.
//
// When recursive set to true, if the key points to a file,
// the file will be deleted; if the key points to a directory,
// then everything under the directory (including all child directories)
// will be deleted.
func (c *Client) Delete(key string, recursive bool) (*Response, error) {
raw, err := c.RawDelete(key, recursive, false)
if err != nil {
return nil, err
}
return raw.Unmarshal()
}
// DeleteDir deletes an empty directory or a key value pair
func (c *Client) DeleteDir(key string) (*Response, error) {
raw, err := c.RawDelete(key, false, true)
if err != nil {
return nil, err
}
return raw.Unmarshal()
}
func (c *Client) RawDelete(key string, recursive bool, dir bool) (*RawResponse, error) {
ops := Options{
"recursive": recursive,
"dir": dir,
}
return c.delete(key, ops)
}

View file

@ -0,0 +1,49 @@
package etcd
import (
"encoding/json"
"fmt"
)
const (
ErrCodeEtcdNotReachable = 501
ErrCodeUnhandledHTTPStatus = 502
)
var (
errorMap = map[int]string{
ErrCodeEtcdNotReachable: "All the given peers are not reachable",
}
)
type EtcdError struct {
ErrorCode int `json:"errorCode"`
Message string `json:"message"`
Cause string `json:"cause,omitempty"`
Index uint64 `json:"index"`
}
func (e EtcdError) Error() string {
return fmt.Sprintf("%v: %v (%v) [%v]", e.ErrorCode, e.Message, e.Cause, e.Index)
}
func newError(errorCode int, cause string, index uint64) *EtcdError {
return &EtcdError{
ErrorCode: errorCode,
Message: errorMap[errorCode],
Cause: cause,
Index: index,
}
}
func handleError(b []byte) error {
etcdErr := new(EtcdError)
err := json.Unmarshal(b, etcdErr)
if err != nil {
logger.Warningf("cannot unmarshal etcd error: %v", err)
return err
}
return etcdErr
}

View file

@ -0,0 +1,32 @@
package etcd
// Get gets the file or directory associated with the given key.
// If the key points to a directory, files and directories under
// it will be returned in sorted or unsorted order, depending on
// the sort flag.
// If recursive is set to false, contents under child directories
// will not be returned.
// If recursive is set to true, all the contents will be returned.
func (c *Client) Get(key string, sort, recursive bool) (*Response, error) {
raw, err := c.RawGet(key, sort, recursive)
if err != nil {
return nil, err
}
return raw.Unmarshal()
}
func (c *Client) RawGet(key string, sort, recursive bool) (*RawResponse, error) {
var q bool
if c.config.Consistency == STRONG_CONSISTENCY {
q = true
}
ops := Options{
"recursive": recursive,
"sorted": sort,
"quorum": q,
}
return c.get(key, ops)
}

View file

@ -0,0 +1,30 @@
package etcd
import "encoding/json"
type Member struct {
ID string `json:"id"`
Name string `json:"name"`
PeerURLs []string `json:"peerURLs"`
ClientURLs []string `json:"clientURLs"`
}
type memberCollection []Member
func (c *memberCollection) UnmarshalJSON(data []byte) error {
d := struct {
Members []Member
}{}
if err := json.Unmarshal(data, &d); err != nil {
return err
}
if d.Members == nil {
*c = make([]Member, 0)
return nil
}
*c = d.Members
return nil
}

View file

@ -0,0 +1,72 @@
package etcd
import (
"fmt"
"net/url"
"reflect"
)
type Options map[string]interface{}
// An internally-used data structure that represents a mapping
// between valid options and their kinds
type validOptions map[string]reflect.Kind
// Valid options for GET, PUT, POST, DELETE
// Using CAPITALIZED_UNDERSCORE to emphasize that these
// values are meant to be used as constants.
var (
VALID_GET_OPTIONS = validOptions{
"recursive": reflect.Bool,
"quorum": reflect.Bool,
"sorted": reflect.Bool,
"wait": reflect.Bool,
"waitIndex": reflect.Uint64,
}
VALID_PUT_OPTIONS = validOptions{
"prevValue": reflect.String,
"prevIndex": reflect.Uint64,
"prevExist": reflect.Bool,
"dir": reflect.Bool,
}
VALID_POST_OPTIONS = validOptions{}
VALID_DELETE_OPTIONS = validOptions{
"recursive": reflect.Bool,
"dir": reflect.Bool,
"prevValue": reflect.String,
"prevIndex": reflect.Uint64,
}
)
// Convert options to a string of HTML parameters
func (ops Options) toParameters(validOps validOptions) (string, error) {
p := "?"
values := url.Values{}
if ops == nil {
return "", nil
}
for k, v := range ops {
// Check if the given option is valid (that it exists)
kind := validOps[k]
if kind == reflect.Invalid {
return "", fmt.Errorf("Invalid option: %v", k)
}
// Check if the given option is of the valid type
t := reflect.TypeOf(v)
if kind != t.Kind() {
return "", fmt.Errorf("Option %s should be of %v kind, not of %v kind.",
k, kind, t.Kind())
}
values.Set(k, fmt.Sprintf("%v", v))
}
p += values.Encode()
return p, nil
}

View file

@ -0,0 +1,405 @@
package etcd
import (
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"path"
"strings"
"sync"
"time"
)
// Errors introduced by handling requests
var (
ErrRequestCancelled = errors.New("sending request is cancelled")
)
type RawRequest struct {
Method string
RelativePath string
Values url.Values
Cancel <-chan bool
}
// NewRawRequest returns a new RawRequest
func NewRawRequest(method, relativePath string, values url.Values, cancel <-chan bool) *RawRequest {
return &RawRequest{
Method: method,
RelativePath: relativePath,
Values: values,
Cancel: cancel,
}
}
// getCancelable issues a cancelable GET request
func (c *Client) getCancelable(key string, options Options,
cancel <-chan bool) (*RawResponse, error) {
logger.Debugf("get %s [%s]", key, c.cluster.pick())
p := keyToPath(key)
str, err := options.toParameters(VALID_GET_OPTIONS)
if err != nil {
return nil, err
}
p += str
req := NewRawRequest("GET", p, nil, cancel)
resp, err := c.SendRequest(req)
if err != nil {
return nil, err
}
return resp, nil
}
// get issues a GET request
func (c *Client) get(key string, options Options) (*RawResponse, error) {
return c.getCancelable(key, options, nil)
}
// put issues a PUT request
func (c *Client) put(key string, value string, ttl uint64,
options Options) (*RawResponse, error) {
logger.Debugf("put %s, %s, ttl: %d, [%s]", key, value, ttl, c.cluster.pick())
p := keyToPath(key)
str, err := options.toParameters(VALID_PUT_OPTIONS)
if err != nil {
return nil, err
}
p += str
req := NewRawRequest("PUT", p, buildValues(value, ttl), nil)
resp, err := c.SendRequest(req)
if err != nil {
return nil, err
}
return resp, nil
}
// post issues a POST request
func (c *Client) post(key string, value string, ttl uint64) (*RawResponse, error) {
logger.Debugf("post %s, %s, ttl: %d, [%s]", key, value, ttl, c.cluster.pick())
p := keyToPath(key)
req := NewRawRequest("POST", p, buildValues(value, ttl), nil)
resp, err := c.SendRequest(req)
if err != nil {
return nil, err
}
return resp, nil
}
// delete issues a DELETE request
func (c *Client) delete(key string, options Options) (*RawResponse, error) {
logger.Debugf("delete %s [%s]", key, c.cluster.pick())
p := keyToPath(key)
str, err := options.toParameters(VALID_DELETE_OPTIONS)
if err != nil {
return nil, err
}
p += str
req := NewRawRequest("DELETE", p, nil, nil)
resp, err := c.SendRequest(req)
if err != nil {
return nil, err
}
return resp, nil
}
// SendRequest sends a HTTP request and returns a Response as defined by etcd
func (c *Client) SendRequest(rr *RawRequest) (*RawResponse, error) {
var req *http.Request
var resp *http.Response
var httpPath string
var err error
var respBody []byte
var numReqs = 1
checkRetry := c.CheckRetry
if checkRetry == nil {
checkRetry = DefaultCheckRetry
}
cancelled := make(chan bool, 1)
reqLock := new(sync.Mutex)
if rr.Cancel != nil {
cancelRoutine := make(chan bool)
defer close(cancelRoutine)
go func() {
select {
case <-rr.Cancel:
cancelled <- true
logger.Debug("send.request is cancelled")
case <-cancelRoutine:
return
}
// Repeat canceling request until this thread is stopped
// because we have no idea about whether it succeeds.
for {
reqLock.Lock()
c.httpClient.Transport.(*http.Transport).CancelRequest(req)
reqLock.Unlock()
select {
case <-time.After(100 * time.Millisecond):
case <-cancelRoutine:
return
}
}
}()
}
// If we connect to a follower and consistency is required, retry until
// we connect to a leader
sleep := 25 * time.Millisecond
maxSleep := time.Second
for attempt := 0; ; attempt++ {
if attempt > 0 {
select {
case <-cancelled:
return nil, ErrRequestCancelled
case <-time.After(sleep):
sleep = sleep * 2
if sleep > maxSleep {
sleep = maxSleep
}
}
}
logger.Debug("Connecting to etcd: attempt ", attempt+1, " for ", rr.RelativePath)
// get httpPath if not set
if httpPath == "" {
httpPath = c.getHttpPath(rr.RelativePath)
}
// Return a cURL command if curlChan is set
if c.cURLch != nil {
command := fmt.Sprintf("curl -X %s %s", rr.Method, httpPath)
for key, value := range rr.Values {
command += fmt.Sprintf(" -d %s=%s", key, value[0])
}
if c.credentials != nil {
command += fmt.Sprintf(" -u %s", c.credentials.username)
}
c.sendCURL(command)
}
logger.Debug("send.request.to ", httpPath, " | method ", rr.Method)
req, err := func() (*http.Request, error) {
reqLock.Lock()
defer reqLock.Unlock()
if rr.Values == nil {
if req, err = http.NewRequest(rr.Method, httpPath, nil); err != nil {
return nil, err
}
} else {
body := strings.NewReader(rr.Values.Encode())
if req, err = http.NewRequest(rr.Method, httpPath, body); err != nil {
return nil, err
}
req.Header.Set("Content-Type",
"application/x-www-form-urlencoded; param=value")
}
return req, nil
}()
if err != nil {
return nil, err
}
if c.credentials != nil {
req.SetBasicAuth(c.credentials.username, c.credentials.password)
}
resp, err = c.httpClient.Do(req)
// clear previous httpPath
httpPath = ""
defer func() {
if resp != nil {
resp.Body.Close()
}
}()
// If the request was cancelled, return ErrRequestCancelled directly
select {
case <-cancelled:
return nil, ErrRequestCancelled
default:
}
numReqs++
// network error, change a machine!
if err != nil {
logger.Debug("network error: ", err.Error())
lastResp := http.Response{}
if checkErr := checkRetry(c.cluster, numReqs, lastResp, err); checkErr != nil {
return nil, checkErr
}
c.cluster.failure()
continue
}
// if there is no error, it should receive response
logger.Debug("recv.response.from ", httpPath)
if validHttpStatusCode[resp.StatusCode] {
// try to read byte code and break the loop
respBody, err = ioutil.ReadAll(resp.Body)
if err == nil {
logger.Debug("recv.success ", httpPath)
break
}
// ReadAll error may be caused due to cancel request
select {
case <-cancelled:
return nil, ErrRequestCancelled
default:
}
if err == io.ErrUnexpectedEOF {
// underlying connection was closed prematurely, probably by timeout
// TODO: empty body or unexpectedEOF can cause http.Transport to get hosed;
// this allows the client to detect that and take evasive action. Need
// to revisit once code.google.com/p/go/issues/detail?id=8648 gets fixed.
respBody = []byte{}
break
}
}
if resp.StatusCode == http.StatusTemporaryRedirect {
u, err := resp.Location()
if err != nil {
logger.Warning(err)
} else {
// set httpPath for following redirection
httpPath = u.String()
}
resp.Body.Close()
continue
}
if checkErr := checkRetry(c.cluster, numReqs, *resp,
errors.New("Unexpected HTTP status code")); checkErr != nil {
return nil, checkErr
}
resp.Body.Close()
}
r := &RawResponse{
StatusCode: resp.StatusCode,
Body: respBody,
Header: resp.Header,
}
return r, nil
}
// DefaultCheckRetry defines the retrying behaviour for bad HTTP requests
// If we have retried 2 * machine number, stop retrying.
// If status code is InternalServerError, sleep for 200ms.
func DefaultCheckRetry(cluster *Cluster, numReqs int, lastResp http.Response,
err error) error {
if isEmptyResponse(lastResp) {
// always retry if it failed to get response from one machine
return err
} else if !shouldRetry(lastResp) {
body := []byte("nil")
if lastResp.Body != nil {
if b, err := ioutil.ReadAll(lastResp.Body); err == nil {
body = b
}
}
errStr := fmt.Sprintf("unhandled http status [%s] with body [%s]", http.StatusText(lastResp.StatusCode), body)
return newError(ErrCodeUnhandledHTTPStatus, errStr, 0)
}
if numReqs > 2*len(cluster.Machines) {
errStr := fmt.Sprintf("failed to propose on members %v twice [last error: %v]", cluster.Machines, err)
return newError(ErrCodeEtcdNotReachable, errStr, 0)
}
if shouldRetry(lastResp) {
// sleep some time and expect leader election finish
time.Sleep(time.Millisecond * 200)
}
logger.Warning("bad response status code", lastResp.StatusCode)
return nil
}
func isEmptyResponse(r http.Response) bool { return r.StatusCode == 0 }
// shouldRetry returns whether the reponse deserves retry.
func shouldRetry(r http.Response) bool {
// TODO: only retry when the cluster is in leader election
// We cannot do it exactly because etcd doesn't support it well.
return r.StatusCode == http.StatusInternalServerError
}
func (c *Client) getHttpPath(s ...string) string {
fullPath := c.cluster.pick() + "/" + version
for _, seg := range s {
fullPath = fullPath + "/" + seg
}
return fullPath
}
// buildValues builds a url.Values map according to the given value and ttl
func buildValues(value string, ttl uint64) url.Values {
v := url.Values{}
if value != "" {
v.Set("value", value)
}
if ttl > 0 {
v.Set("ttl", fmt.Sprintf("%v", ttl))
}
return v
}
// convert key string to http path exclude version, including URL escaping
// for example: key[foo] -> path[keys/foo]
// key[/%z] -> path[keys/%25z]
// key[/] -> path[keys/]
func keyToPath(key string) string {
// URL-escape our key, except for slashes
p := strings.Replace(url.QueryEscape(path.Join("keys", key)), "%2F", "/", -1)
// corner case: if key is "/" or "//" ect
// path join will clear the tailing "/"
// we need to add it back
if p == "keys" {
p = "keys/"
}
return p
}

View file

@ -0,0 +1,89 @@
package etcd
import (
"encoding/json"
"net/http"
"strconv"
"time"
)
const (
rawResponse = iota
normalResponse
)
type responseType int
type RawResponse struct {
StatusCode int
Body []byte
Header http.Header
}
var (
validHttpStatusCode = map[int]bool{
http.StatusCreated: true,
http.StatusOK: true,
http.StatusBadRequest: true,
http.StatusNotFound: true,
http.StatusPreconditionFailed: true,
http.StatusForbidden: true,
}
)
// Unmarshal parses RawResponse and stores the result in Response
func (rr *RawResponse) Unmarshal() (*Response, error) {
if rr.StatusCode != http.StatusOK && rr.StatusCode != http.StatusCreated {
return nil, handleError(rr.Body)
}
resp := new(Response)
err := json.Unmarshal(rr.Body, resp)
if err != nil {
return nil, err
}
// attach index and term to response
resp.EtcdIndex, _ = strconv.ParseUint(rr.Header.Get("X-Etcd-Index"), 10, 64)
resp.RaftIndex, _ = strconv.ParseUint(rr.Header.Get("X-Raft-Index"), 10, 64)
resp.RaftTerm, _ = strconv.ParseUint(rr.Header.Get("X-Raft-Term"), 10, 64)
return resp, nil
}
type Response struct {
Action string `json:"action"`
Node *Node `json:"node"`
PrevNode *Node `json:"prevNode,omitempty"`
EtcdIndex uint64 `json:"etcdIndex"`
RaftIndex uint64 `json:"raftIndex"`
RaftTerm uint64 `json:"raftTerm"`
}
type Node struct {
Key string `json:"key, omitempty"`
Value string `json:"value,omitempty"`
Dir bool `json:"dir,omitempty"`
Expiration *time.Time `json:"expiration,omitempty"`
TTL int64 `json:"ttl,omitempty"`
Nodes Nodes `json:"nodes,omitempty"`
ModifiedIndex uint64 `json:"modifiedIndex,omitempty"`
CreatedIndex uint64 `json:"createdIndex,omitempty"`
}
type Nodes []*Node
// interfaces for sorting
func (ns Nodes) Len() int {
return len(ns)
}
func (ns Nodes) Less(i, j int) bool {
return ns[i].Key < ns[j].Key
}
func (ns Nodes) Swap(i, j int) {
ns[i], ns[j] = ns[j], ns[i]
}

View file

@ -0,0 +1,137 @@
package etcd
// Set sets the given key to the given value.
// It will create a new key value pair or replace the old one.
// It will not replace a existing directory.
func (c *Client) Set(key string, value string, ttl uint64) (*Response, error) {
raw, err := c.RawSet(key, value, ttl)
if err != nil {
return nil, err
}
return raw.Unmarshal()
}
// SetDir sets the given key to a directory.
// It will create a new directory or replace the old key value pair by a directory.
// It will not replace a existing directory.
func (c *Client) SetDir(key string, ttl uint64) (*Response, error) {
raw, err := c.RawSetDir(key, ttl)
if err != nil {
return nil, err
}
return raw.Unmarshal()
}
// CreateDir creates a directory. It succeeds only if
// the given key does not yet exist.
func (c *Client) CreateDir(key string, ttl uint64) (*Response, error) {
raw, err := c.RawCreateDir(key, ttl)
if err != nil {
return nil, err
}
return raw.Unmarshal()
}
// UpdateDir updates the given directory. It succeeds only if the
// given key already exists.
func (c *Client) UpdateDir(key string, ttl uint64) (*Response, error) {
raw, err := c.RawUpdateDir(key, ttl)
if err != nil {
return nil, err
}
return raw.Unmarshal()
}
// Create creates a file with the given value under the given key. It succeeds
// only if the given key does not yet exist.
func (c *Client) Create(key string, value string, ttl uint64) (*Response, error) {
raw, err := c.RawCreate(key, value, ttl)
if err != nil {
return nil, err
}
return raw.Unmarshal()
}
// CreateInOrder creates a file with a key that's guaranteed to be higher than other
// keys in the given directory. It is useful for creating queues.
func (c *Client) CreateInOrder(dir string, value string, ttl uint64) (*Response, error) {
raw, err := c.RawCreateInOrder(dir, value, ttl)
if err != nil {
return nil, err
}
return raw.Unmarshal()
}
// Update updates the given key to the given value. It succeeds only if the
// given key already exists.
func (c *Client) Update(key string, value string, ttl uint64) (*Response, error) {
raw, err := c.RawUpdate(key, value, ttl)
if err != nil {
return nil, err
}
return raw.Unmarshal()
}
func (c *Client) RawUpdateDir(key string, ttl uint64) (*RawResponse, error) {
ops := Options{
"prevExist": true,
"dir": true,
}
return c.put(key, "", ttl, ops)
}
func (c *Client) RawCreateDir(key string, ttl uint64) (*RawResponse, error) {
ops := Options{
"prevExist": false,
"dir": true,
}
return c.put(key, "", ttl, ops)
}
func (c *Client) RawSet(key string, value string, ttl uint64) (*RawResponse, error) {
return c.put(key, value, ttl, nil)
}
func (c *Client) RawSetDir(key string, ttl uint64) (*RawResponse, error) {
ops := Options{
"dir": true,
}
return c.put(key, "", ttl, ops)
}
func (c *Client) RawUpdate(key string, value string, ttl uint64) (*RawResponse, error) {
ops := Options{
"prevExist": true,
}
return c.put(key, value, ttl, ops)
}
func (c *Client) RawCreate(key string, value string, ttl uint64) (*RawResponse, error) {
ops := Options{
"prevExist": false,
}
return c.put(key, value, ttl, ops)
}
func (c *Client) RawCreateInOrder(dir string, value string, ttl uint64) (*RawResponse, error) {
return c.post(dir, value, ttl)
}

View file

@ -0,0 +1,6 @@
package etcd
const (
version = "v2"
packageVersion = "v2.0.0"
)

View file

@ -0,0 +1,103 @@
package etcd
import (
"errors"
)
// Errors introduced by the Watch command.
var (
ErrWatchStoppedByUser = errors.New("Watch stopped by the user via stop channel")
)
// If recursive is set to true the watch returns the first change under the given
// prefix since the given index.
//
// If recursive is set to false the watch returns the first change to the given key
// since the given index.
//
// To watch for the latest change, set waitIndex = 0.
//
// If a receiver channel is given, it will be a long-term watch. Watch will block at the
//channel. After someone receives the channel, it will go on to watch that
// prefix. If a stop channel is given, the client can close long-term watch using
// the stop channel.
func (c *Client) Watch(prefix string, waitIndex uint64, recursive bool,
receiver chan *Response, stop chan bool) (*Response, error) {
logger.Debugf("watch %s [%s]", prefix, c.cluster.Leader)
if receiver == nil {
raw, err := c.watchOnce(prefix, waitIndex, recursive, stop)
if err != nil {
return nil, err
}
return raw.Unmarshal()
}
defer close(receiver)
for {
raw, err := c.watchOnce(prefix, waitIndex, recursive, stop)
if err != nil {
return nil, err
}
resp, err := raw.Unmarshal()
if err != nil {
return nil, err
}
waitIndex = resp.Node.ModifiedIndex + 1
receiver <- resp
}
}
func (c *Client) RawWatch(prefix string, waitIndex uint64, recursive bool,
receiver chan *RawResponse, stop chan bool) (*RawResponse, error) {
logger.Debugf("rawWatch %s [%s]", prefix, c.cluster.Leader)
if receiver == nil {
return c.watchOnce(prefix, waitIndex, recursive, stop)
}
for {
raw, err := c.watchOnce(prefix, waitIndex, recursive, stop)
if err != nil {
return nil, err
}
resp, err := raw.Unmarshal()
if err != nil {
return nil, err
}
waitIndex = resp.Node.ModifiedIndex + 1
receiver <- raw
}
}
// helper func
// return when there is change under the given prefix
func (c *Client) watchOnce(key string, waitIndex uint64, recursive bool, stop chan bool) (*RawResponse, error) {
options := Options{
"wait": true,
}
if waitIndex > 0 {
options["waitIndex"] = waitIndex
}
if recursive {
options["recursive"] = true
}
resp, err := c.getCancelable(key, options, stop)
if err == ErrRequestCancelled {
return nil, ErrWatchStoppedByUser
}
return resp, err
}

View file

@ -0,0 +1,34 @@
language: go
go:
- 1.3
# - 1.4
# see https://github.com/moovweb/gvm/pull/116 for why Go 1.4 is currently disabled
# let us have speedy Docker-based Travis workers
sudo: false
before_install:
# Symlink below is needed for Travis CI to work correctly on personal forks of libkv
- ln -s $HOME/gopath/src/github.com/${TRAVIS_REPO_SLUG///libkv/} $HOME/gopath/src/github.com/docker
- go get code.google.com/p/go.tools/cmd/vet
- go get code.google.com/p/go.tools/cmd/cover
- go get github.com/mattn/goveralls
- go get github.com/golang/lint/golint
- go get github.com/GeertJohan/fgt
before_script:
- script/travis_consul.sh 0.5.2
- script/travis_etcd.sh 2.0.11
- script/travis_zk.sh 3.4.6
script:
- ./consul agent -server -bootstrap-expect 1 -data-dir /tmp/consul -config-file=./config.json 1>/dev/null &
- ./etcd/etcd --listen-client-urls 'http://0.0.0.0:4001' --advertise-client-urls 'http://127.0.0.1:4001' >/dev/null 2>&1 &
- ./zk/bin/zkServer.sh start ./zk/conf/zoo.cfg 1> /dev/null
- script/validate-gofmt
- go vet ./...
- fgt golint ./...
- go test -v -race ./...
- script/coverage
- goveralls -service=travis-ci -coverprofile=goverage.report

View file

@ -0,0 +1,191 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
Copyright 2014-2015 Docker, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View file

@ -0,0 +1,108 @@
# libkv
[![GoDoc](https://godoc.org/github.com/docker/libkv?status.png)](https://godoc.org/github.com/docker/libkv)
[![Build Status](https://travis-ci.org/docker/libkv.svg?branch=master)](https://travis-ci.org/docker/libkv)
[![Coverage Status](https://coveralls.io/repos/docker/libkv/badge.svg)](https://coveralls.io/r/docker/libkv)
`libkv` provides a `Go` native library to store metadata.
The goal of `libkv` is to abstract common store operations for multiple Key/Value backends and offer the same experience no matter which one of the backend you want to use.
For example, you can use it to store your metadata or for service discovery to register machines and endpoints inside your cluster.
You can also easily implement a generic *Leader Election* on top of it (see the [swarm/leadership](https://github.com/docker/swarm/tree/master/leadership) package).
As of now, `libkv` offers support for `Consul`, `Etcd` and `Zookeeper`.
## Example of usage
### Create a new store and use Put/Get
```go
package main
import (
"fmt"
"time"
"github.com/docker/libkv"
"github.com/docker/libkv/store"
log "github.com/Sirupsen/logrus"
)
func main() {
client := "localhost:8500"
// Initialize a new store with consul
kv, err = libkv.NewStore(
store.CONSUL, // or "consul"
[]string{client},
&store.Config{
ConnectionTimeout: 10*time.Second,
},
)
if err != nil {
log.Fatal("Cannot create store consul")
}
key := "foo"
err = kv.Put(key, []byte("bar"), nil)
if err != nil {
log.Error("Error trying to put value at key `", key, "`")
}
pair, err := kv.Get(key)
if err != nil {
log.Error("Error trying accessing value at key `", key, "`")
}
log.Info("value: ", string(pair.Value))
}
```
You can find other usage examples for `libkv` under the `docker/swarm` or `docker/libnetwork` repositories.
## Details
You should expect the same experience for basic operations like `Get`/`Put`, etc.
However calls like `WatchTree` may return different events (or number of events) depending on the backend (for now, `Etcd` and `Consul` will likely return more events than `Zookeeper` that you should triage properly).
## Create a new storage backend
A new **storage backend** should include those calls:
```go
type Store interface {
Put(key string, value []byte, options *WriteOptions) error
Get(key string) (*KVPair, error)
Delete(key string) error
Exists(key string) (bool, error)
Watch(key string, stopCh <-chan struct{}) (<-chan *KVPair, error)
WatchTree(directory string, stopCh <-chan struct{}) (<-chan []*KVPair, error)
NewLock(key string, options *LockOptions) (Locker, error)
List(directory string) ([]*KVPair, error)
DeleteTree(directory string) error
AtomicPut(key string, value []byte, previous *KVPair, options *WriteOptions) (bool, *KVPair, error)
AtomicDelete(key string, previous *KVPair) (bool, error)
Close()
}
```
You can get inspiration from existing backends to create a new one. This interface could be subject to changes to improve the experience of using the library and contributing to a new backend.
##Roadmap
- Make the API nicer to use (using `options`)
- Provide more options (`consistency` for example)
- Improve performance (remove extras `Get`/`List` operations)
- Add more exhaustive tests
- New backends?
##Contributing
Want to hack on libkv? [Docker's contributions guidelines](https://github.com/docker/docker/blob/master/CONTRIBUTING.md) apply.
##Copyright and license
Code and documentation copyright 2015 Docker, inc. Code released under the Apache 2.0 license. Docs released under Creative commons.

View file

@ -0,0 +1,29 @@
package libkv
import (
"github.com/docker/libkv/store"
"github.com/docker/libkv/store/consul"
"github.com/docker/libkv/store/etcd"
"github.com/docker/libkv/store/zookeeper"
)
// Initialize creates a new Store object, initializing the client
type Initialize func(addrs []string, options *store.Config) (store.Store, error)
var (
// Backend initializers
initializers = map[store.Backend]Initialize{
store.CONSUL: consul.New,
store.ETCD: etcd.New,
store.ZK: zookeeper.New,
}
)
// NewStore creates a an instance of store
func NewStore(backend store.Backend, addrs []string, options *store.Config) (store.Store, error) {
if init, exists := initializers[backend]; exists {
return init(addrs, options)
}
return nil, store.ErrNotSupported
}

View file

@ -0,0 +1,416 @@
package consul
import (
"crypto/tls"
"net/http"
"strings"
"sync"
"time"
"github.com/docker/libkv/store"
api "github.com/hashicorp/consul/api"
)
const (
// DefaultWatchWaitTime is how long we block for at a
// time to check if the watched key has changed. This
// affects the minimum time it takes to cancel a watch.
DefaultWatchWaitTime = 15 * time.Second
)
// Consul is the receiver type for the
// Store interface
type Consul struct {
sync.Mutex
config *api.Config
client *api.Client
ephemeralTTL time.Duration
}
type consulLock struct {
lock *api.Lock
}
// New creates a new Consul client given a list
// of endpoints and optional tls config
func New(endpoints []string, options *store.Config) (store.Store, error) {
s := &Consul{}
// Create Consul client
config := api.DefaultConfig()
s.config = config
config.HttpClient = http.DefaultClient
config.Address = endpoints[0]
config.Scheme = "http"
// Set options
if options != nil {
if options.TLS != nil {
s.setTLS(options.TLS)
}
if options.ConnectionTimeout != 0 {
s.setTimeout(options.ConnectionTimeout)
}
if options.EphemeralTTL != 0 {
s.setEphemeralTTL(options.EphemeralTTL)
}
}
// Creates a new client
client, err := api.NewClient(config)
if err != nil {
return nil, err
}
s.client = client
return s, nil
}
// SetTLS sets Consul TLS options
func (s *Consul) setTLS(tls *tls.Config) {
s.config.HttpClient.Transport = &http.Transport{
TLSClientConfig: tls,
}
s.config.Scheme = "https"
}
// SetTimeout sets the timout for connecting to Consul
func (s *Consul) setTimeout(time time.Duration) {
s.config.WaitTime = time
}
// SetEphemeralTTL sets the ttl for ephemeral nodes
func (s *Consul) setEphemeralTTL(ttl time.Duration) {
s.ephemeralTTL = ttl
}
// Normalize the key for usage in Consul
func (s *Consul) normalize(key string) string {
key = store.Normalize(key)
return strings.TrimPrefix(key, "/")
}
func (s *Consul) refreshSession(pair *api.KVPair) error {
// Check if there is any previous session with an active TTL
session, err := s.getActiveSession(pair.Key)
if err != nil {
return err
}
if session == "" {
entry := &api.SessionEntry{
Behavior: api.SessionBehaviorDelete,
TTL: s.ephemeralTTL.String(),
}
// Create the key session
session, _, err = s.client.Session().Create(entry, nil)
if err != nil {
return err
}
}
lockOpts := &api.LockOptions{
Key: pair.Key,
Session: session,
}
// Lock and ignore if lock is held
// It's just a placeholder for the
// ephemeral behavior
lock, _ := s.client.LockOpts(lockOpts)
if lock != nil {
lock.Lock(nil)
}
_, _, err = s.client.Session().Renew(session, nil)
if err != nil {
return s.refreshSession(pair)
}
return nil
}
// getActiveSession checks if the key already has
// a session attached
func (s *Consul) getActiveSession(key string) (string, error) {
pair, _, err := s.client.KV().Get(key, nil)
if err != nil {
return "", err
}
if pair != nil && pair.Session != "" {
return pair.Session, nil
}
return "", nil
}
// Get the value at "key", returns the last modified index
// to use in conjunction to CAS calls
func (s *Consul) Get(key string) (*store.KVPair, error) {
options := &api.QueryOptions{
AllowStale: false,
RequireConsistent: true,
}
pair, meta, err := s.client.KV().Get(s.normalize(key), options)
if err != nil {
return nil, err
}
// If pair is nil then the key does not exist
if pair == nil {
return nil, store.ErrKeyNotFound
}
return &store.KVPair{Key: pair.Key, Value: pair.Value, LastIndex: meta.LastIndex}, nil
}
// Put a value at "key"
func (s *Consul) Put(key string, value []byte, opts *store.WriteOptions) error {
key = s.normalize(key)
p := &api.KVPair{
Key: key,
Value: value,
}
if opts != nil && opts.Ephemeral {
// Create or refresh the session
err := s.refreshSession(p)
if err != nil {
return err
}
}
_, err := s.client.KV().Put(p, nil)
return err
}
// Delete a value at "key"
func (s *Consul) Delete(key string) error {
_, err := s.client.KV().Delete(s.normalize(key), nil)
return err
}
// Exists checks that the key exists inside the store
func (s *Consul) Exists(key string) (bool, error) {
_, err := s.Get(key)
if err != nil && err == store.ErrKeyNotFound {
return false, err
}
return true, nil
}
// List child nodes of a given directory
func (s *Consul) List(directory string) ([]*store.KVPair, error) {
pairs, _, err := s.client.KV().List(s.normalize(directory), nil)
if err != nil {
return nil, err
}
if len(pairs) == 0 {
return nil, store.ErrKeyNotFound
}
kv := []*store.KVPair{}
for _, pair := range pairs {
if pair.Key == directory {
continue
}
kv = append(kv, &store.KVPair{
Key: pair.Key,
Value: pair.Value,
LastIndex: pair.ModifyIndex,
})
}
return kv, nil
}
// DeleteTree deletes a range of keys under a given directory
func (s *Consul) DeleteTree(directory string) error {
_, err := s.client.KV().DeleteTree(s.normalize(directory), nil)
return err
}
// Watch for changes on a "key"
// It returns a channel that will receive changes or pass
// on errors. Upon creation, the current value will first
// be sent to the channel. Providing a non-nil stopCh can
// be used to stop watching.
func (s *Consul) Watch(key string, stopCh <-chan struct{}) (<-chan *store.KVPair, error) {
kv := s.client.KV()
watchCh := make(chan *store.KVPair)
go func() {
defer close(watchCh)
// Use a wait time in order to check if we should quit
// from time to time.
opts := &api.QueryOptions{WaitTime: DefaultWatchWaitTime}
for {
// Check if we should quit
select {
case <-stopCh:
return
default:
}
// Get the key
pair, meta, err := kv.Get(key, opts)
if err != nil {
return
}
// If LastIndex didn't change then it means `Get` returned
// because of the WaitTime and the key didn't changed.
if opts.WaitIndex == meta.LastIndex {
continue
}
opts.WaitIndex = meta.LastIndex
// Return the value to the channel
// FIXME: What happens when a key is deleted?
if pair != nil {
watchCh <- &store.KVPair{
Key: pair.Key,
Value: pair.Value,
LastIndex: pair.ModifyIndex,
}
}
}
}()
return watchCh, nil
}
// WatchTree watches for changes on a "directory"
// It returns a channel that will receive changes or pass
// on errors. Upon creating a watch, the current childs values
// will be sent to the channel .Providing a non-nil stopCh can
// be used to stop watching.
func (s *Consul) WatchTree(directory string, stopCh <-chan struct{}) (<-chan []*store.KVPair, error) {
kv := s.client.KV()
watchCh := make(chan []*store.KVPair)
go func() {
defer close(watchCh)
// Use a wait time in order to check if we should quit
// from time to time.
opts := &api.QueryOptions{WaitTime: DefaultWatchWaitTime}
for {
// Check if we should quit
select {
case <-stopCh:
return
default:
}
// Get all the childrens
pairs, meta, err := kv.List(directory, opts)
if err != nil {
return
}
// If LastIndex didn't change then it means `Get` returned
// because of the WaitTime and the child keys didn't change.
if opts.WaitIndex == meta.LastIndex {
continue
}
opts.WaitIndex = meta.LastIndex
// Return children KV pairs to the channel
kv := []*store.KVPair{}
for _, pair := range pairs {
if pair.Key == directory {
continue
}
kv = append(kv, &store.KVPair{
Key: pair.Key,
Value: pair.Value,
LastIndex: pair.ModifyIndex,
})
}
watchCh <- kv
}
}()
return watchCh, nil
}
// NewLock returns a handle to a lock struct which can
// be used to provide mutual exclusion on a key
func (s *Consul) NewLock(key string, options *store.LockOptions) (store.Locker, error) {
consulOpts := &api.LockOptions{
Key: s.normalize(key),
}
if options != nil {
consulOpts.Value = options.Value
}
l, err := s.client.LockOpts(consulOpts)
if err != nil {
return nil, err
}
return &consulLock{lock: l}, nil
}
// Lock attempts to acquire the lock and blocks while
// doing so. It returns a channel that is closed if our
// lock is lost or if an error occurs
func (l *consulLock) Lock() (<-chan struct{}, error) {
return l.lock.Lock(nil)
}
// Unlock the "key". Calling unlock while
// not holding the lock will throw an error
func (l *consulLock) Unlock() error {
return l.lock.Unlock()
}
// AtomicPut put a value at "key" if the key has not been
// modified in the meantime, throws an error if this is the case
func (s *Consul) AtomicPut(key string, value []byte, previous *store.KVPair, options *store.WriteOptions) (bool, *store.KVPair, error) {
if previous == nil {
return false, nil, store.ErrPreviousNotSpecified
}
p := &api.KVPair{Key: s.normalize(key), Value: value, ModifyIndex: previous.LastIndex}
if work, _, err := s.client.KV().CAS(p, nil); err != nil {
return false, nil, err
} else if !work {
return false, nil, store.ErrKeyModified
}
pair, err := s.Get(key)
if err != nil {
return false, nil, err
}
return true, pair, nil
}
// AtomicDelete deletes a value at "key" if the key has not
// been modified in the meantime, throws an error if this is the case
func (s *Consul) AtomicDelete(key string, previous *store.KVPair) (bool, error) {
if previous == nil {
return false, store.ErrPreviousNotSpecified
}
p := &api.KVPair{Key: s.normalize(key), ModifyIndex: previous.LastIndex}
if work, _, err := s.client.KV().DeleteCAS(p, nil); err != nil {
return false, err
} else if !work {
return false, store.ErrKeyModified
}
return true, nil
}
// Close closes the client connection
func (s *Consul) Close() {
return
}

View file

@ -0,0 +1,478 @@
package etcd
import (
"crypto/tls"
"net"
"net/http"
"strings"
"time"
etcd "github.com/coreos/go-etcd/etcd"
"github.com/docker/libkv/store"
)
// Etcd is the receiver type for the
// Store interface
type Etcd struct {
client *etcd.Client
ephemeralTTL time.Duration
}
type etcdLock struct {
client *etcd.Client
stopLock chan struct{}
key string
value string
last *etcd.Response
ttl uint64
}
const (
periodicSync = 10 * time.Minute
defaultLockTTL = 20 * time.Second
defaultUpdateTime = 5 * time.Second
)
// New creates a new Etcd client given a list
// of endpoints and an optional tls config
func New(addrs []string, options *store.Config) (store.Store, error) {
s := &Etcd{}
entries := store.CreateEndpoints(addrs, "http")
s.client = etcd.NewClient(entries)
// Set options
if options != nil {
if options.TLS != nil {
s.setTLS(options.TLS)
}
if options.ConnectionTimeout != 0 {
s.setTimeout(options.ConnectionTimeout)
}
if options.EphemeralTTL != 0 {
s.setEphemeralTTL(options.EphemeralTTL)
}
}
// Periodic SyncCluster
go func() {
for {
s.client.SyncCluster()
time.Sleep(periodicSync)
}
}()
return s, nil
}
// SetTLS sets the tls configuration given the path
// of certificate files
func (s *Etcd) setTLS(tls *tls.Config) {
// Change to https scheme
var addrs []string
entries := s.client.GetCluster()
for _, entry := range entries {
addrs = append(addrs, strings.Replace(entry, "http", "https", -1))
}
s.client.SetCluster(addrs)
// Set transport
t := http.Transport{
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: tls,
}
s.client.SetTransport(&t)
}
// setTimeout sets the timeout used for connecting to the store
func (s *Etcd) setTimeout(time time.Duration) {
s.client.SetDialTimeout(time)
}
// setEphemeralHeartbeat sets the heartbeat value to notify
// that a node is alive
func (s *Etcd) setEphemeralTTL(time time.Duration) {
s.ephemeralTTL = time
}
// createDirectory creates the entire path for a directory
// that does not exist
func (s *Etcd) createDirectory(path string) error {
if _, err := s.client.CreateDir(store.Normalize(path), 10); err != nil {
if etcdError, ok := err.(*etcd.EtcdError); ok {
// Skip key already exists
if etcdError.ErrorCode != 105 {
return err
}
} else {
return err
}
}
return nil
}
// Get the value at "key", returns the last modified index
// to use in conjunction to Atomic calls
func (s *Etcd) Get(key string) (pair *store.KVPair, err error) {
result, err := s.client.Get(store.Normalize(key), false, false)
if err != nil {
if etcdError, ok := err.(*etcd.EtcdError); ok {
// Not a Directory or Not a file
if etcdError.ErrorCode == 102 || etcdError.ErrorCode == 104 {
return nil, store.ErrKeyNotFound
}
}
return nil, err
}
pair = &store.KVPair{
Key: key,
Value: []byte(result.Node.Value),
LastIndex: result.Node.ModifiedIndex,
}
return pair, nil
}
// Put a value at "key"
func (s *Etcd) Put(key string, value []byte, opts *store.WriteOptions) error {
// Default TTL = 0 means no expiration
var ttl uint64
if opts != nil && opts.Ephemeral {
ttl = uint64(s.ephemeralTTL.Seconds())
}
if _, err := s.client.Set(key, string(value), ttl); err != nil {
if etcdError, ok := err.(*etcd.EtcdError); ok {
// Not a directory
if etcdError.ErrorCode == 104 {
// Remove the last element (the actual key)
// and create the full directory path
err = s.createDirectory(store.GetDirectory(key))
if err != nil {
return err
}
// Now that the directory is created, set the key
if _, err := s.client.Set(key, string(value), ttl); err != nil {
return err
}
}
}
return err
}
return nil
}
// Delete a value at "key"
func (s *Etcd) Delete(key string) error {
_, err := s.client.Delete(store.Normalize(key), false)
return err
}
// Exists checks if the key exists inside the store
func (s *Etcd) Exists(key string) (bool, error) {
entry, err := s.Get(key)
if err != nil && entry != nil {
if err == store.ErrKeyNotFound || entry.Value == nil {
return false, nil
}
return false, err
}
return true, nil
}
// Watch for changes on a "key"
// It returns a channel that will receive changes or pass
// on errors. Upon creation, the current value will first
// be sent to the channel. Providing a non-nil stopCh can
// be used to stop watching.
func (s *Etcd) Watch(key string, stopCh <-chan struct{}) (<-chan *store.KVPair, error) {
// Get the current value
current, err := s.Get(key)
if err != nil {
return nil, err
}
// Start an etcd watch.
// Note: etcd will send the current value through the channel.
etcdWatchCh := make(chan *etcd.Response)
etcdStopCh := make(chan bool)
go s.client.Watch(store.Normalize(key), 0, false, etcdWatchCh, etcdStopCh)
// Adapter goroutine: The goal here is to convert whatever
// format etcd is using into our interface.
watchCh := make(chan *store.KVPair)
go func() {
defer close(watchCh)
// Push the current value through the channel.
watchCh <- current
for {
select {
case result := <-etcdWatchCh:
watchCh <- &store.KVPair{
Key: key,
Value: []byte(result.Node.Value),
LastIndex: result.Node.ModifiedIndex,
}
case <-stopCh:
etcdStopCh <- true
return
}
}
}()
return watchCh, nil
}
// WatchTree watches for changes on a "directory"
// It returns a channel that will receive changes or pass
// on errors. Upon creating a watch, the current childs values
// will be sent to the channel .Providing a non-nil stopCh can
// be used to stop watching.
func (s *Etcd) WatchTree(directory string, stopCh <-chan struct{}) (<-chan []*store.KVPair, error) {
// Get child values
current, err := s.List(directory)
if err != nil {
return nil, err
}
// Start the watch
etcdWatchCh := make(chan *etcd.Response)
etcdStopCh := make(chan bool)
go s.client.Watch(store.Normalize(directory), 0, true, etcdWatchCh, etcdStopCh)
// Adapter goroutine: The goal here is to convert whatever
// format etcd is using into our interface.
watchCh := make(chan []*store.KVPair)
go func() {
defer close(watchCh)
// Push the current value through the channel.
watchCh <- current
for {
select {
case <-etcdWatchCh:
// FIXME: We should probably use the value pushed by the channel.
// However, Node.Nodes seems to be empty.
if list, err := s.List(directory); err == nil {
watchCh <- list
}
case <-stopCh:
etcdStopCh <- true
return
}
}
}()
return watchCh, nil
}
// AtomicPut put a value at "key" if the key has not been
// modified in the meantime, throws an error if this is the case
func (s *Etcd) AtomicPut(key string, value []byte, previous *store.KVPair, options *store.WriteOptions) (bool, *store.KVPair, error) {
if previous == nil {
return false, nil, store.ErrPreviousNotSpecified
}
meta, err := s.client.CompareAndSwap(store.Normalize(key), string(value), 0, "", previous.LastIndex)
if err != nil {
if etcdError, ok := err.(*etcd.EtcdError); ok {
// Compare Failed
if etcdError.ErrorCode == 101 {
return false, nil, store.ErrKeyModified
}
}
return false, nil, err
}
updated := &store.KVPair{
Key: key,
Value: value,
LastIndex: meta.Node.ModifiedIndex,
}
return true, updated, nil
}
// AtomicDelete deletes a value at "key" if the key
// has not been modified in the meantime, throws an
// error if this is the case
func (s *Etcd) AtomicDelete(key string, previous *store.KVPair) (bool, error) {
if previous == nil {
return false, store.ErrPreviousNotSpecified
}
_, err := s.client.CompareAndDelete(store.Normalize(key), "", previous.LastIndex)
if err != nil {
if etcdError, ok := err.(*etcd.EtcdError); ok {
// Compare failed
if etcdError.ErrorCode == 101 {
return false, store.ErrKeyModified
}
}
return false, err
}
return true, nil
}
// List child nodes of a given directory
func (s *Etcd) List(directory string) ([]*store.KVPair, error) {
resp, err := s.client.Get(store.Normalize(directory), true, true)
if err != nil {
return nil, err
}
kv := []*store.KVPair{}
for _, n := range resp.Node.Nodes {
key := strings.TrimLeft(n.Key, "/")
kv = append(kv, &store.KVPair{
Key: key,
Value: []byte(n.Value),
LastIndex: n.ModifiedIndex,
})
}
return kv, nil
}
// DeleteTree deletes a range of keys under a given directory
func (s *Etcd) DeleteTree(directory string) error {
_, err := s.client.Delete(store.Normalize(directory), true)
return err
}
// NewLock returns a handle to a lock struct which can
// be used to provide mutual exclusion on a key
func (s *Etcd) NewLock(key string, options *store.LockOptions) (lock store.Locker, err error) {
var value string
ttl := uint64(time.Duration(defaultLockTTL).Seconds())
// Apply options on Lock
if options != nil {
if options.Value != nil {
value = string(options.Value)
}
if options.TTL != 0 {
ttl = uint64(options.TTL.Seconds())
}
}
// Create lock object
lock = &etcdLock{
client: s.client,
key: key,
value: value,
ttl: ttl,
}
return lock, nil
}
// Lock attempts to acquire the lock and blocks while
// doing so. It returns a channel that is closed if our
// lock is lost or if an error occurs
func (l *etcdLock) Lock() (<-chan struct{}, error) {
key := store.Normalize(l.key)
// Lock holder channels
lockHeld := make(chan struct{})
stopLocking := make(chan struct{})
var lastIndex uint64
for {
resp, err := l.client.Create(key, l.value, l.ttl)
if err != nil {
if etcdError, ok := err.(*etcd.EtcdError); ok {
// Key already exists
if etcdError.ErrorCode != 105 {
lastIndex = ^uint64(0)
}
}
} else {
lastIndex = resp.Node.ModifiedIndex
}
_, err = l.client.CompareAndSwap(key, l.value, l.ttl, "", lastIndex)
if err == nil {
// Leader section
l.stopLock = stopLocking
go l.holdLock(key, lockHeld, stopLocking)
break
} else {
// Seeker section
chW := make(chan *etcd.Response)
chWStop := make(chan bool)
l.waitLock(key, chW, chWStop)
// Delete or Expire event occured
// Retry
}
}
return lockHeld, nil
}
// Hold the lock as long as we can
// Updates the key ttl periodically until we receive
// an explicit stop signal from the Unlock method
func (l *etcdLock) holdLock(key string, lockHeld chan struct{}, stopLocking chan struct{}) {
defer close(lockHeld)
update := time.NewTicker(defaultUpdateTime)
defer update.Stop()
var err error
for {
select {
case <-update.C:
l.last, err = l.client.Update(key, l.value, l.ttl)
if err != nil {
return
}
case <-stopLocking:
return
}
}
}
// WaitLock simply waits for the key to be available for creation
func (l *etcdLock) waitLock(key string, eventCh chan *etcd.Response, stopWatchCh chan bool) {
go l.client.Watch(key, 0, false, eventCh, stopWatchCh)
for event := range eventCh {
if event.Action == "delete" || event.Action == "expire" {
return
}
}
}
// Unlock the "key". Calling unlock while
// not holding the lock will throw an error
func (l *etcdLock) Unlock() error {
if l.stopLock != nil {
l.stopLock <- struct{}{}
}
if l.last != nil {
_, err := l.client.CompareAndDelete(store.Normalize(l.key), l.value, l.last.Node.ModifiedIndex)
if err != nil {
return err
}
}
return nil
}
// Close closes the client connection
func (s *Etcd) Close() {
return
}

View file

@ -0,0 +1,47 @@
package store
import (
"strings"
)
// CreateEndpoints creates a list of endpoints given the right scheme
func CreateEndpoints(addrs []string, scheme string) (entries []string) {
for _, addr := range addrs {
entries = append(entries, scheme+"://"+addr)
}
return entries
}
// Normalize the key for each store to the form:
//
// /path/to/key
//
func Normalize(key string) string {
return "/" + join(SplitKey(key))
}
// GetDirectory gets the full directory part of
// the key to the form:
//
// /path/to/
//
func GetDirectory(key string) string {
parts := SplitKey(key)
parts = parts[:len(parts)-1]
return "/" + join(parts)
}
// SplitKey splits the key to extract path informations
func SplitKey(key string) (path []string) {
if strings.Contains(key, "/") {
path = strings.Split(key, "/")
} else {
path = []string{key}
}
return path
}
// join the path parts with '/'
func join(parts []string) string {
return strings.Join(parts, "/")
}

View file

@ -0,0 +1,118 @@
package store
import (
"crypto/tls"
"errors"
"time"
)
// Backend represents a KV Store Backend
type Backend string
const (
// CONSUL backend
CONSUL = "consul"
// ETCD backend
ETCD = "etcd"
// ZK backend
ZK = "zk"
)
var (
// ErrNotSupported is thrown when the backend k/v store is not supported by libkv
ErrNotSupported = errors.New("Backend storage not supported yet, please choose another one")
// ErrNotImplemented is thrown when a method is not implemented by the current backend
ErrNotImplemented = errors.New("Call not implemented in current backend")
// ErrNotReachable is thrown when the API cannot be reached for issuing common store operations
ErrNotReachable = errors.New("Api not reachable")
// ErrCannotLock is thrown when there is an error acquiring a lock on a key
ErrCannotLock = errors.New("Error acquiring the lock")
// ErrKeyModified is thrown during an atomic operation if the index does not match the one in the store
ErrKeyModified = errors.New("Unable to complete atomic operation, key modified")
// ErrKeyNotFound is thrown when the key is not found in the store during a Get operation
ErrKeyNotFound = errors.New("Key not found in store")
// ErrPreviousNotSpecified is thrown when the previous value is not specified for an atomic operation
ErrPreviousNotSpecified = errors.New("Previous K/V pair should be provided for the Atomic operation")
)
// Config contains the options for a storage client
type Config struct {
TLS *tls.Config
ConnectionTimeout time.Duration
EphemeralTTL time.Duration
}
// Store represents the backend K/V storage
// Each store should support every call listed
// here. Or it couldn't be implemented as a K/V
// backend for libkv
type Store interface {
// Put a value at the specified key
Put(key string, value []byte, options *WriteOptions) error
// Get a value given its key
Get(key string) (*KVPair, error)
// Delete the value at the specified key
Delete(key string) error
// Verify if a Key exists in the store
Exists(key string) (bool, error)
// Watch for changes on a key
Watch(key string, stopCh <-chan struct{}) (<-chan *KVPair, error)
// WatchTree watches for changes on child nodes under
// a given a directory
WatchTree(directory string, stopCh <-chan struct{}) (<-chan []*KVPair, error)
// CreateLock for a given key.
// The returned Locker is not held and must be acquired
// with `.Lock`. The Value is optional.
NewLock(key string, options *LockOptions) (Locker, error)
// List the content of a given prefix
List(directory string) ([]*KVPair, error)
// DeleteTree deletes a range of keys under a given directory
DeleteTree(directory string) error
// Atomic operation on a single value
AtomicPut(key string, value []byte, previous *KVPair, options *WriteOptions) (bool, *KVPair, error)
// Atomic delete of a single value
AtomicDelete(key string, previous *KVPair) (bool, error)
// Close the store connection
Close()
}
// KVPair represents {Key, Value, Lastindex} tuple
type KVPair struct {
Key string
Value []byte
LastIndex uint64
}
// WriteOptions contains optional request parameters
type WriteOptions struct {
Heartbeat time.Duration
Ephemeral bool
}
// LockOptions contains optional request parameters
type LockOptions struct {
Value []byte // Optional, value to associate with the lock
TTL time.Duration // Optional, expiration ttl associated with the lock
}
// WatchCallback is used for watch methods on keys
// and is triggered on key change
type WatchCallback func(entries ...*KVPair)
// Locker provides locking mechanism on top of the store.
// Similar to `sync.Lock` except it may return errors.
type Locker interface {
Lock() (<-chan struct{}, error)
Unlock() error
}

View file

@ -0,0 +1,355 @@
package zookeeper
import (
"strings"
"time"
"github.com/docker/libkv/store"
zk "github.com/samuel/go-zookeeper/zk"
)
const defaultTimeout = 10 * time.Second
// Zookeeper is the receiver type for
// the Store interface
type Zookeeper struct {
timeout time.Duration
client *zk.Conn
}
type zookeeperLock struct {
client *zk.Conn
lock *zk.Lock
key string
value []byte
}
// New creates a new Zookeeper client given a
// list of endpoints and an optional tls config
func New(endpoints []string, options *store.Config) (store.Store, error) {
s := &Zookeeper{}
s.timeout = defaultTimeout
// Set options
if options != nil {
if options.ConnectionTimeout != 0 {
s.setTimeout(options.ConnectionTimeout)
}
}
// Connect to Zookeeper
conn, _, err := zk.Connect(endpoints, s.timeout)
if err != nil {
return nil, err
}
s.client = conn
return s, nil
}
// setTimeout sets the timeout for connecting to Zookeeper
func (s *Zookeeper) setTimeout(time time.Duration) {
s.timeout = time
}
// Get the value at "key", returns the last modified index
// to use in conjunction to Atomic calls
func (s *Zookeeper) Get(key string) (pair *store.KVPair, err error) {
resp, meta, err := s.client.Get(store.Normalize(key))
if err != nil {
return nil, err
}
// If resp is nil, the key does not exist
if resp == nil {
return nil, store.ErrKeyNotFound
}
pair = &store.KVPair{
Key: key,
Value: resp,
LastIndex: uint64(meta.Version),
}
return pair, nil
}
// createFullPath creates the entire path for a directory
// that does not exist
func (s *Zookeeper) createFullPath(path []string, ephemeral bool) error {
for i := 1; i <= len(path); i++ {
newpath := "/" + strings.Join(path[:i], "/")
if i == len(path) && ephemeral {
_, err := s.client.Create(newpath, []byte{1}, zk.FlagEphemeral, zk.WorldACL(zk.PermAll))
return err
}
_, err := s.client.Create(newpath, []byte{1}, 0, zk.WorldACL(zk.PermAll))
if err != nil {
// Skip if node already exists
if err != zk.ErrNodeExists {
return err
}
}
}
return nil
}
// Put a value at "key"
func (s *Zookeeper) Put(key string, value []byte, opts *store.WriteOptions) error {
fkey := store.Normalize(key)
exists, err := s.Exists(key)
if err != nil {
return err
}
if !exists {
if opts != nil && opts.Ephemeral {
s.createFullPath(store.SplitKey(key), opts.Ephemeral)
} else {
s.createFullPath(store.SplitKey(key), false)
}
}
_, err = s.client.Set(fkey, value, -1)
return err
}
// Delete a value at "key"
func (s *Zookeeper) Delete(key string) error {
err := s.client.Delete(store.Normalize(key), -1)
return err
}
// Exists checks if the key exists inside the store
func (s *Zookeeper) Exists(key string) (bool, error) {
exists, _, err := s.client.Exists(store.Normalize(key))
if err != nil {
return false, err
}
return exists, nil
}
// Watch for changes on a "key"
// It returns a channel that will receive changes or pass
// on errors. Upon creation, the current value will first
// be sent to the channel. Providing a non-nil stopCh can
// be used to stop watching.
func (s *Zookeeper) Watch(key string, stopCh <-chan struct{}) (<-chan *store.KVPair, error) {
// Get the key first
pair, err := s.Get(key)
if err != nil {
return nil, err
}
// Catch zk notifications and fire changes into the channel.
watchCh := make(chan *store.KVPair)
go func() {
defer close(watchCh)
// Get returns the current value to the channel prior
// to listening to any event that may occur on that key
watchCh <- pair
for {
_, _, eventCh, err := s.client.GetW(store.Normalize(key))
if err != nil {
return
}
select {
case e := <-eventCh:
if e.Type == zk.EventNodeDataChanged {
if entry, err := s.Get(key); err == nil {
watchCh <- entry
}
}
case <-stopCh:
// There is no way to stop GetW so just quit
return
}
}
}()
return watchCh, nil
}
// WatchTree watches for changes on a "directory"
// It returns a channel that will receive changes or pass
// on errors. Upon creating a watch, the current childs values
// will be sent to the channel .Providing a non-nil stopCh can
// be used to stop watching.
func (s *Zookeeper) WatchTree(directory string, stopCh <-chan struct{}) (<-chan []*store.KVPair, error) {
// List the childrens first
entries, err := s.List(directory)
if err != nil {
return nil, err
}
// Catch zk notifications and fire changes into the channel.
watchCh := make(chan []*store.KVPair)
go func() {
defer close(watchCh)
// List returns the children values to the channel
// prior to listening to any events that may occur
// on those keys
watchCh <- entries
for {
_, _, eventCh, err := s.client.ChildrenW(store.Normalize(directory))
if err != nil {
return
}
select {
case e := <-eventCh:
if e.Type == zk.EventNodeChildrenChanged {
if kv, err := s.List(directory); err == nil {
watchCh <- kv
}
}
case <-stopCh:
// There is no way to stop GetW so just quit
return
}
}
}()
return watchCh, nil
}
// List child nodes of a given directory
func (s *Zookeeper) List(directory string) ([]*store.KVPair, error) {
keys, stat, err := s.client.Children(store.Normalize(directory))
if err != nil {
return nil, err
}
kv := []*store.KVPair{}
// FIXME Costly Get request for each child key..
for _, key := range keys {
pair, err := s.Get(directory + store.Normalize(key))
if err != nil {
return nil, err
}
kv = append(kv, &store.KVPair{
Key: key,
Value: []byte(pair.Value),
LastIndex: uint64(stat.Version),
})
}
return kv, nil
}
// DeleteTree deletes a range of keys under a given directory
func (s *Zookeeper) DeleteTree(directory string) error {
pairs, err := s.List(directory)
if err != nil {
return err
}
var reqs []interface{}
for _, pair := range pairs {
reqs = append(reqs, &zk.DeleteRequest{
Path: store.Normalize(directory + "/" + pair.Key),
Version: -1,
})
}
_, err = s.client.Multi(reqs...)
return err
}
// AtomicPut put a value at "key" if the key has not been
// modified in the meantime, throws an error if this is the case
func (s *Zookeeper) AtomicPut(key string, value []byte, previous *store.KVPair, _ *store.WriteOptions) (bool, *store.KVPair, error) {
if previous == nil {
return false, nil, store.ErrPreviousNotSpecified
}
meta, err := s.client.Set(store.Normalize(key), value, int32(previous.LastIndex))
if err != nil {
// Compare Failed
if err == zk.ErrBadVersion {
return false, nil, store.ErrKeyModified
}
return false, nil, err
}
pair := &store.KVPair{
Key: key,
Value: value,
LastIndex: uint64(meta.Version),
}
return true, pair, nil
}
// AtomicDelete deletes a value at "key" if the key
// has not been modified in the meantime, throws an
// error if this is the case
func (s *Zookeeper) AtomicDelete(key string, previous *store.KVPair) (bool, error) {
if previous == nil {
return false, store.ErrPreviousNotSpecified
}
err := s.client.Delete(store.Normalize(key), int32(previous.LastIndex))
if err != nil {
if err == zk.ErrBadVersion {
return false, store.ErrKeyModified
}
return false, err
}
return true, nil
}
// NewLock returns a handle to a lock struct which can
// be used to provide mutual exclusion on a key
func (s *Zookeeper) NewLock(key string, options *store.LockOptions) (lock store.Locker, err error) {
value := []byte("")
// Apply options
if options != nil {
if options.Value != nil {
value = options.Value
}
}
lock = &zookeeperLock{
client: s.client,
key: store.Normalize(key),
value: value,
lock: zk.NewLock(s.client, store.Normalize(key), zk.WorldACL(zk.PermAll)),
}
return lock, err
}
// Lock attempts to acquire the lock and blocks while
// doing so. It returns a channel that is closed if our
// lock is lost or if an error occurs
func (l *zookeeperLock) Lock() (<-chan struct{}, error) {
err := l.lock.Lock()
if err == nil {
// We hold the lock, we can set our value
// FIXME: The value is left behind
// (problematic for leader election)
_, err = l.client.Set(l.key, l.value, -1)
}
return make(chan struct{}), err
}
// Unlock the "key". Calling unlock while
// not holding the lock will throw an error
func (l *zookeeperLock) Unlock() error {
return l.lock.Unlock()
}
// Close closes the client connection
func (s *Zookeeper) Close() {
s.client.Close()
}

View file

@ -1,4 +1,5 @@
Alessandro Boch <aboch@docker.com> (@aboch)
Alexandr Morozov <lk4d4@docker.com> (@LK4D4)
Arnaud Porterie <arnaud@docker.com> (@icecrime)
Madhu Venugopal <madhu@docker.com> (@mavenugo)
Jana Radhakrishnan <mrjana@docker.com> (@mrjana)
Madhu Venugopal <madhu@docker.com> (@mavenugo)

View file

@ -22,7 +22,7 @@ build: ${build_image}.created
${docker} make build-local
build-local:
$(shell which godep) go build -tags experimental ./...
$(shell which godep) go build -tags libnetwork_discovery ./...
check: ${build_image}.created
${docker} make check-local

View file

@ -48,10 +48,9 @@ There are many networking solutions available to suit a broad range of use-cases
}
// A container can join the endpoint by providing the container ID to the join
// api which returns the sandbox key which can be used to access the sandbox
// created for the container during join.
// api.
// Join acceps Variadic arguments which will be made use of by libnetwork and Drivers
_, err = ep.Join("container1",
err = ep.Join("container1",
libnetwork.JoinOptionHostname("test"),
libnetwork.JoinOptionDomainname("docker.io"))
if err != nil {

View file

@ -0,0 +1,90 @@
package config
import (
"strings"
"github.com/BurntSushi/toml"
)
// Config encapsulates configurations of various Libnetwork components
type Config struct {
Daemon DaemonCfg
Cluster ClusterCfg
Datastore DatastoreCfg
}
// DaemonCfg represents libnetwork core configuration
type DaemonCfg struct {
Debug bool
DefaultNetwork string
DefaultDriver string
}
// ClusterCfg represents cluster configuration
type ClusterCfg struct {
Discovery string
Address string
Heartbeat uint64
}
// DatastoreCfg represents Datastore configuration.
type DatastoreCfg struct {
Embedded bool
Client DatastoreClientCfg
}
// DatastoreClientCfg represents Datastore Client-only mode configuration
type DatastoreClientCfg struct {
Provider string
Address string
}
// ParseConfig parses the libnetwork configuration file
func ParseConfig(tomlCfgFile string) (*Config, error) {
var cfg Config
if _, err := toml.DecodeFile(tomlCfgFile, &cfg); err != nil {
return nil, err
}
return &cfg, nil
}
// Option is a option setter function type used to pass varios configurations
// to the controller
type Option func(c *Config)
// OptionDefaultNetwork function returns an option setter for a default network
func OptionDefaultNetwork(dn string) Option {
return func(c *Config) {
c.Daemon.DefaultNetwork = strings.TrimSpace(dn)
}
}
// OptionDefaultDriver function returns an option setter for default driver
func OptionDefaultDriver(dd string) Option {
return func(c *Config) {
c.Daemon.DefaultDriver = strings.TrimSpace(dd)
}
}
// OptionKVProvider function returns an option setter for kvstore provider
func OptionKVProvider(provider string) Option {
return func(c *Config) {
c.Datastore.Client.Provider = strings.TrimSpace(provider)
}
}
// OptionKVProviderURL function returns an option setter for kvstore url
func OptionKVProviderURL(url string) Option {
return func(c *Config) {
c.Datastore.Client.Address = strings.TrimSpace(url)
}
}
// ProcessOptions processes options and stores it in config
func (c *Config) ProcessOptions(options ...Option) {
for _, opt := range options {
if opt != nil {
opt(c)
}
}
}

View file

@ -0,0 +1,12 @@
title = "LibNetwork Configuration file"
[daemon]
debug = false
[cluster]
discovery = "token://swarm-discovery-token"
Address = "Cluster-wide reachable Host IP"
[datastore]
embedded = false
[datastore.client]
provider = "consul"
Address = "localhost:8500"

View file

@ -3,7 +3,7 @@ Package libnetwork provides the basic functionality and extension points to
create network namespaces and allocate interfaces for containers to use.
// Create a new controller instance
controller, _err := libnetwork.New()
controller, _err := libnetwork.New(nil)
// Select and configure the network driver
networkType := "bridge"
@ -33,10 +33,9 @@ create network namespaces and allocate interfaces for containers to use.
}
// A container can join the endpoint by providing the container ID to the join
// api which returns the sandbox key which can be used to access the sandbox
// created for the container during join.
// api.
// Join acceps Variadic arguments which will be made use of by libnetwork and Drivers
_, err = ep.Join("container1",
err = ep.Join("container1",
libnetwork.JoinOptionHostname("test"),
libnetwork.JoinOptionDomainname("docker.io"))
if err != nil {
@ -46,11 +45,17 @@ create network namespaces and allocate interfaces for containers to use.
package libnetwork
import (
"fmt"
"net"
"sync"
log "github.com/Sirupsen/logrus"
"github.com/docker/docker/pkg/plugins"
"github.com/docker/docker/pkg/stringid"
"github.com/docker/libnetwork/config"
"github.com/docker/libnetwork/datastore"
"github.com/docker/libnetwork/driverapi"
"github.com/docker/libnetwork/hostdiscovery"
"github.com/docker/libnetwork/sandbox"
"github.com/docker/libnetwork/types"
)
@ -61,6 +66,9 @@ type NetworkController interface {
// ConfigureNetworkDriver applies the passed options to the driver instance for the specified network type
ConfigureNetworkDriver(networkType string, options map[string]interface{}) error
// Config method returns the bootup configuration for the controller
Config() config.Config
// Create a new network. The options parameter carries network specific options.
// Labels support will be added in the near future.
NewNetwork(networkType, name string, options ...NetworkOption) (Network, error)
@ -85,11 +93,12 @@ type NetworkController interface {
// When the function returns true, the walk will stop.
type NetworkWalker func(nw Network) bool
type sandboxData struct {
sandbox sandbox.Sandbox
refCnt int
type driverData struct {
driver driverapi.Driver
capability driverapi.Capability
}
type driverTable map[string]*driverData
type networkTable map[types.UUID]*network
type endpointTable map[types.UUID]*endpoint
type sandboxTable map[string]*sandboxData
@ -98,38 +107,88 @@ type controller struct {
networks networkTable
drivers driverTable
sandboxes sandboxTable
cfg *config.Config
store datastore.DataStore
sync.Mutex
}
// New creates a new instance of network controller.
func New() (NetworkController, error) {
func New(cfgOptions ...config.Option) (NetworkController, error) {
var cfg *config.Config
if len(cfgOptions) > 0 {
cfg = &config.Config{}
cfg.ProcessOptions(cfgOptions...)
}
c := &controller{
cfg: cfg,
networks: networkTable{},
sandboxes: sandboxTable{},
drivers: driverTable{}}
if err := initDrivers(c); err != nil {
return nil, err
}
if cfg != nil {
if err := c.initDataStore(); err != nil {
// Failing to initalize datastore is a bad situation to be in.
// But it cannot fail creating the Controller
log.Debugf("Failed to Initialize Datastore due to %v. Operating in non-clustered mode", err)
}
if err := c.initDiscovery(); err != nil {
// Failing to initalize discovery is a bad situation to be in.
// But it cannot fail creating the Controller
log.Debugf("Failed to Initialize Discovery : %v", err)
}
}
return c, nil
}
func (c *controller) validateHostDiscoveryConfig() bool {
if c.cfg == nil || c.cfg.Cluster.Discovery == "" || c.cfg.Cluster.Address == "" {
return false
}
return true
}
func (c *controller) initDiscovery() error {
if c.cfg == nil {
return fmt.Errorf("discovery initialization requires a valid configuration")
}
hostDiscovery := hostdiscovery.NewHostDiscovery()
return hostDiscovery.StartDiscovery(&c.cfg.Cluster, c.hostJoinCallback, c.hostLeaveCallback)
}
func (c *controller) hostJoinCallback(hosts []net.IP) {
}
func (c *controller) hostLeaveCallback(hosts []net.IP) {
}
func (c *controller) Config() config.Config {
c.Lock()
defer c.Unlock()
return *c.cfg
}
func (c *controller) ConfigureNetworkDriver(networkType string, options map[string]interface{}) error {
c.Lock()
d, ok := c.drivers[networkType]
dd, ok := c.drivers[networkType]
c.Unlock()
if !ok {
return NetworkTypeError(networkType)
}
return d.Config(options)
return dd.driver.Config(options)
}
func (c *controller) RegisterDriver(networkType string, driver driverapi.Driver) error {
func (c *controller) RegisterDriver(networkType string, driver driverapi.Driver, capability driverapi.Capability) error {
c.Lock()
defer c.Unlock()
if _, ok := c.drivers[networkType]; ok {
return driverapi.ErrActiveRegistration(networkType)
}
c.drivers[networkType] = driver
c.drivers[networkType] = &driverData{driver, capability}
return nil
}
@ -139,18 +198,6 @@ func (c *controller) NewNetwork(networkType, name string, options ...NetworkOpti
if name == "" {
return nil, ErrInvalidName(name)
}
// Check if a driver for the specified network type is available
c.Lock()
d, ok := c.drivers[networkType]
c.Unlock()
if !ok {
var err error
d, err = c.loadDriver(networkType)
if err != nil {
return nil, err
}
}
// Check if a network already exists with the specified network name
c.Lock()
for _, n := range c.networks {
@ -163,27 +210,60 @@ func (c *controller) NewNetwork(networkType, name string, options ...NetworkOpti
// Construct the network object
network := &network{
name: name,
id: types.UUID(stringid.GenerateRandomID()),
ctrlr: c,
driver: d,
endpoints: endpointTable{},
name: name,
networkType: networkType,
id: types.UUID(stringid.GenerateRandomID()),
ctrlr: c,
endpoints: endpointTable{},
}
network.processOptions(options...)
// Create the network
if err := d.CreateNetwork(network.id, network.generic); err != nil {
if err := c.addNetwork(network); err != nil {
return nil, err
}
// Store the network handler in controller
c.Lock()
c.networks[network.id] = network
c.Unlock()
if err := c.updateNetworkToStore(network); err != nil {
if e := network.Delete(); e != nil {
log.Warnf("couldnt cleanup network %s: %v", network.name, err)
}
return nil, err
}
return network, nil
}
func (c *controller) addNetwork(n *network) error {
c.Lock()
// Check if a driver for the specified network type is available
dd, ok := c.drivers[n.networkType]
c.Unlock()
if !ok {
var err error
dd, err = c.loadDriver(n.networkType)
if err != nil {
return err
}
}
n.Lock()
n.driver = dd.driver
d := n.driver
n.Unlock()
// Create the network
if err := d.CreateNetwork(n.id, n.generic); err != nil {
return err
}
c.Lock()
c.networks[n.id] = n
c.Unlock()
return nil
}
func (c *controller) Networks() []Network {
c.Lock()
defer c.Unlock()
@ -239,52 +319,7 @@ func (c *controller) NetworkByID(id string) (Network, error) {
return nil, ErrNoSuchNetwork(id)
}
func (c *controller) sandboxAdd(key string, create bool) (sandbox.Sandbox, error) {
c.Lock()
defer c.Unlock()
sData, ok := c.sandboxes[key]
if !ok {
sb, err := sandbox.NewSandbox(key, create)
if err != nil {
return nil, err
}
sData = &sandboxData{sandbox: sb, refCnt: 1}
c.sandboxes[key] = sData
return sData.sandbox, nil
}
sData.refCnt++
return sData.sandbox, nil
}
func (c *controller) sandboxRm(key string) {
c.Lock()
defer c.Unlock()
sData := c.sandboxes[key]
sData.refCnt--
if sData.refCnt == 0 {
sData.sandbox.Destroy()
delete(c.sandboxes, key)
}
}
func (c *controller) sandboxGet(key string) sandbox.Sandbox {
c.Lock()
defer c.Unlock()
sData, ok := c.sandboxes[key]
if !ok {
return nil
}
return sData.sandbox
}
func (c *controller) loadDriver(networkType string) (driverapi.Driver, error) {
func (c *controller) loadDriver(networkType string) (*driverData, error) {
// Plugins pkg performs lazy loading of plugins that acts as remote drivers.
// As per the design, this Get call will result in remote driver discovery if there is a corresponding plugin available.
_, err := plugins.Get(networkType, driverapi.NetworkPluginEndpointType)
@ -296,11 +331,24 @@ func (c *controller) loadDriver(networkType string) (driverapi.Driver, error) {
}
c.Lock()
defer c.Unlock()
d, ok := c.drivers[networkType]
dd, ok := c.drivers[networkType]
if !ok {
return nil, ErrInvalidNetworkDriver(networkType)
}
return d, nil
return dd, nil
}
func (c *controller) isDriverGlobalScoped(networkType string) (bool, error) {
c.Lock()
dd, ok := c.drivers[networkType]
c.Unlock()
if !ok {
return false, types.NotFoundErrorf("driver not found for %s", networkType)
}
if dd.capability.Scope == driverapi.GlobalScope {
return true, nil
}
return false, nil
}
func (c *controller) GC() {

View file

@ -0,0 +1,174 @@
package datastore
import (
"encoding/json"
"reflect"
"strings"
"github.com/docker/libkv"
"github.com/docker/libkv/store"
"github.com/docker/libnetwork/config"
"github.com/docker/libnetwork/types"
)
//DataStore exported
type DataStore interface {
// GetObject gets data from datastore and unmarshals to the specified object
GetObject(key string, o interface{}) error
// PutObject adds a new Record based on an object into the datastore
PutObject(kvObject KV) error
// PutObjectAtomic provides an atomic add and update operation for a Record
PutObjectAtomic(kvObject KV) error
// DeleteObject deletes a record
DeleteObject(kvObject KV) error
// DeleteObjectAtomic performs an atomic delete operation
DeleteObjectAtomic(kvObject KV) error
// DeleteTree deletes a record
DeleteTree(kvObject KV) error
// KVStore returns access to the KV Store
KVStore() store.Store
}
// ErrKeyModified is raised for an atomic update when the update is working on a stale state
var ErrKeyModified = store.ErrKeyModified
type datastore struct {
store store.Store
}
//KV Key Value interface used by objects to be part of the DataStore
type KV interface {
// Key method lets an object to provide the Key to be used in KV Store
Key() []string
// KeyPrefix method lets an object to return immediate parent key that can be used for tree walk
KeyPrefix() []string
// Value method lets an object to marshal its content to be stored in the KV store
Value() []byte
// Index method returns the latest DB Index as seen by the object
Index() uint64
// SetIndex method allows the datastore to store the latest DB Index into the object
SetIndex(uint64)
}
const (
// NetworkKeyPrefix is the prefix for network key in the kv store
NetworkKeyPrefix = "network"
// EndpointKeyPrefix is the prefix for endpoint key in the kv store
EndpointKeyPrefix = "endpoint"
)
var rootChain = []string{"docker", "libnetwork"}
//Key provides convenient method to create a Key
func Key(key ...string) string {
keychain := append(rootChain, key...)
str := strings.Join(keychain, "/")
return str + "/"
}
//ParseKey provides convenient method to unpack the key to complement the Key function
func ParseKey(key string) ([]string, error) {
chain := strings.Split(strings.Trim(key, "/"), "/")
// The key must atleast be equal to the rootChain in order to be considered as valid
if len(chain) <= len(rootChain) || !reflect.DeepEqual(chain[0:len(rootChain)], rootChain) {
return nil, types.BadRequestErrorf("invalid Key : %s", key)
}
return chain[len(rootChain):], nil
}
// newClient used to connect to KV Store
func newClient(kv string, addrs string) (DataStore, error) {
store, err := libkv.NewStore(store.Backend(kv), []string{addrs}, &store.Config{})
if err != nil {
return nil, err
}
ds := &datastore{store: store}
return ds, nil
}
// NewDataStore creates a new instance of LibKV data store
func NewDataStore(cfg *config.DatastoreCfg) (DataStore, error) {
if cfg == nil {
return nil, types.BadRequestErrorf("invalid configuration passed to datastore")
}
// TODO : cfg.Embedded case
return newClient(cfg.Client.Provider, cfg.Client.Address)
}
// NewCustomDataStore can be used by clients to plugin cusom datatore that adhers to store.Store
func NewCustomDataStore(customStore store.Store) DataStore {
return &datastore{store: customStore}
}
func (ds *datastore) KVStore() store.Store {
return ds.store
}
// PutObjectAtomic adds a new Record based on an object into the datastore
func (ds *datastore) PutObjectAtomic(kvObject KV) error {
if kvObject == nil {
return types.BadRequestErrorf("invalid KV Object : nil")
}
kvObjValue := kvObject.Value()
if kvObjValue == nil {
return types.BadRequestErrorf("invalid KV Object with a nil Value for key %s", Key(kvObject.Key()...))
}
previous := &store.KVPair{Key: Key(kvObject.Key()...), LastIndex: kvObject.Index()}
_, pair, err := ds.store.AtomicPut(Key(kvObject.Key()...), kvObjValue, previous, nil)
if err != nil {
return err
}
kvObject.SetIndex(pair.LastIndex)
return nil
}
// PutObject adds a new Record based on an object into the datastore
func (ds *datastore) PutObject(kvObject KV) error {
if kvObject == nil {
return types.BadRequestErrorf("invalid KV Object : nil")
}
return ds.putObjectWithKey(kvObject, kvObject.Key()...)
}
func (ds *datastore) putObjectWithKey(kvObject KV, key ...string) error {
kvObjValue := kvObject.Value()
if kvObjValue == nil {
return types.BadRequestErrorf("invalid KV Object with a nil Value for key %s", Key(kvObject.Key()...))
}
return ds.store.Put(Key(key...), kvObjValue, nil)
}
// GetObject returns a record matching the key
func (ds *datastore) GetObject(key string, o interface{}) error {
kvPair, err := ds.store.Get(key)
if err != nil {
return err
}
return json.Unmarshal(kvPair.Value, o)
}
// DeleteObject unconditionally deletes a record from the store
func (ds *datastore) DeleteObject(kvObject KV) error {
return ds.store.Delete(Key(kvObject.Key()...))
}
// DeleteObjectAtomic performs atomic delete on a record
func (ds *datastore) DeleteObjectAtomic(kvObject KV) error {
if kvObject == nil {
return types.BadRequestErrorf("invalid KV Object : nil")
}
previous := &store.KVPair{Key: Key(kvObject.Key()...), LastIndex: kvObject.Index()}
_, err := ds.store.AtomicDelete(Key(kvObject.Key()...), previous)
return err
}
// DeleteTree unconditionally deletes a record from the store
func (ds *datastore) DeleteTree(kvObject KV) error {
return ds.store.DeleteTree(Key(kvObject.KeyPrefix()...))
}

View file

@ -0,0 +1,119 @@
package datastore
import (
"errors"
"github.com/docker/libkv/store"
"github.com/docker/libnetwork/types"
)
var (
// ErrNotImplmented exported
ErrNotImplmented = errors.New("Functionality not implemented")
)
// MockData exported
type MockData struct {
Data []byte
Index uint64
}
// MockStore exported
type MockStore struct {
db map[string]*MockData
}
// NewMockStore creates a Map backed Datastore that is useful for mocking
func NewMockStore() *MockStore {
db := make(map[string]*MockData)
return &MockStore{db}
}
// Get the value at "key", returns the last modified index
// to use in conjunction to CAS calls
func (s *MockStore) Get(key string) (*store.KVPair, error) {
mData := s.db[key]
if mData == nil {
return nil, nil
}
return &store.KVPair{Value: mData.Data, LastIndex: mData.Index}, nil
}
// Put a value at "key"
func (s *MockStore) Put(key string, value []byte, options *store.WriteOptions) error {
mData := s.db[key]
if mData == nil {
mData = &MockData{value, 0}
}
mData.Index = mData.Index + 1
s.db[key] = mData
return nil
}
// Delete a value at "key"
func (s *MockStore) Delete(key string) error {
delete(s.db, key)
return nil
}
// Exists checks that the key exists inside the store
func (s *MockStore) Exists(key string) (bool, error) {
_, ok := s.db[key]
return ok, nil
}
// List gets a range of values at "directory"
func (s *MockStore) List(prefix string) ([]*store.KVPair, error) {
return nil, ErrNotImplmented
}
// DeleteTree deletes a range of values at "directory"
func (s *MockStore) DeleteTree(prefix string) error {
delete(s.db, prefix)
return nil
}
// Watch a single key for modifications
func (s *MockStore) Watch(key string, stopCh <-chan struct{}) (<-chan *store.KVPair, error) {
return nil, ErrNotImplmented
}
// WatchTree triggers a watch on a range of values at "directory"
func (s *MockStore) WatchTree(prefix string, stopCh <-chan struct{}) (<-chan []*store.KVPair, error) {
return nil, ErrNotImplmented
}
// NewLock exposed
func (s *MockStore) NewLock(key string, options *store.LockOptions) (store.Locker, error) {
return nil, ErrNotImplmented
}
// AtomicPut put a value at "key" if the key has not been
// modified in the meantime, throws an error if this is the case
func (s *MockStore) AtomicPut(key string, newValue []byte, previous *store.KVPair, options *store.WriteOptions) (bool, *store.KVPair, error) {
mData := s.db[key]
if mData != nil && mData.Index != previous.LastIndex {
return false, nil, types.BadRequestErrorf("atomic put failed due to mismatched Index")
}
err := s.Put(key, newValue, nil)
if err != nil {
return false, nil, err
}
return true, &store.KVPair{Key: key, Value: newValue, LastIndex: s.db[key].Index}, nil
}
// AtomicDelete deletes a value at "key" if the key has not
// been modified in the meantime, throws an error if this is the case
func (s *MockStore) AtomicDelete(key string, previous *store.KVPair) (bool, error) {
mData := s.db[key]
if mData != nil && mData.Index != previous.LastIndex {
return false, types.BadRequestErrorf("atomic delete failed due to mismatched Index")
}
return true, s.Delete(key)
}
// Close closes the client connection
func (s *MockStore) Close() {
return
}

View file

@ -104,6 +104,10 @@ type JoinInfo interface {
// SetGatewayIPv6 sets the default IPv6 gateway when a container joins the endpoint.
SetGatewayIPv6(net.IP) error
// AddStaticRoute adds a routes to the sandbox.
// It may be used in addtion to or instead of a default gateway (as above).
AddStaticRoute(destination *net.IPNet, routeType int, nextHop net.IP, interfaceID int) error
// SetHostsPath sets the overriding /etc/hosts path to use for the container.
SetHostsPath(string) error
@ -114,5 +118,20 @@ type JoinInfo interface {
// DriverCallback provides a Callback interface for Drivers into LibNetwork
type DriverCallback interface {
// RegisterDriver provides a way for Remote drivers to dynamically register new NetworkType and associate with a driver instance
RegisterDriver(name string, driver Driver) error
RegisterDriver(name string, driver Driver, capability Capability) error
}
// Scope indicates the drivers scope capability
type Scope int
const (
// LocalScope represents the driver capable of providing networking services for containers in a single host
LocalScope Scope = iota
// GlobalScope represents the driver capable of providing networking services for containers across hosts
GlobalScope
)
// Capability represents the high level capabilities of the drivers which libnetwork can make use of
type Capability struct {
Scope Scope
}

View file

@ -4,7 +4,7 @@ import (
"errors"
"net"
"os/exec"
"strings"
"strconv"
"sync"
"github.com/Sirupsen/logrus"
@ -14,7 +14,6 @@ import (
"github.com/docker/libnetwork/netutils"
"github.com/docker/libnetwork/options"
"github.com/docker/libnetwork/portmapper"
"github.com/docker/libnetwork/sandbox"
"github.com/docker/libnetwork/types"
"github.com/vishvananda/netlink"
)
@ -30,16 +29,15 @@ const (
var (
ipAllocator *ipallocator.IPAllocator
portMapper *portmapper.PortMapper
)
// Configuration info for the "bridge" driver.
type Configuration struct {
// configuration info for the "bridge" driver.
type configuration struct {
EnableIPForwarding bool
}
// NetworkConfiguration for network specific configuration
type NetworkConfiguration struct {
// networkConfiguration for network specific configuration
type networkConfiguration struct {
BridgeName string
AddressIPv4 *net.IPNet
FixedCIDR *net.IPNet
@ -56,50 +54,53 @@ type NetworkConfiguration struct {
EnableUserlandProxy bool
}
// EndpointConfiguration represents the user specified configuration for the sandbox endpoint
type EndpointConfiguration struct {
// endpointConfiguration represents the user specified configuration for the sandbox endpoint
type endpointConfiguration struct {
MacAddress net.HardwareAddr
PortBindings []types.PortBinding
ExposedPorts []types.TransportPort
}
// ContainerConfiguration represents the user specified configuration for a container
type ContainerConfiguration struct {
// containerConfiguration represents the user specified configuration for a container
type containerConfiguration struct {
ParentEndpoints []string
ChildEndpoints []string
}
type bridgeEndpoint struct {
id types.UUID
intf *sandbox.Interface
srcName string
addr *net.IPNet
addrv6 *net.IPNet
macAddress net.HardwareAddr
config *EndpointConfiguration // User specified parameters
containerConfig *ContainerConfiguration
config *endpointConfiguration // User specified parameters
containerConfig *containerConfiguration
portMapping []types.PortBinding // Operation port bindings
}
type bridgeNetwork struct {
id types.UUID
bridge *bridgeInterface // The bridge's L3 interface
config *NetworkConfiguration
endpoints map[types.UUID]*bridgeEndpoint // key: endpoint id
id types.UUID
bridge *bridgeInterface // The bridge's L3 interface
config *networkConfiguration
endpoints map[types.UUID]*bridgeEndpoint // key: endpoint id
portMapper *portmapper.PortMapper
sync.Mutex
}
type driver struct {
config *Configuration
network *bridgeNetwork
config *configuration
network *bridgeNetwork
networks map[types.UUID]*bridgeNetwork
sync.Mutex
}
func init() {
ipAllocator = ipallocator.New()
portMapper = portmapper.New()
}
// New constructs a new bridge driver
func newDriver() driverapi.Driver {
return &driver{}
return &driver{networks: map[types.UUID]*bridgeNetwork{}}
}
// Init registers a new instance of bridge driver
@ -109,13 +110,15 @@ func Init(dc driverapi.DriverCallback) error {
if out, err := exec.Command("modprobe", "-va", "bridge", "nf_nat", "br_netfilter").Output(); err != nil {
logrus.Warnf("Running modprobe bridge nf_nat failed with message: %s, error: %v", out, err)
}
return dc.RegisterDriver(networkType, newDriver())
c := driverapi.Capability{
Scope: driverapi.LocalScope,
}
return dc.RegisterDriver(networkType, newDriver(), c)
}
// Validate performs a static validation on the network configuration parameters.
// Whatever can be assessed a priori before attempting any programming.
func (c *NetworkConfiguration) Validate() error {
func (c *networkConfiguration) Validate() error {
if c.Mtu < 0 {
return ErrInvalidMtu(c.Mtu)
}
@ -153,6 +156,167 @@ func (c *NetworkConfiguration) Validate() error {
return nil
}
// Conflicts check if two NetworkConfiguration objects overlap
func (c *networkConfiguration) Conflicts(o *networkConfiguration) bool {
if o == nil {
return false
}
// Also empty, becasue only one network with empty name is allowed
if c.BridgeName == o.BridgeName {
return true
}
// They must be in different subnets
if (c.AddressIPv4 != nil && o.AddressIPv4 != nil) &&
(c.AddressIPv4.Contains(o.AddressIPv4.IP) || o.AddressIPv4.Contains(c.AddressIPv4.IP)) {
return true
}
return false
}
// fromMap retrieve the configuration data from the map form.
func (c *networkConfiguration) fromMap(data map[string]interface{}) error {
var err error
if i, ok := data["BridgeName"]; ok && i != nil {
if c.BridgeName, ok = i.(string); !ok {
return types.BadRequestErrorf("invalid type for BridgeName value")
}
}
if i, ok := data["Mtu"]; ok && i != nil {
if s, ok := i.(string); ok {
if c.Mtu, err = strconv.Atoi(s); err != nil {
return types.BadRequestErrorf("failed to parse Mtu value: %s", err.Error())
}
} else {
return types.BadRequestErrorf("invalid type for Mtu value")
}
}
if i, ok := data["EnableIPv6"]; ok && i != nil {
if s, ok := i.(string); ok {
if c.EnableIPv6, err = strconv.ParseBool(s); err != nil {
return types.BadRequestErrorf("failed to parse EnableIPv6 value: %s", err.Error())
}
} else {
return types.BadRequestErrorf("invalid type for EnableIPv6 value")
}
}
if i, ok := data["EnableIPTables"]; ok && i != nil {
if s, ok := i.(string); ok {
if c.EnableIPTables, err = strconv.ParseBool(s); err != nil {
return types.BadRequestErrorf("failed to parse EnableIPTables value: %s", err.Error())
}
} else {
return types.BadRequestErrorf("invalid type for EnableIPTables value")
}
}
if i, ok := data["EnableIPMasquerade"]; ok && i != nil {
if s, ok := i.(string); ok {
if c.EnableIPMasquerade, err = strconv.ParseBool(s); err != nil {
return types.BadRequestErrorf("failed to parse EnableIPMasquerade value: %s", err.Error())
}
} else {
return types.BadRequestErrorf("invalid type for EnableIPMasquerade value")
}
}
if i, ok := data["EnableICC"]; ok && i != nil {
if s, ok := i.(string); ok {
if c.EnableICC, err = strconv.ParseBool(s); err != nil {
return types.BadRequestErrorf("failed to parse EnableICC value: %s", err.Error())
}
} else {
return types.BadRequestErrorf("invalid type for EnableICC value")
}
}
if i, ok := data["AllowNonDefaultBridge"]; ok && i != nil {
if s, ok := i.(string); ok {
if c.AllowNonDefaultBridge, err = strconv.ParseBool(s); err != nil {
return types.BadRequestErrorf("failed to parse AllowNonDefaultBridge value: %s", err.Error())
}
} else {
return types.BadRequestErrorf("invalid type for AllowNonDefaultBridge value")
}
}
if i, ok := data["AddressIPv4"]; ok && i != nil {
if s, ok := i.(string); ok {
if ip, nw, e := net.ParseCIDR(s); e == nil {
nw.IP = ip
c.AddressIPv4 = nw
} else {
return types.BadRequestErrorf("failed to parse AddressIPv4 value")
}
} else {
return types.BadRequestErrorf("invalid type for AddressIPv4 value")
}
}
if i, ok := data["FixedCIDR"]; ok && i != nil {
if s, ok := i.(string); ok {
if ip, nw, e := net.ParseCIDR(s); e == nil {
nw.IP = ip
c.FixedCIDR = nw
} else {
return types.BadRequestErrorf("failed to parse FixedCIDR value")
}
} else {
return types.BadRequestErrorf("invalid type for FixedCIDR value")
}
}
if i, ok := data["FixedCIDRv6"]; ok && i != nil {
if s, ok := i.(string); ok {
if ip, nw, e := net.ParseCIDR(s); e == nil {
nw.IP = ip
c.FixedCIDRv6 = nw
} else {
return types.BadRequestErrorf("failed to parse FixedCIDRv6 value")
}
} else {
return types.BadRequestErrorf("invalid type for FixedCIDRv6 value")
}
}
if i, ok := data["DefaultGatewayIPv4"]; ok && i != nil {
if s, ok := i.(string); ok {
if c.DefaultGatewayIPv4 = net.ParseIP(s); c.DefaultGatewayIPv4 == nil {
return types.BadRequestErrorf("failed to parse DefaultGatewayIPv4 value")
}
} else {
return types.BadRequestErrorf("invalid type for DefaultGatewayIPv4 value")
}
}
if i, ok := data["DefaultGatewayIPv6"]; ok && i != nil {
if s, ok := i.(string); ok {
if c.DefaultGatewayIPv6 = net.ParseIP(s); c.DefaultGatewayIPv6 == nil {
return types.BadRequestErrorf("failed to parse DefaultGatewayIPv6 value")
}
} else {
return types.BadRequestErrorf("invalid type for DefaultGatewayIPv6 value")
}
}
if i, ok := data["DefaultBindingIP"]; ok && i != nil {
if s, ok := i.(string); ok {
if c.DefaultBindingIP = net.ParseIP(s); c.DefaultBindingIP == nil {
return types.BadRequestErrorf("failed to parse DefaultBindingIP value")
}
} else {
return types.BadRequestErrorf("invalid type for DefaultBindingIP value")
}
}
return nil
}
func (n *bridgeNetwork) getEndpoint(eid types.UUID) (*bridgeEndpoint, error) {
n.Lock()
defer n.Unlock()
@ -168,8 +332,75 @@ func (n *bridgeNetwork) getEndpoint(eid types.UUID) (*bridgeEndpoint, error) {
return nil, nil
}
// Install/Removes the iptables rules needed to isolate this network
// from each of the other networks
func (n *bridgeNetwork) isolateNetwork(others []*bridgeNetwork, enable bool) error {
n.Lock()
thisV4 := n.bridge.bridgeIPv4
thisV6 := getV6Network(n.config, n.bridge)
n.Unlock()
// Install the rules to isolate this networks against each of the other networks
for _, o := range others {
o.Lock()
otherV4 := o.bridge.bridgeIPv4
otherV6 := getV6Network(o.config, o.bridge)
o.Unlock()
if !types.CompareIPNet(thisV4, otherV4) {
// It's ok to pass a.b.c.d/x, iptables will ignore the host subnet bits
if err := setINC(thisV4.String(), otherV4.String(), enable); err != nil {
return err
}
}
if thisV6 != nil && otherV6 != nil && !types.CompareIPNet(thisV6, otherV6) {
if err := setINC(thisV6.String(), otherV6.String(), enable); err != nil {
return err
}
}
}
return nil
}
// Checks whether this network's configuration for the network with this id conflicts with any of the passed networks
func (c *networkConfiguration) conflictsWithNetworks(id types.UUID, others []*bridgeNetwork) error {
for _, nw := range others {
nw.Lock()
nwID := nw.id
nwConfig := nw.config
nwBridge := nw.bridge
nw.Unlock()
if nwID == id {
continue
}
// Verify the name (which may have been set by newInterface()) does not conflict with
// existing bridge interfaces. Ironically the system chosen name gets stored in the config...
// Basically we are checking if the two original configs were both empty.
if nwConfig.BridgeName == c.BridgeName {
return types.ForbiddenErrorf("conflicts with network %s (%s) by bridge name", nwID, nwConfig.BridgeName)
}
// If this network config specifies the AddressIPv4, we need
// to make sure it does not conflict with any previously allocated
// bridges. This could not be completely caught by the config conflict
// check, because networks which config does not specify the AddressIPv4
// get their address and subnet selected by the driver (see electBridgeIPv4())
if c.AddressIPv4 != nil {
if nwBridge.bridgeIPv4.Contains(c.AddressIPv4.IP) ||
c.AddressIPv4.Contains(nwBridge.bridgeIPv4.IP) {
return types.ForbiddenErrorf("conflicts with network %s (%s) by ip network", nwID, nwConfig.BridgeName)
}
}
}
return nil
}
func (d *driver) Config(option map[string]interface{}) error {
var config *Configuration
var config *configuration
d.Lock()
defer d.Unlock()
@ -182,12 +413,12 @@ func (d *driver) Config(option map[string]interface{}) error {
if ok && genericData != nil {
switch opt := genericData.(type) {
case options.Generic:
opaqueConfig, err := options.GenerateFromModel(opt, &Configuration{})
opaqueConfig, err := options.GenerateFromModel(opt, &configuration{})
if err != nil {
return err
}
config = opaqueConfig.(*Configuration)
case *Configuration:
config = opaqueConfig.(*configuration)
case *configuration:
config = opt
default:
return &ErrInvalidDriverConfig{}
@ -195,7 +426,7 @@ func (d *driver) Config(option map[string]interface{}) error {
d.config = config
} else {
config = &Configuration{}
config = &configuration{}
}
if config.EnableIPForwarding {
@ -206,82 +437,170 @@ func (d *driver) Config(option map[string]interface{}) error {
}
func (d *driver) getNetwork(id types.UUID) (*bridgeNetwork, error) {
// Just a dummy function to return the only network managed by Bridge driver.
// But this API makes the caller code unchanged when we move to support multiple networks.
d.Lock()
defer d.Unlock()
return d.network, nil
}
func parseNetworkOptions(option options.Generic) (*NetworkConfiguration, error) {
var config *NetworkConfiguration
genericData, ok := option[netlabel.GenericData]
if ok && genericData != nil {
switch opt := genericData.(type) {
case options.Generic:
opaqueConfig, err := options.GenerateFromModel(opt, &NetworkConfiguration{})
if err != nil {
return nil, err
}
config = opaqueConfig.(*NetworkConfiguration)
case *NetworkConfiguration:
config = opt
default:
return nil, &ErrInvalidNetworkConfig{}
}
if err := config.Validate(); err != nil {
return nil, err
}
} else {
config = &NetworkConfiguration{}
if id == "" {
return nil, types.BadRequestErrorf("invalid network id: %s", id)
}
if nw, ok := d.networks[id]; ok {
return nw, nil
}
return nil, nil
}
func parseNetworkGenericOptions(data interface{}) (*networkConfiguration, error) {
var (
err error
config *networkConfiguration
)
switch opt := data.(type) {
case *networkConfiguration:
config = opt
case map[string]interface{}:
config = &networkConfiguration{
EnableICC: true,
EnableIPTables: true,
EnableIPMasquerade: true,
}
err = config.fromMap(opt)
case options.Generic:
var opaqueConfig interface{}
if opaqueConfig, err = options.GenerateFromModel(opt, config); err == nil {
config = opaqueConfig.(*networkConfiguration)
}
default:
err = types.BadRequestErrorf("do not recognize network configuration format: %T", opt)
}
return config, err
}
func parseNetworkOptions(option options.Generic) (*networkConfiguration, error) {
var err error
config := &networkConfiguration{}
// Parse generic label first, config will be re-assigned
if genData, ok := option[netlabel.GenericData]; ok && genData != nil {
if config, err = parseNetworkGenericOptions(genData); err != nil {
return nil, err
}
}
// Process well-known labels next
if _, ok := option[netlabel.EnableIPv6]; ok {
config.EnableIPv6 = option[netlabel.EnableIPv6].(bool)
}
// Finally validate the configuration
if err = config.Validate(); err != nil {
return nil, err
}
return config, nil
}
// Returns the non link-local IPv6 subnet for the containers attached to this bridge if found, nil otherwise
func getV6Network(config *networkConfiguration, i *bridgeInterface) *net.IPNet {
if config.FixedCIDRv6 != nil {
return config.FixedCIDRv6
}
if i.bridgeIPv6 != nil && i.bridgeIPv6.IP != nil && !i.bridgeIPv6.IP.IsLinkLocalUnicast() {
return i.bridgeIPv6
}
return nil
}
// Return a slice of networks over which caller can iterate safely
func (d *driver) getNetworks() []*bridgeNetwork {
d.Lock()
defer d.Unlock()
ls := make([]*bridgeNetwork, 0, len(d.networks))
for _, nw := range d.networks {
ls = append(ls, nw)
}
return ls
}
// Create a new network using bridge plugin
func (d *driver) CreateNetwork(id types.UUID, option map[string]interface{}) error {
var err error
// Driver must be configured
d.Lock()
// Sanity checks
if d.network != nil {
d.Lock()
if _, ok := d.networks[id]; ok {
d.Unlock()
return &ErrNetworkExists{}
return types.ForbiddenErrorf("network %s exists", id)
}
d.Unlock()
// Parse and validate the config. It should not conflict with existing networks' config
config, err := parseNetworkOptions(option)
if err != nil {
return err
}
networkList := d.getNetworks()
for _, nw := range networkList {
nw.Lock()
nwConfig := nw.config
nw.Unlock()
if nwConfig.Conflicts(config) {
return types.ForbiddenErrorf("conflicts with network %s (%s)", nw.id, nw.config.BridgeName)
}
}
// Create and set network handler in driver
d.network = &bridgeNetwork{id: id, endpoints: make(map[types.UUID]*bridgeEndpoint)}
network := d.network
network := &bridgeNetwork{
id: id,
endpoints: make(map[types.UUID]*bridgeEndpoint),
config: config,
portMapper: portmapper.New(),
}
d.Lock()
d.networks[id] = network
d.Unlock()
// On failure make sure to reset driver network handler to nil
defer func() {
if err != nil {
d.Lock()
d.network = nil
delete(d.networks, id)
d.Unlock()
}
}()
config, err := parseNetworkOptions(option)
if err != nil {
return err
}
network.config = config
// Create or retrieve the bridge L3 interface
bridgeIface := newInterface(config)
network.bridge = bridgeIface
// Verify the network configuration does not conflict with previously installed
// networks. This step is needed now because driver might have now set the bridge
// name on this config struct. And because we need to check for possible address
// conflicts, so we need to check against operationa lnetworks.
if err := config.conflictsWithNetworks(id, networkList); err != nil {
return err
}
setupNetworkIsolationRules := func(config *networkConfiguration, i *bridgeInterface) error {
defer func() {
if err != nil {
if err := network.isolateNetwork(networkList, false); err != nil {
logrus.Warnf("Failed on removing the inter-network iptables rules on cleanup: %v", err)
}
}
}()
err := network.isolateNetwork(networkList, true)
return err
}
// Prepare the bridge setup configuration
bridgeSetup := newBridgeSetup(config, bridgeIface)
@ -330,13 +649,16 @@ func (d *driver) CreateNetwork(id types.UUID, option map[string]interface{}) err
{!config.EnableUserlandProxy, setupLoopbackAdressesRouting},
// Setup IPTables.
{config.EnableIPTables, setupIPTables},
{config.EnableIPTables, network.setupIPTables},
// Setup DefaultGatewayIPv4
{config.DefaultGatewayIPv4 != nil, setupGatewayIPv4},
// Setup DefaultGatewayIPv6
{config.DefaultGatewayIPv6 != nil, setupGatewayIPv6},
// Add inter-network communication rules.
{config.EnableIPTables, setupNetworkIsolationRules},
} {
if step.Condition {
bridgeSetup.queueStep(step.Fn)
@ -359,8 +681,23 @@ func (d *driver) DeleteNetwork(nid types.UUID) error {
// Get network handler and remove it from driver
d.Lock()
n := d.network
d.network = nil
n, ok := d.networks[nid]
d.Unlock()
if !ok {
return types.InternalMaskableErrorf("network %s does not exist", nid)
}
n.Lock()
config := n.config
n.Unlock()
if config.BridgeName == DefaultBridgeName {
return types.ForbiddenErrorf("default network of type \"%s\" cannot be deleted", networkType)
}
d.Lock()
delete(d.networks, nid)
d.Unlock()
// On failure set network handler back in driver, but
@ -368,8 +705,8 @@ func (d *driver) DeleteNetwork(nid types.UUID) error {
defer func() {
if err != nil {
d.Lock()
if d.network == nil {
d.network = n
if _, ok := d.networks[nid]; !ok {
d.networks[nid] = n
}
d.Unlock()
}
@ -387,6 +724,22 @@ func (d *driver) DeleteNetwork(nid types.UUID) error {
return err
}
// In case of failures after this point, restore the network isolation rules
nwList := d.getNetworks()
defer func() {
if err != nil {
if err := n.isolateNetwork(nwList, true); err != nil {
logrus.Warnf("Failed on restoring the inter-network iptables rules on cleanup: %v", err)
}
}
}()
// Remove inter-network communication rules.
err = n.isolateNetwork(nwList, false)
if err != nil {
return err
}
// Programming
err = netlink.LinkDel(n.bridge.Link)
@ -409,9 +762,12 @@ func (d *driver) CreateEndpoint(nid, eid types.UUID, epInfo driverapi.EndpointIn
// Get the network handler and make sure it exists
d.Lock()
n := d.network
config := n.config
n, ok := d.networks[nid]
d.Unlock()
if !ok {
return types.NotFoundErrorf("network %s does not exist", nid)
}
if n == nil {
return driverapi.ErrNoNetwork(nid)
}
@ -457,13 +813,13 @@ func (d *driver) CreateEndpoint(nid, eid types.UUID, epInfo driverapi.EndpointIn
}()
// Generate a name for what will be the host side pipe interface
name1, err := generateIfaceName()
name1, err := netutils.GenerateIfaceName(vethPrefix, vethLen)
if err != nil {
return err
}
// Generate a name for what will be the sandbox side pipe interface
name2, err := generateIfaceName()
name2, err := netutils.GenerateIfaceName(vethPrefix, vethLen)
if err != nil {
return err
}
@ -498,6 +854,10 @@ func (d *driver) CreateEndpoint(nid, eid types.UUID, epInfo driverapi.EndpointIn
}
}()
n.Lock()
config := n.config
n.Unlock()
// Add bridge inherited attributes to pipe interfaces
if config.Mtu != 0 {
err = netlink.LinkSetMTU(host, config.Mtu)
@ -566,25 +926,20 @@ func (d *driver) CreateEndpoint(nid, eid types.UUID, epInfo driverapi.EndpointIn
}
// Create the sandbox side pipe interface
intf := &sandbox.Interface{}
intf.SrcName = name2
intf.DstName = containerVethPrefix
intf.Address = ipv4Addr
endpoint.srcName = name2
endpoint.addr = ipv4Addr
if config.EnableIPv6 {
intf.AddressIPv6 = ipv6Addr
endpoint.addrv6 = ipv6Addr
}
// Store the interface in endpoint, this is needed for cleanup on DeleteEndpoint()
endpoint.intf = intf
err = epInfo.AddInterface(ifaceID, endpoint.macAddress, *ipv4Addr, *ipv6Addr)
if err != nil {
return err
}
// Program any required port mapping and store them in the endpoint
endpoint.portMapping, err = allocatePorts(epConfig, intf, config.DefaultBindingIP, config.EnableUserlandProxy)
endpoint.portMapping, err = n.allocatePorts(epConfig, endpoint, config.DefaultBindingIP, config.EnableUserlandProxy)
if err != nil {
return err
}
@ -597,9 +952,12 @@ func (d *driver) DeleteEndpoint(nid, eid types.UUID) error {
// Get the network handler and make sure it exists
d.Lock()
n := d.network
config := n.config
n, ok := d.networks[nid]
d.Unlock()
if !ok {
return types.NotFoundErrorf("network %s does not exist", nid)
}
if n == nil {
return driverapi.ErrNoNetwork(nid)
}
@ -639,17 +997,21 @@ func (d *driver) DeleteEndpoint(nid, eid types.UUID) error {
}()
// Remove port mappings. Do not stop endpoint delete on unmap failure
releasePorts(ep)
n.releasePorts(ep)
// Release the v4 address allocated to this endpoint's sandbox interface
err = ipAllocator.ReleaseIP(n.bridge.bridgeIPv4, ep.intf.Address.IP)
err = ipAllocator.ReleaseIP(n.bridge.bridgeIPv4, ep.addr.IP)
if err != nil {
return err
}
n.Lock()
config := n.config
n.Unlock()
// Release the v6 address allocated to this endpoint's sandbox interface
if config.EnableIPv6 {
err := ipAllocator.ReleaseIP(n.bridge.bridgeIPv6, ep.intf.AddressIPv6.IP)
err := ipAllocator.ReleaseIP(n.bridge.bridgeIPv6, ep.addrv6.IP)
if err != nil {
return err
}
@ -657,7 +1019,7 @@ func (d *driver) DeleteEndpoint(nid, eid types.UUID) error {
// Try removal of link. Discard error: link pair might have
// already been deleted by sandbox delete.
link, err := netlink.LinkByName(ep.intf.SrcName)
link, err := netlink.LinkByName(ep.srcName)
if err == nil {
netlink.LinkDel(link)
}
@ -668,8 +1030,11 @@ func (d *driver) DeleteEndpoint(nid, eid types.UUID) error {
func (d *driver) EndpointOperInfo(nid, eid types.UUID) (map[string]interface{}, error) {
// Get the network handler and make sure it exists
d.Lock()
n := d.network
n, ok := d.networks[nid]
d.Unlock()
if !ok {
return nil, types.NotFoundErrorf("network %s does not exist", nid)
}
if n == nil {
return nil, driverapi.ErrNoNetwork(nid)
}
@ -737,7 +1102,7 @@ func (d *driver) Join(nid, eid types.UUID, sboxKey string, jinfo driverapi.JoinI
for _, iNames := range jinfo.InterfaceNames() {
// Make sure to set names on the correct interface ID.
if iNames.ID() == ifaceID {
err = iNames.SetNames(endpoint.intf.SrcName, endpoint.intf.DstName)
err = iNames.SetNames(endpoint.srcName, containerVethPrefix)
if err != nil {
return err
}
@ -786,7 +1151,7 @@ func (d *driver) Leave(nid, eid types.UUID) error {
func (d *driver) link(network *bridgeNetwork, endpoint *bridgeEndpoint, options map[string]interface{}, enable bool) error {
var (
cc *ContainerConfiguration
cc *containerConfiguration
err error
)
@ -815,8 +1180,8 @@ func (d *driver) link(network *bridgeNetwork, endpoint *bridgeEndpoint, options
return err
}
l := newLink(parentEndpoint.intf.Address.IP.String(),
endpoint.intf.Address.IP.String(),
l := newLink(parentEndpoint.addr.IP.String(),
endpoint.addr.IP.String(),
endpoint.config.ExposedPorts, network.config.BridgeName)
if enable {
err = l.Enable()
@ -848,8 +1213,8 @@ func (d *driver) link(network *bridgeNetwork, endpoint *bridgeEndpoint, options
continue
}
l := newLink(endpoint.intf.Address.IP.String(),
childEndpoint.intf.Address.IP.String(),
l := newLink(endpoint.addr.IP.String(),
childEndpoint.addr.IP.String(),
childEndpoint.config.ExposedPorts, network.config.BridgeName)
if enable {
err = l.Enable()
@ -877,12 +1242,12 @@ func (d *driver) Type() string {
return networkType
}
func parseEndpointOptions(epOptions map[string]interface{}) (*EndpointConfiguration, error) {
func parseEndpointOptions(epOptions map[string]interface{}) (*endpointConfiguration, error) {
if epOptions == nil {
return nil, nil
}
ec := &EndpointConfiguration{}
ec := &endpointConfiguration{}
if opt, ok := epOptions[netlabel.MacAddress]; ok {
if mac, ok := opt.(net.HardwareAddr); ok {
@ -911,7 +1276,7 @@ func parseEndpointOptions(epOptions map[string]interface{}) (*EndpointConfigurat
return ec, nil
}
func parseContainerOptions(cOptions map[string]interface{}) (*ContainerConfiguration, error) {
func parseContainerOptions(cOptions map[string]interface{}) (*containerConfiguration, error) {
if cOptions == nil {
return nil, nil
}
@ -921,12 +1286,12 @@ func parseContainerOptions(cOptions map[string]interface{}) (*ContainerConfigura
}
switch opt := genericData.(type) {
case options.Generic:
opaqueConfig, err := options.GenerateFromModel(opt, &ContainerConfiguration{})
opaqueConfig, err := options.GenerateFromModel(opt, &containerConfiguration{})
if err != nil {
return nil, err
}
return opaqueConfig.(*ContainerConfiguration), nil
case *ContainerConfiguration:
return opaqueConfig.(*containerConfiguration), nil
case *containerConfiguration:
return opt, nil
default:
return nil, nil
@ -958,28 +1323,9 @@ func generateMacAddr(ip net.IP) net.HardwareAddr {
return hw
}
func electMacAddress(epConfig *EndpointConfiguration, ip net.IP) net.HardwareAddr {
func electMacAddress(epConfig *endpointConfiguration, ip net.IP) net.HardwareAddr {
if epConfig != nil && epConfig.MacAddress != nil {
return epConfig.MacAddress
}
return generateMacAddr(ip)
}
// Generates a name to be used for a virtual ethernet
// interface. The name is constructed by 'veth' appended
// by a randomly generated hex value. (example: veth0f60e2c)
func generateIfaceName() (string, error) {
for i := 0; i < 3; i++ {
name, err := netutils.GenerateRandomName(vethPrefix, vethLen)
if err != nil {
continue
}
if _, err := net.InterfaceByName(name); err != nil {
if strings.Contains(err.Error(), "no such") {
return name, nil
}
return "", err
}
}
return "", &ErrIfaceName{}
}

View file

@ -22,10 +22,10 @@ type bridgeInterface struct {
}
// newInterface creates a new bridge interface structure. It attempts to find
// an already existing device identified by the Configuration BridgeName field,
// an already existing device identified by the configuration BridgeName field,
// or the default bridge name when unspecified), but doesn't attempt to create
// one when missing
func newInterface(config *NetworkConfiguration) *bridgeInterface {
func newInterface(config *networkConfiguration) *bridgeInterface {
i := &bridgeInterface{}
// Initialize the bridge name to the default if unspecified.

View file

@ -7,7 +7,6 @@ import (
"net"
"github.com/Sirupsen/logrus"
"github.com/docker/libnetwork/sandbox"
"github.com/docker/libnetwork/types"
)
@ -15,7 +14,7 @@ var (
defaultBindingIP = net.IPv4(0, 0, 0, 0)
)
func allocatePorts(epConfig *EndpointConfiguration, intf *sandbox.Interface, reqDefBindIP net.IP, ulPxyEnabled bool) ([]types.PortBinding, error) {
func (n *bridgeNetwork) allocatePorts(epConfig *endpointConfiguration, ep *bridgeEndpoint, reqDefBindIP net.IP, ulPxyEnabled bool) ([]types.PortBinding, error) {
if epConfig == nil || epConfig.PortBindings == nil {
return nil, nil
}
@ -25,16 +24,16 @@ func allocatePorts(epConfig *EndpointConfiguration, intf *sandbox.Interface, req
defHostIP = reqDefBindIP
}
return allocatePortsInternal(epConfig.PortBindings, intf.Address.IP, defHostIP, ulPxyEnabled)
return n.allocatePortsInternal(epConfig.PortBindings, ep.addr.IP, defHostIP, ulPxyEnabled)
}
func allocatePortsInternal(bindings []types.PortBinding, containerIP, defHostIP net.IP, ulPxyEnabled bool) ([]types.PortBinding, error) {
func (n *bridgeNetwork) allocatePortsInternal(bindings []types.PortBinding, containerIP, defHostIP net.IP, ulPxyEnabled bool) ([]types.PortBinding, error) {
bs := make([]types.PortBinding, 0, len(bindings))
for _, c := range bindings {
b := c.GetCopy()
if err := allocatePort(&b, containerIP, defHostIP, ulPxyEnabled); err != nil {
if err := n.allocatePort(&b, containerIP, defHostIP, ulPxyEnabled); err != nil {
// On allocation failure, release previously allocated ports. On cleanup error, just log a warning message
if cuErr := releasePortsInternal(bs); cuErr != nil {
if cuErr := n.releasePortsInternal(bs); cuErr != nil {
logrus.Warnf("Upon allocation failure for %v, failed to clear previously allocated port bindings: %v", b, cuErr)
}
return nil, err
@ -44,7 +43,7 @@ func allocatePortsInternal(bindings []types.PortBinding, containerIP, defHostIP
return bs, nil
}
func allocatePort(bnd *types.PortBinding, containerIP, defHostIP net.IP, ulPxyEnabled bool) error {
func (n *bridgeNetwork) allocatePort(bnd *types.PortBinding, containerIP, defHostIP net.IP, ulPxyEnabled bool) error {
var (
host net.Addr
err error
@ -66,7 +65,7 @@ func allocatePort(bnd *types.PortBinding, containerIP, defHostIP net.IP, ulPxyEn
// Try up to maxAllocatePortAttempts times to get a port that's not already allocated.
for i := 0; i < maxAllocatePortAttempts; i++ {
if host, err = portMapper.Map(container, bnd.HostIP, int(bnd.HostPort), ulPxyEnabled); err == nil {
if host, err = n.portMapper.Map(container, bnd.HostIP, int(bnd.HostPort), ulPxyEnabled); err == nil {
break
}
// There is no point in immediately retrying to map an explicitly chosen port.
@ -94,16 +93,16 @@ func allocatePort(bnd *types.PortBinding, containerIP, defHostIP net.IP, ulPxyEn
}
}
func releasePorts(ep *bridgeEndpoint) error {
return releasePortsInternal(ep.portMapping)
func (n *bridgeNetwork) releasePorts(ep *bridgeEndpoint) error {
return n.releasePortsInternal(ep.portMapping)
}
func releasePortsInternal(bindings []types.PortBinding) error {
func (n *bridgeNetwork) releasePortsInternal(bindings []types.PortBinding) error {
var errorBuf bytes.Buffer
// Attempt to release all port bindings, do not stop on failure
for _, m := range bindings {
if err := releasePort(m); err != nil {
if err := n.releasePort(m); err != nil {
errorBuf.WriteString(fmt.Sprintf("\ncould not release %v because of %v", m, err))
}
}
@ -114,11 +113,11 @@ func releasePortsInternal(bindings []types.PortBinding) error {
return nil
}
func releasePort(bnd types.PortBinding) error {
func (n *bridgeNetwork) releasePort(bnd types.PortBinding) error {
// Construct the host side transport address
host, err := bnd.HostAddr()
if err != nil {
return err
}
return portMapper.Unmap(host)
return n.portMapper.Unmap(host)
}

View file

@ -1,14 +1,14 @@
package bridge
type setupStep func(*NetworkConfiguration, *bridgeInterface) error
type setupStep func(*networkConfiguration, *bridgeInterface) error
type bridgeSetup struct {
config *NetworkConfiguration
config *networkConfiguration
bridge *bridgeInterface
steps []setupStep
}
func newBridgeSetup(c *NetworkConfiguration, i *bridgeInterface) *bridgeSetup {
func newBridgeSetup(c *networkConfiguration, i *bridgeInterface) *bridgeSetup {
return &bridgeSetup{config: c, bridge: i}
}

View file

@ -4,11 +4,12 @@ import (
log "github.com/Sirupsen/logrus"
"github.com/docker/docker/pkg/parsers/kernel"
"github.com/docker/libnetwork/netutils"
"github.com/docker/libnetwork/types"
"github.com/vishvananda/netlink"
)
// SetupDevice create a new bridge interface/
func setupDevice(config *NetworkConfiguration, i *bridgeInterface) error {
func setupDevice(config *networkConfiguration, i *bridgeInterface) error {
// We only attempt to create the bridge when the requested device name is
// the default one.
if config.BridgeName != DefaultBridgeName && !config.AllowNonDefaultBridge {
@ -31,11 +32,14 @@ func setupDevice(config *NetworkConfiguration, i *bridgeInterface) error {
}
// Call out to netlink to create the device.
return netlink.LinkAdd(i.Link)
if err = netlink.LinkAdd(i.Link); err != nil {
return types.InternalErrorf("Failed to program bridge link: %s", err.Error())
}
return nil
}
// SetupDeviceUp ups the given bridge interface.
func setupDeviceUp(config *NetworkConfiguration, i *bridgeInterface) error {
func setupDeviceUp(config *networkConfiguration, i *bridgeInterface) error {
err := netlink.LinkSetUp(i.Link)
if err != nil {
return err

View file

@ -4,7 +4,7 @@ import (
log "github.com/Sirupsen/logrus"
)
func setupFixedCIDRv4(config *NetworkConfiguration, i *bridgeInterface) error {
func setupFixedCIDRv4(config *networkConfiguration, i *bridgeInterface) error {
addrv4, _, err := i.addresses()
if err != nil {
return err

View file

@ -7,7 +7,7 @@ import (
"github.com/vishvananda/netlink"
)
func setupFixedCIDRv6(config *NetworkConfiguration, i *bridgeInterface) error {
func setupFixedCIDRv6(config *networkConfiguration, i *bridgeInterface) error {
log.Debugf("Using IPv6 subnet: %v", config.FixedCIDRv6)
if err := ipAllocator.RegisterSubnet(config.FixedCIDRv6, config.FixedCIDRv6); err != nil {
return &FixedCIDRv6Error{Net: config.FixedCIDRv6, Err: err}

View file

@ -10,7 +10,7 @@ const (
ipv4ForwardConfPerm = 0644
)
func setupIPForwarding(config *Configuration) error {
func setupIPForwarding(config *configuration) error {
// Sanity Check
if config.EnableIPForwarding == false {
return &ErrIPFwdCfg{}

View file

@ -13,7 +13,7 @@ const (
DockerChain = "DOCKER"
)
func setupIPTables(config *NetworkConfiguration, i *bridgeInterface) error {
func (n *bridgeNetwork) setupIPTables(config *networkConfiguration, i *bridgeInterface) error {
// Sanity check.
if config.EnableIPTables == false {
return IPTableCfgError(config.BridgeName)
@ -39,7 +39,7 @@ func setupIPTables(config *NetworkConfiguration, i *bridgeInterface) error {
return fmt.Errorf("Failed to create FILTER chain: %s", err.Error())
}
portMapper.SetIptablesChain(chain)
n.portMapper.SetIptablesChain(chain)
return nil
}
@ -171,3 +171,38 @@ func setIcc(bridgeIface string, iccEnable, insert bool) error {
return nil
}
// Control Inter Network Communication. Install/remove only if it is not/is present.
func setINC(network1, network2 string, enable bool) error {
var (
table = iptables.Filter
chain = "FORWARD"
args = [2][]string{{"-s", network1, "-d", network2, "-j", "DROP"}, {"-s", network2, "-d", network1, "-j", "DROP"}}
)
if enable {
for i := 0; i < 2; i++ {
if iptables.Exists(table, chain, args[i]...) {
continue
}
if output, err := iptables.Raw(append([]string{"-I", chain}, args[i]...)...); err != nil {
return fmt.Errorf("unable to add inter-network communication rule: %s", err.Error())
} else if len(output) != 0 {
return fmt.Errorf("error adding inter-network communication rule: %s", string(output))
}
}
} else {
for i := 0; i < 2; i++ {
if !iptables.Exists(table, chain, args[i]...) {
continue
}
if output, err := iptables.Raw(append([]string{"-D", chain}, args[i]...)...); err != nil {
return fmt.Errorf("unable to remove inter-network communication rule: %s", err.Error())
} else if len(output) != 0 {
return fmt.Errorf("error removing inter-network communication rule: %s", string(output))
}
}
}
return nil
}

View file

@ -4,7 +4,6 @@ import (
"fmt"
"io/ioutil"
"net"
"path/filepath"
log "github.com/Sirupsen/logrus"
@ -20,31 +19,25 @@ func init() {
// In theory this shouldn't matter - in practice there's bound to be a few scripts relying
// on the internal addressing or other stupid things like that.
// They shouldn't, but hey, let's not break them unless we really have to.
for _, addr := range []string{
"172.17.42.1/16", // Don't use 172.16.0.0/16, it conflicts with EC2 DNS 172.16.0.23
"10.0.42.1/16", // Don't even try using the entire /8, that's too intrusive
"10.1.42.1/16",
"10.42.42.1/16",
"172.16.42.1/24",
"172.16.43.1/24",
"172.16.44.1/24",
"10.0.42.1/24",
"10.0.43.1/24",
"192.168.42.1/24",
"192.168.43.1/24",
"192.168.44.1/24",
} {
ip, net, err := net.ParseCIDR(addr)
if err != nil {
log.Errorf("Failed to parse address %s", addr)
continue
}
net.IP = ip.To4()
bridgeNetworks = append(bridgeNetworks, net)
// Don't use 172.16.0.0/16, it conflicts with EC2 DNS 172.16.0.23
// 172.[17-31].42.1/16
mask := []byte{255, 255, 0, 0}
for i := 17; i < 32; i++ {
bridgeNetworks = append(bridgeNetworks, &net.IPNet{IP: []byte{172, byte(i), 42, 1}, Mask: mask})
}
// 10.[0-255].42.1/16
for i := 0; i < 256; i++ {
bridgeNetworks = append(bridgeNetworks, &net.IPNet{IP: []byte{10, byte(i), 42, 1}, Mask: mask})
}
// 192.168.[42-44].1/24
mask[2] = 255
for i := 42; i < 45; i++ {
bridgeNetworks = append(bridgeNetworks, &net.IPNet{IP: []byte{192, 168, byte(i), 1}, Mask: mask})
}
}
func setupBridgeIPv4(config *NetworkConfiguration, i *bridgeInterface) error {
func setupBridgeIPv4(config *networkConfiguration, i *bridgeInterface) error {
addrv4, _, err := i.addresses()
if err != nil {
return err
@ -81,12 +74,12 @@ func setupBridgeIPv4(config *NetworkConfiguration, i *bridgeInterface) error {
return nil
}
func allocateBridgeIP(config *NetworkConfiguration, i *bridgeInterface) error {
func allocateBridgeIP(config *networkConfiguration, i *bridgeInterface) error {
ipAllocator.RequestIP(i.bridgeIPv4, i.bridgeIPv4.IP)
return nil
}
func electBridgeIPv4(config *NetworkConfiguration) (*net.IPNet, error) {
func electBridgeIPv4(config *networkConfiguration) (*net.IPNet, error) {
// Use the requested IPv4 CIDR when available.
if config.AddressIPv4 != nil {
return config.AddressIPv4, nil
@ -112,7 +105,7 @@ func electBridgeIPv4(config *NetworkConfiguration) (*net.IPNet, error) {
return nil, IPv4AddrRangeError(config.BridgeName)
}
func setupGatewayIPv4(config *NetworkConfiguration, i *bridgeInterface) error {
func setupGatewayIPv4(config *networkConfiguration, i *bridgeInterface) error {
if !i.bridgeIPv4.Contains(config.DefaultGatewayIPv4) {
return &ErrInvalidGateway{}
}
@ -126,7 +119,7 @@ func setupGatewayIPv4(config *NetworkConfiguration, i *bridgeInterface) error {
return nil
}
func setupLoopbackAdressesRouting(config *NetworkConfiguration, i *bridgeInterface) error {
func setupLoopbackAdressesRouting(config *networkConfiguration, i *bridgeInterface) error {
// Enable loopback adresses routing
sysPath := filepath.Join("/proc/sys/net/ipv4/conf", config.BridgeName, "route_localnet")
if err := ioutil.WriteFile(sysPath, []byte{'1', '\n'}, 0644); err != nil {

View file

@ -26,7 +26,7 @@ func init() {
}
}
func setupBridgeIPv6(config *NetworkConfiguration, i *bridgeInterface) error {
func setupBridgeIPv6(config *networkConfiguration, i *bridgeInterface) error {
// Enable IPv6 on the bridge
procFile := "/proc/sys/net/ipv6/conf/" + config.BridgeName + "/disable_ipv6"
if err := ioutil.WriteFile(procFile, []byte{'0', '\n'}, ipv6ForwardConfPerm); err != nil {
@ -52,7 +52,7 @@ func setupBridgeIPv6(config *NetworkConfiguration, i *bridgeInterface) error {
return nil
}
func setupGatewayIPv6(config *NetworkConfiguration, i *bridgeInterface) error {
func setupGatewayIPv6(config *networkConfiguration, i *bridgeInterface) error {
if config.FixedCIDRv6 == nil {
return &ErrInvalidContainerSubnet{}
}
@ -69,7 +69,7 @@ func setupGatewayIPv6(config *NetworkConfiguration, i *bridgeInterface) error {
return nil
}
func setupIPv6Forwarding(config *NetworkConfiguration, i *bridgeInterface) error {
func setupIPv6Forwarding(config *networkConfiguration, i *bridgeInterface) error {
// Enable IPv6 forwarding
if err := ioutil.WriteFile("/proc/sys/net/ipv6/conf/default/forwarding", []byte{'1', '\n'}, ipv6ForwardConfPerm); err != nil {
logrus.Warnf("Unable to enable IPv6 default forwarding: %v", err)

View file

@ -4,7 +4,7 @@ import (
"github.com/vishvananda/netlink"
)
func setupVerifyAndReconcile(config *NetworkConfiguration, i *bridgeInterface) error {
func setupVerifyAndReconcile(config *networkConfiguration, i *bridgeInterface) error {
// Fetch a single IPv4 and a slice of IPv6 addresses from the bridge.
addrv4, addrsv6, err := i.addresses()
if err != nil {

View file

@ -1,17 +1,25 @@
package host
import (
"sync"
"github.com/docker/libnetwork/driverapi"
"github.com/docker/libnetwork/types"
)
const networkType = "host"
type driver struct{}
type driver struct {
network types.UUID
sync.Mutex
}
// Init registers a new instance of host driver
func Init(dc driverapi.DriverCallback) error {
return dc.RegisterDriver(networkType, &driver{})
c := driverapi.Capability{
Scope: driverapi.LocalScope,
}
return dc.RegisterDriver(networkType, &driver{}, c)
}
func (d *driver) Config(option map[string]interface{}) error {
@ -19,11 +27,20 @@ func (d *driver) Config(option map[string]interface{}) error {
}
func (d *driver) CreateNetwork(id types.UUID, option map[string]interface{}) error {
d.Lock()
defer d.Unlock()
if d.network != "" {
return types.ForbiddenErrorf("only one instance of \"%s\" network is allowed", networkType)
}
d.network = id
return nil
}
func (d *driver) DeleteNetwork(nid types.UUID) error {
return nil
return types.ForbiddenErrorf("network of type \"%s\" cannot be deleted", networkType)
}
func (d *driver) CreateEndpoint(nid, eid types.UUID, epInfo driverapi.EndpointInfo, epOptions map[string]interface{}) error {

View file

@ -1,17 +1,25 @@
package null
import (
"sync"
"github.com/docker/libnetwork/driverapi"
"github.com/docker/libnetwork/types"
)
const networkType = "null"
type driver struct{}
type driver struct {
network types.UUID
sync.Mutex
}
// Init registers a new instance of null driver
func Init(dc driverapi.DriverCallback) error {
return dc.RegisterDriver(networkType, &driver{})
c := driverapi.Capability{
Scope: driverapi.LocalScope,
}
return dc.RegisterDriver(networkType, &driver{}, c)
}
func (d *driver) Config(option map[string]interface{}) error {
@ -19,11 +27,20 @@ func (d *driver) Config(option map[string]interface{}) error {
}
func (d *driver) CreateNetwork(id types.UUID, option map[string]interface{}) error {
d.Lock()
defer d.Unlock()
if d.network != "" {
return types.ForbiddenErrorf("only one instance of \"%s\" network is allowed", networkType)
}
d.network = id
return nil
}
func (d *driver) DeleteNetwork(nid types.UUID) error {
return nil
return types.ForbiddenErrorf("network of type \"%s\" cannot be deleted", networkType)
}
func (d *driver) CreateEndpoint(nid, eid types.UUID, epInfo driverapi.EndpointInfo, epOptions map[string]interface{}) error {

View file

@ -23,7 +23,10 @@ func newDriver(name string, client *plugins.Client) driverapi.Driver {
// plugin is activated.
func Init(dc driverapi.DriverCallback) error {
plugins.Handle(driverapi.NetworkPluginEndpointType, func(name string, client *plugins.Client) {
if err := dc.RegisterDriver(name, newDriver(name, client)); err != nil {
c := driverapi.Capability{
Scope: driverapi.GlobalScope,
}
if err := dc.RegisterDriver(name, newDriver(name, client), c); err != nil {
log.Errorf("error registering driver for %s due to %v", name, err)
}
})
@ -168,7 +171,7 @@ func (d *driver) Join(nid, eid types.UUID, sboxKey string, jinfo driverapi.JoinI
return fmt.Errorf("no correlating interface %d in supplied interface names", i)
}
supplied := ifaceNames[i]
if err := iface.SetNames(supplied.SrcName, supplied.DstName); err != nil {
if err := iface.SetNames(supplied.SrcName, supplied.DstPrefix); err != nil {
return errorWithRollback(fmt.Sprintf("failed to set interface name: %s", err), d.Leave(nid, eid))
}
}
@ -190,6 +193,17 @@ func (d *driver) Join(nid, eid types.UUID, sboxKey string, jinfo driverapi.JoinI
return errorWithRollback(fmt.Sprintf("failed to set gateway IPv6: %v", addr), d.Leave(nid, eid))
}
}
if len(res.StaticRoutes) > 0 {
routes, err := res.parseStaticRoutes()
if err != nil {
return err
}
for _, route := range routes {
if jinfo.AddStaticRoute(route.Destination, route.RouteType, route.NextHop, route.InterfaceID) != nil {
return errorWithRollback(fmt.Sprintf("failed to set static route: %v", route), d.Leave(nid, eid))
}
}
}
if jinfo.SetHostsPath(res.HostsPath) != nil {
return errorWithRollback(fmt.Sprintf("failed to set hosts path: %s", res.HostsPath), d.Leave(nid, eid))
}

View file

@ -1,6 +1,11 @@
package remote
import "net"
import (
"fmt"
"net"
"github.com/docker/libnetwork/types"
)
type response struct {
Err string
@ -45,6 +50,13 @@ type endpointInterface struct {
MacAddress string
}
type staticRoute struct {
Destination string
RouteType int
NextHop string
InterfaceID int
}
type createEndpointResponse struct {
response
Interfaces []*endpointInterface
@ -67,9 +79,7 @@ type iface struct {
}
func (r *createEndpointResponse) parseInterfaces() ([]*iface, error) {
var (
ifaces = make([]*iface, len(r.Interfaces))
)
var ifaces = make([]*iface, len(r.Interfaces))
for i, inIf := range r.Interfaces {
var err error
outIf := &iface{ID: inIf.ID}
@ -93,6 +103,30 @@ func (r *createEndpointResponse) parseInterfaces() ([]*iface, error) {
return ifaces, nil
}
func (r *joinResponse) parseStaticRoutes() ([]*types.StaticRoute, error) {
var routes = make([]*types.StaticRoute, len(r.StaticRoutes))
for i, inRoute := range r.StaticRoutes {
var err error
outRoute := &types.StaticRoute{InterfaceID: inRoute.InterfaceID, RouteType: inRoute.RouteType}
if inRoute.Destination != "" {
if outRoute.Destination, err = toAddr(inRoute.Destination); err != nil {
return nil, err
}
}
if inRoute.NextHop != "" {
outRoute.NextHop = net.ParseIP(inRoute.NextHop)
if outRoute.NextHop == nil {
return nil, fmt.Errorf("failed to parse nexthop IP %s", inRoute.NextHop)
}
}
routes[i] = outRoute
}
return routes, nil
}
type deleteEndpointRequest struct {
NetworkID string
EndpointID string
@ -120,8 +154,8 @@ type joinRequest struct {
}
type ifaceName struct {
SrcName string
DstName string
SrcName string
DstPrefix string
}
type joinResponse struct {
@ -129,6 +163,7 @@ type joinResponse struct {
InterfaceNames []*ifaceName
Gateway string
GatewayIPv6 string
StaticRoutes []*staticRoute
HostsPath string
ResolvConfPath string
}

View file

@ -0,0 +1,58 @@
package windows
import (
"github.com/docker/libnetwork/driverapi"
"github.com/docker/libnetwork/types"
)
const networkType = "windows"
// TODO Windows. This is a placeholder for now
type driver struct{}
// Init registers a new instance of null driver
func Init(dc driverapi.DriverCallback) error {
c := driverapi.Capability{
Scope: driverapi.LocalScope,
}
return dc.RegisterDriver(networkType, &driver{}, c)
}
func (d *driver) Config(option map[string]interface{}) error {
return nil
}
func (d *driver) CreateNetwork(id types.UUID, option map[string]interface{}) error {
return nil
}
func (d *driver) DeleteNetwork(nid types.UUID) error {
return nil
}
func (d *driver) CreateEndpoint(nid, eid types.UUID, epInfo driverapi.EndpointInfo, epOptions map[string]interface{}) error {
return nil
}
func (d *driver) DeleteEndpoint(nid, eid types.UUID) error {
return nil
}
func (d *driver) EndpointOperInfo(nid, eid types.UUID) (map[string]interface{}, error) {
return make(map[string]interface{}, 0), nil
}
// Join method is invoked when a Sandbox is attached to an endpoint.
func (d *driver) Join(nid, eid types.UUID, sboxKey string, jinfo driverapi.JoinInfo, options map[string]interface{}) error {
return nil
}
// Leave method is invoked when a Sandbox detaches from an endpoint.
func (d *driver) Leave(nid, eid types.UUID) error {
return nil
}
func (d *driver) Type() string {
return networkType
}

View file

@ -0,0 +1,19 @@
package libnetwork
import (
"github.com/docker/libnetwork/driverapi"
"github.com/docker/libnetwork/drivers/null"
"github.com/docker/libnetwork/drivers/remote"
)
func initDrivers(dc driverapi.DriverCallback) error {
for _, fn := range [](func(driverapi.DriverCallback) error){
null.Init,
remote.Init,
} {
if err := fn(dc); err != nil {
return err
}
}
return nil
}

View file

@ -8,8 +8,6 @@ import (
"github.com/docker/libnetwork/drivers/remote"
)
type driverTable map[string]driverapi.Driver
func initDrivers(dc driverapi.DriverCallback) error {
for _, fn := range [](func(driverapi.DriverCallback) error){
bridge.Init,

View file

@ -0,0 +1,17 @@
package libnetwork
import (
"github.com/docker/libnetwork/driverapi"
"github.com/docker/libnetwork/drivers/windows"
)
func initDrivers(dc driverapi.DriverCallback) error {
for _, fn := range [](func(driverapi.DriverCallback) error){
windows.Init,
} {
if err := fn(dc); err != nil {
return err
}
}
return nil
}

View file

@ -2,14 +2,17 @@ package libnetwork
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"os"
"path"
"path/filepath"
"sync"
"github.com/Sirupsen/logrus"
log "github.com/Sirupsen/logrus"
"github.com/docker/docker/pkg/ioutils"
"github.com/docker/libnetwork/datastore"
"github.com/docker/libnetwork/etchosts"
"github.com/docker/libnetwork/netlabel"
"github.com/docker/libnetwork/resolvconf"
@ -30,8 +33,8 @@ type Endpoint interface {
// Join creates a new sandbox for the given container ID and populates the
// network resources allocated for the endpoint and joins the sandbox to
// the endpoint. It returns the sandbox key to the caller
Join(containerID string, options ...EndpointOption) (*ContainerData, error)
// the endpoint.
Join(containerID string, options ...EndpointOption) error
// Leave removes the sandbox associated with container ID and detaches
// the network resources populated in the sandbox
@ -40,9 +43,12 @@ type Endpoint interface {
// Return certain operational data belonging to this endpoint
Info() EndpointInfo
// Info returns a collection of driver operational data related to this endpoint retrieved from the driver
// DriverInfo returns a collection of driver operational data related to this endpoint retrieved from the driver
DriverInfo() (map[string]interface{}, error)
// ContainerInfo returns the info available at the endpoint about the attached container
ContainerInfo() ContainerInfo
// Delete and detaches this endpoint from the network.
Delete() error
}
@ -78,6 +84,7 @@ type containerConfig struct {
resolvConfPathConfig
generic map[string]interface{}
useDefaultSandBox bool
prio int // higher the value, more the priority
}
type extraHost struct {
@ -95,22 +102,105 @@ type containerInfo struct {
id string
config containerConfig
data ContainerData
sync.Mutex
}
func (ci *containerInfo) ID() string {
return ci.id
}
func (ci *containerInfo) Labels() map[string]interface{} {
return ci.config.generic
}
type endpoint struct {
name string
id types.UUID
network *network
sandboxInfo *sandbox.Info
iFaces []*endpointInterface
joinInfo *endpointJoinInfo
container *containerInfo
exposedPorts []types.TransportPort
generic map[string]interface{}
joinLeaveDone chan struct{}
dbIndex uint64
sync.Mutex
}
func (ci *containerInfo) MarshalJSON() ([]byte, error) {
ci.Lock()
defer ci.Unlock()
// We are just interested in the container ID. This can be expanded to include all of containerInfo if there is a need
return json.Marshal(ci.id)
}
func (ci *containerInfo) UnmarshalJSON(b []byte) (err error) {
ci.Lock()
defer ci.Unlock()
var id string
if err := json.Unmarshal(b, &id); err != nil {
return err
}
ci.id = id
return nil
}
func (ep *endpoint) MarshalJSON() ([]byte, error) {
ep.Lock()
defer ep.Unlock()
epMap := make(map[string]interface{})
epMap["name"] = ep.name
epMap["id"] = string(ep.id)
epMap["ep_iface"] = ep.iFaces
epMap["exposed_ports"] = ep.exposedPorts
epMap["generic"] = ep.generic
if ep.container != nil {
epMap["container"] = ep.container
}
return json.Marshal(epMap)
}
func (ep *endpoint) UnmarshalJSON(b []byte) (err error) {
ep.Lock()
defer ep.Unlock()
var epMap map[string]interface{}
if err := json.Unmarshal(b, &epMap); err != nil {
return err
}
ep.name = epMap["name"].(string)
ep.id = types.UUID(epMap["id"].(string))
ib, _ := json.Marshal(epMap["ep_iface"])
var ifaces []endpointInterface
json.Unmarshal(ib, &ifaces)
ep.iFaces = make([]*endpointInterface, 0)
for _, iface := range ifaces {
ep.iFaces = append(ep.iFaces, &iface)
}
tb, _ := json.Marshal(epMap["exposed_ports"])
var tPorts []types.TransportPort
json.Unmarshal(tb, &tPorts)
ep.exposedPorts = tPorts
epc, ok := epMap["container"]
if ok {
cb, _ := json.Marshal(epc)
var cInfo containerInfo
json.Unmarshal(cb, &cInfo)
ep.container = &cInfo
}
if epMap["generic"] != nil {
ep.generic = epMap["generic"].(map[string]interface{})
}
return nil
}
const defaultPrefix = "/var/lib/docker/network/files"
func (ep *endpoint) ID() string {
@ -134,6 +224,52 @@ func (ep *endpoint) Network() string {
return ep.network.name
}
// endpoint Key structure : endpoint/network-id/endpoint-id
func (ep *endpoint) Key() []string {
ep.Lock()
n := ep.network
defer ep.Unlock()
return []string{datastore.EndpointKeyPrefix, string(n.id), string(ep.id)}
}
func (ep *endpoint) KeyPrefix() []string {
ep.Lock()
n := ep.network
defer ep.Unlock()
return []string{datastore.EndpointKeyPrefix, string(n.id)}
}
func (ep *endpoint) networkIDFromKey(key []string) (types.UUID, error) {
// endpoint Key structure : endpoint/network-id/endpoint-id
// its an invalid key if the key doesnt have all the 3 key elements above
if key == nil || len(key) < 3 || key[0] != datastore.EndpointKeyPrefix {
return types.UUID(""), fmt.Errorf("invalid endpoint key : %v", key)
}
// network-id is placed at index=1. pls refer to endpoint.Key() method
return types.UUID(key[1]), nil
}
func (ep *endpoint) Value() []byte {
b, err := json.Marshal(ep)
if err != nil {
return nil
}
return b
}
func (ep *endpoint) Index() uint64 {
ep.Lock()
defer ep.Unlock()
return ep.dbIndex
}
func (ep *endpoint) SetIndex(index uint64) {
ep.Lock()
defer ep.Unlock()
ep.dbIndex = index
}
func (ep *endpoint) processOptions(options ...EndpointOption) {
ep.Lock()
defer ep.Unlock()
@ -203,20 +339,27 @@ func (ep *endpoint) joinLeaveEnd() {
}
}
func (ep *endpoint) Join(containerID string, options ...EndpointOption) (*ContainerData, error) {
func (ep *endpoint) Join(containerID string, options ...EndpointOption) error {
var err error
if containerID == "" {
return nil, InvalidContainerIDError(containerID)
return InvalidContainerIDError(containerID)
}
ep.joinLeaveStart()
defer ep.joinLeaveEnd()
defer func() {
ep.joinLeaveEnd()
if err != nil {
if e := ep.Leave(containerID, options...); e != nil {
log.Warnf("couldnt leave endpoint : %v", ep.name, err)
}
}
}()
ep.Lock()
if ep.container != nil {
ep.Unlock()
return nil, ErrInvalidJoin{}
return ErrInvalidJoin{}
}
ep.container = &containerInfo{
@ -233,16 +376,14 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) (*Contai
container := ep.container
network := ep.network
epid := ep.id
joinInfo := ep.joinInfo
ifaces := ep.iFaces
ep.Unlock()
defer func() {
ep.Lock()
if err != nil {
ep.Lock()
ep.container = nil
ep.Unlock()
}
ep.Unlock()
}()
network.Lock()
@ -260,63 +401,53 @@ func (ep *endpoint) Join(containerID string, options ...EndpointOption) (*Contai
err = driver.Join(nid, epid, sboxKey, ep, container.config.generic)
if err != nil {
return nil, err
return err
}
err = ep.buildHostsFiles()
if err != nil {
return nil, err
return err
}
err = ep.updateParentHosts()
if err != nil {
return nil, err
return err
}
err = ep.setupDNS()
if err != nil {
return nil, err
return err
}
sb, err := ctrlr.sandboxAdd(sboxKey, !container.config.useDefaultSandBox)
sb, err := ctrlr.sandboxAdd(sboxKey, !container.config.useDefaultSandBox, ep)
if err != nil {
return nil, err
return err
}
defer func() {
if err != nil {
ctrlr.sandboxRm(sboxKey)
ctrlr.sandboxRm(sboxKey, ep)
}
}()
for _, i := range ifaces {
iface := &sandbox.Interface{
SrcName: i.srcName,
DstName: i.dstPrefix,
Address: &i.addr,
}
if i.addrv6.IP.To16() != nil {
iface.AddressIPv6 = &i.addrv6
}
err = sb.AddInterface(iface)
if err != nil {
return nil, err
}
}
err = sb.SetGateway(joinInfo.gw)
if err != nil {
return nil, err
}
err = sb.SetGatewayIPv6(joinInfo.gw6)
if err != nil {
return nil, err
if err := network.ctrlr.updateEndpointToStore(ep); err != nil {
return err
}
container.data.SandboxKey = sb.Key()
cData := container.data
return nil
}
return &cData, nil
func (ep *endpoint) hasInterface(iName string) bool {
ep.Lock()
defer ep.Unlock()
for _, iface := range ep.iFaces {
if iface.srcName == iName {
return true
}
}
return false
}
func (ep *endpoint) Leave(containerID string, options ...EndpointOption) error {
@ -331,7 +462,7 @@ func (ep *endpoint) Leave(containerID string, options ...EndpointOption) error {
container := ep.container
n := ep.network
if container == nil || container.id == "" ||
if container == nil || container.id == "" || container.data.SandboxKey == "" ||
containerID == "" || container.id != containerID {
if container == nil {
err = ErrNoContainer{}
@ -350,33 +481,73 @@ func (ep *endpoint) Leave(containerID string, options ...EndpointOption) error {
ctrlr := n.ctrlr
n.Unlock()
err = driver.Leave(n.id, ep.id)
sb := ctrlr.sandboxGet(container.data.SandboxKey)
for _, i := range sb.Interfaces() {
err = sb.RemoveInterface(i)
if err != nil {
logrus.Debugf("Remove interface failed: %v", err)
}
if err := ctrlr.updateEndpointToStore(ep); err != nil {
ep.Lock()
ep.container = container
ep.Unlock()
return err
}
ctrlr.sandboxRm(container.data.SandboxKey)
err = driver.Leave(n.id, ep.id)
ctrlr.sandboxRm(container.data.SandboxKey, ep)
return err
}
func (ep *endpoint) Delete() error {
var err error
ep.Lock()
epid := ep.id
name := ep.name
n := ep.network
if ep.container != nil {
ep.Unlock()
return &ActiveContainerError{name: name, id: string(epid)}
}
n.Lock()
ctrlr := n.ctrlr
n.Unlock()
ep.Unlock()
if err = ctrlr.deleteEndpointFromStore(ep); err != nil {
return err
}
defer func() {
if err != nil {
ep.SetIndex(0)
if e := ctrlr.updateEndpointToStore(ep); e != nil {
log.Warnf("failed to recreate endpoint in store %s : %v", name, err)
}
}
}()
// Update the endpoint count in network and update it in the datastore
n.DecEndpointCnt()
if err = ctrlr.updateNetworkToStore(n); err != nil {
return err
}
defer func() {
if err != nil {
n.IncEndpointCnt()
if e := ctrlr.updateNetworkToStore(n); e != nil {
log.Warnf("failed to update network %s : %v", n.name, e)
}
}
}()
if err = ep.deleteEndpoint(); err != nil {
return err
}
return nil
}
func (ep *endpoint) deleteEndpoint() error {
ep.Lock()
n := ep.network
name := ep.name
epid := ep.id
ep.Unlock()
n.Lock()
@ -390,16 +561,17 @@ func (ep *endpoint) Delete() error {
driver := n.driver
delete(n.endpoints, epid)
n.Unlock()
defer func() {
if err != nil {
if err := driver.DeleteEndpoint(nid, epid); err != nil {
if _, ok := err.(types.ForbiddenError); ok {
n.Lock()
n.endpoints[epid] = ep
n.Unlock()
return err
}
}()
err = driver.DeleteEndpoint(nid, epid)
return err
log.Warnf("driver error deleting endpoint %s : %v", name, err)
}
return nil
}
func (ep *endpoint) buildHostsFiles() error {
@ -436,11 +608,6 @@ func (ep *endpoint) buildHostsFiles() error {
}
}
name := container.config.hostName
if container.config.domainName != "" {
name = name + "." + container.config.domainName
}
for _, extraHost := range container.config.extraHosts {
extraContent = append(extraContent,
etchosts.Record{Hosts: extraHost.name, IP: extraHost.IP})
@ -622,6 +789,14 @@ func EndpointOptionGeneric(generic map[string]interface{}) EndpointOption {
}
}
// JoinOptionPriority function returns an option setter for priority option to
// be passed to endpoint Join method.
func JoinOptionPriority(prio int) EndpointOption {
return func(ep *endpoint) {
ep.container.config.prio = prio
}
}
// JoinOptionHostname function returns an option setter for hostname option to
// be passed to endpoint Join method.
func JoinOptionHostname(name string) EndpointOption {

View file

@ -1,6 +1,7 @@
package libnetwork
import (
"encoding/json"
"net"
"github.com/docker/libnetwork/driverapi"
@ -39,6 +40,14 @@ type InterfaceInfo interface {
AddressIPv6() net.IPNet
}
// ContainerInfo provides an interface to retrieve the info about the container attached to the endpoint
type ContainerInfo interface {
// ID returns the ID of the container
ID() string
// Labels returns the container's labels
Labels() map[string]interface{}
}
type endpointInterface struct {
id int
mac net.HardwareAddr
@ -46,6 +55,60 @@ type endpointInterface struct {
addrv6 net.IPNet
srcName string
dstPrefix string
routes []*net.IPNet
}
func (epi *endpointInterface) MarshalJSON() ([]byte, error) {
epMap := make(map[string]interface{})
epMap["id"] = epi.id
epMap["mac"] = epi.mac.String()
epMap["addr"] = epi.addr.String()
epMap["addrv6"] = epi.addrv6.String()
epMap["srcName"] = epi.srcName
epMap["dstPrefix"] = epi.dstPrefix
var routes []string
for _, route := range epi.routes {
routes = append(routes, route.String())
}
epMap["routes"] = routes
return json.Marshal(epMap)
}
func (epi *endpointInterface) UnmarshalJSON(b []byte) (err error) {
var epMap map[string]interface{}
if err := json.Unmarshal(b, &epMap); err != nil {
return err
}
epi.id = int(epMap["id"].(float64))
mac, _ := net.ParseMAC(epMap["mac"].(string))
epi.mac = mac
_, ipnet, _ := net.ParseCIDR(epMap["addr"].(string))
if ipnet != nil {
epi.addr = *ipnet
}
_, ipnet, _ = net.ParseCIDR(epMap["addrv6"].(string))
if ipnet != nil {
epi.addrv6 = *ipnet
}
epi.srcName = epMap["srcName"].(string)
epi.dstPrefix = epMap["dstPrefix"].(string)
rb, _ := json.Marshal(epMap["routes"])
var routes []string
json.Unmarshal(rb, &routes)
epi.routes = make([]*net.IPNet, 0)
for _, route := range routes {
_, ipr, err := net.ParseCIDR(route)
if err == nil {
epi.routes = append(epi.routes, ipr)
}
}
return nil
}
type endpointJoinInfo struct {
@ -53,6 +116,19 @@ type endpointJoinInfo struct {
gw6 net.IP
hostsPath string
resolvConfPath string
StaticRoutes []*types.StaticRoute
}
func (ep *endpoint) ContainerInfo() ContainerInfo {
ep.Lock()
ci := ep.container
defer ep.Unlock()
// Need this since we return the interface
if ci == nil {
return nil
}
return ci
}
func (ep *endpoint) Info() EndpointInfo {
@ -114,25 +190,25 @@ func (ep *endpoint) AddInterface(id int, mac net.HardwareAddr, ipv4 net.IPNet, i
return nil
}
func (i *endpointInterface) ID() int {
return i.id
func (epi *endpointInterface) ID() int {
return epi.id
}
func (i *endpointInterface) MacAddress() net.HardwareAddr {
return types.GetMacCopy(i.mac)
func (epi *endpointInterface) MacAddress() net.HardwareAddr {
return types.GetMacCopy(epi.mac)
}
func (i *endpointInterface) Address() net.IPNet {
return (*types.GetIPNetCopy(&i.addr))
func (epi *endpointInterface) Address() net.IPNet {
return (*types.GetIPNetCopy(&epi.addr))
}
func (i *endpointInterface) AddressIPv6() net.IPNet {
return (*types.GetIPNetCopy(&i.addrv6))
func (epi *endpointInterface) AddressIPv6() net.IPNet {
return (*types.GetIPNetCopy(&epi.addrv6))
}
func (i *endpointInterface) SetNames(srcName string, dstPrefix string) error {
i.srcName = srcName
i.dstPrefix = dstPrefix
func (epi *endpointInterface) SetNames(srcName string, dstPrefix string) error {
epi.srcName = srcName
epi.dstPrefix = dstPrefix
return nil
}
@ -149,6 +225,35 @@ func (ep *endpoint) InterfaceNames() []driverapi.InterfaceNameInfo {
return iList
}
func (ep *endpoint) AddStaticRoute(destination *net.IPNet, routeType int, nextHop net.IP, interfaceID int) error {
ep.Lock()
defer ep.Unlock()
r := types.StaticRoute{Destination: destination, RouteType: routeType, NextHop: nextHop, InterfaceID: interfaceID}
if routeType == types.NEXTHOP {
// If the route specifies a next-hop, then it's loosely routed (i.e. not bound to a particular interface).
ep.joinInfo.StaticRoutes = append(ep.joinInfo.StaticRoutes, &r)
} else {
// If the route doesn't specify a next-hop, it must be a connected route, bound to an interface.
if err := ep.addInterfaceRoute(&r); err != nil {
return err
}
}
return nil
}
func (ep *endpoint) addInterfaceRoute(route *types.StaticRoute) error {
for _, iface := range ep.iFaces {
if iface.id == route.InterfaceID {
iface.routes = append(iface.routes, route.Destination)
return nil
}
}
return types.BadRequestErrorf("Interface with ID %d doesn't exist.",
route.InterfaceID)
}
func (ep *endpoint) SandboxKey() string {
ep.Lock()
defer ep.Unlock()

View file

@ -11,8 +11,8 @@ func (nsn ErrNoSuchNetwork) Error() string {
return fmt.Sprintf("network %s not found", string(nsn))
}
// BadRequest denotes the type of this error
func (nsn ErrNoSuchNetwork) BadRequest() {}
// NotFound denotes the type of this error
func (nsn ErrNoSuchNetwork) NotFound() {}
// ErrNoSuchEndpoint is returned when a endpoint query finds no result
type ErrNoSuchEndpoint string
@ -21,8 +21,8 @@ func (nse ErrNoSuchEndpoint) Error() string {
return fmt.Sprintf("endpoint %s not found", string(nse))
}
// BadRequest denotes the type of this error
func (nse ErrNoSuchEndpoint) BadRequest() {}
// NotFound denotes the type of this error
func (nse ErrNoSuchEndpoint) NotFound() {}
// ErrInvalidNetworkDriver is returned if an invalid driver
// name is passed.
@ -79,6 +79,13 @@ func (in ErrInvalidName) Error() string {
// BadRequest denotes the type of this error
func (in ErrInvalidName) BadRequest() {}
// ErrInvalidConfigFile type is returned when an invalid LibNetwork config file is detected
type ErrInvalidConfigFile string
func (cf ErrInvalidConfigFile) Error() string {
return fmt.Sprintf("Invalid Config file %q", string(cf))
}
// NetworkTypeError type is returned when the network type string is not
// known to libnetwork.
type NetworkTypeError string

View file

@ -0,0 +1,154 @@
// +build libnetwork_discovery
package hostdiscovery
import (
"errors"
"fmt"
"net"
"sync"
"time"
log "github.com/Sirupsen/logrus"
mapset "github.com/deckarep/golang-set"
"github.com/docker/libnetwork/config"
"github.com/docker/swarm/discovery"
// Anonymous import will be removed after we upgrade to latest swarm
_ "github.com/docker/swarm/discovery/file"
// Anonymous import will be removed after we upgrade to latest swarm
_ "github.com/docker/swarm/discovery/kv"
// Anonymous import will be removed after we upgrade to latest swarm
_ "github.com/docker/swarm/discovery/nodes"
// Anonymous import will be removed after we upgrade to latest swarm
_ "github.com/docker/swarm/discovery/token"
)
const defaultHeartbeat = time.Duration(10) * time.Second
const TTLFactor = 3
type hostDiscovery struct {
discovery discovery.Discovery
nodes mapset.Set
stopChan chan struct{}
sync.Mutex
}
// NewHostDiscovery function creates a host discovery object
func NewHostDiscovery() HostDiscovery {
return &hostDiscovery{nodes: mapset.NewSet(), stopChan: make(chan struct{})}
}
func (h *hostDiscovery) StartDiscovery(cfg *config.ClusterCfg, joinCallback JoinCallback, leaveCallback LeaveCallback) error {
if cfg == nil {
return fmt.Errorf("discovery requires a valid configuration")
}
hb := time.Duration(cfg.Heartbeat) * time.Second
if hb == 0 {
hb = defaultHeartbeat
}
d, err := discovery.New(cfg.Discovery, hb, TTLFactor*hb)
if err != nil {
return err
}
if ip := net.ParseIP(cfg.Address); ip == nil {
return errors.New("address config should be either ipv4 or ipv6 address")
}
if err := d.Register(cfg.Address + ":0"); err != nil {
return err
}
h.Lock()
h.discovery = d
h.Unlock()
discoveryCh, errCh := d.Watch(h.stopChan)
go h.monitorDiscovery(discoveryCh, errCh, joinCallback, leaveCallback)
go h.sustainHeartbeat(d, hb, cfg)
return nil
}
func (h *hostDiscovery) monitorDiscovery(ch <-chan discovery.Entries, errCh <-chan error, joinCallback JoinCallback, leaveCallback LeaveCallback) {
for {
select {
case entries := <-ch:
h.processCallback(entries, joinCallback, leaveCallback)
case err := <-errCh:
log.Errorf("discovery error: %v", err)
case <-h.stopChan:
return
}
}
}
func (h *hostDiscovery) StopDiscovery() error {
h.Lock()
stopChan := h.stopChan
h.discovery = nil
h.Unlock()
close(stopChan)
return nil
}
func (h *hostDiscovery) sustainHeartbeat(d discovery.Discovery, hb time.Duration, config *config.ClusterCfg) {
for {
select {
case <-h.stopChan:
return
case <-time.After(hb):
if err := d.Register(config.Address + ":0"); err != nil {
log.Warn(err)
}
}
}
}
func (h *hostDiscovery) processCallback(entries discovery.Entries, joinCallback JoinCallback, leaveCallback LeaveCallback) {
updated := hosts(entries)
h.Lock()
existing := h.nodes
added, removed := diff(existing, updated)
h.nodes = updated
h.Unlock()
if len(added) > 0 {
joinCallback(added)
}
if len(removed) > 0 {
leaveCallback(removed)
}
}
func diff(existing mapset.Set, updated mapset.Set) (added []net.IP, removed []net.IP) {
addSlice := updated.Difference(existing).ToSlice()
removeSlice := existing.Difference(updated).ToSlice()
for _, ip := range addSlice {
added = append(added, net.ParseIP(ip.(string)))
}
for _, ip := range removeSlice {
removed = append(removed, net.ParseIP(ip.(string)))
}
return
}
func (h *hostDiscovery) Fetch() ([]net.IP, error) {
h.Lock()
defer h.Unlock()
ips := []net.IP{}
for _, ipstr := range h.nodes.ToSlice() {
ips = append(ips, net.ParseIP(ipstr.(string)))
}
return ips, nil
}
func hosts(entries discovery.Entries) mapset.Set {
hosts := mapset.NewSet()
for _, entry := range entries {
hosts.Add(entry.Host)
}
return hosts
}

View file

@ -0,0 +1,23 @@
package hostdiscovery
import (
"net"
"github.com/docker/libnetwork/config"
)
// JoinCallback provides a callback event for new node joining the cluster
type JoinCallback func(entries []net.IP)
// LeaveCallback provides a callback event for node leaving the cluster
type LeaveCallback func(entries []net.IP)
// HostDiscovery primary interface
type HostDiscovery interface {
// StartDiscovery initiates the discovery process and provides appropriate callbacks
StartDiscovery(*config.ClusterCfg, JoinCallback, LeaveCallback) error
// StopDiscovery stops the discovery perocess
StopDiscovery() error
// Fetch returns a list of host IPs that are currently discovered
Fetch() ([]net.IP, error)
}

View file

@ -0,0 +1,28 @@
// +build !libnetwork_discovery
package hostdiscovery
import (
"net"
"github.com/docker/libnetwork/config"
)
type hostDiscovery struct{}
// NewHostDiscovery function creates a host discovery object
func NewHostDiscovery() HostDiscovery {
return &hostDiscovery{}
}
func (h *hostDiscovery) StartDiscovery(cfg *config.ClusterCfg, joinCallback JoinCallback, leaveCallback LeaveCallback) error {
return nil
}
func (h *hostDiscovery) StopDiscovery() error {
return nil
}
func (h *hostDiscovery) Fetch() ([]net.IP, error) {
return []net.IP{}, nil
}

View file

@ -0,0 +1,6 @@
title = "LibNetwork Configuration file"
[cluster]
discovery = "token://08469efb104bce980931ed24c8eb03a2"
Address = "1.1.1.1"
Heartbeat = 3

View file

@ -9,7 +9,9 @@ import (
"fmt"
"io"
"net"
"strings"
"github.com/docker/libnetwork/types"
"github.com/vishvananda/netlink"
)
@ -147,3 +149,22 @@ func GenerateRandomName(prefix string, size int) (string, error) {
}
return prefix + hex.EncodeToString(id)[:size], nil
}
// GenerateIfaceName returns an interface name using the passed in
// prefix and the length of random bytes. The api ensures that the
// there are is no interface which exists with that name.
func GenerateIfaceName(prefix string, len int) (string, error) {
for i := 0; i < 3; i++ {
name, err := GenerateRandomName(prefix, len)
if err != nil {
continue
}
if _, err := net.InterfaceByName(name); err != nil {
if strings.Contains(err.Error(), "no such") {
return name, nil
}
return "", err
}
}
return "", types.InternalErrorf("could not generate interface name")
}

View file

@ -1,9 +1,12 @@
package libnetwork
import (
"encoding/json"
"sync"
log "github.com/Sirupsen/logrus"
"github.com/docker/docker/pkg/stringid"
"github.com/docker/libnetwork/datastore"
"github.com/docker/libnetwork/driverapi"
"github.com/docker/libnetwork/netlabel"
"github.com/docker/libnetwork/options"
@ -54,8 +57,10 @@ type network struct {
id types.UUID
driver driverapi.Driver
enableIPv6 bool
endpointCnt uint64
endpoints endpointTable
generic options.Generic
dbIndex uint64
sync.Mutex
}
@ -84,6 +89,85 @@ func (n *network) Type() string {
return n.driver.Type()
}
func (n *network) Key() []string {
n.Lock()
defer n.Unlock()
return []string{datastore.NetworkKeyPrefix, string(n.id)}
}
func (n *network) KeyPrefix() []string {
return []string{datastore.NetworkKeyPrefix}
}
func (n *network) Value() []byte {
n.Lock()
defer n.Unlock()
b, err := json.Marshal(n)
if err != nil {
return nil
}
return b
}
func (n *network) Index() uint64 {
n.Lock()
defer n.Unlock()
return n.dbIndex
}
func (n *network) SetIndex(index uint64) {
n.Lock()
n.dbIndex = index
n.Unlock()
}
func (n *network) EndpointCnt() uint64 {
n.Lock()
defer n.Unlock()
return n.endpointCnt
}
func (n *network) IncEndpointCnt() {
n.Lock()
n.endpointCnt++
n.Unlock()
}
func (n *network) DecEndpointCnt() {
n.Lock()
n.endpointCnt--
n.Unlock()
}
// TODO : Can be made much more generic with the help of reflection (but has some golang limitations)
func (n *network) MarshalJSON() ([]byte, error) {
netMap := make(map[string]interface{})
netMap["name"] = n.name
netMap["id"] = string(n.id)
netMap["networkType"] = n.networkType
netMap["endpointCnt"] = n.endpointCnt
netMap["enableIPv6"] = n.enableIPv6
netMap["generic"] = n.generic
return json.Marshal(netMap)
}
// TODO : Can be made much more generic with the help of reflection (but has some golang limitations)
func (n *network) UnmarshalJSON(b []byte) (err error) {
var netMap map[string]interface{}
if err := json.Unmarshal(b, &netMap); err != nil {
return err
}
n.name = netMap["name"].(string)
n.id = types.UUID(netMap["id"].(string))
n.networkType = netMap["networkType"].(string)
n.endpointCnt = uint64(netMap["endpointCnt"].(float64))
n.enableIPv6 = netMap["enableIPv6"].(bool)
if netMap["generic"] != nil {
n.generic = netMap["generic"].(map[string]interface{})
}
return nil
}
// NetworkOption is a option setter function type used to pass varios options to
// NewNetwork method. The various setter functions of type NetworkOption are
// provided by libnetwork, they look like NetworkOptionXXXX(...)
@ -111,53 +195,129 @@ func (n *network) processOptions(options ...NetworkOption) {
func (n *network) Delete() error {
var err error
n.ctrlr.Lock()
_, ok := n.ctrlr.networks[n.id]
n.Lock()
ctrlr := n.ctrlr
n.Unlock()
ctrlr.Lock()
_, ok := ctrlr.networks[n.id]
ctrlr.Unlock()
if !ok {
n.ctrlr.Unlock()
return &UnknownNetworkError{name: n.name, id: string(n.id)}
}
n.Lock()
numEps := len(n.endpoints)
n.Unlock()
numEps := n.EndpointCnt()
if numEps != 0 {
n.ctrlr.Unlock()
return &ActiveEndpointsError{name: n.name, id: string(n.id)}
}
delete(n.ctrlr.networks, n.id)
// deleteNetworkFromStore performs an atomic delete operation and the network.endpointCnt field will help
// prevent any possible race between endpoint join and network delete
if err = ctrlr.deleteNetworkFromStore(n); err != nil {
if err == datastore.ErrKeyModified {
return types.InternalErrorf("operation in progress. delete failed for network %s. Please try again.")
}
return err
}
if err = n.deleteNetwork(); err != nil {
return err
}
return nil
}
func (n *network) deleteNetwork() error {
n.Lock()
id := n.id
d := n.driver
n.ctrlr.Lock()
delete(n.ctrlr.networks, id)
n.ctrlr.Unlock()
defer func() {
if err != nil {
n.Unlock()
if err := d.DeleteNetwork(n.id); err != nil {
// Forbidden Errors should be honored
if _, ok := err.(types.ForbiddenError); ok {
n.ctrlr.Lock()
n.ctrlr.networks[n.id] = n
n.ctrlr.Unlock()
return err
}
log.Warnf("driver error deleting network %s : %v", n.name, err)
}
return nil
}
func (n *network) addEndpoint(ep *endpoint) error {
var err error
n.Lock()
n.endpoints[ep.id] = ep
d := n.driver
n.Unlock()
defer func() {
if err != nil {
n.Lock()
delete(n.endpoints, ep.id)
n.Unlock()
}
}()
err = n.driver.DeleteNetwork(n.id)
return err
err = d.CreateEndpoint(n.id, ep.id, ep, ep.generic)
if err != nil {
return err
}
return nil
}
func (n *network) CreateEndpoint(name string, options ...EndpointOption) (Endpoint, error) {
var err error
if name == "" {
return nil, ErrInvalidName(name)
}
if _, err = n.EndpointByName(name); err == nil {
return nil, types.ForbiddenErrorf("service endpoint with name %s already exists", name)
}
ep := &endpoint{name: name, iFaces: []*endpointInterface{}, generic: make(map[string]interface{})}
ep.id = types.UUID(stringid.GenerateRandomID())
ep.network = n
ep.processOptions(options...)
d := n.driver
err := d.CreateEndpoint(n.id, ep.id, ep, ep.generic)
if err != nil {
n.Lock()
ctrlr := n.ctrlr
n.Unlock()
n.IncEndpointCnt()
if err = ctrlr.updateNetworkToStore(n); err != nil {
return nil, err
}
defer func() {
if err != nil {
n.DecEndpointCnt()
if err = ctrlr.updateNetworkToStore(n); err != nil {
log.Warnf("endpoint count cleanup failed when updating network for %s : %v", name, err)
}
}
}()
if err = n.addEndpoint(ep); err != nil {
return nil, err
}
defer func() {
if err != nil {
if e := ep.Delete(); ep != nil {
log.Warnf("cleaning up endpoint failed %s : %v", name, e)
}
}
}()
if err = ctrlr.updateEndpointToStore(ep); err != nil {
return nil, err
}
n.Lock()
n.endpoints[ep.id] = ep
n.Unlock()
return ep, nil
}
@ -214,3 +374,10 @@ func (n *network) EndpointByID(id string) (Endpoint, error) {
}
return nil, ErrNoSuchEndpoint(id)
}
func (n *network) isGlobalScoped() (bool, error) {
n.Lock()
c := n.ctrlr
n.Unlock()
return c.isDriverGlobalScoped(n.networkType)
}

View file

@ -1,81 +0,0 @@
package sandbox
import (
"fmt"
"net"
"os"
"runtime"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netns"
)
func configureInterface(iface netlink.Link, settings *Interface) error {
ifaceName := iface.Attrs().Name
ifaceConfigurators := []struct {
Fn func(netlink.Link, *Interface) error
ErrMessage string
}{
{setInterfaceName, fmt.Sprintf("error renaming interface %q to %q", ifaceName, settings.DstName)},
{setInterfaceIP, fmt.Sprintf("error setting interface %q IP to %q", ifaceName, settings.Address)},
{setInterfaceIPv6, fmt.Sprintf("error setting interface %q IPv6 to %q", ifaceName, settings.AddressIPv6)},
}
for _, config := range ifaceConfigurators {
if err := config.Fn(iface, settings); err != nil {
return fmt.Errorf("%s: %v", config.ErrMessage, err)
}
}
return nil
}
func programGateway(path string, gw net.IP) error {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
origns, err := netns.Get()
if err != nil {
return err
}
defer origns.Close()
f, err := os.OpenFile(path, os.O_RDONLY, 0)
if err != nil {
return fmt.Errorf("failed get network namespace %q: %v", path, err)
}
defer f.Close()
nsFD := f.Fd()
if err = netns.Set(netns.NsHandle(nsFD)); err != nil {
return err
}
defer netns.Set(origns)
gwRoutes, err := netlink.RouteGet(gw)
if err != nil {
return fmt.Errorf("route for the gateway could not be found: %v", err)
}
return netlink.RouteAdd(&netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
LinkIndex: gwRoutes[0].LinkIndex,
Gw: gw,
})
}
func setInterfaceIP(iface netlink.Link, settings *Interface) error {
ipAddr := &netlink.Addr{IPNet: settings.Address, Label: ""}
return netlink.AddrAdd(iface, ipAddr)
}
func setInterfaceIPv6(iface netlink.Link, settings *Interface) error {
if settings.AddressIPv6 == nil {
return nil
}
ipAddr := &netlink.Addr{IPNet: settings.AddressIPv6, Label: ""}
return netlink.AddrAdd(iface, ipAddr)
}
func setInterfaceName(iface netlink.Link, settings *Interface) error {
return netlink.LinkSetName(iface, settings.DstName)
}

View file

@ -0,0 +1,313 @@
package sandbox
import (
"fmt"
"net"
"sync"
"github.com/docker/libnetwork/types"
"github.com/vishvananda/netlink"
)
// IfaceOption is a function option type to set interface options
type IfaceOption func(i *nwIface)
type nwIface struct {
srcName string
dstName string
master string
dstMaster string
address *net.IPNet
addressIPv6 *net.IPNet
routes []*net.IPNet
bridge bool
ns *networkNamespace
sync.Mutex
}
func (i *nwIface) SrcName() string {
i.Lock()
defer i.Unlock()
return i.srcName
}
func (i *nwIface) DstName() string {
i.Lock()
defer i.Unlock()
return i.dstName
}
func (i *nwIface) DstMaster() string {
i.Lock()
defer i.Unlock()
return i.dstMaster
}
func (i *nwIface) Bridge() bool {
i.Lock()
defer i.Unlock()
return i.bridge
}
func (i *nwIface) Master() string {
i.Lock()
defer i.Unlock()
return i.master
}
func (i *nwIface) Address() *net.IPNet {
i.Lock()
defer i.Unlock()
return types.GetIPNetCopy(i.address)
}
func (i *nwIface) AddressIPv6() *net.IPNet {
i.Lock()
defer i.Unlock()
return types.GetIPNetCopy(i.addressIPv6)
}
func (i *nwIface) Routes() []*net.IPNet {
i.Lock()
defer i.Unlock()
routes := make([]*net.IPNet, len(i.routes))
for index, route := range i.routes {
r := types.GetIPNetCopy(route)
routes[index] = r
}
return routes
}
func (n *networkNamespace) Interfaces() []Interface {
n.Lock()
defer n.Unlock()
ifaces := make([]Interface, len(n.iFaces))
for i, iface := range n.iFaces {
ifaces[i] = iface
}
return ifaces
}
func (i *nwIface) Remove() error {
i.Lock()
n := i.ns
i.Unlock()
n.Lock()
path := n.path
n.Unlock()
return nsInvoke(path, func(nsFD int) error { return nil }, func(callerFD int) error {
// Find the network inteerface identified by the DstName attribute.
iface, err := netlink.LinkByName(i.DstName())
if err != nil {
return err
}
// Down the interface before configuring
if err := netlink.LinkSetDown(iface); err != nil {
return err
}
err = netlink.LinkSetName(iface, i.SrcName())
if err != nil {
fmt.Println("LinkSetName failed: ", err)
return err
}
// if it is a bridge just delete it.
if i.Bridge() {
if err := netlink.LinkDel(iface); err != nil {
return fmt.Errorf("failed deleting bridge %q: %v", i.SrcName(), err)
}
} else {
// Move the network interface to caller namespace.
if err := netlink.LinkSetNsFd(iface, callerFD); err != nil {
fmt.Println("LinkSetNsPid failed: ", err)
return err
}
}
n.Lock()
for index, intf := range n.iFaces {
if intf == i {
n.iFaces = append(n.iFaces[:index], n.iFaces[index+1:]...)
break
}
}
n.Unlock()
return nil
})
}
func (n *networkNamespace) findDstMaster(srcName string) string {
n.Lock()
defer n.Unlock()
for _, i := range n.iFaces {
// The master should match the srcname of the interface and the
// master interface should be of type bridge.
if i.SrcName() == srcName && i.Bridge() {
return i.DstName()
}
}
return ""
}
func (n *networkNamespace) AddInterface(srcName, dstPrefix string, options ...IfaceOption) error {
i := &nwIface{srcName: srcName, dstName: dstPrefix, ns: n}
i.processInterfaceOptions(options...)
if i.master != "" {
i.dstMaster = n.findDstMaster(i.master)
if i.dstMaster == "" {
return fmt.Errorf("could not find an appropriate master %q for %q",
i.master, i.srcName)
}
}
n.Lock()
i.dstName = fmt.Sprintf("%s%d", i.dstName, n.nextIfIndex)
n.nextIfIndex++
path := n.path
n.Unlock()
return nsInvoke(path, func(nsFD int) error {
// If it is a bridge interface we have to create the bridge inside
// the namespace so don't try to lookup the interface using srcName
if i.bridge {
return nil
}
// Find the network interface identified by the SrcName attribute.
iface, err := netlink.LinkByName(i.srcName)
if err != nil {
return fmt.Errorf("failed to get link by name %q: %v", i.srcName, err)
}
// Move the network interface to the destination namespace.
if err := netlink.LinkSetNsFd(iface, nsFD); err != nil {
return fmt.Errorf("failed to set namespace on link %q: %v", i.srcName, err)
}
return nil
}, func(callerFD int) error {
if i.bridge {
link := &netlink.Bridge{
LinkAttrs: netlink.LinkAttrs{
Name: i.srcName,
},
}
if err := netlink.LinkAdd(link); err != nil {
return fmt.Errorf("failed to create bridge %q: %v", i.srcName, err)
}
}
// Find the network interface identified by the SrcName attribute.
iface, err := netlink.LinkByName(i.srcName)
if err != nil {
return fmt.Errorf("failed to get link by name %q: %v", i.srcName, err)
}
// Down the interface before configuring
if err := netlink.LinkSetDown(iface); err != nil {
return fmt.Errorf("failed to set link down: %v", err)
}
// Configure the interface now this is moved in the proper namespace.
if err := configureInterface(iface, i); err != nil {
return err
}
// Up the interface.
if err := netlink.LinkSetUp(iface); err != nil {
return fmt.Errorf("failed to set link up: %v", err)
}
n.Lock()
n.iFaces = append(n.iFaces, i)
n.Unlock()
return nil
})
}
func configureInterface(iface netlink.Link, i *nwIface) error {
ifaceName := iface.Attrs().Name
ifaceConfigurators := []struct {
Fn func(netlink.Link, *nwIface) error
ErrMessage string
}{
{setInterfaceName, fmt.Sprintf("error renaming interface %q to %q", ifaceName, i.DstName())},
{setInterfaceIP, fmt.Sprintf("error setting interface %q IP to %q", ifaceName, i.Address())},
{setInterfaceIPv6, fmt.Sprintf("error setting interface %q IPv6 to %q", ifaceName, i.AddressIPv6())},
{setInterfaceRoutes, fmt.Sprintf("error setting interface %q routes to %q", ifaceName, i.Routes())},
{setInterfaceMaster, fmt.Sprintf("error setting interface %q master to %q", ifaceName, i.DstMaster())},
}
for _, config := range ifaceConfigurators {
if err := config.Fn(iface, i); err != nil {
return fmt.Errorf("%s: %v", config.ErrMessage, err)
}
}
return nil
}
func setInterfaceMaster(iface netlink.Link, i *nwIface) error {
if i.DstMaster() == "" {
return nil
}
return netlink.LinkSetMaster(iface, &netlink.Bridge{
LinkAttrs: netlink.LinkAttrs{Name: i.DstMaster()}})
}
func setInterfaceIP(iface netlink.Link, i *nwIface) error {
if i.Address() == nil {
return nil
}
ipAddr := &netlink.Addr{IPNet: i.Address(), Label: ""}
return netlink.AddrAdd(iface, ipAddr)
}
func setInterfaceIPv6(iface netlink.Link, i *nwIface) error {
if i.AddressIPv6() == nil {
return nil
}
ipAddr := &netlink.Addr{IPNet: i.AddressIPv6(), Label: ""}
return netlink.AddrAdd(iface, ipAddr)
}
func setInterfaceName(iface netlink.Link, i *nwIface) error {
return netlink.LinkSetName(iface, i.DstName())
}
func setInterfaceRoutes(iface netlink.Link, i *nwIface) error {
for _, route := range i.Routes() {
err := netlink.RouteAdd(&netlink.Route{
Scope: netlink.SCOPE_LINK,
LinkIndex: iface.Attrs().Index,
Dst: route,
})
if err != nil {
return err
}
}
return nil
}

View file

@ -0,0 +1,4 @@
package sandbox
// IfaceOption is a function option type to set interface options
type IfaceOption func()

View file

@ -12,6 +12,7 @@ import (
log "github.com/Sirupsen/logrus"
"github.com/docker/docker/pkg/reexec"
"github.com/docker/libnetwork/types"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netns"
)
@ -31,9 +32,12 @@ var (
// interface. It represents a linux network namespace, and moves an interface
// into it when called on method AddInterface or sets the gateway etc.
type networkNamespace struct {
path string
sinfo *Info
nextIfIndex int
path string
iFaces []*nwIface
gw net.IP
gwv6 net.IP
staticRoutes []*types.StaticRoute
nextIfIndex int
sync.Mutex
}
@ -134,12 +138,16 @@ func GenerateKey(containerID string) string {
// NewSandbox provides a new sandbox instance created in an os specific way
// provided a key which uniquely identifies the sandbox
func NewSandbox(key string, osCreate bool) (Sandbox, error) {
info, err := createNetworkNamespace(key, osCreate)
err := createNetworkNamespace(key, osCreate)
if err != nil {
return nil, err
}
return &networkNamespace{path: key, sinfo: info}, nil
return &networkNamespace{path: key}, nil
}
func (n *networkNamespace) InterfaceOptions() IfaceOptionSetter {
return n
}
func reexecCreateNamespace() {
@ -156,18 +164,18 @@ func reexecCreateNamespace() {
}
}
func createNetworkNamespace(path string, osCreate bool) (*Info, error) {
func createNetworkNamespace(path string, osCreate bool) error {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
origns, err := netns.Get()
if err != nil {
return nil, err
return err
}
defer origns.Close()
if err := createNamespaceFile(path); err != nil {
return nil, err
return err
}
cmd := &exec.Cmd{
@ -181,12 +189,10 @@ func createNetworkNamespace(path string, osCreate bool) (*Info, error) {
cmd.SysProcAttr.Cloneflags = syscall.CLONE_NEWNET
}
if err := cmd.Run(); err != nil {
return nil, fmt.Errorf("namespace creation reexec command failed: %v", err)
return fmt.Errorf("namespace creation reexec command failed: %v", err)
}
interfaces := []*Interface{}
info := &Info{Interfaces: interfaces}
return info, nil
return nil
}
func unmountNamespaceFile(path string) {
@ -224,7 +230,7 @@ func loopbackUp() error {
return netlink.LinkSetUp(iface)
}
func (n *networkNamespace) RemoveInterface(i *Interface) error {
func nsInvoke(path string, prefunc func(nsFD int) error, postfunc func(callerFD int) error) error {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
@ -234,75 +240,18 @@ func (n *networkNamespace) RemoveInterface(i *Interface) error {
}
defer origns.Close()
f, err := os.OpenFile(n.path, os.O_RDONLY, 0)
f, err := os.OpenFile(path, os.O_RDONLY, 0)
if err != nil {
return fmt.Errorf("failed get network namespace %q: %v", n.path, err)
return fmt.Errorf("failed get network namespace %q: %v", path, err)
}
defer f.Close()
nsFD := f.Fd()
if err = netns.Set(netns.NsHandle(nsFD)); err != nil {
return err
}
defer netns.Set(origns)
// Find the network inteerface identified by the DstName attribute.
iface, err := netlink.LinkByName(i.DstName)
if err != nil {
return err
}
// Down the interface before configuring
if err := netlink.LinkSetDown(iface); err != nil {
return err
}
err = netlink.LinkSetName(iface, i.SrcName)
if err != nil {
fmt.Println("LinkSetName failed: ", err)
return err
}
// Move the network interface to caller namespace.
if err := netlink.LinkSetNsFd(iface, int(origns)); err != nil {
fmt.Println("LinkSetNsPid failed: ", err)
return err
}
return nil
}
func (n *networkNamespace) AddInterface(i *Interface) error {
n.Lock()
i.DstName = fmt.Sprintf("%s%d", i.DstName, n.nextIfIndex)
n.nextIfIndex++
n.Unlock()
runtime.LockOSThread()
defer runtime.UnlockOSThread()
origns, err := netns.Get()
if err != nil {
return err
}
defer origns.Close()
f, err := os.OpenFile(n.path, os.O_RDONLY, 0)
if err != nil {
return fmt.Errorf("failed get network namespace %q: %v", n.path, err)
}
defer f.Close()
// Find the network interface identified by the SrcName attribute.
iface, err := netlink.LinkByName(i.SrcName)
if err != nil {
return err
}
// Move the network interface to the destination namespace.
nsFD := f.Fd()
if err := netlink.LinkSetNsFd(iface, int(nsFD)); err != nil {
return err
// Invoked before the namespace switch happens but after the namespace file
// handle is obtained.
if err := prefunc(int(nsFD)); err != nil {
return fmt.Errorf("failed in prefunc: %v", err)
}
if err = netns.Set(netns.NsHandle(nsFD)); err != nil {
@ -310,56 +259,19 @@ func (n *networkNamespace) AddInterface(i *Interface) error {
}
defer netns.Set(origns)
// Down the interface before configuring
if err := netlink.LinkSetDown(iface); err != nil {
return err
}
// Configure the interface now this is moved in the proper namespace.
if err := configureInterface(iface, i); err != nil {
return err
}
// Up the interface.
if err := netlink.LinkSetUp(iface); err != nil {
return err
}
// Invoked after the namespace switch.
return postfunc(int(origns))
}
func (n *networkNamespace) nsPath() string {
n.Lock()
n.sinfo.Interfaces = append(n.sinfo.Interfaces, i)
n.Unlock()
defer n.Unlock()
return nil
return n.path
}
func (n *networkNamespace) SetGateway(gw net.IP) error {
if len(gw) == 0 {
return nil
}
err := programGateway(n.path, gw)
if err == nil {
n.sinfo.Gateway = gw
}
return err
}
func (n *networkNamespace) SetGatewayIPv6(gw net.IP) error {
if len(gw) == 0 {
return nil
}
err := programGateway(n.path, gw)
if err == nil {
n.sinfo.GatewayIPv6 = gw
}
return err
}
func (n *networkNamespace) Interfaces() []*Interface {
return n.sinfo.Interfaces
func (n *networkNamespace) Info() Info {
return n
}
func (n *networkNamespace) Key() string {

View file

@ -0,0 +1,8 @@
// +build !linux,!windows
package sandbox
// GC triggers garbage collection of namespace path right away
// and waits for it.
func GC() {
}

View file

@ -0,0 +1,23 @@
package sandbox
// GenerateKey generates a sandbox key based on the passed
// container id.
func GenerateKey(containerID string) string {
maxLen := 12
if len(containerID) < maxLen {
maxLen = len(containerID)
}
return containerID[:maxLen]
}
// NewSandbox provides a new sandbox instance created in an os specific way
// provided a key which uniquely identifies the sandbox
func NewSandbox(key string, osCreate bool) (Sandbox, error) {
return nil, nil
}
// GC triggers garbage collection of namespace path right away
// and waits for it.
func GC() {
}

View file

@ -0,0 +1,41 @@
package sandbox
import "net"
func (i *nwIface) processInterfaceOptions(options ...IfaceOption) {
for _, opt := range options {
if opt != nil {
opt(i)
}
}
}
func (n *networkNamespace) Bridge(isBridge bool) IfaceOption {
return func(i *nwIface) {
i.bridge = isBridge
}
}
func (n *networkNamespace) Master(name string) IfaceOption {
return func(i *nwIface) {
i.master = name
}
}
func (n *networkNamespace) Address(addr *net.IPNet) IfaceOption {
return func(i *nwIface) {
i.address = addr
}
}
func (n *networkNamespace) AddressIPv6(addr *net.IPNet) IfaceOption {
return func(i *nwIface) {
i.addressIPv6 = addr
}
}
func (n *networkNamespace) Routes(routes []*net.IPNet) IfaceOption {
return func(i *nwIface) {
i.routes = routes
}
}

View file

@ -0,0 +1,198 @@
package sandbox
import (
"fmt"
"net"
"github.com/docker/libnetwork/types"
"github.com/vishvananda/netlink"
)
func (n *networkNamespace) Gateway() net.IP {
n.Lock()
defer n.Unlock()
return n.gw
}
func (n *networkNamespace) GatewayIPv6() net.IP {
n.Lock()
defer n.Unlock()
return n.gwv6
}
func (n *networkNamespace) StaticRoutes() []*types.StaticRoute {
n.Lock()
defer n.Unlock()
routes := make([]*types.StaticRoute, len(n.staticRoutes))
for i, route := range n.staticRoutes {
r := route.GetCopy()
routes[i] = r
}
return routes
}
func (n *networkNamespace) setGateway(gw net.IP) {
n.Lock()
n.gw = gw
n.Unlock()
}
func (n *networkNamespace) setGatewayIPv6(gwv6 net.IP) {
n.Lock()
n.gwv6 = gwv6
n.Unlock()
}
func (n *networkNamespace) SetGateway(gw net.IP) error {
// Silently return if the gateway is empty
if len(gw) == 0 {
return nil
}
err := programGateway(n.nsPath(), gw, true)
if err == nil {
n.setGateway(gw)
}
return err
}
func (n *networkNamespace) UnsetGateway() error {
gw := n.Gateway()
// Silently return if the gateway is empty
if len(gw) == 0 {
return nil
}
err := programGateway(n.nsPath(), gw, false)
if err == nil {
n.setGateway(net.IP{})
}
return err
}
func programGateway(path string, gw net.IP, isAdd bool) error {
return nsInvoke(path, func(nsFD int) error { return nil }, func(callerFD int) error {
gwRoutes, err := netlink.RouteGet(gw)
if err != nil {
return fmt.Errorf("route for the gateway could not be found: %v", err)
}
if isAdd {
return netlink.RouteAdd(&netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
LinkIndex: gwRoutes[0].LinkIndex,
Gw: gw,
})
}
return netlink.RouteDel(&netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
LinkIndex: gwRoutes[0].LinkIndex,
Gw: gw,
})
})
}
// Program a route in to the namespace routing table.
func programRoute(path string, dest *net.IPNet, nh net.IP) error {
return nsInvoke(path, func(nsFD int) error { return nil }, func(callerFD int) error {
gwRoutes, err := netlink.RouteGet(nh)
if err != nil {
return fmt.Errorf("route for the next hop could not be found: %v", err)
}
return netlink.RouteAdd(&netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
LinkIndex: gwRoutes[0].LinkIndex,
Gw: gwRoutes[0].Gw,
Dst: dest,
})
})
}
// Delete a route from the namespace routing table.
func removeRoute(path string, dest *net.IPNet, nh net.IP) error {
return nsInvoke(path, func(nsFD int) error { return nil }, func(callerFD int) error {
gwRoutes, err := netlink.RouteGet(nh)
if err != nil {
return fmt.Errorf("route for the next hop could not be found: %v", err)
}
return netlink.RouteDel(&netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
LinkIndex: gwRoutes[0].LinkIndex,
Gw: gwRoutes[0].Gw,
Dst: dest,
})
})
}
func (n *networkNamespace) SetGatewayIPv6(gwv6 net.IP) error {
// Silently return if the gateway is empty
if len(gwv6) == 0 {
return nil
}
err := programGateway(n.nsPath(), gwv6, true)
if err == nil {
n.SetGatewayIPv6(gwv6)
}
return err
}
func (n *networkNamespace) UnsetGatewayIPv6() error {
gwv6 := n.GatewayIPv6()
// Silently return if the gateway is empty
if len(gwv6) == 0 {
return nil
}
err := programGateway(n.nsPath(), gwv6, false)
if err == nil {
n.Lock()
n.gwv6 = net.IP{}
n.Unlock()
}
return err
}
func (n *networkNamespace) AddStaticRoute(r *types.StaticRoute) error {
err := programRoute(n.nsPath(), r.Destination, r.NextHop)
if err == nil {
n.Lock()
n.staticRoutes = append(n.staticRoutes, r)
n.Unlock()
}
return err
}
func (n *networkNamespace) RemoveStaticRoute(r *types.StaticRoute) error {
n.Lock()
err := removeRoute(n.nsPath(), r.Destination, r.NextHop)
if err == nil {
n.Lock()
lastIndex := len(n.staticRoutes) - 1
for i, v := range n.staticRoutes {
if v == r {
// Overwrite the route we're removing with the last element
n.staticRoutes[i] = n.staticRoutes[lastIndex]
// Shorten the slice to trim the extra element
n.staticRoutes = n.staticRoutes[:lastIndex]
break
}
}
n.Unlock()
}
return err
}

View file

@ -12,22 +12,12 @@ type Sandbox interface {
// The path where the network namespace is mounted.
Key() string
// The collection of Interface previously added with the AddInterface
// method. Note that this doesn't incude network interfaces added in any
// other way (such as the default loopback interface which are automatically
// created on creation of a sandbox).
Interfaces() []*Interface
// Add an existing Interface to this sandbox. The operation will rename
// from the Interface SrcName to DstName as it moves, and reconfigure the
// interface according to the specified settings. The caller is expected
// to only provide a prefix for DstName. The AddInterface api will auto-generate
// an appropriate suffix for the DstName to disambiguate.
AddInterface(*Interface) error
// Remove an interface from the sandbox by renamin to original name
// and moving it out of the sandbox.
RemoveInterface(*Interface) error
AddInterface(SrcName string, DstPrefix string, options ...IfaceOption) error
// Set default IPv4 gateway for the sandbox
SetGateway(gw net.IP) error
@ -35,23 +25,69 @@ type Sandbox interface {
// Set default IPv6 gateway for the sandbox
SetGatewayIPv6(gw net.IP) error
// Unset the previously set default IPv4 gateway in the sandbox
UnsetGateway() error
// Unset the previously set default IPv6 gateway in the sandbox
UnsetGatewayIPv6() error
// Add a static route to the sandbox.
AddStaticRoute(*types.StaticRoute) error
// Remove a static route from the sandbox.
RemoveStaticRoute(*types.StaticRoute) error
// Returns an interface with methods to set interface options.
InterfaceOptions() IfaceOptionSetter
// Returns an interface with methods to get sandbox state.
Info() Info
// Destroy the sandbox
Destroy() error
}
// IfaceOptionSetter interface defines the option setter methods for interface options.
type IfaceOptionSetter interface {
// Bridge returns an option setter to set if the interface is a bridge.
Bridge(bool) IfaceOption
// Address returns an option setter to set IPv4 address.
Address(*net.IPNet) IfaceOption
// Address returns an option setter to set IPv6 address.
AddressIPv6(*net.IPNet) IfaceOption
// Master returns an option setter to set the master interface if any for this
// interface. The master interface name should refer to the srcname of a
// previously added interface of type bridge.
Master(string) IfaceOption
// Address returns an option setter to set interface routes.
Routes([]*net.IPNet) IfaceOption
}
// Info represents all possible information that
// the driver wants to place in the sandbox which includes
// interfaces, routes and gateway
type Info struct {
Interfaces []*Interface
type Info interface {
// The collection of Interface previously added with the AddInterface
// method. Note that this doesn't incude network interfaces added in any
// other way (such as the default loopback interface which are automatically
// created on creation of a sandbox).
Interfaces() []Interface
// IPv4 gateway for the sandbox.
Gateway net.IP
Gateway() net.IP
// IPv6 gateway for the sandbox.
GatewayIPv6 net.IP
GatewayIPv6() net.IP
// TODO: Add routes and ip tables etc.
// Additional static routes for the sandbox. (Note that directly
// connected routes are stored on the particular interface they refer to.)
StaticRoutes() []*types.StaticRoute
// TODO: Add ip tables etc.
}
// Interface represents the settings and identity of a network device. It is
@ -59,101 +95,32 @@ type Info struct {
// caller to use this information when moving interface SrcName from host
// namespace to DstName in a different net namespace with the appropriate
// network settings.
type Interface struct {
type Interface interface {
// The name of the interface in the origin network namespace.
SrcName string
SrcName() string
// The name that will be assigned to the interface once moves inside a
// network namespace. When the caller passes in a DstName, it is only
// expected to pass a prefix. The name will modified with an appropriately
// auto-generated suffix.
DstName string
DstName() string
// IPv4 address for the interface.
Address *net.IPNet
Address() *net.IPNet
// IPv6 address for the interface.
AddressIPv6 *net.IPNet
}
// GetCopy returns a copy of this Interface structure
func (i *Interface) GetCopy() *Interface {
return &Interface{
SrcName: i.SrcName,
DstName: i.DstName,
Address: types.GetIPNetCopy(i.Address),
AddressIPv6: types.GetIPNetCopy(i.AddressIPv6),
}
}
// Equal checks if this instance of Interface is equal to the passed one
func (i *Interface) Equal(o *Interface) bool {
if i == o {
return true
}
if o == nil {
return false
}
if i.SrcName != o.SrcName || i.DstName != o.DstName {
return false
}
if !types.CompareIPNet(i.Address, o.Address) {
return false
}
if !types.CompareIPNet(i.AddressIPv6, o.AddressIPv6) {
return false
}
return true
}
// GetCopy returns a copy of this SandboxInfo structure
func (s *Info) GetCopy() *Info {
list := make([]*Interface, len(s.Interfaces))
for i, iface := range s.Interfaces {
list[i] = iface.GetCopy()
}
gw := types.GetIPCopy(s.Gateway)
gw6 := types.GetIPCopy(s.GatewayIPv6)
return &Info{Interfaces: list, Gateway: gw, GatewayIPv6: gw6}
}
// Equal checks if this instance of SandboxInfo is equal to the passed one
func (s *Info) Equal(o *Info) bool {
if s == o {
return true
}
if o == nil {
return false
}
if !s.Gateway.Equal(o.Gateway) {
return false
}
if !s.GatewayIPv6.Equal(o.GatewayIPv6) {
return false
}
if (s.Interfaces == nil && o.Interfaces != nil) ||
(s.Interfaces != nil && o.Interfaces == nil) ||
(len(s.Interfaces) != len(o.Interfaces)) {
return false
}
// Note: At the moment, the two lists must be in the same order
for i := 0; i < len(s.Interfaces); i++ {
if !s.Interfaces[i].Equal(o.Interfaces[i]) {
return false
}
}
return true
AddressIPv6() *net.IPNet
// IP routes for the interface.
Routes() []*net.IPNet
// Bridge returns true if the interface is a bridge
Bridge() bool
// Master returns the srcname of the master interface for this interface.
Master() string
// Remove an interface from the sandbox by renaming to original name
// and moving it out of the sandbox.
Remove() error
}

View file

@ -1,15 +1,22 @@
// +build !linux
// +build !linux,!windows
package sandbox
import "errors"
var (
// ErrNotImplemented is for platforms which don't implement sandbox
ErrNotImplemented = errors.New("not implemented")
)
// NewSandbox provides a new sandbox instance created in an os specific way
// provided a key which uniquely identifies the sandbox
func NewSandbox(key string) (Sandbox, error) {
func NewSandbox(key string, osCreate bool) (Sandbox, error) {
return nil, ErrNotImplemented
}
// GenerateKey generates a sandbox key based on the passed
// container id.
func GenerateKey(containerID string) string {
return ""
}

View file

@ -0,0 +1,245 @@
package libnetwork
import (
"container/heap"
"sync"
"github.com/Sirupsen/logrus"
"github.com/docker/libnetwork/sandbox"
)
type epHeap []*endpoint
type sandboxData struct {
sbox sandbox.Sandbox
refCnt int
endpoints epHeap
sync.Mutex
}
func (eh epHeap) Len() int { return len(eh) }
func (eh epHeap) Less(i, j int) bool {
eh[i].Lock()
eh[j].Lock()
defer eh[j].Unlock()
defer eh[i].Unlock()
if eh[i].container.config.prio == eh[j].container.config.prio {
return eh[i].network.Name() < eh[j].network.Name()
}
return eh[i].container.config.prio > eh[j].container.config.prio
}
func (eh epHeap) Swap(i, j int) { eh[i], eh[j] = eh[j], eh[i] }
func (eh *epHeap) Push(x interface{}) {
*eh = append(*eh, x.(*endpoint))
}
func (eh *epHeap) Pop() interface{} {
old := *eh
n := len(old)
x := old[n-1]
*eh = old[0 : n-1]
return x
}
func (s *sandboxData) updateGateway(ep *endpoint) error {
sb := s.sandbox()
if err := sb.UnsetGateway(); err != nil {
return err
}
if err := sb.UnsetGatewayIPv6(); err != nil {
return err
}
if ep == nil {
return nil
}
ep.Lock()
joinInfo := ep.joinInfo
ep.Unlock()
if err := sb.SetGateway(joinInfo.gw); err != nil {
return err
}
if err := sb.SetGatewayIPv6(joinInfo.gw6); err != nil {
return err
}
return nil
}
func (s *sandboxData) addEndpoint(ep *endpoint) error {
ep.Lock()
joinInfo := ep.joinInfo
ifaces := ep.iFaces
ep.Unlock()
sb := s.sandbox()
for _, i := range ifaces {
var ifaceOptions []sandbox.IfaceOption
ifaceOptions = append(ifaceOptions, sb.InterfaceOptions().Address(&i.addr),
sb.InterfaceOptions().Routes(i.routes))
if i.addrv6.IP.To16() != nil {
ifaceOptions = append(ifaceOptions,
sb.InterfaceOptions().AddressIPv6(&i.addrv6))
}
if err := sb.AddInterface(i.srcName, i.dstPrefix, ifaceOptions...); err != nil {
return err
}
}
if joinInfo != nil {
// Set up non-interface routes.
for _, r := range ep.joinInfo.StaticRoutes {
if err := sb.AddStaticRoute(r); err != nil {
return err
}
}
}
s.Lock()
heap.Push(&s.endpoints, ep)
highEp := s.endpoints[0]
s.Unlock()
if ep == highEp {
if err := s.updateGateway(ep); err != nil {
return err
}
}
s.Lock()
s.refCnt++
s.Unlock()
return nil
}
func (s *sandboxData) rmEndpoint(ep *endpoint) int {
ep.Lock()
joinInfo := ep.joinInfo
ep.Unlock()
sb := s.sandbox()
for _, i := range sb.Info().Interfaces() {
// Only remove the interfaces owned by this endpoint from the sandbox.
if ep.hasInterface(i.SrcName()) {
if err := i.Remove(); err != nil {
logrus.Debugf("Remove interface failed: %v", err)
}
}
}
// Remove non-interface routes.
for _, r := range joinInfo.StaticRoutes {
if err := sb.RemoveStaticRoute(r); err != nil {
logrus.Debugf("Remove route failed: %v", err)
}
}
// We don't check if s.endpoints is empty here because
// it should never be empty during a rmEndpoint call and
// if it is we will rightfully panic here
s.Lock()
highEpBefore := s.endpoints[0]
var (
i int
e *endpoint
)
for i, e = range s.endpoints {
if e == ep {
break
}
}
heap.Remove(&s.endpoints, i)
var highEpAfter *endpoint
if len(s.endpoints) > 0 {
highEpAfter = s.endpoints[0]
}
s.Unlock()
if highEpBefore != highEpAfter {
s.updateGateway(highEpAfter)
}
s.Lock()
s.refCnt--
refCnt := s.refCnt
s.Unlock()
if refCnt == 0 {
s.sandbox().Destroy()
}
return refCnt
}
func (s *sandboxData) sandbox() sandbox.Sandbox {
s.Lock()
defer s.Unlock()
return s.sbox
}
func (c *controller) sandboxAdd(key string, create bool, ep *endpoint) (sandbox.Sandbox, error) {
c.Lock()
sData, ok := c.sandboxes[key]
c.Unlock()
if !ok {
sb, err := sandbox.NewSandbox(key, create)
if err != nil {
return nil, err
}
sData = &sandboxData{
sbox: sb,
endpoints: epHeap{},
}
heap.Init(&sData.endpoints)
c.Lock()
c.sandboxes[key] = sData
c.Unlock()
}
if err := sData.addEndpoint(ep); err != nil {
return nil, err
}
return sData.sandbox(), nil
}
func (c *controller) sandboxRm(key string, ep *endpoint) {
c.Lock()
sData := c.sandboxes[key]
c.Unlock()
if sData.rmEndpoint(ep) == 0 {
c.Lock()
delete(c.sandboxes, key)
c.Unlock()
}
}
func (c *controller) sandboxGet(key string) sandbox.Sandbox {
c.Lock()
sData, ok := c.sandboxes[key]
c.Unlock()
if !ok {
return nil
}
return sData.sandbox()
}

View file

@ -0,0 +1,299 @@
package libnetwork
import (
"encoding/json"
"fmt"
log "github.com/Sirupsen/logrus"
"github.com/docker/libnetwork/datastore"
"github.com/docker/libnetwork/types"
)
func (c *controller) validateDatastoreConfig() bool {
if c.cfg == nil || c.cfg.Datastore.Client.Provider == "" || c.cfg.Datastore.Client.Address == "" {
return false
}
return true
}
func (c *controller) initDataStore() error {
c.Lock()
cfg := c.cfg
c.Unlock()
if !c.validateDatastoreConfig() {
return fmt.Errorf("datastore initialization requires a valid configuration")
}
store, err := datastore.NewDataStore(&cfg.Datastore)
if err != nil {
return err
}
c.Lock()
c.store = store
c.Unlock()
return c.watchStore()
}
func (c *controller) newNetworkFromStore(n *network) error {
n.Lock()
n.ctrlr = c
n.endpoints = endpointTable{}
n.Unlock()
return c.addNetwork(n)
}
func (c *controller) updateNetworkToStore(n *network) error {
global, err := n.isGlobalScoped()
if err != nil || !global {
return err
}
c.Lock()
cs := c.store
c.Unlock()
if cs == nil {
log.Debugf("datastore not initialized. Network %s is not added to the store", n.Name())
return nil
}
return cs.PutObjectAtomic(n)
}
func (c *controller) deleteNetworkFromStore(n *network) error {
global, err := n.isGlobalScoped()
if err != nil || !global {
return err
}
c.Lock()
cs := c.store
c.Unlock()
if cs == nil {
log.Debugf("datastore not initialized. Network %s is not deleted from datastore", n.Name())
return nil
}
if err := cs.DeleteObjectAtomic(n); err != nil {
return err
}
return nil
}
func (c *controller) getNetworkFromStore(nid types.UUID) (*network, error) {
n := network{id: nid}
if err := c.store.GetObject(datastore.Key(n.Key()...), &n); err != nil {
return nil, err
}
return &n, nil
}
func (c *controller) newEndpointFromStore(key string, ep *endpoint) error {
ep.Lock()
n := ep.network
id := ep.id
ep.Unlock()
if n == nil {
// Possibly the watch event for the network has not shown up yet
// Try to get network from the store
nid, err := networkIDFromEndpointKey(key, ep)
if err != nil {
return err
}
n, err = c.getNetworkFromStore(nid)
if err != nil {
return err
}
if err := c.newNetworkFromStore(n); err != nil {
return err
}
n = c.networks[nid]
}
_, err := n.EndpointByID(string(id))
if err != nil {
if _, ok := err.(ErrNoSuchEndpoint); ok {
return n.addEndpoint(ep)
}
}
return err
}
func (c *controller) updateEndpointToStore(ep *endpoint) error {
ep.Lock()
n := ep.network
name := ep.name
ep.Unlock()
global, err := n.isGlobalScoped()
if err != nil || !global {
return err
}
c.Lock()
cs := c.store
c.Unlock()
if cs == nil {
log.Debugf("datastore not initialized. endpoint %s is not added to the store", name)
return nil
}
return cs.PutObjectAtomic(ep)
}
func (c *controller) getEndpointFromStore(eid types.UUID) (*endpoint, error) {
ep := endpoint{id: eid}
if err := c.store.GetObject(datastore.Key(ep.Key()...), &ep); err != nil {
return nil, err
}
return &ep, nil
}
func (c *controller) deleteEndpointFromStore(ep *endpoint) error {
ep.Lock()
n := ep.network
ep.Unlock()
global, err := n.isGlobalScoped()
if err != nil || !global {
return err
}
c.Lock()
cs := c.store
c.Unlock()
if cs == nil {
log.Debugf("datastore not initialized. endpoint %s is not deleted from datastore", ep.Name())
return nil
}
if err := cs.DeleteObjectAtomic(ep); err != nil {
return err
}
return nil
}
func (c *controller) watchStore() error {
c.Lock()
cs := c.store
c.Unlock()
nwPairs, err := cs.KVStore().WatchTree(datastore.Key(datastore.NetworkKeyPrefix), nil)
if err != nil {
return err
}
epPairs, err := cs.KVStore().WatchTree(datastore.Key(datastore.EndpointKeyPrefix), nil)
if err != nil {
return err
}
go func() {
for {
select {
case nws := <-nwPairs:
for _, kve := range nws {
var n network
err := json.Unmarshal(kve.Value, &n)
if err != nil {
log.Error(err)
continue
}
n.dbIndex = kve.LastIndex
c.Lock()
existing, ok := c.networks[n.id]
c.Unlock()
if ok {
existing.Lock()
// Skip existing network update
if existing.dbIndex != n.dbIndex {
existing.dbIndex = n.dbIndex
existing.endpointCnt = n.endpointCnt
}
existing.Unlock()
continue
}
if err = c.newNetworkFromStore(&n); err != nil {
log.Error(err)
}
}
case eps := <-epPairs:
for _, epe := range eps {
var ep endpoint
err := json.Unmarshal(epe.Value, &ep)
if err != nil {
log.Error(err)
continue
}
ep.dbIndex = epe.LastIndex
n, err := c.networkFromEndpointKey(epe.Key, &ep)
if err != nil {
if _, ok := err.(ErrNoSuchNetwork); !ok {
log.Error(err)
continue
}
}
if n != nil {
ep.network = n.(*network)
}
if c.processEndpointUpdate(&ep) {
err = c.newEndpointFromStore(epe.Key, &ep)
if err != nil {
log.Error(err)
}
}
}
}
}
}()
return nil
}
func (c *controller) networkFromEndpointKey(key string, ep *endpoint) (Network, error) {
nid, err := networkIDFromEndpointKey(key, ep)
if err != nil {
return nil, err
}
return c.NetworkByID(string(nid))
}
func networkIDFromEndpointKey(key string, ep *endpoint) (types.UUID, error) {
eKey, err := datastore.ParseKey(key)
if err != nil {
return types.UUID(""), err
}
return ep.networkIDFromKey(eKey)
}
func (c *controller) processEndpointUpdate(ep *endpoint) bool {
nw := ep.network
if nw == nil {
return true
}
nw.Lock()
id := nw.id
nw.Unlock()
c.Lock()
n, ok := c.networks[id]
c.Unlock()
if !ok {
return true
}
existing, _ := n.EndpointByID(string(ep.id))
if existing == nil {
return true
}
ee := existing.(*endpoint)
ee.Lock()
if ee.dbIndex != ep.dbIndex {
ee.dbIndex = ep.dbIndex
if ee.container != nil && ep.container != nil {
// we care only about the container id
ee.container.id = ep.container.id
} else {
// we still care only about the container id, but this is a short-cut to communicate join or leave operation
ee.container = ep.container
}
}
ee.Unlock()
return false
}

View file

@ -1,34 +0,0 @@
package libnetwork
import (
"fmt"
"runtime"
"syscall"
)
// Via http://git.kernel.org/cgit/linux/kernel/git/torvalds/linux.git/commit/?id=7b21fddd087678a70ad64afc0f632e0f1071b092
//
// We need different setns values for the different platforms and arch
// We are declaring the macro here because the SETNS syscall does not exist in th stdlib
var setNsMap = map[string]uintptr{
"linux/386": 346,
"linux/amd64": 308,
"linux/arm": 374,
"linux/ppc64": 350,
"linux/ppc64le": 350,
"linux/s390x": 339,
}
func setns(fd uintptr, flags uintptr) error {
ns, exists := setNsMap[fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH)]
if !exists {
return fmt.Errorf("unsupported platform %s/%s", runtime.GOOS, runtime.GOARCH)
}
_, _, err := syscall.RawSyscall(ns, fd, flags, 0)
if err != 0 {
return err
}
return nil
}

View file

@ -22,7 +22,7 @@ func (t *TransportPort) GetCopy() TransportPort {
return TransportPort{Proto: t.Proto, Port: t.Port}
}
// PortBinding represent a port binding between the container an the host
// PortBinding represent a port binding between the container and the host
type PortBinding struct {
Proto Protocol
IP net.IP
@ -184,6 +184,40 @@ func CompareIPNet(a, b *net.IPNet) bool {
return a.IP.Equal(b.IP) && bytes.Equal(a.Mask, b.Mask)
}
const (
// NEXTHOP indicates a StaticRoute with an IP next hop.
NEXTHOP = iota
// CONNECTED indicates a StaticRoute with a interface for directly connected peers.
CONNECTED
)
// StaticRoute is a statically-provisioned IP route.
type StaticRoute struct {
Destination *net.IPNet
RouteType int // NEXT_HOP or CONNECTED
// NextHop will be resolved by the kernel (i.e. as a loose hop).
NextHop net.IP
// InterfaceID must refer to a defined interface on the
// Endpoint to which the routes are specified. Routes specified this way
// are interpreted as directly connected to the specified interface (no
// next hop will be used).
InterfaceID int
}
// GetCopy returns a copy of this StaticRoute structure
func (r *StaticRoute) GetCopy() *StaticRoute {
d := GetIPNetCopy(r.Destination)
nh := GetIPCopy(r.NextHop)
return &StaticRoute{Destination: d,
RouteType: r.RouteType,
NextHop: nh,
InterfaceID: r.InterfaceID}
}
/******************************
* Well-known Error Interfaces
******************************/

View file

@ -0,0 +1,39 @@
Consul API client
=================
This package provides the `api` package which attempts to
provide programmatic access to the full Consul API.
Currently, all of the Consul APIs included in version 0.3 are supported.
Documentation
=============
The full documentation is available on [Godoc](http://godoc.org/github.com/hashicorp/consul/api)
Usage
=====
Below is an example of using the Consul client:
```go
// Get a new client, with KV endpoints
client, _ := api.NewClient(api.DefaultConfig())
kv := client.KV()
// PUT a new KV pair
p := &api.KVPair{Key: "foo", Value: []byte("test")}
_, err := kv.Put(p, nil)
if err != nil {
panic(err)
}
// Lookup the pair
pair, _, err := kv.Get("foo", nil)
if err != nil {
panic(err)
}
fmt.Printf("KV: %v", pair)
```

View file

@ -0,0 +1,140 @@
package api
const (
// ACLCLientType is the client type token
ACLClientType = "client"
// ACLManagementType is the management type token
ACLManagementType = "management"
)
// ACLEntry is used to represent an ACL entry
type ACLEntry struct {
CreateIndex uint64
ModifyIndex uint64
ID string
Name string
Type string
Rules string
}
// ACL can be used to query the ACL endpoints
type ACL struct {
c *Client
}
// ACL returns a handle to the ACL endpoints
func (c *Client) ACL() *ACL {
return &ACL{c}
}
// Create is used to generate a new token with the given parameters
func (a *ACL) Create(acl *ACLEntry, q *WriteOptions) (string, *WriteMeta, error) {
r := a.c.newRequest("PUT", "/v1/acl/create")
r.setWriteOptions(q)
r.obj = acl
rtt, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return "", nil, err
}
defer resp.Body.Close()
wm := &WriteMeta{RequestTime: rtt}
var out struct{ ID string }
if err := decodeBody(resp, &out); err != nil {
return "", nil, err
}
return out.ID, wm, nil
}
// Update is used to update the rules of an existing token
func (a *ACL) Update(acl *ACLEntry, q *WriteOptions) (*WriteMeta, error) {
r := a.c.newRequest("PUT", "/v1/acl/update")
r.setWriteOptions(q)
r.obj = acl
rtt, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return nil, err
}
defer resp.Body.Close()
wm := &WriteMeta{RequestTime: rtt}
return wm, nil
}
// Destroy is used to destroy a given ACL token ID
func (a *ACL) Destroy(id string, q *WriteOptions) (*WriteMeta, error) {
r := a.c.newRequest("PUT", "/v1/acl/destroy/"+id)
r.setWriteOptions(q)
rtt, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return nil, err
}
resp.Body.Close()
wm := &WriteMeta{RequestTime: rtt}
return wm, nil
}
// Clone is used to return a new token cloned from an existing one
func (a *ACL) Clone(id string, q *WriteOptions) (string, *WriteMeta, error) {
r := a.c.newRequest("PUT", "/v1/acl/clone/"+id)
r.setWriteOptions(q)
rtt, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return "", nil, err
}
defer resp.Body.Close()
wm := &WriteMeta{RequestTime: rtt}
var out struct{ ID string }
if err := decodeBody(resp, &out); err != nil {
return "", nil, err
}
return out.ID, wm, nil
}
// Info is used to query for information about an ACL token
func (a *ACL) Info(id string, q *QueryOptions) (*ACLEntry, *QueryMeta, error) {
r := a.c.newRequest("GET", "/v1/acl/info/"+id)
r.setQueryOptions(q)
rtt, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return nil, nil, err
}
defer resp.Body.Close()
qm := &QueryMeta{}
parseQueryMeta(resp, qm)
qm.RequestTime = rtt
var entries []*ACLEntry
if err := decodeBody(resp, &entries); err != nil {
return nil, nil, err
}
if len(entries) > 0 {
return entries[0], qm, nil
}
return nil, qm, nil
}
// List is used to get all the ACL tokens
func (a *ACL) List(q *QueryOptions) ([]*ACLEntry, *QueryMeta, error) {
r := a.c.newRequest("GET", "/v1/acl/list")
r.setQueryOptions(q)
rtt, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return nil, nil, err
}
defer resp.Body.Close()
qm := &QueryMeta{}
parseQueryMeta(resp, qm)
qm.RequestTime = rtt
var entries []*ACLEntry
if err := decodeBody(resp, &entries); err != nil {
return nil, nil, err
}
return entries, qm, nil
}

View file

@ -0,0 +1,334 @@
package api
import (
"fmt"
)
// AgentCheck represents a check known to the agent
type AgentCheck struct {
Node string
CheckID string
Name string
Status string
Notes string
Output string
ServiceID string
ServiceName string
}
// AgentService represents a service known to the agent
type AgentService struct {
ID string
Service string
Tags []string
Port int
Address string
}
// AgentMember represents a cluster member known to the agent
type AgentMember struct {
Name string
Addr string
Port uint16
Tags map[string]string
Status int
ProtocolMin uint8
ProtocolMax uint8
ProtocolCur uint8
DelegateMin uint8
DelegateMax uint8
DelegateCur uint8
}
// AgentServiceRegistration is used to register a new service
type AgentServiceRegistration struct {
ID string `json:",omitempty"`
Name string `json:",omitempty"`
Tags []string `json:",omitempty"`
Port int `json:",omitempty"`
Address string `json:",omitempty"`
Check *AgentServiceCheck
Checks AgentServiceChecks
}
// AgentCheckRegistration is used to register a new check
type AgentCheckRegistration struct {
ID string `json:",omitempty"`
Name string `json:",omitempty"`
Notes string `json:",omitempty"`
ServiceID string `json:",omitempty"`
AgentServiceCheck
}
// AgentServiceCheck is used to create an associated
// check for a service
type AgentServiceCheck struct {
Script string `json:",omitempty"`
Interval string `json:",omitempty"`
Timeout string `json:",omitempty"`
TTL string `json:",omitempty"`
HTTP string `json:",omitempty"`
Status string `json:",omitempty"`
}
type AgentServiceChecks []*AgentServiceCheck
// Agent can be used to query the Agent endpoints
type Agent struct {
c *Client
// cache the node name
nodeName string
}
// Agent returns a handle to the agent endpoints
func (c *Client) Agent() *Agent {
return &Agent{c: c}
}
// Self is used to query the agent we are speaking to for
// information about itself
func (a *Agent) Self() (map[string]map[string]interface{}, error) {
r := a.c.newRequest("GET", "/v1/agent/self")
_, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return nil, err
}
defer resp.Body.Close()
var out map[string]map[string]interface{}
if err := decodeBody(resp, &out); err != nil {
return nil, err
}
return out, nil
}
// NodeName is used to get the node name of the agent
func (a *Agent) NodeName() (string, error) {
if a.nodeName != "" {
return a.nodeName, nil
}
info, err := a.Self()
if err != nil {
return "", err
}
name := info["Config"]["NodeName"].(string)
a.nodeName = name
return name, nil
}
// Checks returns the locally registered checks
func (a *Agent) Checks() (map[string]*AgentCheck, error) {
r := a.c.newRequest("GET", "/v1/agent/checks")
_, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return nil, err
}
defer resp.Body.Close()
var out map[string]*AgentCheck
if err := decodeBody(resp, &out); err != nil {
return nil, err
}
return out, nil
}
// Services returns the locally registered services
func (a *Agent) Services() (map[string]*AgentService, error) {
r := a.c.newRequest("GET", "/v1/agent/services")
_, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return nil, err
}
defer resp.Body.Close()
var out map[string]*AgentService
if err := decodeBody(resp, &out); err != nil {
return nil, err
}
return out, nil
}
// Members returns the known gossip members. The WAN
// flag can be used to query a server for WAN members.
func (a *Agent) Members(wan bool) ([]*AgentMember, error) {
r := a.c.newRequest("GET", "/v1/agent/members")
if wan {
r.params.Set("wan", "1")
}
_, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return nil, err
}
defer resp.Body.Close()
var out []*AgentMember
if err := decodeBody(resp, &out); err != nil {
return nil, err
}
return out, nil
}
// ServiceRegister is used to register a new service with
// the local agent
func (a *Agent) ServiceRegister(service *AgentServiceRegistration) error {
r := a.c.newRequest("PUT", "/v1/agent/service/register")
r.obj = service
_, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return err
}
resp.Body.Close()
return nil
}
// ServiceDeregister is used to deregister a service with
// the local agent
func (a *Agent) ServiceDeregister(serviceID string) error {
r := a.c.newRequest("PUT", "/v1/agent/service/deregister/"+serviceID)
_, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return err
}
resp.Body.Close()
return nil
}
// PassTTL is used to set a TTL check to the passing state
func (a *Agent) PassTTL(checkID, note string) error {
return a.UpdateTTL(checkID, note, "pass")
}
// WarnTTL is used to set a TTL check to the warning state
func (a *Agent) WarnTTL(checkID, note string) error {
return a.UpdateTTL(checkID, note, "warn")
}
// FailTTL is used to set a TTL check to the failing state
func (a *Agent) FailTTL(checkID, note string) error {
return a.UpdateTTL(checkID, note, "fail")
}
// UpdateTTL is used to update the TTL of a check
func (a *Agent) UpdateTTL(checkID, note, status string) error {
switch status {
case "pass":
case "warn":
case "fail":
default:
return fmt.Errorf("Invalid status: %s", status)
}
endpoint := fmt.Sprintf("/v1/agent/check/%s/%s", status, checkID)
r := a.c.newRequest("PUT", endpoint)
r.params.Set("note", note)
_, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return err
}
resp.Body.Close()
return nil
}
// CheckRegister is used to register a new check with
// the local agent
func (a *Agent) CheckRegister(check *AgentCheckRegistration) error {
r := a.c.newRequest("PUT", "/v1/agent/check/register")
r.obj = check
_, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return err
}
resp.Body.Close()
return nil
}
// CheckDeregister is used to deregister a check with
// the local agent
func (a *Agent) CheckDeregister(checkID string) error {
r := a.c.newRequest("PUT", "/v1/agent/check/deregister/"+checkID)
_, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return err
}
resp.Body.Close()
return nil
}
// Join is used to instruct the agent to attempt a join to
// another cluster member
func (a *Agent) Join(addr string, wan bool) error {
r := a.c.newRequest("PUT", "/v1/agent/join/"+addr)
if wan {
r.params.Set("wan", "1")
}
_, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return err
}
resp.Body.Close()
return nil
}
// ForceLeave is used to have the agent eject a failed node
func (a *Agent) ForceLeave(node string) error {
r := a.c.newRequest("PUT", "/v1/agent/force-leave/"+node)
_, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return err
}
resp.Body.Close()
return nil
}
// EnableServiceMaintenance toggles service maintenance mode on
// for the given service ID.
func (a *Agent) EnableServiceMaintenance(serviceID, reason string) error {
r := a.c.newRequest("PUT", "/v1/agent/service/maintenance/"+serviceID)
r.params.Set("enable", "true")
r.params.Set("reason", reason)
_, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return err
}
resp.Body.Close()
return nil
}
// DisableServiceMaintenance toggles service maintenance mode off
// for the given service ID.
func (a *Agent) DisableServiceMaintenance(serviceID string) error {
r := a.c.newRequest("PUT", "/v1/agent/service/maintenance/"+serviceID)
r.params.Set("enable", "false")
_, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return err
}
resp.Body.Close()
return nil
}
// EnableNodeMaintenance toggles node maintenance mode on for the
// agent we are connected to.
func (a *Agent) EnableNodeMaintenance(reason string) error {
r := a.c.newRequest("PUT", "/v1/agent/maintenance")
r.params.Set("enable", "true")
r.params.Set("reason", reason)
_, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return err
}
resp.Body.Close()
return nil
}
// DisableNodeMaintenance toggles node maintenance mode off for the
// agent we are connected to.
func (a *Agent) DisableNodeMaintenance() error {
r := a.c.newRequest("PUT", "/v1/agent/maintenance")
r.params.Set("enable", "false")
_, resp, err := requireOK(a.c.doRequest(r))
if err != nil {
return err
}
resp.Body.Close()
return nil
}

View file

@ -0,0 +1,442 @@
package api
import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
)
// QueryOptions are used to parameterize a query
type QueryOptions struct {
// Providing a datacenter overwrites the DC provided
// by the Config
Datacenter string
// AllowStale allows any Consul server (non-leader) to service
// a read. This allows for lower latency and higher throughput
AllowStale bool
// RequireConsistent forces the read to be fully consistent.
// This is more expensive but prevents ever performing a stale
// read.
RequireConsistent bool
// WaitIndex is used to enable a blocking query. Waits
// until the timeout or the next index is reached
WaitIndex uint64
// WaitTime is used to bound the duration of a wait.
// Defaults to that of the Config, but can be overriden.
WaitTime time.Duration
// Token is used to provide a per-request ACL token
// which overrides the agent's default token.
Token string
}
// WriteOptions are used to parameterize a write
type WriteOptions struct {
// Providing a datacenter overwrites the DC provided
// by the Config
Datacenter string
// Token is used to provide a per-request ACL token
// which overrides the agent's default token.
Token string
}
// QueryMeta is used to return meta data about a query
type QueryMeta struct {
// LastIndex. This can be used as a WaitIndex to perform
// a blocking query
LastIndex uint64
// Time of last contact from the leader for the
// server servicing the request
LastContact time.Duration
// Is there a known leader
KnownLeader bool
// How long did the request take
RequestTime time.Duration
}
// WriteMeta is used to return meta data about a write
type WriteMeta struct {
// How long did the request take
RequestTime time.Duration
}
// HttpBasicAuth is used to authenticate http client with HTTP Basic Authentication
type HttpBasicAuth struct {
// Username to use for HTTP Basic Authentication
Username string
// Password to use for HTTP Basic Authentication
Password string
}
// Config is used to configure the creation of a client
type Config struct {
// Address is the address of the Consul server
Address string
// Scheme is the URI scheme for the Consul server
Scheme string
// Datacenter to use. If not provided, the default agent datacenter is used.
Datacenter string
// HttpClient is the client to use. Default will be
// used if not provided.
HttpClient *http.Client
// HttpAuth is the auth info to use for http access.
HttpAuth *HttpBasicAuth
// WaitTime limits how long a Watch will block. If not provided,
// the agent default values will be used.
WaitTime time.Duration
// Token is used to provide a per-request ACL token
// which overrides the agent's default token.
Token string
}
// DefaultConfig returns a default configuration for the client
func DefaultConfig() *Config {
config := &Config{
Address: "127.0.0.1:8500",
Scheme: "http",
HttpClient: http.DefaultClient,
}
if addr := os.Getenv("CONSUL_HTTP_ADDR"); addr != "" {
config.Address = addr
}
if token := os.Getenv("CONSUL_HTTP_TOKEN"); token != "" {
config.Token = token
}
if auth := os.Getenv("CONSUL_HTTP_AUTH"); auth != "" {
var username, password string
if strings.Contains(auth, ":") {
split := strings.SplitN(auth, ":", 2)
username = split[0]
password = split[1]
} else {
username = auth
}
config.HttpAuth = &HttpBasicAuth{
Username: username,
Password: password,
}
}
if ssl := os.Getenv("CONSUL_HTTP_SSL"); ssl != "" {
enabled, err := strconv.ParseBool(ssl)
if err != nil {
log.Printf("[WARN] client: could not parse CONSUL_HTTP_SSL: %s", err)
}
if enabled {
config.Scheme = "https"
}
}
if verify := os.Getenv("CONSUL_HTTP_SSL_VERIFY"); verify != "" {
doVerify, err := strconv.ParseBool(verify)
if err != nil {
log.Printf("[WARN] client: could not parse CONSUL_HTTP_SSL_VERIFY: %s", err)
}
if !doVerify {
config.HttpClient.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
}
}
return config
}
// Client provides a client to the Consul API
type Client struct {
config Config
}
// NewClient returns a new client
func NewClient(config *Config) (*Client, error) {
// bootstrap the config
defConfig := DefaultConfig()
if len(config.Address) == 0 {
config.Address = defConfig.Address
}
if len(config.Scheme) == 0 {
config.Scheme = defConfig.Scheme
}
if config.HttpClient == nil {
config.HttpClient = defConfig.HttpClient
}
if parts := strings.SplitN(config.Address, "unix://", 2); len(parts) == 2 {
config.HttpClient = &http.Client{
Transport: &http.Transport{
Dial: func(_, _ string) (net.Conn, error) {
return net.Dial("unix", parts[1])
},
},
}
config.Address = parts[1]
}
client := &Client{
config: *config,
}
return client, nil
}
// request is used to help build up a request
type request struct {
config *Config
method string
url *url.URL
params url.Values
body io.Reader
obj interface{}
}
// setQueryOptions is used to annotate the request with
// additional query options
func (r *request) setQueryOptions(q *QueryOptions) {
if q == nil {
return
}
if q.Datacenter != "" {
r.params.Set("dc", q.Datacenter)
}
if q.AllowStale {
r.params.Set("stale", "")
}
if q.RequireConsistent {
r.params.Set("consistent", "")
}
if q.WaitIndex != 0 {
r.params.Set("index", strconv.FormatUint(q.WaitIndex, 10))
}
if q.WaitTime != 0 {
r.params.Set("wait", durToMsec(q.WaitTime))
}
if q.Token != "" {
r.params.Set("token", q.Token)
}
}
// durToMsec converts a duration to a millisecond specified string
func durToMsec(dur time.Duration) string {
return fmt.Sprintf("%dms", dur/time.Millisecond)
}
// setWriteOptions is used to annotate the request with
// additional write options
func (r *request) setWriteOptions(q *WriteOptions) {
if q == nil {
return
}
if q.Datacenter != "" {
r.params.Set("dc", q.Datacenter)
}
if q.Token != "" {
r.params.Set("token", q.Token)
}
}
// toHTTP converts the request to an HTTP request
func (r *request) toHTTP() (*http.Request, error) {
// Encode the query parameters
r.url.RawQuery = r.params.Encode()
// Check if we should encode the body
if r.body == nil && r.obj != nil {
if b, err := encodeBody(r.obj); err != nil {
return nil, err
} else {
r.body = b
}
}
// Create the HTTP request
req, err := http.NewRequest(r.method, r.url.RequestURI(), r.body)
if err != nil {
return nil, err
}
req.URL.Host = r.url.Host
req.URL.Scheme = r.url.Scheme
req.Host = r.url.Host
// Setup auth
if r.config.HttpAuth != nil {
req.SetBasicAuth(r.config.HttpAuth.Username, r.config.HttpAuth.Password)
}
return req, nil
}
// newRequest is used to create a new request
func (c *Client) newRequest(method, path string) *request {
r := &request{
config: &c.config,
method: method,
url: &url.URL{
Scheme: c.config.Scheme,
Host: c.config.Address,
Path: path,
},
params: make(map[string][]string),
}
if c.config.Datacenter != "" {
r.params.Set("dc", c.config.Datacenter)
}
if c.config.WaitTime != 0 {
r.params.Set("wait", durToMsec(r.config.WaitTime))
}
if c.config.Token != "" {
r.params.Set("token", r.config.Token)
}
return r
}
// doRequest runs a request with our client
func (c *Client) doRequest(r *request) (time.Duration, *http.Response, error) {
req, err := r.toHTTP()
if err != nil {
return 0, nil, err
}
start := time.Now()
resp, err := c.config.HttpClient.Do(req)
diff := time.Now().Sub(start)
return diff, resp, err
}
// Query is used to do a GET request against an endpoint
// and deserialize the response into an interface using
// standard Consul conventions.
func (c *Client) query(endpoint string, out interface{}, q *QueryOptions) (*QueryMeta, error) {
r := c.newRequest("GET", endpoint)
r.setQueryOptions(q)
rtt, resp, err := requireOK(c.doRequest(r))
if err != nil {
return nil, err
}
defer resp.Body.Close()
qm := &QueryMeta{}
parseQueryMeta(resp, qm)
qm.RequestTime = rtt
if err := decodeBody(resp, out); err != nil {
return nil, err
}
return qm, nil
}
// write is used to do a PUT request against an endpoint
// and serialize/deserialized using the standard Consul conventions.
func (c *Client) write(endpoint string, in, out interface{}, q *WriteOptions) (*WriteMeta, error) {
r := c.newRequest("PUT", endpoint)
r.setWriteOptions(q)
r.obj = in
rtt, resp, err := requireOK(c.doRequest(r))
if err != nil {
return nil, err
}
defer resp.Body.Close()
wm := &WriteMeta{RequestTime: rtt}
if out != nil {
if err := decodeBody(resp, &out); err != nil {
return nil, err
}
}
return wm, nil
}
// parseQueryMeta is used to help parse query meta-data
func parseQueryMeta(resp *http.Response, q *QueryMeta) error {
header := resp.Header
// Parse the X-Consul-Index
index, err := strconv.ParseUint(header.Get("X-Consul-Index"), 10, 64)
if err != nil {
return fmt.Errorf("Failed to parse X-Consul-Index: %v", err)
}
q.LastIndex = index
// Parse the X-Consul-LastContact
last, err := strconv.ParseUint(header.Get("X-Consul-LastContact"), 10, 64)
if err != nil {
return fmt.Errorf("Failed to parse X-Consul-LastContact: %v", err)
}
q.LastContact = time.Duration(last) * time.Millisecond
// Parse the X-Consul-KnownLeader
switch header.Get("X-Consul-KnownLeader") {
case "true":
q.KnownLeader = true
default:
q.KnownLeader = false
}
return nil
}
// decodeBody is used to JSON decode a body
func decodeBody(resp *http.Response, out interface{}) error {
dec := json.NewDecoder(resp.Body)
return dec.Decode(out)
}
// encodeBody is used to encode a request body
func encodeBody(obj interface{}) (io.Reader, error) {
buf := bytes.NewBuffer(nil)
enc := json.NewEncoder(buf)
if err := enc.Encode(obj); err != nil {
return nil, err
}
return buf, nil
}
// requireOK is used to wrap doRequest and check for a 200
func requireOK(d time.Duration, resp *http.Response, e error) (time.Duration, *http.Response, error) {
if e != nil {
if resp != nil {
resp.Body.Close()
}
return d, nil, e
}
if resp.StatusCode != 200 {
var buf bytes.Buffer
io.Copy(&buf, resp.Body)
resp.Body.Close()
return d, nil, fmt.Errorf("Unexpected response code: %d (%s)", resp.StatusCode, buf.Bytes())
}
return d, resp, nil
}

Some files were not shown because too many files have changed in this diff Show more