Browse Source

Add support for SRV query in embedded DNS

Signed-off-by: Santhosh Manohar <santhosh@docker.com>
Santhosh Manohar 9 years ago
parent
commit
0051e39750

+ 100 - 0
libnetwork/libnetwork_internal_test.go

@@ -318,6 +318,106 @@ func TestAuxAddresses(t *testing.T) {
 	}
 }
 
+func TestSRVServiceQuery(t *testing.T) {
+	c, err := New()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer c.Stop()
+
+	n, err := c.NewNetwork("bridge", "net1", "", nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer func() {
+		if err := n.Delete(); err != nil {
+			t.Fatal(err)
+		}
+	}()
+
+	ep, err := n.CreateEndpoint("testep")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	sb, err := c.NewSandbox("c1")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer func() {
+		if err := sb.Delete(); err != nil {
+			t.Fatal(err)
+		}
+	}()
+
+	err = ep.Join(sb)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	sr := svcInfo{
+		svcMap:     make(map[string][]net.IP),
+		svcIPv6Map: make(map[string][]net.IP),
+		ipMap:      make(map[string]string),
+		service:    make(map[string][]servicePorts),
+	}
+	// backing container for the service
+	cTarget := serviceTarget{
+		name: "task1.web.swarm",
+		ip:   net.ParseIP("192.168.10.2"),
+		port: 80,
+	}
+	// backing host for the service
+	hTarget := serviceTarget{
+		name: "node1.docker-cluster",
+		ip:   net.ParseIP("10.10.10.2"),
+		port: 45321,
+	}
+	httpPort := servicePorts{
+		portName: "_http",
+		proto:    "_tcp",
+		target:   []serviceTarget{cTarget},
+	}
+
+	extHTTPPort := servicePorts{
+		portName: "_host_http",
+		proto:    "_tcp",
+		target:   []serviceTarget{hTarget},
+	}
+	sr.service["web.swarm"] = append(sr.service["web.swarm"], httpPort)
+	sr.service["web.swarm"] = append(sr.service["web.swarm"], extHTTPPort)
+
+	c.(*controller).svcRecords[n.ID()] = sr
+
+	_, ip, err := ep.Info().Sandbox().ResolveService("_http._tcp.web.swarm")
+	if err != nil {
+		t.Fatal(err)
+	}
+	if len(ip) == 0 {
+		t.Fatal(err)
+	}
+	if ip[0].String() != "192.168.10.2" {
+		t.Fatal(err)
+	}
+
+	_, ip, err = ep.Info().Sandbox().ResolveService("_host_http._tcp.web.swarm")
+	if err != nil {
+		t.Fatal(err)
+	}
+	if len(ip) == 0 {
+		t.Fatal(err)
+	}
+	if ip[0].String() != "10.10.10.2" {
+		t.Fatal(err)
+	}
+
+	// Try resolving a service name with invalid protocol, should fail..
+	_, _, err = ep.Info().Sandbox().ResolveService("_http._icmp.web.swarm")
+	if err == nil {
+		t.Fatal(err)
+	}
+}
+
 func TestIpamReleaseOnNetDriverFailures(t *testing.T) {
 	if !testutils.IsRunningInContainer() {
 		defer testutils.SetupTestOSContext(t)()

+ 4 - 0
libnetwork/libnetwork_test.go

@@ -1224,6 +1224,10 @@ func (f *fakeSandbox) ResolveIP(ip string) string {
 	return ""
 }
 
+func (f *fakeSandbox) ResolveService(name string) ([]*net.SRV, []net.IP, error) {
+	return nil, nil, nil
+}
+
 func (f *fakeSandbox) Endpoints() []libnetwork.Endpoint {
 	return nil
 }

+ 1 - 0
libnetwork/network.go

@@ -74,6 +74,7 @@ type svcInfo struct {
 	svcMap     map[string][]net.IP
 	svcIPv6Map map[string][]net.IP
 	ipMap      map[string]string
+	service    map[string][]servicePorts
 }
 
 // IpamConf contains all the ipam related configurations for a network

+ 40 - 3
libnetwork/resolver.go

@@ -275,15 +275,48 @@ func (r *resolver) handlePTRQuery(ptr string, query *dns.Msg) (*dns.Msg, error)
 	return resp, nil
 }
 
+func (r *resolver) handleSRVQuery(svc string, query *dns.Msg) (*dns.Msg, error) {
+	srv, ip, err := r.sb.ResolveService(svc)
+
+	if err != nil {
+		return nil, err
+	}
+	if len(srv) != len(ip) {
+		return nil, fmt.Errorf("invalid reply for SRV query %s", svc)
+	}
+
+	resp := createRespMsg(query)
+
+	for i, r := range srv {
+		rr := new(dns.SRV)
+		rr.Hdr = dns.RR_Header{Name: svc, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: respTTL}
+		rr.Port = r.Port
+		rr.Target = r.Target
+		resp.Answer = append(resp.Answer, rr)
+
+		rr1 := new(dns.A)
+		rr1.Hdr = dns.RR_Header{Name: r.Target, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: respTTL}
+		rr1.A = ip[i]
+		resp.Extra = append(resp.Extra, rr1)
+	}
+	return resp, nil
+
+}
+
 func truncateResp(resp *dns.Msg, maxSize int, isTCP bool) {
 	if !isTCP {
 		resp.Truncated = true
 	}
 
+	srv := resp.Question[0].Qtype == dns.TypeSRV
 	// trim the Answer RRs one by one till the whole message fits
 	// within the reply size
 	for resp.Len() > maxSize {
 		resp.Answer = resp.Answer[:len(resp.Answer)-1]
+
+		if srv && len(resp.Extra) > 0 {
+			resp.Extra = resp.Extra[:len(resp.Extra)-1]
+		}
 	}
 }
 
@@ -299,12 +332,16 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
 		return
 	}
 	name := query.Question[0].Name
-	if query.Question[0].Qtype == dns.TypeA {
+
+	switch query.Question[0].Qtype {
+	case dns.TypeA:
 		resp, err = r.handleIPQuery(name, query, types.IPv4)
-	} else if query.Question[0].Qtype == dns.TypeAAAA {
+	case dns.TypeAAAA:
 		resp, err = r.handleIPQuery(name, query, types.IPv6)
-	} else if query.Question[0].Qtype == dns.TypePTR {
+	case dns.TypePTR:
 		resp, err = r.handlePTRQuery(name, query)
+	case dns.TypeSRV:
+		resp, err = r.handleSRVQuery(name, query)
 	}
 
 	if err != nil {

+ 58 - 0
libnetwork/sandbox.go

@@ -45,6 +45,9 @@ type Sandbox interface {
 	// ResolveIP returns the service name for the passed in IP. IP is in reverse dotted
 	// notation; the format used for DNS PTR records
 	ResolveIP(name string) string
+	// ResolveService returns all the backend details about the containers or hosts
+	// backing a service. Its purpose is to satisfy an SRV query
+	ResolveService(name string) ([]*net.SRV, []net.IP, error)
 	// Endpoints returns all the endpoints connected to the sandbox
 	Endpoints() []Endpoint
 }
@@ -425,6 +428,61 @@ func (sb *sandbox) execFunc(f func()) {
 	sb.osSbox.InvokeFunc(f)
 }
 
+func (sb *sandbox) ResolveService(name string) ([]*net.SRV, []net.IP, error) {
+	srv := []*net.SRV{}
+	ip := []net.IP{}
+
+	log.Debugf("Service name To resolve: %v", name)
+
+	parts := strings.Split(name, ".")
+	if len(parts) < 3 {
+		return nil, nil, fmt.Errorf("invalid service name, %s", name)
+	}
+
+	portName := parts[0]
+	proto := parts[1]
+	if proto != "_tcp" && proto != "_udp" {
+		return nil, nil, fmt.Errorf("invalid protocol in service, %s", name)
+	}
+	svcName := strings.Join(parts[2:], ".")
+
+	for _, ep := range sb.getConnectedEndpoints() {
+		n := ep.getNetwork()
+
+		sr, ok := n.getController().svcRecords[n.ID()]
+		if !ok {
+			continue
+		}
+
+		svcs, ok := sr.service[svcName]
+		if !ok {
+			continue
+		}
+
+		for _, svc := range svcs {
+			if svc.portName != portName {
+				continue
+			}
+			if svc.proto != proto {
+				continue
+			}
+			for _, t := range svc.target {
+				srv = append(srv,
+					&net.SRV{
+						Target: t.name,
+						Port:   t.port,
+					})
+
+				ip = append(ip, t.ip)
+			}
+		}
+		if len(srv) > 0 {
+			break
+		}
+	}
+	return srv, ip, nil
+}
+
 func (sb *sandbox) ResolveName(name string, ipType int) ([]net.IP, bool) {
 	// Embedded server owns the docker network domain. Resolution should work
 	// for both container_name and container_name.network_name

+ 13 - 0
libnetwork/service.go

@@ -2,6 +2,19 @@ package libnetwork
 
 import "net"
 
+// backing container or host's info
+type serviceTarget struct {
+	name string
+	ip   net.IP
+	port uint16
+}
+
+type servicePorts struct {
+	portName string
+	proto    string
+	target   []serviceTarget
+}
+
 type service struct {
 	name     string
 	id       string