diff --git a/libnetwork/controller.go b/libnetwork/controller.go index ae6f493168..b110f3b32c 100644 --- a/libnetwork/controller.go +++ b/libnetwork/controller.go @@ -284,6 +284,7 @@ func (c *controller) addNetwork(n *network) error { } n.Lock() + n.svcRecords = svcMap{} n.driver = dd.driver d := n.driver n.Unlock() diff --git a/libnetwork/endpoint.go b/libnetwork/endpoint.go index a4bce82a5b..72ab9a297b 100644 --- a/libnetwork/endpoint.go +++ b/libnetwork/endpoint.go @@ -571,9 +571,39 @@ func (ep *endpoint) deleteEndpoint() error { } log.Warnf("driver error deleting endpoint %s : %v", name, err) } + + n.updateSvcRecord(ep, false) return nil } +func (ep *endpoint) addHostEntries(recs []etchosts.Record) { + ep.Lock() + container := ep.container + ep.Unlock() + + if container == nil { + return + } + + if err := etchosts.Add(container.config.hostsPath, recs); err != nil { + log.Warnf("Failed adding service host entries to the running container: %v", err) + } +} + +func (ep *endpoint) deleteHostEntries(recs []etchosts.Record) { + ep.Lock() + container := ep.container + ep.Unlock() + + if container == nil { + return + } + + if err := etchosts.Delete(container.config.hostsPath, recs); err != nil { + log.Warnf("Failed deleting service host entries to the running container: %v", err) + } +} + func (ep *endpoint) buildHostsFiles() error { var extraContent []etchosts.Record @@ -581,6 +611,7 @@ func (ep *endpoint) buildHostsFiles() error { container := ep.container joinInfo := ep.joinInfo ifaces := ep.iFaces + n := ep.network ep.Unlock() if container == nil { @@ -613,6 +644,8 @@ func (ep *endpoint) buildHostsFiles() error { etchosts.Record{Hosts: extraHost.name, IP: extraHost.IP}) } + extraContent = append(extraContent, n.getSvcRecords()...) + IP := "" if len(ifaces) != 0 && ifaces[0] != nil { IP = ifaces[0].addr.IP.String() diff --git a/libnetwork/etchosts/etchosts.go b/libnetwork/etchosts/etchosts.go index 88e6b63e70..9095b483b4 100644 --- a/libnetwork/etchosts/etchosts.go +++ b/libnetwork/etchosts/etchosts.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "io/ioutil" + "os" "regexp" ) @@ -65,6 +66,45 @@ func Build(path, IP, hostname, domainname string, extraContent []Record) error { return ioutil.WriteFile(path, content.Bytes(), 0644) } +// Add adds an arbitrary number of Records to an already existing /etc/hosts file +func Add(path string, recs []Record) error { + f, err := os.Open(path) + if err != nil { + return err + } + + content := bytes.NewBuffer(nil) + + _, err = content.ReadFrom(f) + if err != nil { + return err + } + + for _, r := range recs { + if _, err := r.WriteTo(content); err != nil { + return err + } + } + + return ioutil.WriteFile(path, content.Bytes(), 0644) +} + +// Delete deletes an arbitrary number of Records already existing in /etc/hosts file +func Delete(path string, recs []Record) error { + old, err := ioutil.ReadFile(path) + if err != nil { + return err + } + + regexpStr := fmt.Sprintf("\\S*\\t%s\\n", regexp.QuoteMeta(recs[0].Hosts)) + for _, r := range recs[1:] { + regexpStr = regexpStr + "|" + fmt.Sprintf("\\S*\\t%s\\n", regexp.QuoteMeta(r.Hosts)) + } + + var re = regexp.MustCompile(regexpStr) + return ioutil.WriteFile(path, re.ReplaceAll(old, []byte("")), 0644) +} + // Update all IP addresses where hostname matches. // path is path to host file // IP is new IP address diff --git a/libnetwork/etchosts/etchosts_test.go b/libnetwork/etchosts/etchosts_test.go index 8c8b87c016..ce17d57455 100644 --- a/libnetwork/etchosts/etchosts_test.go +++ b/libnetwork/etchosts/etchosts_test.go @@ -134,3 +134,82 @@ func TestUpdate(t *testing.T) { t.Fatalf("Expected to find '%s' got '%s'", expected, content) } } + +func TestAdd(t *testing.T) { + file, err := ioutil.TempFile("", "") + if err != nil { + t.Fatal(err) + } + defer os.Remove(file.Name()) + + err = Build(file.Name(), "", "", "", nil) + if err != nil { + t.Fatal(err) + } + + if err := Add(file.Name(), []Record{ + Record{ + Hosts: "testhostname", + IP: "2.2.2.2", + }, + }); err != nil { + t.Fatal(err) + } + + content, err := ioutil.ReadFile(file.Name()) + if err != nil { + t.Fatal(err) + } + + if expected := "2.2.2.2\ttesthostname\n"; !bytes.Contains(content, []byte(expected)) { + t.Fatalf("Expected to find '%s' got '%s'", expected, content) + } +} + +func TestDelete(t *testing.T) { + file, err := ioutil.TempFile("", "") + if err != nil { + t.Fatal(err) + } + defer os.Remove(file.Name()) + + err = Build(file.Name(), "", "", "", nil) + if err != nil { + t.Fatal(err) + } + + if err := Add(file.Name(), []Record{ + Record{ + Hosts: "testhostname1", + IP: "1.1.1.1", + }, + Record{ + Hosts: "testhostname2", + IP: "2.2.2.2", + }, + }); err != nil { + t.Fatal(err) + } + + if err := Delete(file.Name(), []Record{ + Record{ + Hosts: "testhostname1", + IP: "1.1.1.1", + }, + }); err != nil { + t.Fatal(err) + } + + content, err := ioutil.ReadFile(file.Name()) + if err != nil { + t.Fatal(err) + } + + if expected := "2.2.2.2\ttesthostname2\n"; !bytes.Contains(content, []byte(expected)) { + t.Fatalf("Expected to find '%s' got '%s'", expected, content) + } + + if expected := "1.1.1.1\ttesthostname1\n"; bytes.Contains(content, []byte(expected)) { + t.Fatalf("Did not expect to find '%s' got '%s'", expected, content) + } +} diff --git a/libnetwork/network.go b/libnetwork/network.go index 1c48832ee7..de1bbb6684 100644 --- a/libnetwork/network.go +++ b/libnetwork/network.go @@ -2,6 +2,7 @@ package libnetwork import ( "encoding/json" + "net" "sync" log "github.com/Sirupsen/logrus" @@ -9,6 +10,7 @@ import ( "github.com/docker/libnetwork/config" "github.com/docker/libnetwork/datastore" "github.com/docker/libnetwork/driverapi" + "github.com/docker/libnetwork/etchosts" "github.com/docker/libnetwork/netlabel" "github.com/docker/libnetwork/options" "github.com/docker/libnetwork/types" @@ -51,6 +53,8 @@ type Network interface { // When the function returns true, the walk will stop. type EndpointWalker func(ep Endpoint) bool +type svcMap map[string]net.IP + type network struct { ctrlr *controller name string @@ -62,6 +66,7 @@ type network struct { endpoints endpointTable generic options.Generic dbIndex uint64 + svcRecords svcMap stopWatchCh chan struct{} sync.Mutex } @@ -272,6 +277,8 @@ func (n *network) addEndpoint(ep *endpoint) error { if err != nil { return err } + + n.updateSvcRecord(ep, true) return nil } @@ -384,3 +391,62 @@ func (n *network) isGlobalScoped() (bool, error) { n.Unlock() return c.isDriverGlobalScoped(n.networkType) } + +func (n *network) updateSvcRecord(ep *endpoint, isAdd bool) { + n.Lock() + var recs []etchosts.Record + for _, iface := range ep.InterfaceList() { + if isAdd { + n.svcRecords[ep.Name()] = iface.Address().IP + n.svcRecords[ep.Name()+"."+n.name] = iface.Address().IP + } else { + delete(n.svcRecords, ep.Name()) + delete(n.svcRecords, ep.Name()+"."+n.name) + } + + recs = append(recs, etchosts.Record{ + Hosts: ep.Name(), + IP: iface.Address().IP.String(), + }) + + recs = append(recs, etchosts.Record{ + Hosts: ep.Name() + "." + n.name, + IP: iface.Address().IP.String(), + }) + } + n.Unlock() + + var epList []*endpoint + n.WalkEndpoints(func(e Endpoint) bool { + cEp := e.(*endpoint) + cEp.Lock() + if cEp.container != nil { + epList = append(epList, cEp) + } + cEp.Unlock() + return false + }) + + for _, cEp := range epList { + if isAdd { + cEp.addHostEntries(recs) + } else { + cEp.deleteHostEntries(recs) + } + } +} + +func (n *network) getSvcRecords() []etchosts.Record { + n.Lock() + defer n.Unlock() + + var recs []etchosts.Record + for h, ip := range n.svcRecords { + recs = append(recs, etchosts.Record{ + Hosts: h, + IP: ip.String(), + }) + } + + return recs +}