瀏覽代碼

libnetwork: leave global logger alone in tests

Swapping out the global logger on the fly is causing tests to flake out
by logging to a test's log output after the test function has returned.
Refactor Resolver to use a dependency-injected logger and the resolver
unit tests to inject a private logger instance into the Resolver under
test.

Signed-off-by: Cory Snider <csnider@mirantis.com>
Cory Snider 2 年之前
父節點
當前提交
d4f3858a40
共有 2 個文件被更改,包括 34 次插入31 次删除
  1. 24 16
      libnetwork/resolver.go
  2. 10 15
      libnetwork/resolver_test.go

+ 24 - 16
libnetwork/resolver.go

@@ -71,6 +71,7 @@ type Resolver struct {
 	listenAddress string
 	listenAddress string
 	proxyDNS      bool
 	proxyDNS      bool
 	startCh       chan struct{}
 	startCh       chan struct{}
+	logger        *logrus.Logger
 
 
 	fwdSem      *semaphore.Weighted // Limit the number of concurrent external DNS requests in-flight
 	fwdSem      *semaphore.Weighted // Limit the number of concurrent external DNS requests in-flight
 	logInverval rate.Sometimes      // Rate-limit logging about hitting the fwdSem limit
 	logInverval rate.Sometimes      // Rate-limit logging about hitting the fwdSem limit
@@ -89,6 +90,13 @@ func NewResolver(address string, proxyDNS bool, backend DNSBackend) *Resolver {
 	}
 	}
 }
 }
 
 
+func (r *Resolver) log() *logrus.Logger {
+	if r.logger == nil {
+		return logrus.StandardLogger()
+	}
+	return r.logger
+}
+
 // SetupFunc returns the setup function that should be run in the container's
 // SetupFunc returns the setup function that should be run in the container's
 // network namespace.
 // network namespace.
 func (r *Resolver) SetupFunc(port int) func() {
 func (r *Resolver) SetupFunc(port int) func() {
@@ -140,7 +148,7 @@ func (r *Resolver) Start() error {
 	r.server = s
 	r.server = s
 	go func() {
 	go func() {
 		if err := s.ActivateAndServe(); err != nil {
 		if err := s.ActivateAndServe(); err != nil {
-			logrus.WithError(err).Error("[resolver] failed to start PacketConn DNS server")
+			r.log().WithError(err).Error("[resolver] failed to start PacketConn DNS server")
 		}
 		}
 	}()
 	}()
 
 
@@ -148,7 +156,7 @@ func (r *Resolver) Start() error {
 	r.tcpServer = tcpServer
 	r.tcpServer = tcpServer
 	go func() {
 	go func() {
 		if err := tcpServer.ActivateAndServe(); err != nil {
 		if err := tcpServer.ActivateAndServe(); err != nil {
-			logrus.WithError(err).Error("[resolver] failed to start TCP DNS server")
+			r.log().WithError(err).Error("[resolver] failed to start TCP DNS server")
 		}
 		}
 	}()
 	}()
 	return nil
 	return nil
@@ -249,7 +257,7 @@ func (r *Resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) {
 
 
 	if addr == nil && ipv6Miss {
 	if addr == nil && ipv6Miss {
 		// Send a reply without any Answer sections
 		// Send a reply without any Answer sections
-		logrus.Debugf("[resolver] lookup name %s present without IPv6 address", name)
+		r.log().Debugf("[resolver] lookup name %s present without IPv6 address", name)
 		resp := createRespMsg(query)
 		resp := createRespMsg(query)
 		return resp, nil
 		return resp, nil
 	}
 	}
@@ -257,7 +265,7 @@ func (r *Resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) {
 		return nil, nil
 		return nil, nil
 	}
 	}
 
 
-	logrus.Debugf("[resolver] lookup for %s: IP %v", name, addr)
+	r.log().Debugf("[resolver] lookup for %s: IP %v", name, addr)
 
 
 	resp := createRespMsg(query)
 	resp := createRespMsg(query)
 	if len(addr) > 1 {
 	if len(addr) > 1 {
@@ -298,7 +306,7 @@ func (r *Resolver) handlePTRQuery(query *dns.Msg) (*dns.Msg, error) {
 		return nil, nil
 		return nil, nil
 	}
 	}
 
 
-	logrus.Debugf("[resolver] lookup for IP %s: name %s", name, host)
+	r.log().Debugf("[resolver] lookup for IP %s: name %s", name, host)
 	fqdn := dns.Fqdn(host)
 	fqdn := dns.Fqdn(host)
 
 
 	resp := new(dns.Msg)
 	resp := new(dns.Msg)
@@ -365,17 +373,17 @@ func (r *Resolver) serveDNS(w dns.ResponseWriter, query *dns.Msg) {
 	case dns.TypeSRV:
 	case dns.TypeSRV:
 		resp, err = r.handleSRVQuery(query)
 		resp, err = r.handleSRVQuery(query)
 	default:
 	default:
-		logrus.Debugf("[resolver] query type %s is not supported by the embedded DNS and will be forwarded to external DNS", dns.TypeToString[queryType])
+		r.log().Debugf("[resolver] query type %s is not supported by the embedded DNS and will be forwarded to external DNS", dns.TypeToString[queryType])
 	}
 	}
 
 
 	reply := func(msg *dns.Msg) {
 	reply := func(msg *dns.Msg) {
 		if err = w.WriteMsg(msg); err != nil {
 		if err = w.WriteMsg(msg); err != nil {
-			logrus.WithError(err).Errorf("[resolver] failed to write response")
+			r.log().WithError(err).Errorf("[resolver] failed to write response")
 		}
 		}
 	}
 	}
 
 
 	if err != nil {
 	if err != nil {
-		logrus.WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType])
+		r.log().WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType])
 		reply(new(dns.Msg).SetRcode(query, dns.RcodeServerFailure))
 		reply(new(dns.Msg).SetRcode(query, dns.RcodeServerFailure))
 		return
 		return
 	}
 	}
@@ -460,7 +468,7 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
 		cancel()
 		cancel()
 		if err != nil {
 		if err != nil {
 			r.logInverval.Do(func() {
 			r.logInverval.Do(func() {
-				logrus.Errorf("[resolver] more than %v concurrent queries", maxConcurrent)
+				r.log().Errorf("[resolver] more than %v concurrent queries", maxConcurrent)
 			})
 			})
 			return new(dns.Msg).SetRcode(query, dns.RcodeRefused)
 			return new(dns.Msg).SetRcode(query, dns.RcodeRefused)
 		}
 		}
@@ -476,7 +484,7 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
 		case dns.RcodeServerFailure, dns.RcodeRefused:
 		case dns.RcodeServerFailure, dns.RcodeRefused:
 			// Server returned FAILURE: continue with the next external DNS server
 			// Server returned FAILURE: continue with the next external DNS server
 			// Server returned REFUSED: this can be a transitional status, so continue with the next external DNS server
 			// Server returned REFUSED: this can be a transitional status, so continue with the next external DNS server
-			logrus.Debugf("[resolver] external DNS %s:%s returned failure:\n%s", proto, extDNS.IPStr, resp)
+			r.log().Debugf("[resolver] external DNS %s:%s returned failure:\n%s", proto, extDNS.IPStr, resp)
 			continue
 			continue
 		}
 		}
 		answers := 0
 		answers := 0
@@ -486,17 +494,17 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
 			case dns.TypeA:
 			case dns.TypeA:
 				answers++
 				answers++
 				ip := rr.(*dns.A).A
 				ip := rr.(*dns.A).A
-				logrus.Debugf("[resolver] received A record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
+				r.log().Debugf("[resolver] received A record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
 				r.backend.HandleQueryResp(h.Name, ip)
 				r.backend.HandleQueryResp(h.Name, ip)
 			case dns.TypeAAAA:
 			case dns.TypeAAAA:
 				answers++
 				answers++
 				ip := rr.(*dns.AAAA).AAAA
 				ip := rr.(*dns.AAAA).AAAA
-				logrus.Debugf("[resolver] received AAAA record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
+				r.log().Debugf("[resolver] received AAAA record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
 				r.backend.HandleQueryResp(h.Name, ip)
 				r.backend.HandleQueryResp(h.Name, ip)
 			}
 			}
 		}
 		}
 		if len(resp.Answer) == 0 {
 		if len(resp.Answer) == 0 {
-			logrus.Debugf("[resolver] external DNS %s:%s returned response with no answers:\n%s", proto, extDNS.IPStr, resp)
+			r.log().Debugf("[resolver] external DNS %s:%s returned response with no answers:\n%s", proto, extDNS.IPStr, resp)
 		}
 		}
 		resp.Compress = true
 		resp.Compress = true
 		return resp
 		return resp
@@ -508,12 +516,12 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
 func (r *Resolver) exchange(proto string, extDNS extDNSEntry, query *dns.Msg) *dns.Msg {
 func (r *Resolver) exchange(proto string, extDNS extDNSEntry, query *dns.Msg) *dns.Msg {
 	extConn, err := r.dialExtDNS(proto, extDNS)
 	extConn, err := r.dialExtDNS(proto, extDNS)
 	if err != nil {
 	if err != nil {
-		logrus.WithError(err).Warn("[resolver] connect failed")
+		r.log().WithError(err).Warn("[resolver] connect failed")
 		return nil
 		return nil
 	}
 	}
 	defer extConn.Close()
 	defer extConn.Close()
 
 
-	log := logrus.WithFields(logrus.Fields{
+	log := r.log().WithFields(logrus.Fields{
 		"dns-server":  extConn.RemoteAddr().Network() + ":" + extConn.RemoteAddr().String(),
 		"dns-server":  extConn.RemoteAddr().Network() + ":" + extConn.RemoteAddr().String(),
 		"client-addr": extConn.LocalAddr().Network() + ":" + extConn.LocalAddr().String(),
 		"client-addr": extConn.LocalAddr().Network() + ":" + extConn.LocalAddr().String(),
 		"question":    query.Question[0].String(),
 		"question":    query.Question[0].String(),
@@ -534,7 +542,7 @@ func (r *Resolver) exchange(proto string, extDNS extDNSEntry, query *dns.Msg) *d
 		UDPSize: dns.MaxMsgSize,
 		UDPSize: dns.MaxMsgSize,
 	}).ExchangeWithConn(query, &dns.Conn{Conn: extConn})
 	}).ExchangeWithConn(query, &dns.Conn{Conn: extConn})
 	if err != nil {
 	if err != nil {
-		logrus.WithError(err).Errorf("[resolver] failed to query DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String())
+		r.log().WithError(err).Errorf("[resolver] failed to query DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String())
 		return nil
 		return nil
 	}
 	}
 
 

+ 10 - 15
libnetwork/resolver_test.go

@@ -377,14 +377,13 @@ func TestOversizedDNSReply(t *testing.T) {
 
 
 	srvAddr := srv.LocalAddr().(*net.UDPAddr)
 	srvAddr := srv.LocalAddr().(*net.UDPAddr)
 	rsv := NewResolver("", true, noopDNSBackend{})
 	rsv := NewResolver("", true, noopDNSBackend{})
+	// The resolver logs lots of valuable info at level debug. Redirect it
+	// to t.Log() so the log spew is emitted only if the test fails.
+	rsv.logger = testLogger(t)
 	rsv.SetExtServers([]extDNSEntry{
 	rsv.SetExtServers([]extDNSEntry{
 		{IPStr: srvAddr.IP.String(), port: uint16(srvAddr.Port), HostLoopback: true},
 		{IPStr: srvAddr.IP.String(), port: uint16(srvAddr.Port), HostLoopback: true},
 	})
 	})
 
 
-	// The resolver logs lots of valuable info at level debug. Redirect it
-	// to t.Log() so the log spew is emitted only if the test fails.
-	defer redirectLogrusTo(t)()
-
 	w := &tstwriter{network: srvAddr.Network()}
 	w := &tstwriter{network: srvAddr.Network()}
 	q := new(dns.Msg).SetQuestion("s3.amazonaws.com.", dns.TypeA)
 	q := new(dns.Msg).SetQuestion("s3.amazonaws.com.", dns.TypeA)
 	rsv.serveDNS(w, q)
 	rsv.serveDNS(w, q)
@@ -396,14 +395,11 @@ func TestOversizedDNSReply(t *testing.T) {
 	checkDNSRRType(t, resp.Answer[0].Header().Rrtype, dns.TypeA)
 	checkDNSRRType(t, resp.Answer[0].Header().Rrtype, dns.TypeA)
 }
 }
 
 
-func redirectLogrusTo(t *testing.T) func() {
-	oldLevel, oldOut := logrus.StandardLogger().Level, logrus.StandardLogger().Out
-	logrus.StandardLogger().SetLevel(logrus.DebugLevel)
-	logrus.SetOutput(tlogWriter{t})
-	return func() {
-		logrus.StandardLogger().SetLevel(oldLevel)
-		logrus.StandardLogger().SetOutput(oldOut)
-	}
+func testLogger(t *testing.T) *logrus.Logger {
+	logger := logrus.New()
+	logger.SetLevel(logrus.DebugLevel)
+	logger.SetOutput(tlogWriter{t})
+	return logger
 }
 }
 
 
 type tlogWriter struct{ t *testing.T }
 type tlogWriter struct{ t *testing.T }
@@ -445,9 +441,8 @@ func TestReplySERVFAIL(t *testing.T) {
 	}
 	}
 	for _, tt := range cases {
 	for _, tt := range cases {
 		t.Run(tt.name, func(t *testing.T) {
 		t.Run(tt.name, func(t *testing.T) {
-			defer redirectLogrusTo(t)
-
 			rsv := NewResolver("", tt.proxyDNS, badSRVDNSBackend{})
 			rsv := NewResolver("", tt.proxyDNS, badSRVDNSBackend{})
+			rsv.logger = testLogger(t)
 			w := &tstwriter{}
 			w := &tstwriter{}
 			rsv.serveDNS(w, tt.q)
 			rsv.serveDNS(w, tt.q)
 			resp := w.GetResponse()
 			resp := w.GetResponse()
@@ -507,7 +502,7 @@ func TestProxyNXDOMAIN(t *testing.T) {
 
 
 	// The resolver logs lots of valuable info at level debug. Redirect it
 	// The resolver logs lots of valuable info at level debug. Redirect it
 	// to t.Log() so the log spew is emitted only if the test fails.
 	// to t.Log() so the log spew is emitted only if the test fails.
-	defer redirectLogrusTo(t)()
+	rsv.logger = testLogger(t)
 
 
 	w := &tstwriter{network: srvAddr.Network()}
 	w := &tstwriter{network: srvAddr.Network()}
 	q := new(dns.Msg).SetQuestion("example.net.", dns.TypeA)
 	q := new(dns.Msg).SetQuestion("example.net.", dns.TypeA)