소스 검색

Merge pull request #45586 from corhere/fix-flaky-resolver-test

libnetwork/osl: restore the right thread's netns
Bjorn Neergaard 2 년 전
부모
커밋
ecbd126d6a
5개의 변경된 파일128개의 추가작업 그리고 54개의 파일을 삭제
  1. 17 12
      libnetwork/osl/namespace_linux.go
  2. 24 16
      libnetwork/resolver.go
  3. 33 26
      libnetwork/resolver_test.go
  4. 43 0
      libnetwork/testutils/sanity_linux.go
  5. 11 0
      libnetwork/testutils/sanity_notlinux.go

+ 17 - 12
libnetwork/osl/namespace_linux.go

@@ -600,24 +600,29 @@ func (n *networkNamespace) checkLoV6() {
 }
 
 func setIPv6(nspath, iface string, enable bool) error {
-	origNS, err := netns.Get()
-	if err != nil {
-		return fmt.Errorf("failed to get current network namespace: %w", err)
-	}
-	defer origNS.Close()
-
-	namespace, err := netns.GetFromPath(nspath)
-	if err != nil {
-		return fmt.Errorf("failed get network namespace %q: %w", nspath, err)
-	}
-	defer namespace.Close()
-
 	errCh := make(chan error, 1)
 	go func() {
 		defer close(errCh)
 
+		namespace, err := netns.GetFromPath(nspath)
+		if err != nil {
+			errCh <- fmt.Errorf("failed get network namespace %q: %w", nspath, err)
+			return
+		}
+		defer namespace.Close()
+
 		runtime.LockOSThread()
+
+		origNS, err := netns.Get()
+		if err != nil {
+			runtime.UnlockOSThread()
+			errCh <- fmt.Errorf("failed to get current network namespace: %w", err)
+			return
+		}
+		defer origNS.Close()
+
 		if err = netns.Set(namespace); err != nil {
+			runtime.UnlockOSThread()
 			errCh <- fmt.Errorf("setting into container netns %q failed: %w", nspath, err)
 			return
 		}

+ 24 - 16
libnetwork/resolver.go

@@ -71,6 +71,7 @@ type Resolver struct {
 	listenAddress string
 	proxyDNS      bool
 	startCh       chan struct{}
+	logger        *logrus.Logger
 
 	fwdSem      *semaphore.Weighted // Limit the number of concurrent external DNS requests in-flight
 	logInverval rate.Sometimes      // Rate-limit logging about hitting the fwdSem limit
@@ -89,6 +90,13 @@ func NewResolver(address string, proxyDNS bool, backend DNSBackend) *Resolver {
 	}
 }
 
+func (r *Resolver) log() *logrus.Logger {
+	if r.logger == nil {
+		return logrus.StandardLogger()
+	}
+	return r.logger
+}
+
 // SetupFunc returns the setup function that should be run in the container's
 // network namespace.
 func (r *Resolver) SetupFunc(port int) func() {
@@ -140,7 +148,7 @@ func (r *Resolver) Start() error {
 	r.server = s
 	go func() {
 		if err := s.ActivateAndServe(); err != nil {
-			logrus.WithError(err).Error("[resolver] failed to start PacketConn DNS server")
+			r.log().WithError(err).Error("[resolver] failed to start PacketConn DNS server")
 		}
 	}()
 
@@ -148,7 +156,7 @@ func (r *Resolver) Start() error {
 	r.tcpServer = tcpServer
 	go func() {
 		if err := tcpServer.ActivateAndServe(); err != nil {
-			logrus.WithError(err).Error("[resolver] failed to start TCP DNS server")
+			r.log().WithError(err).Error("[resolver] failed to start TCP DNS server")
 		}
 	}()
 	return nil
@@ -249,7 +257,7 @@ func (r *Resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) {
 
 	if addr == nil && ipv6Miss {
 		// Send a reply without any Answer sections
-		logrus.Debugf("[resolver] lookup name %s present without IPv6 address", name)
+		r.log().Debugf("[resolver] lookup name %s present without IPv6 address", name)
 		resp := createRespMsg(query)
 		return resp, nil
 	}
@@ -257,7 +265,7 @@ func (r *Resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) {
 		return nil, nil
 	}
 
-	logrus.Debugf("[resolver] lookup for %s: IP %v", name, addr)
+	r.log().Debugf("[resolver] lookup for %s: IP %v", name, addr)
 
 	resp := createRespMsg(query)
 	if len(addr) > 1 {
@@ -298,7 +306,7 @@ func (r *Resolver) handlePTRQuery(query *dns.Msg) (*dns.Msg, error) {
 		return nil, nil
 	}
 
-	logrus.Debugf("[resolver] lookup for IP %s: name %s", name, host)
+	r.log().Debugf("[resolver] lookup for IP %s: name %s", name, host)
 	fqdn := dns.Fqdn(host)
 
 	resp := new(dns.Msg)
@@ -365,17 +373,17 @@ func (r *Resolver) serveDNS(w dns.ResponseWriter, query *dns.Msg) {
 	case dns.TypeSRV:
 		resp, err = r.handleSRVQuery(query)
 	default:
-		logrus.Debugf("[resolver] query type %s is not supported by the embedded DNS and will be forwarded to external DNS", dns.TypeToString[queryType])
+		r.log().Debugf("[resolver] query type %s is not supported by the embedded DNS and will be forwarded to external DNS", dns.TypeToString[queryType])
 	}
 
 	reply := func(msg *dns.Msg) {
 		if err = w.WriteMsg(msg); err != nil {
-			logrus.WithError(err).Errorf("[resolver] failed to write response")
+			r.log().WithError(err).Errorf("[resolver] failed to write response")
 		}
 	}
 
 	if err != nil {
-		logrus.WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType])
+		r.log().WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType])
 		reply(new(dns.Msg).SetRcode(query, dns.RcodeServerFailure))
 		return
 	}
@@ -460,7 +468,7 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
 		cancel()
 		if err != nil {
 			r.logInverval.Do(func() {
-				logrus.Errorf("[resolver] more than %v concurrent queries", maxConcurrent)
+				r.log().Errorf("[resolver] more than %v concurrent queries", maxConcurrent)
 			})
 			return new(dns.Msg).SetRcode(query, dns.RcodeRefused)
 		}
@@ -476,7 +484,7 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
 		case dns.RcodeServerFailure, dns.RcodeRefused:
 			// Server returned FAILURE: continue with the next external DNS server
 			// Server returned REFUSED: this can be a transitional status, so continue with the next external DNS server
-			logrus.Debugf("[resolver] external DNS %s:%s returned failure:\n%s", proto, extDNS.IPStr, resp)
+			r.log().Debugf("[resolver] external DNS %s:%s returned failure:\n%s", proto, extDNS.IPStr, resp)
 			continue
 		}
 		answers := 0
@@ -486,17 +494,17 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
 			case dns.TypeA:
 				answers++
 				ip := rr.(*dns.A).A
-				logrus.Debugf("[resolver] received A record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
+				r.log().Debugf("[resolver] received A record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
 				r.backend.HandleQueryResp(h.Name, ip)
 			case dns.TypeAAAA:
 				answers++
 				ip := rr.(*dns.AAAA).AAAA
-				logrus.Debugf("[resolver] received AAAA record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
+				r.log().Debugf("[resolver] received AAAA record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
 				r.backend.HandleQueryResp(h.Name, ip)
 			}
 		}
 		if len(resp.Answer) == 0 {
-			logrus.Debugf("[resolver] external DNS %s:%s returned response with no answers:\n%s", proto, extDNS.IPStr, resp)
+			r.log().Debugf("[resolver] external DNS %s:%s returned response with no answers:\n%s", proto, extDNS.IPStr, resp)
 		}
 		resp.Compress = true
 		return resp
@@ -508,12 +516,12 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
 func (r *Resolver) exchange(proto string, extDNS extDNSEntry, query *dns.Msg) *dns.Msg {
 	extConn, err := r.dialExtDNS(proto, extDNS)
 	if err != nil {
-		logrus.WithError(err).Warn("[resolver] connect failed")
+		r.log().WithError(err).Warn("[resolver] connect failed")
 		return nil
 	}
 	defer extConn.Close()
 
-	log := logrus.WithFields(logrus.Fields{
+	log := r.log().WithFields(logrus.Fields{
 		"dns-server":  extConn.RemoteAddr().Network() + ":" + extConn.RemoteAddr().String(),
 		"client-addr": extConn.LocalAddr().Network() + ":" + extConn.LocalAddr().String(),
 		"question":    query.Question[0].String(),
@@ -534,7 +542,7 @@ func (r *Resolver) exchange(proto string, extDNS extDNSEntry, query *dns.Msg) *d
 		UDPSize: dns.MaxMsgSize,
 	}).ExchangeWithConn(query, &dns.Conn{Conn: extConn})
 	if err != nil {
-		logrus.WithError(err).Errorf("[resolver] failed to query DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String())
+		r.log().WithError(err).Errorf("[resolver] failed to query DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String())
 		return nil
 	}
 

+ 33 - 26
libnetwork/resolver_test.go

@@ -19,16 +19,22 @@ import (
 
 // a simple/null address type that will be used to fake a local address for unit testing
 type tstaddr struct {
+	network string
 }
 
-func (a *tstaddr) Network() string { return "tcp" }
+func (a *tstaddr) Network() string {
+	if a.network != "" {
+		return a.network
+	}
+	return "tcp"
+}
 
-func (a *tstaddr) String() string { return "127.0.0.1" }
+func (a *tstaddr) String() string { return "(fake)" }
 
 // a simple writer that implements dns.ResponseWriter for unit testing purposes
 type tstwriter struct {
-	localAddr net.Addr
-	msg       *dns.Msg
+	network string
+	msg     *dns.Msg
 }
 
 func (w *tstwriter) WriteMsg(m *dns.Msg) (err error) {
@@ -39,13 +45,12 @@ func (w *tstwriter) WriteMsg(m *dns.Msg) (err error) {
 func (w *tstwriter) Write(m []byte) (int, error) { return 0, nil }
 
 func (w *tstwriter) LocalAddr() net.Addr {
-	if w.localAddr != nil {
-		return w.localAddr
-	}
-	return new(tstaddr)
+	return &tstaddr{network: w.network}
 }
 
-func (w *tstwriter) RemoteAddr() net.Addr { return new(tstaddr) }
+func (w *tstwriter) RemoteAddr() net.Addr {
+	return &tstaddr{network: w.network}
+}
 
 func (w *tstwriter) TsigStatus() error { return nil }
 
@@ -372,15 +377,14 @@ func TestOversizedDNSReply(t *testing.T) {
 
 	srvAddr := srv.LocalAddr().(*net.UDPAddr)
 	rsv := NewResolver("", true, noopDNSBackend{})
+	// The resolver logs lots of valuable info at level debug. Redirect it
+	// to t.Log() so the log spew is emitted only if the test fails.
+	rsv.logger = testLogger(t)
 	rsv.SetExtServers([]extDNSEntry{
 		{IPStr: srvAddr.IP.String(), port: uint16(srvAddr.Port), HostLoopback: true},
 	})
 
-	// The resolver logs lots of valuable info at level debug. Redirect it
-	// to t.Log() so the log spew is emitted only if the test fails.
-	defer redirectLogrusTo(t)()
-
-	w := &tstwriter{localAddr: srv.LocalAddr()}
+	w := &tstwriter{network: srvAddr.Network()}
 	q := new(dns.Msg).SetQuestion("s3.amazonaws.com.", dns.TypeA)
 	rsv.serveDNS(w, q)
 	resp := w.GetResponse()
@@ -391,14 +395,11 @@ func TestOversizedDNSReply(t *testing.T) {
 	checkDNSRRType(t, resp.Answer[0].Header().Rrtype, dns.TypeA)
 }
 
-func redirectLogrusTo(t *testing.T) func() {
-	oldLevel, oldOut := logrus.StandardLogger().Level, logrus.StandardLogger().Out
-	logrus.StandardLogger().SetLevel(logrus.DebugLevel)
-	logrus.SetOutput(tlogWriter{t})
-	return func() {
-		logrus.StandardLogger().SetLevel(oldLevel)
-		logrus.StandardLogger().SetOutput(oldOut)
-	}
+func testLogger(t *testing.T) *logrus.Logger {
+	logger := logrus.New()
+	logger.SetLevel(logrus.DebugLevel)
+	logger.SetOutput(tlogWriter{t})
+	return logger
 }
 
 type tlogWriter struct{ t *testing.T }
@@ -440,9 +441,8 @@ func TestReplySERVFAIL(t *testing.T) {
 	}
 	for _, tt := range cases {
 		t.Run(tt.name, func(t *testing.T) {
-			defer redirectLogrusTo(t)
-
 			rsv := NewResolver("", tt.proxyDNS, badSRVDNSBackend{})
+			rsv.logger = testLogger(t)
 			w := &tstwriter{}
 			rsv.serveDNS(w, tt.q)
 			resp := w.GetResponse()
@@ -494,6 +494,13 @@ func TestProxyNXDOMAIN(t *testing.T) {
 		<-serveDone
 	}()
 
+	// This test, by virtue of running a server and client in different
+	// not-locked-to-thread goroutines, happens to be a good canary for
+	// whether we are leaking unlocked OS threads set to the wrong network
+	// namespace. Make a best-effort attempt to detect that situation so we
+	// are not left chasing ghosts next time.
+	testutils.AssertSocketSameNetNS(t, srv.PacketConn.(*net.UDPConn))
+
 	srvAddr := srv.PacketConn.LocalAddr().(*net.UDPAddr)
 	rsv := NewResolver("", true, noopDNSBackend{})
 	rsv.SetExtServers([]extDNSEntry{
@@ -502,9 +509,9 @@ func TestProxyNXDOMAIN(t *testing.T) {
 
 	// The resolver logs lots of valuable info at level debug. Redirect it
 	// to t.Log() so the log spew is emitted only if the test fails.
-	defer redirectLogrusTo(t)()
+	rsv.logger = testLogger(t)
 
-	w := &tstwriter{localAddr: srv.PacketConn.LocalAddr()}
+	w := &tstwriter{network: srvAddr.Network()}
 	q := new(dns.Msg).SetQuestion("example.net.", dns.TypeA)
 	rsv.serveDNS(w, q)
 	resp := w.GetResponse()

+ 43 - 0
libnetwork/testutils/sanity_linux.go

@@ -0,0 +1,43 @@
+package testutils
+
+import (
+	"errors"
+	"syscall"
+	"testing"
+
+	"github.com/vishvananda/netns"
+	"golang.org/x/sys/unix"
+	"gotest.tools/v3/assert"
+)
+
+// AssertSocketSameNetNS makes a best-effort attempt to assert that conn is in
+// the same network namespace as the current goroutine's thread.
+func AssertSocketSameNetNS(t testing.TB, conn syscall.Conn) {
+	t.Helper()
+
+	sc, err := conn.SyscallConn()
+	assert.NilError(t, err)
+	sc.Control(func(fd uintptr) {
+		srvnsfd, err := unix.IoctlRetInt(int(fd), unix.SIOCGSKNS)
+		if err != nil {
+			if errors.Is(err, unix.EPERM) {
+				t.Log("Cannot determine socket's network namespace. Do we have CAP_NET_ADMIN?")
+				return
+			}
+			if errors.Is(err, unix.ENOSYS) {
+				t.Log("Cannot query socket's network namespace due to missing kernel support.")
+				return
+			}
+			t.Fatal(err)
+		}
+		srvns := netns.NsHandle(srvnsfd)
+		defer srvns.Close()
+
+		curns, err := netns.Get()
+		assert.NilError(t, err)
+		defer curns.Close()
+		if !srvns.Equal(curns) {
+			t.Fatalf("Socket is in network namespace %s, but test goroutine is in %s", srvns, curns)
+		}
+	})
+}

+ 11 - 0
libnetwork/testutils/sanity_notlinux.go

@@ -0,0 +1,11 @@
+//go:build !linux
+
+package testutils
+
+import (
+	"syscall"
+	"testing"
+)
+
+// AssertSocketSameNetNS is a no-op on platforms other than Linux.
+func AssertSocketSameNetNS(t testing.TB, conn syscall.Conn) {}