Merge pull request #45598 from corhere/backport-24.0/fix-flaky-resolver-test
[24.0 backport] libnetwork/osl: restore the right thread's netns
This commit is contained in:
commit
1b0d37bdc2
5 changed files with 128 additions and 54 deletions
|
@ -600,24 +600,29 @@ func (n *networkNamespace) checkLoV6() {
|
|||
}
|
||||
|
||||
func setIPv6(nspath, iface string, enable bool) error {
|
||||
origNS, err := netns.Get()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current network namespace: %w", err)
|
||||
}
|
||||
defer origNS.Close()
|
||||
|
||||
namespace, err := netns.GetFromPath(nspath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed get network namespace %q: %w", nspath, err)
|
||||
}
|
||||
defer namespace.Close()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(errCh)
|
||||
|
||||
namespace, err := netns.GetFromPath(nspath)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("failed get network namespace %q: %w", nspath, err)
|
||||
return
|
||||
}
|
||||
defer namespace.Close()
|
||||
|
||||
runtime.LockOSThread()
|
||||
|
||||
origNS, err := netns.Get()
|
||||
if err != nil {
|
||||
runtime.UnlockOSThread()
|
||||
errCh <- fmt.Errorf("failed to get current network namespace: %w", err)
|
||||
return
|
||||
}
|
||||
defer origNS.Close()
|
||||
|
||||
if err = netns.Set(namespace); err != nil {
|
||||
runtime.UnlockOSThread()
|
||||
errCh <- fmt.Errorf("setting into container netns %q failed: %w", nspath, err)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -71,6 +71,7 @@ type Resolver struct {
|
|||
listenAddress string
|
||||
proxyDNS bool
|
||||
startCh chan struct{}
|
||||
logger *logrus.Logger
|
||||
|
||||
fwdSem *semaphore.Weighted // Limit the number of concurrent external DNS requests in-flight
|
||||
logInverval rate.Sometimes // Rate-limit logging about hitting the fwdSem limit
|
||||
|
@ -89,6 +90,13 @@ func NewResolver(address string, proxyDNS bool, backend DNSBackend) *Resolver {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *Resolver) log() *logrus.Logger {
|
||||
if r.logger == nil {
|
||||
return logrus.StandardLogger()
|
||||
}
|
||||
return r.logger
|
||||
}
|
||||
|
||||
// SetupFunc returns the setup function that should be run in the container's
|
||||
// network namespace.
|
||||
func (r *Resolver) SetupFunc(port int) func() {
|
||||
|
@ -140,7 +148,7 @@ func (r *Resolver) Start() error {
|
|||
r.server = s
|
||||
go func() {
|
||||
if err := s.ActivateAndServe(); err != nil {
|
||||
logrus.WithError(err).Error("[resolver] failed to start PacketConn DNS server")
|
||||
r.log().WithError(err).Error("[resolver] failed to start PacketConn DNS server")
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -148,7 +156,7 @@ func (r *Resolver) Start() error {
|
|||
r.tcpServer = tcpServer
|
||||
go func() {
|
||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
||||
logrus.WithError(err).Error("[resolver] failed to start TCP DNS server")
|
||||
r.log().WithError(err).Error("[resolver] failed to start TCP DNS server")
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
|
@ -249,7 +257,7 @@ func (r *Resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) {
|
|||
|
||||
if addr == nil && ipv6Miss {
|
||||
// Send a reply without any Answer sections
|
||||
logrus.Debugf("[resolver] lookup name %s present without IPv6 address", name)
|
||||
r.log().Debugf("[resolver] lookup name %s present without IPv6 address", name)
|
||||
resp := createRespMsg(query)
|
||||
return resp, nil
|
||||
}
|
||||
|
@ -257,7 +265,7 @@ func (r *Resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) {
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
logrus.Debugf("[resolver] lookup for %s: IP %v", name, addr)
|
||||
r.log().Debugf("[resolver] lookup for %s: IP %v", name, addr)
|
||||
|
||||
resp := createRespMsg(query)
|
||||
if len(addr) > 1 {
|
||||
|
@ -298,7 +306,7 @@ func (r *Resolver) handlePTRQuery(query *dns.Msg) (*dns.Msg, error) {
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
logrus.Debugf("[resolver] lookup for IP %s: name %s", name, host)
|
||||
r.log().Debugf("[resolver] lookup for IP %s: name %s", name, host)
|
||||
fqdn := dns.Fqdn(host)
|
||||
|
||||
resp := new(dns.Msg)
|
||||
|
@ -365,17 +373,17 @@ func (r *Resolver) serveDNS(w dns.ResponseWriter, query *dns.Msg) {
|
|||
case dns.TypeSRV:
|
||||
resp, err = r.handleSRVQuery(query)
|
||||
default:
|
||||
logrus.Debugf("[resolver] query type %s is not supported by the embedded DNS and will be forwarded to external DNS", dns.TypeToString[queryType])
|
||||
r.log().Debugf("[resolver] query type %s is not supported by the embedded DNS and will be forwarded to external DNS", dns.TypeToString[queryType])
|
||||
}
|
||||
|
||||
reply := func(msg *dns.Msg) {
|
||||
if err = w.WriteMsg(msg); err != nil {
|
||||
logrus.WithError(err).Errorf("[resolver] failed to write response")
|
||||
r.log().WithError(err).Errorf("[resolver] failed to write response")
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType])
|
||||
r.log().WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType])
|
||||
reply(new(dns.Msg).SetRcode(query, dns.RcodeServerFailure))
|
||||
return
|
||||
}
|
||||
|
@ -460,7 +468,7 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
|
|||
cancel()
|
||||
if err != nil {
|
||||
r.logInverval.Do(func() {
|
||||
logrus.Errorf("[resolver] more than %v concurrent queries", maxConcurrent)
|
||||
r.log().Errorf("[resolver] more than %v concurrent queries", maxConcurrent)
|
||||
})
|
||||
return new(dns.Msg).SetRcode(query, dns.RcodeRefused)
|
||||
}
|
||||
|
@ -476,7 +484,7 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
|
|||
case dns.RcodeServerFailure, dns.RcodeRefused:
|
||||
// Server returned FAILURE: continue with the next external DNS server
|
||||
// Server returned REFUSED: this can be a transitional status, so continue with the next external DNS server
|
||||
logrus.Debugf("[resolver] external DNS %s:%s returned failure:\n%s", proto, extDNS.IPStr, resp)
|
||||
r.log().Debugf("[resolver] external DNS %s:%s returned failure:\n%s", proto, extDNS.IPStr, resp)
|
||||
continue
|
||||
}
|
||||
answers := 0
|
||||
|
@ -486,17 +494,17 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
|
|||
case dns.TypeA:
|
||||
answers++
|
||||
ip := rr.(*dns.A).A
|
||||
logrus.Debugf("[resolver] received A record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
|
||||
r.log().Debugf("[resolver] received A record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
|
||||
r.backend.HandleQueryResp(h.Name, ip)
|
||||
case dns.TypeAAAA:
|
||||
answers++
|
||||
ip := rr.(*dns.AAAA).AAAA
|
||||
logrus.Debugf("[resolver] received AAAA record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
|
||||
r.log().Debugf("[resolver] received AAAA record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
|
||||
r.backend.HandleQueryResp(h.Name, ip)
|
||||
}
|
||||
}
|
||||
if len(resp.Answer) == 0 {
|
||||
logrus.Debugf("[resolver] external DNS %s:%s returned response with no answers:\n%s", proto, extDNS.IPStr, resp)
|
||||
r.log().Debugf("[resolver] external DNS %s:%s returned response with no answers:\n%s", proto, extDNS.IPStr, resp)
|
||||
}
|
||||
resp.Compress = true
|
||||
return resp
|
||||
|
@ -508,12 +516,12 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
|
|||
func (r *Resolver) exchange(proto string, extDNS extDNSEntry, query *dns.Msg) *dns.Msg {
|
||||
extConn, err := r.dialExtDNS(proto, extDNS)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Warn("[resolver] connect failed")
|
||||
r.log().WithError(err).Warn("[resolver] connect failed")
|
||||
return nil
|
||||
}
|
||||
defer extConn.Close()
|
||||
|
||||
log := logrus.WithFields(logrus.Fields{
|
||||
log := r.log().WithFields(logrus.Fields{
|
||||
"dns-server": extConn.RemoteAddr().Network() + ":" + extConn.RemoteAddr().String(),
|
||||
"client-addr": extConn.LocalAddr().Network() + ":" + extConn.LocalAddr().String(),
|
||||
"question": query.Question[0].String(),
|
||||
|
@ -534,7 +542,7 @@ func (r *Resolver) exchange(proto string, extDNS extDNSEntry, query *dns.Msg) *d
|
|||
UDPSize: dns.MaxMsgSize,
|
||||
}).ExchangeWithConn(query, &dns.Conn{Conn: extConn})
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("[resolver] failed to query DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String())
|
||||
r.log().WithError(err).Errorf("[resolver] failed to query DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String())
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -19,16 +19,22 @@ import (
|
|||
|
||||
// a simple/null address type that will be used to fake a local address for unit testing
|
||||
type tstaddr struct {
|
||||
network string
|
||||
}
|
||||
|
||||
func (a *tstaddr) Network() string { return "tcp" }
|
||||
func (a *tstaddr) Network() string {
|
||||
if a.network != "" {
|
||||
return a.network
|
||||
}
|
||||
return "tcp"
|
||||
}
|
||||
|
||||
func (a *tstaddr) String() string { return "127.0.0.1" }
|
||||
func (a *tstaddr) String() string { return "(fake)" }
|
||||
|
||||
// a simple writer that implements dns.ResponseWriter for unit testing purposes
|
||||
type tstwriter struct {
|
||||
localAddr net.Addr
|
||||
msg *dns.Msg
|
||||
network string
|
||||
msg *dns.Msg
|
||||
}
|
||||
|
||||
func (w *tstwriter) WriteMsg(m *dns.Msg) (err error) {
|
||||
|
@ -39,13 +45,12 @@ func (w *tstwriter) WriteMsg(m *dns.Msg) (err error) {
|
|||
func (w *tstwriter) Write(m []byte) (int, error) { return 0, nil }
|
||||
|
||||
func (w *tstwriter) LocalAddr() net.Addr {
|
||||
if w.localAddr != nil {
|
||||
return w.localAddr
|
||||
}
|
||||
return new(tstaddr)
|
||||
return &tstaddr{network: w.network}
|
||||
}
|
||||
|
||||
func (w *tstwriter) RemoteAddr() net.Addr { return new(tstaddr) }
|
||||
func (w *tstwriter) RemoteAddr() net.Addr {
|
||||
return &tstaddr{network: w.network}
|
||||
}
|
||||
|
||||
func (w *tstwriter) TsigStatus() error { return nil }
|
||||
|
||||
|
@ -372,15 +377,14 @@ func TestOversizedDNSReply(t *testing.T) {
|
|||
|
||||
srvAddr := srv.LocalAddr().(*net.UDPAddr)
|
||||
rsv := NewResolver("", true, noopDNSBackend{})
|
||||
// The resolver logs lots of valuable info at level debug. Redirect it
|
||||
// to t.Log() so the log spew is emitted only if the test fails.
|
||||
rsv.logger = testLogger(t)
|
||||
rsv.SetExtServers([]extDNSEntry{
|
||||
{IPStr: srvAddr.IP.String(), port: uint16(srvAddr.Port), HostLoopback: true},
|
||||
})
|
||||
|
||||
// The resolver logs lots of valuable info at level debug. Redirect it
|
||||
// to t.Log() so the log spew is emitted only if the test fails.
|
||||
defer redirectLogrusTo(t)()
|
||||
|
||||
w := &tstwriter{localAddr: srv.LocalAddr()}
|
||||
w := &tstwriter{network: srvAddr.Network()}
|
||||
q := new(dns.Msg).SetQuestion("s3.amazonaws.com.", dns.TypeA)
|
||||
rsv.serveDNS(w, q)
|
||||
resp := w.GetResponse()
|
||||
|
@ -391,14 +395,11 @@ func TestOversizedDNSReply(t *testing.T) {
|
|||
checkDNSRRType(t, resp.Answer[0].Header().Rrtype, dns.TypeA)
|
||||
}
|
||||
|
||||
func redirectLogrusTo(t *testing.T) func() {
|
||||
oldLevel, oldOut := logrus.StandardLogger().Level, logrus.StandardLogger().Out
|
||||
logrus.StandardLogger().SetLevel(logrus.DebugLevel)
|
||||
logrus.SetOutput(tlogWriter{t})
|
||||
return func() {
|
||||
logrus.StandardLogger().SetLevel(oldLevel)
|
||||
logrus.StandardLogger().SetOutput(oldOut)
|
||||
}
|
||||
func testLogger(t *testing.T) *logrus.Logger {
|
||||
logger := logrus.New()
|
||||
logger.SetLevel(logrus.DebugLevel)
|
||||
logger.SetOutput(tlogWriter{t})
|
||||
return logger
|
||||
}
|
||||
|
||||
type tlogWriter struct{ t *testing.T }
|
||||
|
@ -440,9 +441,8 @@ func TestReplySERVFAIL(t *testing.T) {
|
|||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer redirectLogrusTo(t)
|
||||
|
||||
rsv := NewResolver("", tt.proxyDNS, badSRVDNSBackend{})
|
||||
rsv.logger = testLogger(t)
|
||||
w := &tstwriter{}
|
||||
rsv.serveDNS(w, tt.q)
|
||||
resp := w.GetResponse()
|
||||
|
@ -494,6 +494,13 @@ func TestProxyNXDOMAIN(t *testing.T) {
|
|||
<-serveDone
|
||||
}()
|
||||
|
||||
// This test, by virtue of running a server and client in different
|
||||
// not-locked-to-thread goroutines, happens to be a good canary for
|
||||
// whether we are leaking unlocked OS threads set to the wrong network
|
||||
// namespace. Make a best-effort attempt to detect that situation so we
|
||||
// are not left chasing ghosts next time.
|
||||
testutils.AssertSocketSameNetNS(t, srv.PacketConn.(*net.UDPConn))
|
||||
|
||||
srvAddr := srv.PacketConn.LocalAddr().(*net.UDPAddr)
|
||||
rsv := NewResolver("", true, noopDNSBackend{})
|
||||
rsv.SetExtServers([]extDNSEntry{
|
||||
|
@ -502,9 +509,9 @@ func TestProxyNXDOMAIN(t *testing.T) {
|
|||
|
||||
// The resolver logs lots of valuable info at level debug. Redirect it
|
||||
// to t.Log() so the log spew is emitted only if the test fails.
|
||||
defer redirectLogrusTo(t)()
|
||||
rsv.logger = testLogger(t)
|
||||
|
||||
w := &tstwriter{localAddr: srv.PacketConn.LocalAddr()}
|
||||
w := &tstwriter{network: srvAddr.Network()}
|
||||
q := new(dns.Msg).SetQuestion("example.net.", dns.TypeA)
|
||||
rsv.serveDNS(w, q)
|
||||
resp := w.GetResponse()
|
||||
|
|
43
libnetwork/testutils/sanity_linux.go
Normal file
43
libnetwork/testutils/sanity_linux.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package testutils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/vishvananda/netns"
|
||||
"golang.org/x/sys/unix"
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
// AssertSocketSameNetNS makes a best-effort attempt to assert that conn is in
|
||||
// the same network namespace as the current goroutine's thread.
|
||||
func AssertSocketSameNetNS(t testing.TB, conn syscall.Conn) {
|
||||
t.Helper()
|
||||
|
||||
sc, err := conn.SyscallConn()
|
||||
assert.NilError(t, err)
|
||||
sc.Control(func(fd uintptr) {
|
||||
srvnsfd, err := unix.IoctlRetInt(int(fd), unix.SIOCGSKNS)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EPERM) {
|
||||
t.Log("Cannot determine socket's network namespace. Do we have CAP_NET_ADMIN?")
|
||||
return
|
||||
}
|
||||
if errors.Is(err, unix.ENOSYS) {
|
||||
t.Log("Cannot query socket's network namespace due to missing kernel support.")
|
||||
return
|
||||
}
|
||||
t.Fatal(err)
|
||||
}
|
||||
srvns := netns.NsHandle(srvnsfd)
|
||||
defer srvns.Close()
|
||||
|
||||
curns, err := netns.Get()
|
||||
assert.NilError(t, err)
|
||||
defer curns.Close()
|
||||
if !srvns.Equal(curns) {
|
||||
t.Fatalf("Socket is in network namespace %s, but test goroutine is in %s", srvns, curns)
|
||||
}
|
||||
})
|
||||
}
|
11
libnetwork/testutils/sanity_notlinux.go
Normal file
11
libnetwork/testutils/sanity_notlinux.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
//go:build !linux
|
||||
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// AssertSocketSameNetNS is a no-op on platforms other than Linux.
|
||||
func AssertSocketSameNetNS(t testing.TB, conn syscall.Conn) {}
|
Loading…
Reference in a new issue