libnetwork: leave global logger alone in tests

Swapping out the global logger on the fly is causing tests to flake out
by logging to a test's log output after the test function has returned.
Refactor Resolver to use a dependency-injected logger and the resolver
unit tests to inject a private logger instance into the Resolver under
test.

Signed-off-by: Cory Snider <csnider@mirantis.com>
This commit is contained in:
Cory Snider 2023-05-19 11:27:15 -04:00
parent 0cc6e445d7
commit d4f3858a40
2 changed files with 34 additions and 31 deletions

View file

@ -71,6 +71,7 @@ type Resolver struct {
listenAddress string
proxyDNS bool
startCh chan struct{}
logger *logrus.Logger
fwdSem *semaphore.Weighted // Limit the number of concurrent external DNS requests in-flight
logInverval rate.Sometimes // Rate-limit logging about hitting the fwdSem limit
@ -89,6 +90,13 @@ func NewResolver(address string, proxyDNS bool, backend DNSBackend) *Resolver {
}
}
func (r *Resolver) log() *logrus.Logger {
if r.logger == nil {
return logrus.StandardLogger()
}
return r.logger
}
// SetupFunc returns the setup function that should be run in the container's
// network namespace.
func (r *Resolver) SetupFunc(port int) func() {
@ -140,7 +148,7 @@ func (r *Resolver) Start() error {
r.server = s
go func() {
if err := s.ActivateAndServe(); err != nil {
logrus.WithError(err).Error("[resolver] failed to start PacketConn DNS server")
r.log().WithError(err).Error("[resolver] failed to start PacketConn DNS server")
}
}()
@ -148,7 +156,7 @@ func (r *Resolver) Start() error {
r.tcpServer = tcpServer
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
logrus.WithError(err).Error("[resolver] failed to start TCP DNS server")
r.log().WithError(err).Error("[resolver] failed to start TCP DNS server")
}
}()
return nil
@ -249,7 +257,7 @@ func (r *Resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) {
if addr == nil && ipv6Miss {
// Send a reply without any Answer sections
logrus.Debugf("[resolver] lookup name %s present without IPv6 address", name)
r.log().Debugf("[resolver] lookup name %s present without IPv6 address", name)
resp := createRespMsg(query)
return resp, nil
}
@ -257,7 +265,7 @@ func (r *Resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) {
return nil, nil
}
logrus.Debugf("[resolver] lookup for %s: IP %v", name, addr)
r.log().Debugf("[resolver] lookup for %s: IP %v", name, addr)
resp := createRespMsg(query)
if len(addr) > 1 {
@ -298,7 +306,7 @@ func (r *Resolver) handlePTRQuery(query *dns.Msg) (*dns.Msg, error) {
return nil, nil
}
logrus.Debugf("[resolver] lookup for IP %s: name %s", name, host)
r.log().Debugf("[resolver] lookup for IP %s: name %s", name, host)
fqdn := dns.Fqdn(host)
resp := new(dns.Msg)
@ -365,17 +373,17 @@ func (r *Resolver) serveDNS(w dns.ResponseWriter, query *dns.Msg) {
case dns.TypeSRV:
resp, err = r.handleSRVQuery(query)
default:
logrus.Debugf("[resolver] query type %s is not supported by the embedded DNS and will be forwarded to external DNS", dns.TypeToString[queryType])
r.log().Debugf("[resolver] query type %s is not supported by the embedded DNS and will be forwarded to external DNS", dns.TypeToString[queryType])
}
reply := func(msg *dns.Msg) {
if err = w.WriteMsg(msg); err != nil {
logrus.WithError(err).Errorf("[resolver] failed to write response")
r.log().WithError(err).Errorf("[resolver] failed to write response")
}
}
if err != nil {
logrus.WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType])
r.log().WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType])
reply(new(dns.Msg).SetRcode(query, dns.RcodeServerFailure))
return
}
@ -460,7 +468,7 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
cancel()
if err != nil {
r.logInverval.Do(func() {
logrus.Errorf("[resolver] more than %v concurrent queries", maxConcurrent)
r.log().Errorf("[resolver] more than %v concurrent queries", maxConcurrent)
})
return new(dns.Msg).SetRcode(query, dns.RcodeRefused)
}
@ -476,7 +484,7 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
case dns.RcodeServerFailure, dns.RcodeRefused:
// Server returned FAILURE: continue with the next external DNS server
// Server returned REFUSED: this can be a transitional status, so continue with the next external DNS server
logrus.Debugf("[resolver] external DNS %s:%s returned failure:\n%s", proto, extDNS.IPStr, resp)
r.log().Debugf("[resolver] external DNS %s:%s returned failure:\n%s", proto, extDNS.IPStr, resp)
continue
}
answers := 0
@ -486,17 +494,17 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
case dns.TypeA:
answers++
ip := rr.(*dns.A).A
logrus.Debugf("[resolver] received A record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
r.log().Debugf("[resolver] received A record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
r.backend.HandleQueryResp(h.Name, ip)
case dns.TypeAAAA:
answers++
ip := rr.(*dns.AAAA).AAAA
logrus.Debugf("[resolver] received AAAA record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
r.log().Debugf("[resolver] received AAAA record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
r.backend.HandleQueryResp(h.Name, ip)
}
}
if len(resp.Answer) == 0 {
logrus.Debugf("[resolver] external DNS %s:%s returned response with no answers:\n%s", proto, extDNS.IPStr, resp)
r.log().Debugf("[resolver] external DNS %s:%s returned response with no answers:\n%s", proto, extDNS.IPStr, resp)
}
resp.Compress = true
return resp
@ -508,12 +516,12 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
func (r *Resolver) exchange(proto string, extDNS extDNSEntry, query *dns.Msg) *dns.Msg {
extConn, err := r.dialExtDNS(proto, extDNS)
if err != nil {
logrus.WithError(err).Warn("[resolver] connect failed")
r.log().WithError(err).Warn("[resolver] connect failed")
return nil
}
defer extConn.Close()
log := logrus.WithFields(logrus.Fields{
log := r.log().WithFields(logrus.Fields{
"dns-server": extConn.RemoteAddr().Network() + ":" + extConn.RemoteAddr().String(),
"client-addr": extConn.LocalAddr().Network() + ":" + extConn.LocalAddr().String(),
"question": query.Question[0].String(),
@ -534,7 +542,7 @@ func (r *Resolver) exchange(proto string, extDNS extDNSEntry, query *dns.Msg) *d
UDPSize: dns.MaxMsgSize,
}).ExchangeWithConn(query, &dns.Conn{Conn: extConn})
if err != nil {
logrus.WithError(err).Errorf("[resolver] failed to query DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String())
r.log().WithError(err).Errorf("[resolver] failed to query DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String())
return nil
}

View file

@ -377,14 +377,13 @@ func TestOversizedDNSReply(t *testing.T) {
srvAddr := srv.LocalAddr().(*net.UDPAddr)
rsv := NewResolver("", true, noopDNSBackend{})
// The resolver logs lots of valuable info at level debug. Redirect it
// to t.Log() so the log spew is emitted only if the test fails.
rsv.logger = testLogger(t)
rsv.SetExtServers([]extDNSEntry{
{IPStr: srvAddr.IP.String(), port: uint16(srvAddr.Port), HostLoopback: true},
})
// The resolver logs lots of valuable info at level debug. Redirect it
// to t.Log() so the log spew is emitted only if the test fails.
defer redirectLogrusTo(t)()
w := &tstwriter{network: srvAddr.Network()}
q := new(dns.Msg).SetQuestion("s3.amazonaws.com.", dns.TypeA)
rsv.serveDNS(w, q)
@ -396,14 +395,11 @@ func TestOversizedDNSReply(t *testing.T) {
checkDNSRRType(t, resp.Answer[0].Header().Rrtype, dns.TypeA)
}
func redirectLogrusTo(t *testing.T) func() {
oldLevel, oldOut := logrus.StandardLogger().Level, logrus.StandardLogger().Out
logrus.StandardLogger().SetLevel(logrus.DebugLevel)
logrus.SetOutput(tlogWriter{t})
return func() {
logrus.StandardLogger().SetLevel(oldLevel)
logrus.StandardLogger().SetOutput(oldOut)
}
func testLogger(t *testing.T) *logrus.Logger {
logger := logrus.New()
logger.SetLevel(logrus.DebugLevel)
logger.SetOutput(tlogWriter{t})
return logger
}
type tlogWriter struct{ t *testing.T }
@ -445,9 +441,8 @@ func TestReplySERVFAIL(t *testing.T) {
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
defer redirectLogrusTo(t)
rsv := NewResolver("", tt.proxyDNS, badSRVDNSBackend{})
rsv.logger = testLogger(t)
w := &tstwriter{}
rsv.serveDNS(w, tt.q)
resp := w.GetResponse()
@ -507,7 +502,7 @@ func TestProxyNXDOMAIN(t *testing.T) {
// The resolver logs lots of valuable info at level debug. Redirect it
// to t.Log() so the log spew is emitted only if the test fails.
defer redirectLogrusTo(t)()
rsv.logger = testLogger(t)
w := &tstwriter{network: srvAddr.Network()}
q := new(dns.Msg).SetQuestion("example.net.", dns.TypeA)