Browse Source

Merge pull request #44664 from corhere/embedded-resolver-fixes

libnetwork: improve embedded DNS resolver
Bjorn Neergaard 2 years ago
parent
commit
855c684708

+ 1 - 1
libnetwork/network.go

@@ -223,7 +223,7 @@ type network struct {
 	persist          bool
 	persist          bool
 	drvOnce          *sync.Once
 	drvOnce          *sync.Once
 	resolverOnce     sync.Once //nolint:nolintlint,unused // only used on windows
 	resolverOnce     sync.Once //nolint:nolintlint,unused // only used on windows
-	resolver         []Resolver
+	resolver         []*Resolver
 	internal         bool
 	internal         bool
 	attachable       bool
 	attachable       bool
 	inDelete         bool
 	inDelete         bool

+ 218 - 246
libnetwork/resolver.go

@@ -1,6 +1,7 @@
 package libnetwork
 package libnetwork
 
 
 import (
 import (
+	"context"
 	"fmt"
 	"fmt"
 	"math/rand"
 	"math/rand"
 	"net"
 	"net"
@@ -11,29 +12,10 @@ import (
 	"github.com/docker/docker/libnetwork/types"
 	"github.com/docker/docker/libnetwork/types"
 	"github.com/miekg/dns"
 	"github.com/miekg/dns"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"golang.org/x/sync/semaphore"
+	"golang.org/x/time/rate"
 )
 )
 
 
-// Resolver represents the embedded DNS server in Docker. It operates
-// by listening on container's loopback interface for DNS queries.
-type Resolver interface {
-	// Start starts the name server for the container
-	Start() error
-	// Stop stops the name server for the container. Stopped resolver
-	// can be reused after running the SetupFunc again.
-	Stop()
-	// SetupFunc provides the setup function that should be run
-	// in the container's network namespace.
-	SetupFunc(int) func()
-	// NameServer returns the IP of the DNS resolver for the
-	// containers.
-	NameServer() string
-	// SetExtServers configures the external nameservers the resolver
-	// should use to forward queries
-	SetExtServers([]extDNSEntry)
-	// ResolverOptions returns resolv.conf options that should be set
-	ResolverOptions() []string
-}
-
 // DNSBackend represents a backend DNS resolver used for DNS name
 // DNSBackend represents a backend DNS resolver used for DNS name
 // resolution. All the queries to the resolver are forwarded to the
 // resolution. All the queries to the resolver are forwarded to the
 // backend resolver.
 // backend resolver.
@@ -60,24 +42,25 @@ type DNSBackend interface {
 }
 }
 
 
 const (
 const (
-	dnsPort         = "53"
-	ptrIPv4domain   = ".in-addr.arpa."
-	ptrIPv6domain   = ".ip6.arpa."
-	respTTL         = 600
-	maxExtDNS       = 3 // max number of external servers to try
-	extIOTimeout    = 4 * time.Second
-	defaultRespSize = 512
-	maxConcurrent   = 1024
-	logInterval     = 2 * time.Second
+	dnsPort       = "53"
+	ptrIPv4domain = ".in-addr.arpa."
+	ptrIPv6domain = ".ip6.arpa."
+	respTTL       = 600
+	maxExtDNS     = 3 // max number of external servers to try
+	extIOTimeout  = 4 * time.Second
+	maxConcurrent = 1024
+	logInterval   = 2 * time.Second
 )
 )
 
 
 type extDNSEntry struct {
 type extDNSEntry struct {
 	IPStr        string
 	IPStr        string
+	port         uint16 // for testing
 	HostLoopback bool
 	HostLoopback bool
 }
 }
 
 
-// resolver implements the Resolver interface
-type resolver struct {
+// Resolver is the embedded DNS server in Docker. It operates by listening on
+// the container's loopback interface for DNS queries.
+type Resolver struct {
 	backend       DNSBackend
 	backend       DNSBackend
 	extDNSList    [maxExtDNS]extDNSEntry
 	extDNSList    [maxExtDNS]extDNSEntry
 	server        *dns.Server
 	server        *dns.Server
@@ -85,26 +68,30 @@ type resolver struct {
 	tcpServer     *dns.Server
 	tcpServer     *dns.Server
 	tcpListen     *net.TCPListener
 	tcpListen     *net.TCPListener
 	err           error
 	err           error
-	count         int32
-	tStamp        time.Time
-	queryLock     sync.Mutex
 	listenAddress string
 	listenAddress string
 	proxyDNS      bool
 	proxyDNS      bool
 	startCh       chan struct{}
 	startCh       chan struct{}
+
+	fwdSem      *semaphore.Weighted // Limit the number of concurrent external DNS requests in-flight
+	logInverval rate.Sometimes      // Rate-limit logging about hitting the fwdSem limit
 }
 }
 
 
 // NewResolver creates a new instance of the Resolver
 // NewResolver creates a new instance of the Resolver
-func NewResolver(address string, proxyDNS bool, backend DNSBackend) Resolver {
-	return &resolver{
+func NewResolver(address string, proxyDNS bool, backend DNSBackend) *Resolver {
+	return &Resolver{
 		backend:       backend,
 		backend:       backend,
 		proxyDNS:      proxyDNS,
 		proxyDNS:      proxyDNS,
 		listenAddress: address,
 		listenAddress: address,
 		err:           fmt.Errorf("setup not done yet"),
 		err:           fmt.Errorf("setup not done yet"),
 		startCh:       make(chan struct{}, 1),
 		startCh:       make(chan struct{}, 1),
+		fwdSem:        semaphore.NewWeighted(maxConcurrent),
+		logInverval:   rate.Sometimes{Interval: logInterval},
 	}
 	}
 }
 }
 
 
-func (r *resolver) SetupFunc(port int) func() {
+// SetupFunc returns the setup function that should be run in the container's
+// network namespace.
+func (r *Resolver) SetupFunc(port int) func() {
 	return func() {
 	return func() {
 		var err error
 		var err error
 
 
@@ -135,7 +122,8 @@ func (r *resolver) SetupFunc(port int) func() {
 	}
 	}
 }
 }
 
 
-func (r *resolver) Start() error {
+// Start starts the name server for the container.
+func (r *Resolver) Start() error {
 	r.startCh <- struct{}{}
 	r.startCh <- struct{}{}
 	defer func() { <-r.startCh }()
 	defer func() { <-r.startCh }()
 
 
@@ -148,7 +136,7 @@ func (r *resolver) Start() error {
 		return fmt.Errorf("setting up IP table rules failed: %v", err)
 		return fmt.Errorf("setting up IP table rules failed: %v", err)
 	}
 	}
 
 
-	s := &dns.Server{Handler: r, PacketConn: r.conn}
+	s := &dns.Server{Handler: dns.HandlerFunc(r.serveDNS), PacketConn: r.conn}
 	r.server = s
 	r.server = s
 	go func() {
 	go func() {
 		if err := s.ActivateAndServe(); err != nil {
 		if err := s.ActivateAndServe(); err != nil {
@@ -156,7 +144,7 @@ func (r *resolver) Start() error {
 		}
 		}
 	}()
 	}()
 
 
-	tcpServer := &dns.Server{Handler: r, Listener: r.tcpListen}
+	tcpServer := &dns.Server{Handler: dns.HandlerFunc(r.serveDNS), Listener: r.tcpListen}
 	r.tcpServer = tcpServer
 	r.tcpServer = tcpServer
 	go func() {
 	go func() {
 		if err := tcpServer.ActivateAndServe(); err != nil {
 		if err := tcpServer.ActivateAndServe(); err != nil {
@@ -166,7 +154,9 @@ func (r *resolver) Start() error {
 	return nil
 	return nil
 }
 }
 
 
-func (r *resolver) Stop() {
+// Stop stops the name server for the container. A stopped resolver can be
+// reused after running the SetupFunc again.
+func (r *Resolver) Stop() {
 	r.startCh <- struct{}{}
 	r.startCh <- struct{}{}
 	defer func() { <-r.startCh }()
 	defer func() { <-r.startCh }()
 
 
@@ -179,12 +169,12 @@ func (r *resolver) Stop() {
 	r.conn = nil
 	r.conn = nil
 	r.tcpServer = nil
 	r.tcpServer = nil
 	r.err = fmt.Errorf("setup not done yet")
 	r.err = fmt.Errorf("setup not done yet")
-	r.tStamp = time.Time{}
-	r.count = 0
-	r.queryLock = sync.Mutex{}
+	r.fwdSem = semaphore.NewWeighted(maxConcurrent)
 }
 }
 
 
-func (r *resolver) SetExtServers(extDNS []extDNSEntry) {
+// SetExtServers configures the external nameservers the resolver should use
+// when forwarding queries.
+func (r *Resolver) SetExtServers(extDNS []extDNSEntry) {
 	l := len(extDNS)
 	l := len(extDNS)
 	if l > maxExtDNS {
 	if l > maxExtDNS {
 		l = maxExtDNS
 		l = maxExtDNS
@@ -194,11 +184,13 @@ func (r *resolver) SetExtServers(extDNS []extDNSEntry) {
 	}
 	}
 }
 }
 
 
-func (r *resolver) NameServer() string {
+// NameServer returns the IP of the DNS resolver for the containers.
+func (r *Resolver) NameServer() string {
 	return r.listenAddress
 	return r.listenAddress
 }
 }
 
 
-func (r *resolver) ResolverOptions() []string {
+// ResolverOptions returns resolv.conf options that should be set.
+func (r *Resolver) ResolverOptions() []string {
 	return []string{"ndots:0"}
 	return []string{"ndots:0"}
 }
 }
 
 
@@ -230,7 +222,7 @@ func createRespMsg(query *dns.Msg) *dns.Msg {
 	return resp
 	return resp
 }
 }
 
 
-func (r *resolver) handleMXQuery(query *dns.Msg) (*dns.Msg, error) {
+func (r *Resolver) handleMXQuery(query *dns.Msg) (*dns.Msg, error) {
 	name := query.Question[0].Name
 	name := query.Question[0].Name
 	addrv4, _ := r.backend.ResolveName(name, types.IPv4)
 	addrv4, _ := r.backend.ResolveName(name, types.IPv4)
 	addrv6, _ := r.backend.ResolveName(name, types.IPv6)
 	addrv6, _ := r.backend.ResolveName(name, types.IPv6)
@@ -247,7 +239,7 @@ func (r *resolver) handleMXQuery(query *dns.Msg) (*dns.Msg, error) {
 	return resp, nil
 	return resp, nil
 }
 }
 
 
-func (r *resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) {
+func (r *Resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) {
 	var (
 	var (
 		addr     []net.IP
 		addr     []net.IP
 		ipv6Miss bool
 		ipv6Miss bool
@@ -289,27 +281,24 @@ func (r *resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) {
 	return resp, nil
 	return resp, nil
 }
 }
 
 
-func (r *resolver) handlePTRQuery(query *dns.Msg) (*dns.Msg, error) {
-	var (
-		parts []string
-		ptr   = query.Question[0].Name
-	)
-
-	if strings.HasSuffix(ptr, ptrIPv4domain) {
-		parts = strings.Split(ptr, ptrIPv4domain)
-	} else if strings.HasSuffix(ptr, ptrIPv6domain) {
-		parts = strings.Split(ptr, ptrIPv6domain)
-	} else {
-		return nil, fmt.Errorf("invalid PTR query, %v", ptr)
+func (r *Resolver) handlePTRQuery(query *dns.Msg) (*dns.Msg, error) {
+	ptr := query.Question[0].Name
+	name, after, found := strings.Cut(ptr, ptrIPv4domain)
+	if !found || after != "" {
+		name, after, found = strings.Cut(ptr, ptrIPv6domain)
+	}
+	if !found || after != "" {
+		// Not a known IPv4 or IPv6 PTR domain.
+		// Maybe the external DNS servers know what to do with the query?
+		return nil, nil
 	}
 	}
 
 
-	host := r.backend.ResolveIP(parts[0])
-
-	if len(host) == 0 {
+	host := r.backend.ResolveIP(name)
+	if host == "" {
 		return nil, nil
 		return nil, nil
 	}
 	}
 
 
-	logrus.Debugf("[resolver] lookup for IP %s: name %s", parts[0], host)
+	logrus.Debugf("[resolver] lookup for IP %s: name %s", name, host)
 	fqdn := dns.Fqdn(host)
 	fqdn := dns.Fqdn(host)
 
 
 	resp := new(dns.Msg)
 	resp := new(dns.Msg)
@@ -323,7 +312,7 @@ func (r *resolver) handlePTRQuery(query *dns.Msg) (*dns.Msg, error) {
 	return resp, nil
 	return resp, nil
 }
 }
 
 
-func (r *resolver) handleSRVQuery(query *dns.Msg) (*dns.Msg, error) {
+func (r *Resolver) handleSRVQuery(query *dns.Msg) (*dns.Msg, error) {
 	svc := query.Question[0].Name
 	svc := query.Question[0].Name
 	srv, ip := r.backend.ResolveService(svc)
 	srv, ip := r.backend.ResolveService(svc)
 
 
@@ -351,28 +340,10 @@ func (r *resolver) handleSRVQuery(query *dns.Msg) (*dns.Msg, error) {
 	return resp, nil
 	return resp, nil
 }
 }
 
 
-func truncateResp(resp *dns.Msg, maxSize int, isTCP bool) {
-	if !isTCP {
-		resp.Truncated = true
-	}
-
-	srv := resp.Question[0].Qtype == dns.TypeSRV
-	// trim the Answer RRs one by one till the whole message fits
-	// within the reply size
-	for resp.Len() > maxSize {
-		resp.Answer = resp.Answer[:len(resp.Answer)-1]
-
-		if srv && len(resp.Extra) > 0 {
-			resp.Extra = resp.Extra[:len(resp.Extra)-1]
-		}
-	}
-}
-
-func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
+func (r *Resolver) serveDNS(w dns.ResponseWriter, query *dns.Msg) {
 	var (
 	var (
-		extConn net.Conn
-		resp    *dns.Msg
-		err     error
+		resp *dns.Msg
+		err  error
 	)
 	)
 
 
 	if query == nil || len(query.Question) == 0 {
 	if query == nil || len(query.Question) == 0 {
@@ -397,199 +368,200 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
 		logrus.Debugf("[resolver] query type %s is not supported by the embedded DNS and will be forwarded to external DNS", dns.TypeToString[queryType])
 		logrus.Debugf("[resolver] query type %s is not supported by the embedded DNS and will be forwarded to external DNS", dns.TypeToString[queryType])
 	}
 	}
 
 
+	reply := func(msg *dns.Msg) {
+		if err = w.WriteMsg(msg); err != nil {
+			logrus.WithError(err).Errorf("[resolver] failed to write response")
+		}
+	}
+
 	if err != nil {
 	if err != nil {
 		logrus.WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType])
 		logrus.WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType])
+		reply(new(dns.Msg).SetRcode(query, dns.RcodeServerFailure))
 		return
 		return
 	}
 	}
 
 
-	if resp == nil {
-		// If the backend doesn't support proxying dns request
-		// fail the response
-		if !r.proxyDNS {
-			resp = new(dns.Msg)
-			resp.SetRcode(query, dns.RcodeServerFailure)
-			if err := w.WriteMsg(resp); err != nil {
-				logrus.WithError(err).Error("[resolver] error writing dns response")
+	if resp != nil {
+		// We are the authoritative DNS server for this request so it's
+		// on us to truncate the response message to the size limit
+		// negotiated by the client.
+		maxSize := dns.MinMsgSize
+		if w.LocalAddr().Network() == "tcp" {
+			maxSize = dns.MaxMsgSize
+		} else {
+			if optRR := query.IsEdns0(); optRR != nil {
+				if udpsize := int(optRR.UDPSize()); udpsize > maxSize {
+					maxSize = udpsize
+				}
 			}
 			}
-			return
 		}
 		}
+		resp.Truncate(maxSize)
+		reply(resp)
+		return
+	}
 
 
+	if r.proxyDNS {
 		// If the user sets ndots > 0 explicitly and the query is
 		// If the user sets ndots > 0 explicitly and the query is
 		// in the root domain don't forward it out. We will return
 		// in the root domain don't forward it out. We will return
 		// failure and let the client retry with the search domain
 		// failure and let the client retry with the search domain
-		// attached
-		switch queryType {
-		case dns.TypeA, dns.TypeAAAA:
-			if r.backend.NdotsSet() && !strings.Contains(strings.TrimSuffix(queryName, "."), ".") {
-				resp = createRespMsg(query)
-			}
+		// attached.
+		if (queryType == dns.TypeA || queryType == dns.TypeAAAA) && r.backend.NdotsSet() &&
+			!strings.Contains(strings.TrimSuffix(queryName, "."), ".") {
+			resp = createRespMsg(query)
+		} else {
+			resp = r.forwardExtDNS(w.LocalAddr().Network(), query)
 		}
 		}
 	}
 	}
 
 
-	proto := w.LocalAddr().Network()
-	maxSize := 0
-	if proto == "tcp" {
-		maxSize = dns.MaxMsgSize - 1
-	} else if proto == "udp" {
-		optRR := query.IsEdns0()
-		if optRR != nil {
-			maxSize = int(optRR.UDPSize())
-		}
-		if maxSize < defaultRespSize {
-			maxSize = defaultRespSize
-		}
+	if resp == nil {
+		// We were unable to get an answer from any of the upstream DNS
+		// servers or the backend doesn't support proxying DNS requests.
+		resp = new(dns.Msg).SetRcode(query, dns.RcodeServerFailure)
 	}
 	}
+	reply(resp)
+}
 
 
-	if resp != nil {
-		if resp.Len() > maxSize {
-			truncateResp(resp, maxSize, proto == "tcp")
+func (r *Resolver) dialExtDNS(proto string, server extDNSEntry) (net.Conn, error) {
+	var (
+		extConn net.Conn
+		dialErr error
+	)
+	extConnect := func() {
+		if server.port == 0 {
+			server.port = 53
 		}
 		}
-	} else {
-		for i := 0; i < maxExtDNS; i++ {
-			extDNS := &r.extDNSList[i]
-			if extDNS.IPStr == "" {
-				break
-			}
-			extConnect := func() {
-				addr := fmt.Sprintf("%s:%d", extDNS.IPStr, 53)
-				extConn, err = net.DialTimeout(proto, addr, extIOTimeout)
-			}
+		addr := fmt.Sprintf("%s:%d", server.IPStr, server.port)
+		extConn, dialErr = net.DialTimeout(proto, addr, extIOTimeout)
+	}
 
 
-			if extDNS.HostLoopback {
-				extConnect()
-			} else {
-				execErr := r.backend.ExecFunc(extConnect)
-				if execErr != nil {
-					logrus.Warn(execErr)
-					continue
-				}
-			}
-			if err != nil {
-				logrus.WithField("retries", i).Warnf("[resolver] connect failed: %s", err)
-				continue
-			}
-			logrus.Debugf("[resolver] query %s (%s) from %s, forwarding to %s:%s", queryName, dns.TypeToString[queryType],
-				extConn.LocalAddr().String(), proto, extDNS.IPStr)
+	if server.HostLoopback {
+		extConnect()
+	} else {
+		execErr := r.backend.ExecFunc(extConnect)
+		if execErr != nil {
+			return nil, execErr
+		}
+	}
+	if dialErr != nil {
+		return nil, dialErr
+	}
 
 
-			// Timeout has to be set for every IO operation.
-			if err := extConn.SetDeadline(time.Now().Add(extIOTimeout)); err != nil {
-				logrus.WithError(err).Error("[resolver] error setting conn deadline")
-			}
-			co := &dns.Conn{
-				Conn:    extConn,
-				UDPSize: uint16(maxSize),
-			}
-			defer co.Close()
-
-			// limits the number of outstanding concurrent queries.
-			if !r.forwardQueryStart() {
-				old := r.tStamp
-				r.tStamp = time.Now()
-				if r.tStamp.Sub(old) > logInterval {
-					logrus.Errorf("[resolver] more than %v concurrent queries from %s", maxConcurrent, extConn.LocalAddr().String())
-				}
-				continue
-			}
+	return extConn, nil
+}
 
 
-			err = co.WriteMsg(query)
-			if err != nil {
-				r.forwardQueryEnd()
-				logrus.Debugf("[resolver] send to DNS server failed, %s", err)
-				continue
-			}
+func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
+	queryName, queryType := query.Question[0].Name, query.Question[0].Qtype
+	for _, extDNS := range r.extDNSList {
+		if extDNS.IPStr == "" {
+			break
+		}
 
 
-			resp, err = co.ReadMsg()
-			// Truncated DNS replies should be sent to the client so that the
-			// client can retry over TCP
-			if err != nil && (resp == nil || !resp.Truncated) {
-				r.forwardQueryEnd()
-				logrus.WithError(err).Warnf("[resolver] failed to read from DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String())
-				continue
-			}
-			r.forwardQueryEnd()
+		// limits the number of outstanding concurrent queries.
+		ctx, cancel := context.WithTimeout(context.Background(), extIOTimeout)
+		err := r.fwdSem.Acquire(ctx, 1)
+		cancel()
+		if err != nil {
+			r.logInverval.Do(func() {
+				logrus.Errorf("[resolver] more than %v concurrent queries", maxConcurrent)
+			})
+			return new(dns.Msg).SetRcode(query, dns.RcodeRefused)
+		}
+		resp := func() *dns.Msg {
+			defer r.fwdSem.Release(1)
+			return r.exchange(proto, extDNS, query)
+		}()
+		if resp == nil {
+			continue
+		}
 
 
-			if resp == nil {
-				logrus.Debugf("[resolver] external DNS %s:%s returned empty response for %q", proto, extDNS.IPStr, queryName)
+		switch resp.Rcode {
+		case dns.RcodeServerFailure, dns.RcodeRefused:
+			// Server returned FAILURE: continue with the next external DNS server
+			// Server returned REFUSED: this can be a transitional status, so continue with the next external DNS server
+			logrus.Debugf("[resolver] external DNS %s:%s responded with %s for %q", proto, extDNS.IPStr, statusString(resp.Rcode), queryName)
+			continue
+		case dns.RcodeNameError:
+			// Server returned NXDOMAIN. Stop resolution if it's an authoritative answer (see RFC 8020: https://tools.ietf.org/html/rfc8020#section-2)
+			logrus.Debugf("[resolver] external DNS %s:%s responded with %s for %q", proto, extDNS.IPStr, statusString(resp.Rcode), queryName)
+			if resp.Authoritative {
 				break
 				break
 			}
 			}
-			switch resp.Rcode {
-			case dns.RcodeServerFailure, dns.RcodeRefused:
-				// Server returned FAILURE: continue with the next external DNS server
-				// Server returned REFUSED: this can be a transitional status, so continue with the next external DNS server
-				logrus.Debugf("[resolver] external DNS %s:%s responded with %s for %q", proto, extDNS.IPStr, statusString(resp.Rcode), queryName)
-				continue
-			case dns.RcodeNameError:
-				// Server returned NXDOMAIN. Stop resolution if it's an authoritative answer (see RFC 8020: https://tools.ietf.org/html/rfc8020#section-2)
-				logrus.Debugf("[resolver] external DNS %s:%s responded with %s for %q", proto, extDNS.IPStr, statusString(resp.Rcode), queryName)
-				if resp.Authoritative {
-					break
-				}
-				continue
-			case dns.RcodeSuccess:
-				// All is well
-			default:
-				// Server gave some error. Log the error, and continue with the next external DNS server
-				logrus.Debugf("[resolver] external DNS %s:%s responded with %s (code %d) for %q", proto, extDNS.IPStr, statusString(resp.Rcode), resp.Rcode, queryName)
-				continue
-			}
-			answers := 0
-			for _, rr := range resp.Answer {
-				h := rr.Header()
-				switch h.Rrtype {
-				case dns.TypeA:
-					answers++
-					ip := rr.(*dns.A).A
-					logrus.Debugf("[resolver] received A record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
-					r.backend.HandleQueryResp(h.Name, ip)
-				case dns.TypeAAAA:
-					answers++
-					ip := rr.(*dns.AAAA).AAAA
-					logrus.Debugf("[resolver] received AAAA record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
-					r.backend.HandleQueryResp(h.Name, ip)
-				}
-			}
-			if resp.Answer == nil || answers == 0 {
-				logrus.Debugf("[resolver] external DNS %s:%s did not return any %s records for %q", proto, extDNS.IPStr, dns.TypeToString[queryType], queryName)
+			continue
+		case dns.RcodeSuccess:
+			// All is well
+		default:
+			// Server gave some error. Log the error, and continue with the next external DNS server
+			logrus.Debugf("[resolver] external DNS %s:%s responded with %s (code %d) for %q", proto, extDNS.IPStr, statusString(resp.Rcode), resp.Rcode, queryName)
+			continue
+		}
+		answers := 0
+		for _, rr := range resp.Answer {
+			h := rr.Header()
+			switch h.Rrtype {
+			case dns.TypeA:
+				answers++
+				ip := rr.(*dns.A).A
+				logrus.Debugf("[resolver] received A record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
+				r.backend.HandleQueryResp(h.Name, ip)
+			case dns.TypeAAAA:
+				answers++
+				ip := rr.(*dns.AAAA).AAAA
+				logrus.Debugf("[resolver] received AAAA record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
+				r.backend.HandleQueryResp(h.Name, ip)
 			}
 			}
-			resp.Compress = true
-			break
 		}
 		}
-		if resp == nil {
-			return
+		if resp.Answer == nil || answers == 0 {
+			logrus.Debugf("[resolver] external DNS %s:%s did not return any %s records for %q", proto, extDNS.IPStr, dns.TypeToString[queryType], queryName)
 		}
 		}
+		resp.Compress = true
+		return resp
 	}
 	}
 
 
-	if err = w.WriteMsg(resp); err != nil {
-		logrus.WithError(err).Errorf("[resolver] failed to write response")
-	}
+	return nil
 }
 }
 
 
-func statusString(responseCode int) string {
-	if s, ok := dns.RcodeToString[responseCode]; ok {
-		return s
+func (r *Resolver) exchange(proto string, extDNS extDNSEntry, query *dns.Msg) *dns.Msg {
+	extConn, err := r.dialExtDNS(proto, extDNS)
+	if err != nil {
+		logrus.WithError(err).Warn("[resolver] connect failed")
+		return nil
 	}
 	}
-	return "UNKNOWN"
-}
-
-func (r *resolver) forwardQueryStart() bool {
-	r.queryLock.Lock()
-	defer r.queryLock.Unlock()
-
-	if r.count == maxConcurrent {
-		return false
+	defer extConn.Close()
+
+	log := logrus.WithFields(logrus.Fields{
+		"dns-server":  extConn.RemoteAddr().Network() + ":" + extConn.RemoteAddr().String(),
+		"client-addr": extConn.LocalAddr().Network() + ":" + extConn.LocalAddr().String(),
+		"question":    query.Question[0].String(),
+	})
+	log.Debug("[resolver] forwarding query")
+
+	resp, _, err := (&dns.Client{
+		Timeout: extIOTimeout,
+		// Following the robustness principle, make a best-effort
+		// attempt to receive oversized response messages without
+		// truncating them on our end to forward verbatim to the client.
+		// Some DNS servers (e.g. Mikrotik RouterOS) don't support
+		// EDNS(0) and may send replies over UDP longer than 512 bytes
+		// regardless of what size limit, if any, was advertized in the
+		// query message. Note that ExchangeWithConn will override this
+		// value if it detects an EDNS OPT record in query so only
+		// oversized replies to non-EDNS queries will benefit.
+		UDPSize: dns.MaxMsgSize,
+	}).ExchangeWithConn(query, &dns.Conn{Conn: extConn})
+	if err != nil {
+		logrus.WithError(err).Errorf("[resolver] failed to query DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String())
+		return nil
 	}
 	}
-	r.count++
 
 
-	return true
+	if resp == nil {
+		// Should be impossible, so make noise if it happens anyway.
+		log.Error("[resolver] external DNS returned empty response")
+	}
+	return resp
 }
 }
 
 
-func (r *resolver) forwardQueryEnd() {
-	r.queryLock.Lock()
-	defer r.queryLock.Unlock()
-
-	if r.count == 0 {
-		logrus.Error("[resolver] invalid concurrent query count")
-	} else {
-		r.count--
+func statusString(responseCode int) string {
+	if s, ok := dns.RcodeToString[responseCode]; ok {
+		return s
 	}
 	}
+	return "UNKNOWN"
 }
 }

+ 179 - 7
libnetwork/resolver_test.go

@@ -1,6 +1,8 @@
 package libnetwork
 package libnetwork
 
 
 import (
 import (
+	"encoding/hex"
+	"errors"
 	"net"
 	"net"
 	"runtime"
 	"runtime"
 	"syscall"
 	"syscall"
@@ -10,6 +12,7 @@ import (
 	"github.com/docker/docker/libnetwork/testutils"
 	"github.com/docker/docker/libnetwork/testutils"
 	"github.com/miekg/dns"
 	"github.com/miekg/dns"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"gotest.tools/v3/assert"
 	"gotest.tools/v3/skip"
 	"gotest.tools/v3/skip"
 )
 )
 
 
@@ -23,7 +26,8 @@ func (a *tstaddr) String() string { return "127.0.0.1" }
 
 
 // a simple writer that implements dns.ResponseWriter for unit testing purposes
 // a simple writer that implements dns.ResponseWriter for unit testing purposes
 type tstwriter struct {
 type tstwriter struct {
-	msg *dns.Msg
+	localAddr net.Addr
+	msg       *dns.Msg
 }
 }
 
 
 func (w *tstwriter) WriteMsg(m *dns.Msg) (err error) {
 func (w *tstwriter) WriteMsg(m *dns.Msg) (err error) {
@@ -33,7 +37,12 @@ func (w *tstwriter) WriteMsg(m *dns.Msg) (err error) {
 
 
 func (w *tstwriter) Write(m []byte) (int, error) { return 0, nil }
 func (w *tstwriter) Write(m []byte) (int, error) { return 0, nil }
 
 
-func (w *tstwriter) LocalAddr() net.Addr { return new(tstaddr) }
+func (w *tstwriter) LocalAddr() net.Addr {
+	if w.localAddr != nil {
+		return w.localAddr
+	}
+	return new(tstaddr)
+}
 
 
 func (w *tstwriter) RemoteAddr() net.Addr { return new(tstaddr) }
 func (w *tstwriter) RemoteAddr() net.Addr { return new(tstaddr) }
 
 
@@ -50,12 +59,14 @@ func (w *tstwriter) GetResponse() *dns.Msg { return w.msg }
 func (w *tstwriter) ClearResponse() { w.msg = nil }
 func (w *tstwriter) ClearResponse() { w.msg = nil }
 
 
 func checkNonNullResponse(t *testing.T, m *dns.Msg) {
 func checkNonNullResponse(t *testing.T, m *dns.Msg) {
+	t.Helper()
 	if m == nil {
 	if m == nil {
 		t.Fatal("Null DNS response found. Non Null response msg expected.")
 		t.Fatal("Null DNS response found. Non Null response msg expected.")
 	}
 	}
 }
 }
 
 
 func checkDNSAnswersCount(t *testing.T, m *dns.Msg, expected int) {
 func checkDNSAnswersCount(t *testing.T, m *dns.Msg, expected int) {
+	t.Helper()
 	answers := len(m.Answer)
 	answers := len(m.Answer)
 	if answers != expected {
 	if answers != expected {
 		t.Fatalf("Expected number of answers in response: %d. Found: %d", expected, answers)
 		t.Fatalf("Expected number of answers in response: %d. Found: %d", expected, answers)
@@ -63,12 +74,14 @@ func checkDNSAnswersCount(t *testing.T, m *dns.Msg, expected int) {
 }
 }
 
 
 func checkDNSResponseCode(t *testing.T, m *dns.Msg, expected int) {
 func checkDNSResponseCode(t *testing.T, m *dns.Msg, expected int) {
+	t.Helper()
 	if m.MsgHdr.Rcode != expected {
 	if m.MsgHdr.Rcode != expected {
 		t.Fatalf("Expected DNS response code: %d. Found: %d", expected, m.MsgHdr.Rcode)
 		t.Fatalf("Expected DNS response code: %d. Found: %d", expected, m.MsgHdr.Rcode)
 	}
 	}
 }
 }
 
 
 func checkDNSRRType(t *testing.T, actual, expected uint16) {
 func checkDNSRRType(t *testing.T, actual, expected uint16) {
+	t.Helper()
 	if actual != expected {
 	if actual != expected {
 		t.Fatalf("Expected DNS Rrtype: %d. Found: %d", expected, actual)
 		t.Fatalf("Expected DNS Rrtype: %d. Found: %d", expected, actual)
 	}
 	}
@@ -130,7 +143,7 @@ func TestDNSIPQuery(t *testing.T) {
 	for _, name := range names {
 	for _, name := range names {
 		q := new(dns.Msg)
 		q := new(dns.Msg)
 		q.SetQuestion(name, dns.TypeA)
 		q.SetQuestion(name, dns.TypeA)
-		r.(*resolver).ServeDNS(w, q)
+		r.serveDNS(w, q)
 		resp := w.GetResponse()
 		resp := w.GetResponse()
 		checkNonNullResponse(t, resp)
 		checkNonNullResponse(t, resp)
 		t.Log("Response: ", resp.String())
 		t.Log("Response: ", resp.String())
@@ -150,7 +163,7 @@ func TestDNSIPQuery(t *testing.T) {
 	// test MX query with name1 results in Success response with 0 answer records
 	// test MX query with name1 results in Success response with 0 answer records
 	q := new(dns.Msg)
 	q := new(dns.Msg)
 	q.SetQuestion("name1", dns.TypeMX)
 	q.SetQuestion("name1", dns.TypeMX)
-	r.(*resolver).ServeDNS(w, q)
+	r.serveDNS(w, q)
 	resp := w.GetResponse()
 	resp := w.GetResponse()
 	checkNonNullResponse(t, resp)
 	checkNonNullResponse(t, resp)
 	t.Log("Response: ", resp.String())
 	t.Log("Response: ", resp.String())
@@ -162,7 +175,7 @@ func TestDNSIPQuery(t *testing.T) {
 	// since this is a unit test env, we disable proxying DNS above which results in ServFail rather than NXDOMAIN
 	// since this is a unit test env, we disable proxying DNS above which results in ServFail rather than NXDOMAIN
 	q = new(dns.Msg)
 	q = new(dns.Msg)
 	q.SetQuestion("nonexistent", dns.TypeMX)
 	q.SetQuestion("nonexistent", dns.TypeMX)
-	r.(*resolver).ServeDNS(w, q)
+	r.serveDNS(w, q)
 	resp = w.GetResponse()
 	resp = w.GetResponse()
 	checkNonNullResponse(t, resp)
 	checkNonNullResponse(t, resp)
 	t.Log("Response: ", resp.String())
 	t.Log("Response: ", resp.String())
@@ -278,10 +291,169 @@ func TestDNSProxyServFail(t *testing.T) {
 	localDNSEntries = append(localDNSEntries, extTestDNSEntry)
 	localDNSEntries = append(localDNSEntries, extTestDNSEntry)
 
 
 	// this should generate two requests: the first will fail leading to a retry
 	// this should generate two requests: the first will fail leading to a retry
-	r.(*resolver).SetExtServers(localDNSEntries)
-	r.(*resolver).ServeDNS(w, q)
+	r.SetExtServers(localDNSEntries)
+	r.serveDNS(w, q)
 	if nRequests != 2 {
 	if nRequests != 2 {
 		t.Fatalf("Expected 2 DNS querries. Found: %d", nRequests)
 		t.Fatalf("Expected 2 DNS querries. Found: %d", nRequests)
 	}
 	}
 	t.Logf("Expected number of DNS requests generated")
 	t.Logf("Expected number of DNS requests generated")
 }
 }
+
+// Packet 24 extracted from
+// https://gist.github.com/vojtad/3bac63b8c91b1ec50e8d8b36047317fa/raw/7d75eb3d3448381bf252ae55ea5123a132c46658/host.pcap
+// (https://github.com/moby/moby/issues/44575)
+// which is a non-compliant DNS reply > 512B (w/o EDNS(0)) to the query
+//
+//	s3.amazonaws.com. IN A
+const oversizedDNSReplyMsg = "\xf5\x11\x81\x80\x00\x01\x00\x20\x00\x00\x00\x00\x02\x73\x33\x09" +
+	"\x61\x6d\x61\x7a\x6f\x6e\x61\x77\x73\x03\x63\x6f\x6d\x00\x00\x01" +
+	"\x00\x01\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" +
+	"\x11\x9e\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" +
+	"\x4c\x66\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" +
+	"\xda\x10\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" +
+	"\x01\x3e\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" +
+	"\x88\x68\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" +
+	"\x66\x9e\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" +
+	"\x5f\x28\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" +
+	"\x8e\x4e\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x36\xe7" +
+	"\x84\xf0\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x34\xd8" +
+	"\x92\x45\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" +
+	"\x8f\xa6\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x36\xe7" +
+	"\xc0\xd0\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" +
+	"\xfe\x28\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" +
+	"\xaa\x3d\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" +
+	"\x4e\x56\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" +
+	"\xea\xb0\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" +
+	"\x6d\xed\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" +
+	"\x28\x00\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x34\xd9" +
+	"\xe9\x78\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x34\xd9" +
+	"\x6e\x9e\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x34\xd9" +
+	"\x45\x86\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x34\xd8" +
+	"\x30\x38\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x36\xe7" +
+	"\xc6\xa8\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x03\x05" +
+	"\x01\x9d\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd9" +
+	"\xa8\xe8\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd9" +
+	"\x64\xa6\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd8" +
+	"\x3c\x48\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd8" +
+	"\x35\x20\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd9" +
+	"\x54\xf6\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd9" +
+	"\x5d\x36\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd9" +
+	"\x30\x36\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x36\xe7" +
+	"\x83\x90"
+
+// Regression test for https://github.com/moby/moby/issues/44575
+func TestOversizedDNSReply(t *testing.T) {
+	srv, err := net.ListenPacket("udp", "127.0.0.1:0")
+	assert.NilError(t, err)
+	defer srv.Close()
+	go func() {
+		buf := make([]byte, 65536)
+		for {
+			n, src, err := srv.ReadFrom(buf)
+			if errors.Is(err, net.ErrClosed) {
+				return
+			}
+			t.Logf("[<-%v]\n%s", src, hex.Dump(buf[:n]))
+			if n < 2 {
+				continue
+			}
+			resp := []byte(oversizedDNSReplyMsg)
+			resp[0], resp[1] = buf[0], buf[1] // Copy query ID into response.
+			_, err = srv.WriteTo(resp, src)
+			if errors.Is(err, net.ErrClosed) {
+				return
+			}
+			if err != nil {
+				t.Log(err)
+			}
+		}
+	}()
+
+	srvAddr := srv.LocalAddr().(*net.UDPAddr)
+	rsv := NewResolver("", true, noopDNSBackend{})
+	rsv.SetExtServers([]extDNSEntry{
+		{IPStr: srvAddr.IP.String(), port: uint16(srvAddr.Port), HostLoopback: true},
+	})
+
+	// The resolver logs lots of valuable info at level debug. Redirect it
+	// to t.Log() so the log spew is emitted only if the test fails.
+	defer redirectLogrusTo(t)()
+
+	w := &tstwriter{localAddr: srv.LocalAddr()}
+	q := new(dns.Msg).SetQuestion("s3.amazonaws.com.", dns.TypeA)
+	rsv.serveDNS(w, q)
+	resp := w.GetResponse()
+	checkNonNullResponse(t, resp)
+	t.Log("Response: ", resp.String())
+	checkDNSResponseCode(t, resp, dns.RcodeSuccess)
+	assert.Assert(t, len(resp.Answer) >= 1)
+	checkDNSRRType(t, resp.Answer[0].Header().Rrtype, dns.TypeA)
+}
+
+func redirectLogrusTo(t *testing.T) func() {
+	oldLevel, oldOut := logrus.StandardLogger().Level, logrus.StandardLogger().Out
+	logrus.StandardLogger().SetLevel(logrus.DebugLevel)
+	logrus.SetOutput(tlogWriter{t})
+	return func() {
+		logrus.StandardLogger().SetLevel(oldLevel)
+		logrus.StandardLogger().SetOutput(oldOut)
+	}
+}
+
+type tlogWriter struct{ t *testing.T }
+
+func (w tlogWriter) Write(p []byte) (n int, err error) {
+	w.t.Logf("%s", p)
+	return len(p), nil
+}
+
+type noopDNSBackend struct{ DNSBackend }
+
+func (noopDNSBackend) ResolveName(name string, iplen int) ([]net.IP, bool) { return nil, false }
+
+func (noopDNSBackend) ExecFunc(f func()) error { f(); return nil }
+
+func (noopDNSBackend) NdotsSet() bool { return false }
+
+func (noopDNSBackend) HandleQueryResp(name string, ip net.IP) {}
+
+func TestReplySERVFAIL(t *testing.T) {
+	cases := []struct {
+		name     string
+		q        *dns.Msg
+		proxyDNS bool
+	}{
+		{
+			name: "InternalError",
+			q:    new(dns.Msg).SetQuestion("_sip._tcp.example.com.", dns.TypeSRV),
+		},
+		{
+			name: "ProxyDNS=false",
+			q:    new(dns.Msg).SetQuestion("example.com.", dns.TypeA),
+		},
+		{
+			name:     "ProxyDNS=true", // No extDNS servers configured -> no answer from any upstream
+			q:        new(dns.Msg).SetQuestion("example.com.", dns.TypeA),
+			proxyDNS: true,
+		},
+	}
+	for _, tt := range cases {
+		t.Run(tt.name, func(t *testing.T) {
+			defer redirectLogrusTo(t)
+
+			rsv := NewResolver("", tt.proxyDNS, badSRVDNSBackend{})
+			w := &tstwriter{}
+			rsv.serveDNS(w, tt.q)
+			resp := w.GetResponse()
+			checkNonNullResponse(t, resp)
+			t.Log("Response: ", resp.String())
+			checkDNSResponseCode(t, resp, dns.RcodeServerFailure)
+		})
+	}
+}
+
+type badSRVDNSBackend struct{ noopDNSBackend }
+
+func (badSRVDNSBackend) ResolveService(name string) ([]*net.SRV, []net.IP) {
+	return []*net.SRV{nil, nil, nil}, nil // Mismatched slice lengths
+}

+ 40 - 16
libnetwork/resolver_unix.go

@@ -4,20 +4,20 @@
 package libnetwork
 package libnetwork
 
 
 import (
 import (
+	"fmt"
 	"net"
 	"net"
 
 
 	"github.com/docker/docker/libnetwork/iptables"
 	"github.com/docker/docker/libnetwork/iptables"
-	"github.com/sirupsen/logrus"
 )
 )
 
 
 const (
 const (
-	// outputChain used for docker embed dns
+	// output chain used for docker embedded DNS resolver
 	outputChain = "DOCKER_OUTPUT"
 	outputChain = "DOCKER_OUTPUT"
-	//postroutingchain used for docker embed dns
-	postroutingchain = "DOCKER_POSTROUTING"
+	// postrouting chain used for docker embedded DNS resolver
+	postroutingChain = "DOCKER_POSTROUTING"
 )
 )
 
 
-func (r *resolver) setupIPTable() error {
+func (r *Resolver) setupIPTable() error {
 	if r.err != nil {
 	if r.err != nil {
 		return r.err
 		return r.err
 	}
 	}
@@ -27,36 +27,60 @@ func (r *resolver) setupIPTable() error {
 	_, tcpPort, _ := net.SplitHostPort(ltcpaddr)
 	_, tcpPort, _ := net.SplitHostPort(ltcpaddr)
 	rules := [][]string{
 	rules := [][]string{
 		{"-t", "nat", "-I", outputChain, "-d", resolverIP, "-p", "udp", "--dport", dnsPort, "-j", "DNAT", "--to-destination", laddr},
 		{"-t", "nat", "-I", outputChain, "-d", resolverIP, "-p", "udp", "--dport", dnsPort, "-j", "DNAT", "--to-destination", laddr},
-		{"-t", "nat", "-I", postroutingchain, "-s", resolverIP, "-p", "udp", "--sport", ipPort, "-j", "SNAT", "--to-source", ":" + dnsPort},
+		{"-t", "nat", "-I", postroutingChain, "-s", resolverIP, "-p", "udp", "--sport", ipPort, "-j", "SNAT", "--to-source", ":" + dnsPort},
 		{"-t", "nat", "-I", outputChain, "-d", resolverIP, "-p", "tcp", "--dport", dnsPort, "-j", "DNAT", "--to-destination", ltcpaddr},
 		{"-t", "nat", "-I", outputChain, "-d", resolverIP, "-p", "tcp", "--dport", dnsPort, "-j", "DNAT", "--to-destination", ltcpaddr},
-		{"-t", "nat", "-I", postroutingchain, "-s", resolverIP, "-p", "tcp", "--sport", tcpPort, "-j", "SNAT", "--to-source", ":" + dnsPort},
+		{"-t", "nat", "-I", postroutingChain, "-s", resolverIP, "-p", "tcp", "--sport", tcpPort, "-j", "SNAT", "--to-source", ":" + dnsPort},
 	}
 	}
 
 
-	return r.backend.ExecFunc(func() {
+	var setupErr error
+	err := r.backend.ExecFunc(func() {
 		// TODO IPv6 support
 		// TODO IPv6 support
 		iptable := iptables.GetIptable(iptables.IPv4)
 		iptable := iptables.GetIptable(iptables.IPv4)
 
 
 		// insert outputChain and postroutingchain
 		// insert outputChain and postroutingchain
 		err := iptable.RawCombinedOutputNative("-t", "nat", "-C", "OUTPUT", "-d", resolverIP, "-j", outputChain)
 		err := iptable.RawCombinedOutputNative("-t", "nat", "-C", "OUTPUT", "-d", resolverIP, "-j", outputChain)
 		if err == nil {
 		if err == nil {
-			iptable.RawCombinedOutputNative("-t", "nat", "-F", outputChain)
+			if err := iptable.RawCombinedOutputNative("-t", "nat", "-F", outputChain); err != nil {
+				setupErr = err
+				return
+			}
 		} else {
 		} else {
-			iptable.RawCombinedOutputNative("-t", "nat", "-N", outputChain)
-			iptable.RawCombinedOutputNative("-t", "nat", "-I", "OUTPUT", "-d", resolverIP, "-j", outputChain)
+			if err := iptable.RawCombinedOutputNative("-t", "nat", "-N", outputChain); err != nil {
+				setupErr = err
+				return
+			}
+			if err := iptable.RawCombinedOutputNative("-t", "nat", "-I", "OUTPUT", "-d", resolverIP, "-j", outputChain); err != nil {
+				setupErr = err
+				return
+			}
 		}
 		}
 
 
-		err = iptable.RawCombinedOutputNative("-t", "nat", "-C", "POSTROUTING", "-d", resolverIP, "-j", postroutingchain)
+		err = iptable.RawCombinedOutputNative("-t", "nat", "-C", "POSTROUTING", "-d", resolverIP, "-j", postroutingChain)
 		if err == nil {
 		if err == nil {
-			iptable.RawCombinedOutputNative("-t", "nat", "-F", postroutingchain)
+			if err := iptable.RawCombinedOutputNative("-t", "nat", "-F", postroutingChain); err != nil {
+				setupErr = err
+				return
+			}
 		} else {
 		} else {
-			iptable.RawCombinedOutputNative("-t", "nat", "-N", postroutingchain)
-			iptable.RawCombinedOutputNative("-t", "nat", "-I", "POSTROUTING", "-d", resolverIP, "-j", postroutingchain)
+			if err := iptable.RawCombinedOutputNative("-t", "nat", "-N", postroutingChain); err != nil {
+				setupErr = err
+				return
+			}
+			if err := iptable.RawCombinedOutputNative("-t", "nat", "-I", "POSTROUTING", "-d", resolverIP, "-j", postroutingChain); err != nil {
+				setupErr = err
+				return
+			}
 		}
 		}
 
 
 		for _, rule := range rules {
 		for _, rule := range rules {
 			if iptable.RawCombinedOutputNative(rule...) != nil {
 			if iptable.RawCombinedOutputNative(rule...) != nil {
-				logrus.Errorf("set up rule failed, %v", rule)
+				setupErr = fmt.Errorf("set up rule failed, %v", rule)
+				return
 			}
 			}
 		}
 		}
 	})
 	})
+	if err != nil {
+		return err
+	}
+	return setupErr
 }
 }

+ 1 - 1
libnetwork/resolver_windows.go

@@ -3,6 +3,6 @@
 
 
 package libnetwork
 package libnetwork
 
 
-func (r *resolver) setupIPTable() error {
+func (r *Resolver) setupIPTable() error {
 	return nil
 	return nil
 }
 }

+ 1 - 1
libnetwork/sandbox.go

@@ -38,7 +38,7 @@ type Sandbox struct {
 	extDNS             []extDNSEntry
 	extDNS             []extDNSEntry
 	osSbox             osl.Sandbox
 	osSbox             osl.Sandbox
 	controller         *Controller
 	controller         *Controller
-	resolver           Resolver
+	resolver           *Resolver
 	resolverOnce       sync.Once
 	resolverOnce       sync.Once
 	endpoints          []*Endpoint
 	endpoints          []*Endpoint
 	epPriority         map[string]int
 	epPriority         map[string]int

+ 1 - 1
vendor.mod

@@ -91,7 +91,7 @@ require (
 	golang.org/x/sync v0.1.0
 	golang.org/x/sync v0.1.0
 	golang.org/x/sys v0.5.0
 	golang.org/x/sys v0.5.0
 	golang.org/x/text v0.7.0
 	golang.org/x/text v0.7.0
-	golang.org/x/time v0.1.0
+	golang.org/x/time v0.3.0
 	google.golang.org/genproto v0.0.0-20220706185917-7780775163c4
 	google.golang.org/genproto v0.0.0-20220706185917-7780775163c4
 	google.golang.org/grpc v1.50.1
 	google.golang.org/grpc v1.50.1
 	gotest.tools/v3 v3.4.0
 	gotest.tools/v3 v3.4.0

+ 2 - 2
vendor.sum

@@ -1424,8 +1424,8 @@ golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxb
 golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
 golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
 golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
 golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
 golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
 golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
-golang.org/x/time v0.1.0 h1:xYY+Bajn2a7VBmTM5GikTmnK8ZuX8YgnQCqZpbBNtmA=
-golang.org/x/time v0.1.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
+golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
 golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 golang.org/x/tools v0.0.0-20181011042414-1f849cf54d09/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 golang.org/x/tools v0.0.0-20181011042414-1f849cf54d09/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

+ 8 - 12
vendor/golang.org/x/time/rate/rate.go

@@ -83,7 +83,7 @@ func (lim *Limiter) Burst() int {
 // TokensAt returns the number of tokens available at time t.
 // TokensAt returns the number of tokens available at time t.
 func (lim *Limiter) TokensAt(t time.Time) float64 {
 func (lim *Limiter) TokensAt(t time.Time) float64 {
 	lim.mu.Lock()
 	lim.mu.Lock()
-	_, _, tokens := lim.advance(t) // does not mutute lim
+	_, tokens := lim.advance(t) // does not mutate lim
 	lim.mu.Unlock()
 	lim.mu.Unlock()
 	return tokens
 	return tokens
 }
 }
@@ -183,7 +183,7 @@ func (r *Reservation) CancelAt(t time.Time) {
 		return
 		return
 	}
 	}
 	// advance time to now
 	// advance time to now
-	t, _, tokens := r.lim.advance(t)
+	t, tokens := r.lim.advance(t)
 	// calculate new number of tokens
 	// calculate new number of tokens
 	tokens += restoreTokens
 	tokens += restoreTokens
 	if burst := float64(r.lim.burst); tokens > burst {
 	if burst := float64(r.lim.burst); tokens > burst {
@@ -304,7 +304,7 @@ func (lim *Limiter) SetLimitAt(t time.Time, newLimit Limit) {
 	lim.mu.Lock()
 	lim.mu.Lock()
 	defer lim.mu.Unlock()
 	defer lim.mu.Unlock()
 
 
-	t, _, tokens := lim.advance(t)
+	t, tokens := lim.advance(t)
 
 
 	lim.last = t
 	lim.last = t
 	lim.tokens = tokens
 	lim.tokens = tokens
@@ -321,7 +321,7 @@ func (lim *Limiter) SetBurstAt(t time.Time, newBurst int) {
 	lim.mu.Lock()
 	lim.mu.Lock()
 	defer lim.mu.Unlock()
 	defer lim.mu.Unlock()
 
 
-	t, _, tokens := lim.advance(t)
+	t, tokens := lim.advance(t)
 
 
 	lim.last = t
 	lim.last = t
 	lim.tokens = tokens
 	lim.tokens = tokens
@@ -356,7 +356,7 @@ func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration)
 		}
 		}
 	}
 	}
 
 
-	t, last, tokens := lim.advance(t)
+	t, tokens := lim.advance(t)
 
 
 	// Calculate the remaining number of tokens resulting from the request.
 	// Calculate the remaining number of tokens resulting from the request.
 	tokens -= float64(n)
 	tokens -= float64(n)
@@ -379,15 +379,11 @@ func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration)
 	if ok {
 	if ok {
 		r.tokens = n
 		r.tokens = n
 		r.timeToAct = t.Add(waitDuration)
 		r.timeToAct = t.Add(waitDuration)
-	}
 
 
-	// Update state
-	if ok {
+		// Update state
 		lim.last = t
 		lim.last = t
 		lim.tokens = tokens
 		lim.tokens = tokens
 		lim.lastEvent = r.timeToAct
 		lim.lastEvent = r.timeToAct
-	} else {
-		lim.last = last
 	}
 	}
 
 
 	return r
 	return r
@@ -396,7 +392,7 @@ func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration)
 // advance calculates and returns an updated state for lim resulting from the passage of time.
 // advance calculates and returns an updated state for lim resulting from the passage of time.
 // lim is not changed.
 // lim is not changed.
 // advance requires that lim.mu is held.
 // advance requires that lim.mu is held.
-func (lim *Limiter) advance(t time.Time) (newT time.Time, newLast time.Time, newTokens float64) {
+func (lim *Limiter) advance(t time.Time) (newT time.Time, newTokens float64) {
 	last := lim.last
 	last := lim.last
 	if t.Before(last) {
 	if t.Before(last) {
 		last = t
 		last = t
@@ -409,7 +405,7 @@ func (lim *Limiter) advance(t time.Time) (newT time.Time, newLast time.Time, new
 	if burst := float64(lim.burst); tokens > burst {
 	if burst := float64(lim.burst); tokens > burst {
 		tokens = burst
 		tokens = burst
 	}
 	}
-	return t, last, tokens
+	return t, tokens
 }
 }
 
 
 // durationFromTokens is a unit conversion function from the number of tokens to the duration
 // durationFromTokens is a unit conversion function from the number of tokens to the duration

+ 67 - 0
vendor/golang.org/x/time/rate/sometimes.go

@@ -0,0 +1,67 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rate
+
+import (
+	"sync"
+	"time"
+)
+
+// Sometimes will perform an action occasionally.  The First, Every, and
+// Interval fields govern the behavior of Do, which performs the action.
+// A zero Sometimes value will perform an action exactly once.
+//
+// # Example: logging with rate limiting
+//
+//	var sometimes = rate.Sometimes{First: 3, Interval: 10*time.Second}
+//	func Spammy() {
+//	        sometimes.Do(func() { log.Info("here I am!") })
+//	}
+type Sometimes struct {
+	First    int           // if non-zero, the first N calls to Do will run f.
+	Every    int           // if non-zero, every Nth call to Do will run f.
+	Interval time.Duration // if non-zero and Interval has elapsed since f's last run, Do will run f.
+
+	mu    sync.Mutex
+	count int       // number of Do calls
+	last  time.Time // last time f was run
+}
+
+// Do runs the function f as allowed by First, Every, and Interval.
+//
+// The model is a union (not intersection) of filters.  The first call to Do
+// always runs f.  Subsequent calls to Do run f if allowed by First or Every or
+// Interval.
+//
+// A non-zero First:N causes the first N Do(f) calls to run f.
+//
+// A non-zero Every:M causes every Mth Do(f) call, starting with the first, to
+// run f.
+//
+// A non-zero Interval causes Do(f) to run f if Interval has elapsed since
+// Do last ran f.
+//
+// Specifying multiple filters produces the union of these execution streams.
+// For example, specifying both First:N and Every:M causes the first N Do(f)
+// calls and every Mth Do(f) call, starting with the first, to run f.  See
+// Examples for more.
+//
+// If Do is called multiple times simultaneously, the calls will block and run
+// serially.  Therefore, Do is intended for lightweight operations.
+//
+// Because a call to Do may block until f returns, if f causes Do to be called,
+// it will deadlock.
+func (s *Sometimes) Do(f func()) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	if s.count == 0 ||
+		(s.First > 0 && s.count < s.First) ||
+		(s.Every > 0 && s.count%s.Every == 0) ||
+		(s.Interval > 0 && time.Since(s.last) >= s.Interval) {
+		f()
+		s.last = time.Now()
+	}
+	s.count++
+}

+ 1 - 1
vendor/modules.txt

@@ -1087,7 +1087,7 @@ golang.org/x/text/secure/bidirule
 golang.org/x/text/transform
 golang.org/x/text/transform
 golang.org/x/text/unicode/bidi
 golang.org/x/text/unicode/bidi
 golang.org/x/text/unicode/norm
 golang.org/x/text/unicode/norm
-# golang.org/x/time v0.1.0
+# golang.org/x/time v0.3.0
 ## explicit
 ## explicit
 golang.org/x/time/rate
 golang.org/x/time/rate
 # google.golang.org/api v0.93.0
 # google.golang.org/api v0.93.0