Pārlūkot izejas kodu

Merge pull request #46262 from thaJeztah/libnetwork_resolv_cleanups

libnetwork: resolve: assorted cleanups
Sebastiaan van Stijn 1 gadu atpakaļ
vecāks
revīzija
389b21a341
1 mainītis faili ar 41 papildinājumiem un 46 dzēšanām
  1. 41 46
      libnetwork/resolver.go

+ 41 - 46
libnetwork/resolver.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"fmt"
 	"math/rand"
 	"math/rand"
 	"net"
 	"net"
+	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 	"time"
 	"time"
@@ -199,10 +200,6 @@ func (r *Resolver) ResolverOptions() []string {
 	return []string{"ndots:0"}
 	return []string{"ndots:0"}
 }
 }
 
 
-func setCommonFlags(msg *dns.Msg) {
-	msg.RecursionAvailable = true
-}
-
 //nolint:gosec // The RNG is not used in a security-sensitive context.
 //nolint:gosec // The RNG is not used in a security-sensitive context.
 var (
 var (
 	shuffleRNG   = rand.New(rand.NewSource(time.Now().Unix()))
 	shuffleRNG   = rand.New(rand.NewSource(time.Now().Unix()))
@@ -220,9 +217,9 @@ func shuffleAddr(addr []net.IP) []net.IP {
 }
 }
 
 
 func createRespMsg(query *dns.Msg) *dns.Msg {
 func createRespMsg(query *dns.Msg) *dns.Msg {
-	resp := new(dns.Msg)
+	resp := &dns.Msg{}
 	resp.SetReply(query)
 	resp.SetReply(query)
-	setCommonFlags(resp)
+	resp.RecursionAvailable = true
 
 
 	return resp
 	return resp
 }
 }
@@ -270,17 +267,17 @@ func (r *Resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) {
 	}
 	}
 	if ipType == types.IPv4 {
 	if ipType == types.IPv4 {
 		for _, ip := range addr {
 		for _, ip := range addr {
-			rr := new(dns.A)
-			rr.Hdr = dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: respTTL}
-			rr.A = ip
-			resp.Answer = append(resp.Answer, rr)
+			resp.Answer = append(resp.Answer, &dns.A{
+				Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: respTTL},
+				A:   ip,
+			})
 		}
 		}
 	} else {
 	} else {
 		for _, ip := range addr {
 		for _, ip := range addr {
-			rr := new(dns.AAAA)
-			rr.Hdr = dns.RR_Header{Name: name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: respTTL}
-			rr.AAAA = ip
-			resp.Answer = append(resp.Answer, rr)
+			resp.Answer = append(resp.Answer, &dns.AAAA{
+				Hdr:  dns.RR_Header{Name: name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: respTTL},
+				AAAA: ip,
+			})
 		}
 		}
 	}
 	}
 	return resp, nil
 	return resp, nil
@@ -306,14 +303,11 @@ func (r *Resolver) handlePTRQuery(query *dns.Msg) (*dns.Msg, error) {
 	r.log().Debugf("[resolver] lookup for IP %s: name %s", name, host)
 	r.log().Debugf("[resolver] lookup for IP %s: name %s", name, host)
 	fqdn := dns.Fqdn(host)
 	fqdn := dns.Fqdn(host)
 
 
-	resp := new(dns.Msg)
-	resp.SetReply(query)
-	setCommonFlags(resp)
-
-	rr := new(dns.PTR)
-	rr.Hdr = dns.RR_Header{Name: ptr, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: respTTL}
-	rr.Ptr = fqdn
-	resp.Answer = append(resp.Answer, rr)
+	resp := createRespMsg(query)
+	resp.Answer = append(resp.Answer, &dns.PTR{
+		Hdr: dns.RR_Header{Name: ptr, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: respTTL},
+		Ptr: fqdn,
+	})
 	return resp, nil
 	return resp, nil
 }
 }
 
 
@@ -331,16 +325,15 @@ func (r *Resolver) handleSRVQuery(query *dns.Msg) (*dns.Msg, error) {
 	resp := createRespMsg(query)
 	resp := createRespMsg(query)
 
 
 	for i, r := range srv {
 	for i, r := range srv {
-		rr := new(dns.SRV)
-		rr.Hdr = dns.RR_Header{Name: svc, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: respTTL}
-		rr.Port = r.Port
-		rr.Target = r.Target
-		resp.Answer = append(resp.Answer, rr)
-
-		rr1 := new(dns.A)
-		rr1.Hdr = dns.RR_Header{Name: r.Target, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: respTTL}
-		rr1.A = ip[i]
-		resp.Extra = append(resp.Extra, rr1)
+		resp.Answer = append(resp.Answer, &dns.SRV{
+			Hdr:    dns.RR_Header{Name: svc, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: respTTL},
+			Port:   r.Port,
+			Target: r.Target,
+		})
+		resp.Extra = append(resp.Extra, &dns.A{
+			Hdr: dns.RR_Header{Name: r.Target, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: respTTL},
+			A:   ip[i],
+		})
 	}
 	}
 	return resp, nil
 	return resp, nil
 }
 }
@@ -425,26 +418,28 @@ func (r *Resolver) serveDNS(w dns.ResponseWriter, query *dns.Msg) {
 	reply(resp)
 	reply(resp)
 }
 }
 
 
+const defaultPort = "53"
+
 func (r *Resolver) dialExtDNS(proto string, server extDNSEntry) (net.Conn, error) {
 func (r *Resolver) dialExtDNS(proto string, server extDNSEntry) (net.Conn, error) {
+	port := defaultPort
+	if server.port != 0 {
+		port = strconv.FormatUint(uint64(server.port), 10)
+	}
+	addr := net.JoinHostPort(server.IPStr, port)
+
+	if server.HostLoopback {
+		return net.DialTimeout(proto, addr, extIOTimeout)
+	}
+
 	var (
 	var (
 		extConn net.Conn
 		extConn net.Conn
 		dialErr error
 		dialErr error
 	)
 	)
-	extConnect := func() {
-		if server.port == 0 {
-			server.port = 53
-		}
-		addr := fmt.Sprintf("%s:%d", server.IPStr, server.port)
+	err := r.backend.ExecFunc(func() {
 		extConn, dialErr = net.DialTimeout(proto, addr, extIOTimeout)
 		extConn, dialErr = net.DialTimeout(proto, addr, extIOTimeout)
-	}
-
-	if server.HostLoopback {
-		extConnect()
-	} else {
-		execErr := r.backend.ExecFunc(extConnect)
-		if execErr != nil {
-			return nil, execErr
-		}
+	})
+	if err != nil {
+		return nil, err
 	}
 	}
 	if dialErr != nil {
 	if dialErr != nil {
 		return nil, dialErr
 		return nil, dialErr