Merge pull request #45598 from corhere/backport-24.0/fix-flaky-resolver-test

[24.0 backport] libnetwork/osl: restore the right thread's netns
This commit is contained in:
Sebastiaan van Stijn 2023-05-23 19:15:38 +02:00 committed by GitHub
commit 1b0d37bdc2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 128 additions and 54 deletions

View file

@ -600,24 +600,29 @@ func (n *networkNamespace) checkLoV6() {
}
func setIPv6(nspath, iface string, enable bool) error {
origNS, err := netns.Get()
if err != nil {
return fmt.Errorf("failed to get current network namespace: %w", err)
}
defer origNS.Close()
namespace, err := netns.GetFromPath(nspath)
if err != nil {
return fmt.Errorf("failed get network namespace %q: %w", nspath, err)
}
defer namespace.Close()
errCh := make(chan error, 1)
go func() {
defer close(errCh)
namespace, err := netns.GetFromPath(nspath)
if err != nil {
errCh <- fmt.Errorf("failed get network namespace %q: %w", nspath, err)
return
}
defer namespace.Close()
runtime.LockOSThread()
origNS, err := netns.Get()
if err != nil {
runtime.UnlockOSThread()
errCh <- fmt.Errorf("failed to get current network namespace: %w", err)
return
}
defer origNS.Close()
if err = netns.Set(namespace); err != nil {
runtime.UnlockOSThread()
errCh <- fmt.Errorf("setting into container netns %q failed: %w", nspath, err)
return
}

View file

@ -71,6 +71,7 @@ type Resolver struct {
listenAddress string
proxyDNS bool
startCh chan struct{}
logger *logrus.Logger
fwdSem *semaphore.Weighted // Limit the number of concurrent external DNS requests in-flight
logInverval rate.Sometimes // Rate-limit logging about hitting the fwdSem limit
@ -89,6 +90,13 @@ func NewResolver(address string, proxyDNS bool, backend DNSBackend) *Resolver {
}
}
func (r *Resolver) log() *logrus.Logger {
if r.logger == nil {
return logrus.StandardLogger()
}
return r.logger
}
// SetupFunc returns the setup function that should be run in the container's
// network namespace.
func (r *Resolver) SetupFunc(port int) func() {
@ -140,7 +148,7 @@ func (r *Resolver) Start() error {
r.server = s
go func() {
if err := s.ActivateAndServe(); err != nil {
logrus.WithError(err).Error("[resolver] failed to start PacketConn DNS server")
r.log().WithError(err).Error("[resolver] failed to start PacketConn DNS server")
}
}()
@ -148,7 +156,7 @@ func (r *Resolver) Start() error {
r.tcpServer = tcpServer
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
logrus.WithError(err).Error("[resolver] failed to start TCP DNS server")
r.log().WithError(err).Error("[resolver] failed to start TCP DNS server")
}
}()
return nil
@ -249,7 +257,7 @@ func (r *Resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) {
if addr == nil && ipv6Miss {
// Send a reply without any Answer sections
logrus.Debugf("[resolver] lookup name %s present without IPv6 address", name)
r.log().Debugf("[resolver] lookup name %s present without IPv6 address", name)
resp := createRespMsg(query)
return resp, nil
}
@ -257,7 +265,7 @@ func (r *Resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) {
return nil, nil
}
logrus.Debugf("[resolver] lookup for %s: IP %v", name, addr)
r.log().Debugf("[resolver] lookup for %s: IP %v", name, addr)
resp := createRespMsg(query)
if len(addr) > 1 {
@ -298,7 +306,7 @@ func (r *Resolver) handlePTRQuery(query *dns.Msg) (*dns.Msg, error) {
return nil, nil
}
logrus.Debugf("[resolver] lookup for IP %s: name %s", name, host)
r.log().Debugf("[resolver] lookup for IP %s: name %s", name, host)
fqdn := dns.Fqdn(host)
resp := new(dns.Msg)
@ -365,17 +373,17 @@ func (r *Resolver) serveDNS(w dns.ResponseWriter, query *dns.Msg) {
case dns.TypeSRV:
resp, err = r.handleSRVQuery(query)
default:
logrus.Debugf("[resolver] query type %s is not supported by the embedded DNS and will be forwarded to external DNS", dns.TypeToString[queryType])
r.log().Debugf("[resolver] query type %s is not supported by the embedded DNS and will be forwarded to external DNS", dns.TypeToString[queryType])
}
reply := func(msg *dns.Msg) {
if err = w.WriteMsg(msg); err != nil {
logrus.WithError(err).Errorf("[resolver] failed to write response")
r.log().WithError(err).Errorf("[resolver] failed to write response")
}
}
if err != nil {
logrus.WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType])
r.log().WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType])
reply(new(dns.Msg).SetRcode(query, dns.RcodeServerFailure))
return
}
@ -460,7 +468,7 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
cancel()
if err != nil {
r.logInverval.Do(func() {
logrus.Errorf("[resolver] more than %v concurrent queries", maxConcurrent)
r.log().Errorf("[resolver] more than %v concurrent queries", maxConcurrent)
})
return new(dns.Msg).SetRcode(query, dns.RcodeRefused)
}
@ -476,7 +484,7 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
case dns.RcodeServerFailure, dns.RcodeRefused:
// Server returned FAILURE: continue with the next external DNS server
// Server returned REFUSED: this can be a transitional status, so continue with the next external DNS server
logrus.Debugf("[resolver] external DNS %s:%s returned failure:\n%s", proto, extDNS.IPStr, resp)
r.log().Debugf("[resolver] external DNS %s:%s returned failure:\n%s", proto, extDNS.IPStr, resp)
continue
}
answers := 0
@ -486,17 +494,17 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
case dns.TypeA:
answers++
ip := rr.(*dns.A).A
logrus.Debugf("[resolver] received A record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
r.log().Debugf("[resolver] received A record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
r.backend.HandleQueryResp(h.Name, ip)
case dns.TypeAAAA:
answers++
ip := rr.(*dns.AAAA).AAAA
logrus.Debugf("[resolver] received AAAA record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
r.log().Debugf("[resolver] received AAAA record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
r.backend.HandleQueryResp(h.Name, ip)
}
}
if len(resp.Answer) == 0 {
logrus.Debugf("[resolver] external DNS %s:%s returned response with no answers:\n%s", proto, extDNS.IPStr, resp)
r.log().Debugf("[resolver] external DNS %s:%s returned response with no answers:\n%s", proto, extDNS.IPStr, resp)
}
resp.Compress = true
return resp
@ -508,12 +516,12 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
func (r *Resolver) exchange(proto string, extDNS extDNSEntry, query *dns.Msg) *dns.Msg {
extConn, err := r.dialExtDNS(proto, extDNS)
if err != nil {
logrus.WithError(err).Warn("[resolver] connect failed")
r.log().WithError(err).Warn("[resolver] connect failed")
return nil
}
defer extConn.Close()
log := logrus.WithFields(logrus.Fields{
log := r.log().WithFields(logrus.Fields{
"dns-server": extConn.RemoteAddr().Network() + ":" + extConn.RemoteAddr().String(),
"client-addr": extConn.LocalAddr().Network() + ":" + extConn.LocalAddr().String(),
"question": query.Question[0].String(),
@ -534,7 +542,7 @@ func (r *Resolver) exchange(proto string, extDNS extDNSEntry, query *dns.Msg) *d
UDPSize: dns.MaxMsgSize,
}).ExchangeWithConn(query, &dns.Conn{Conn: extConn})
if err != nil {
logrus.WithError(err).Errorf("[resolver] failed to query DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String())
r.log().WithError(err).Errorf("[resolver] failed to query DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String())
return nil
}

View file

@ -19,16 +19,22 @@ import (
// a simple/null address type that will be used to fake a local address for unit testing
type tstaddr struct {
network string
}
func (a *tstaddr) Network() string { return "tcp" }
func (a *tstaddr) Network() string {
if a.network != "" {
return a.network
}
return "tcp"
}
func (a *tstaddr) String() string { return "127.0.0.1" }
func (a *tstaddr) String() string { return "(fake)" }
// a simple writer that implements dns.ResponseWriter for unit testing purposes
type tstwriter struct {
localAddr net.Addr
msg *dns.Msg
network string
msg *dns.Msg
}
func (w *tstwriter) WriteMsg(m *dns.Msg) (err error) {
@ -39,13 +45,12 @@ func (w *tstwriter) WriteMsg(m *dns.Msg) (err error) {
func (w *tstwriter) Write(m []byte) (int, error) { return 0, nil }
func (w *tstwriter) LocalAddr() net.Addr {
if w.localAddr != nil {
return w.localAddr
}
return new(tstaddr)
return &tstaddr{network: w.network}
}
func (w *tstwriter) RemoteAddr() net.Addr { return new(tstaddr) }
func (w *tstwriter) RemoteAddr() net.Addr {
return &tstaddr{network: w.network}
}
func (w *tstwriter) TsigStatus() error { return nil }
@ -372,15 +377,14 @@ func TestOversizedDNSReply(t *testing.T) {
srvAddr := srv.LocalAddr().(*net.UDPAddr)
rsv := NewResolver("", true, noopDNSBackend{})
// The resolver logs lots of valuable info at level debug. Redirect it
// to t.Log() so the log spew is emitted only if the test fails.
rsv.logger = testLogger(t)
rsv.SetExtServers([]extDNSEntry{
{IPStr: srvAddr.IP.String(), port: uint16(srvAddr.Port), HostLoopback: true},
})
// The resolver logs lots of valuable info at level debug. Redirect it
// to t.Log() so the log spew is emitted only if the test fails.
defer redirectLogrusTo(t)()
w := &tstwriter{localAddr: srv.LocalAddr()}
w := &tstwriter{network: srvAddr.Network()}
q := new(dns.Msg).SetQuestion("s3.amazonaws.com.", dns.TypeA)
rsv.serveDNS(w, q)
resp := w.GetResponse()
@ -391,14 +395,11 @@ func TestOversizedDNSReply(t *testing.T) {
checkDNSRRType(t, resp.Answer[0].Header().Rrtype, dns.TypeA)
}
func redirectLogrusTo(t *testing.T) func() {
oldLevel, oldOut := logrus.StandardLogger().Level, logrus.StandardLogger().Out
logrus.StandardLogger().SetLevel(logrus.DebugLevel)
logrus.SetOutput(tlogWriter{t})
return func() {
logrus.StandardLogger().SetLevel(oldLevel)
logrus.StandardLogger().SetOutput(oldOut)
}
func testLogger(t *testing.T) *logrus.Logger {
logger := logrus.New()
logger.SetLevel(logrus.DebugLevel)
logger.SetOutput(tlogWriter{t})
return logger
}
type tlogWriter struct{ t *testing.T }
@ -440,9 +441,8 @@ func TestReplySERVFAIL(t *testing.T) {
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
defer redirectLogrusTo(t)
rsv := NewResolver("", tt.proxyDNS, badSRVDNSBackend{})
rsv.logger = testLogger(t)
w := &tstwriter{}
rsv.serveDNS(w, tt.q)
resp := w.GetResponse()
@ -494,6 +494,13 @@ func TestProxyNXDOMAIN(t *testing.T) {
<-serveDone
}()
// This test, by virtue of running a server and client in different
// not-locked-to-thread goroutines, happens to be a good canary for
// whether we are leaking unlocked OS threads set to the wrong network
// namespace. Make a best-effort attempt to detect that situation so we
// are not left chasing ghosts next time.
testutils.AssertSocketSameNetNS(t, srv.PacketConn.(*net.UDPConn))
srvAddr := srv.PacketConn.LocalAddr().(*net.UDPAddr)
rsv := NewResolver("", true, noopDNSBackend{})
rsv.SetExtServers([]extDNSEntry{
@ -502,9 +509,9 @@ func TestProxyNXDOMAIN(t *testing.T) {
// The resolver logs lots of valuable info at level debug. Redirect it
// to t.Log() so the log spew is emitted only if the test fails.
defer redirectLogrusTo(t)()
rsv.logger = testLogger(t)
w := &tstwriter{localAddr: srv.PacketConn.LocalAddr()}
w := &tstwriter{network: srvAddr.Network()}
q := new(dns.Msg).SetQuestion("example.net.", dns.TypeA)
rsv.serveDNS(w, q)
resp := w.GetResponse()

View file

@ -0,0 +1,43 @@
package testutils
import (
"errors"
"syscall"
"testing"
"github.com/vishvananda/netns"
"golang.org/x/sys/unix"
"gotest.tools/v3/assert"
)
// AssertSocketSameNetNS makes a best-effort attempt to assert that conn is in
// the same network namespace as the current goroutine's thread.
func AssertSocketSameNetNS(t testing.TB, conn syscall.Conn) {
t.Helper()
sc, err := conn.SyscallConn()
assert.NilError(t, err)
sc.Control(func(fd uintptr) {
srvnsfd, err := unix.IoctlRetInt(int(fd), unix.SIOCGSKNS)
if err != nil {
if errors.Is(err, unix.EPERM) {
t.Log("Cannot determine socket's network namespace. Do we have CAP_NET_ADMIN?")
return
}
if errors.Is(err, unix.ENOSYS) {
t.Log("Cannot query socket's network namespace due to missing kernel support.")
return
}
t.Fatal(err)
}
srvns := netns.NsHandle(srvnsfd)
defer srvns.Close()
curns, err := netns.Get()
assert.NilError(t, err)
defer curns.Close()
if !srvns.Equal(curns) {
t.Fatalf("Socket is in network namespace %s, but test goroutine is in %s", srvns, curns)
}
})
}

View file

@ -0,0 +1,11 @@
//go:build !linux
package testutils
import (
"syscall"
"testing"
)
// AssertSocketSameNetNS is a no-op on platforms other than Linux.
func AssertSocketSameNetNS(t testing.TB, conn syscall.Conn) {}