diff --git a/libnetwork/osl/namespace_linux.go b/libnetwork/osl/namespace_linux.go index 6da6f95b01..9f22d80772 100644 --- a/libnetwork/osl/namespace_linux.go +++ b/libnetwork/osl/namespace_linux.go @@ -600,24 +600,29 @@ func (n *networkNamespace) checkLoV6() { } func setIPv6(nspath, iface string, enable bool) error { - origNS, err := netns.Get() - if err != nil { - return fmt.Errorf("failed to get current network namespace: %w", err) - } - defer origNS.Close() - - namespace, err := netns.GetFromPath(nspath) - if err != nil { - return fmt.Errorf("failed get network namespace %q: %w", nspath, err) - } - defer namespace.Close() - errCh := make(chan error, 1) go func() { defer close(errCh) + namespace, err := netns.GetFromPath(nspath) + if err != nil { + errCh <- fmt.Errorf("failed get network namespace %q: %w", nspath, err) + return + } + defer namespace.Close() + runtime.LockOSThread() + + origNS, err := netns.Get() + if err != nil { + runtime.UnlockOSThread() + errCh <- fmt.Errorf("failed to get current network namespace: %w", err) + return + } + defer origNS.Close() + if err = netns.Set(namespace); err != nil { + runtime.UnlockOSThread() errCh <- fmt.Errorf("setting into container netns %q failed: %w", nspath, err) return } diff --git a/libnetwork/resolver.go b/libnetwork/resolver.go index 1304e95df2..ab19b7b08f 100644 --- a/libnetwork/resolver.go +++ b/libnetwork/resolver.go @@ -71,6 +71,7 @@ type Resolver struct { listenAddress string proxyDNS bool startCh chan struct{} + logger *logrus.Logger fwdSem *semaphore.Weighted // Limit the number of concurrent external DNS requests in-flight logInverval rate.Sometimes // Rate-limit logging about hitting the fwdSem limit @@ -89,6 +90,13 @@ func NewResolver(address string, proxyDNS bool, backend DNSBackend) *Resolver { } } +func (r *Resolver) log() *logrus.Logger { + if r.logger == nil { + return logrus.StandardLogger() + } + return r.logger +} + // SetupFunc returns the setup function that should be run in the container's // network namespace. func (r *Resolver) SetupFunc(port int) func() { @@ -140,7 +148,7 @@ func (r *Resolver) Start() error { r.server = s go func() { if err := s.ActivateAndServe(); err != nil { - logrus.WithError(err).Error("[resolver] failed to start PacketConn DNS server") + r.log().WithError(err).Error("[resolver] failed to start PacketConn DNS server") } }() @@ -148,7 +156,7 @@ func (r *Resolver) Start() error { r.tcpServer = tcpServer go func() { if err := tcpServer.ActivateAndServe(); err != nil { - logrus.WithError(err).Error("[resolver] failed to start TCP DNS server") + r.log().WithError(err).Error("[resolver] failed to start TCP DNS server") } }() return nil @@ -249,7 +257,7 @@ func (r *Resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) { if addr == nil && ipv6Miss { // Send a reply without any Answer sections - logrus.Debugf("[resolver] lookup name %s present without IPv6 address", name) + r.log().Debugf("[resolver] lookup name %s present without IPv6 address", name) resp := createRespMsg(query) return resp, nil } @@ -257,7 +265,7 @@ func (r *Resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) { return nil, nil } - logrus.Debugf("[resolver] lookup for %s: IP %v", name, addr) + r.log().Debugf("[resolver] lookup for %s: IP %v", name, addr) resp := createRespMsg(query) if len(addr) > 1 { @@ -298,7 +306,7 @@ func (r *Resolver) handlePTRQuery(query *dns.Msg) (*dns.Msg, error) { return nil, nil } - logrus.Debugf("[resolver] lookup for IP %s: name %s", name, host) + r.log().Debugf("[resolver] lookup for IP %s: name %s", name, host) fqdn := dns.Fqdn(host) resp := new(dns.Msg) @@ -365,17 +373,17 @@ func (r *Resolver) serveDNS(w dns.ResponseWriter, query *dns.Msg) { case dns.TypeSRV: resp, err = r.handleSRVQuery(query) default: - logrus.Debugf("[resolver] query type %s is not supported by the embedded DNS and will be forwarded to external DNS", dns.TypeToString[queryType]) + r.log().Debugf("[resolver] query type %s is not supported by the embedded DNS and will be forwarded to external DNS", dns.TypeToString[queryType]) } reply := func(msg *dns.Msg) { if err = w.WriteMsg(msg); err != nil { - logrus.WithError(err).Errorf("[resolver] failed to write response") + r.log().WithError(err).Errorf("[resolver] failed to write response") } } if err != nil { - logrus.WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType]) + r.log().WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType]) reply(new(dns.Msg).SetRcode(query, dns.RcodeServerFailure)) return } @@ -460,7 +468,7 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg { cancel() if err != nil { r.logInverval.Do(func() { - logrus.Errorf("[resolver] more than %v concurrent queries", maxConcurrent) + r.log().Errorf("[resolver] more than %v concurrent queries", maxConcurrent) }) return new(dns.Msg).SetRcode(query, dns.RcodeRefused) } @@ -476,7 +484,7 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg { case dns.RcodeServerFailure, dns.RcodeRefused: // Server returned FAILURE: continue with the next external DNS server // Server returned REFUSED: this can be a transitional status, so continue with the next external DNS server - logrus.Debugf("[resolver] external DNS %s:%s returned failure:\n%s", proto, extDNS.IPStr, resp) + r.log().Debugf("[resolver] external DNS %s:%s returned failure:\n%s", proto, extDNS.IPStr, resp) continue } answers := 0 @@ -486,17 +494,17 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg { case dns.TypeA: answers++ ip := rr.(*dns.A).A - logrus.Debugf("[resolver] received A record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr) + r.log().Debugf("[resolver] received A record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr) r.backend.HandleQueryResp(h.Name, ip) case dns.TypeAAAA: answers++ ip := rr.(*dns.AAAA).AAAA - logrus.Debugf("[resolver] received AAAA record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr) + r.log().Debugf("[resolver] received AAAA record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr) r.backend.HandleQueryResp(h.Name, ip) } } if len(resp.Answer) == 0 { - logrus.Debugf("[resolver] external DNS %s:%s returned response with no answers:\n%s", proto, extDNS.IPStr, resp) + r.log().Debugf("[resolver] external DNS %s:%s returned response with no answers:\n%s", proto, extDNS.IPStr, resp) } resp.Compress = true return resp @@ -508,12 +516,12 @@ func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg { func (r *Resolver) exchange(proto string, extDNS extDNSEntry, query *dns.Msg) *dns.Msg { extConn, err := r.dialExtDNS(proto, extDNS) if err != nil { - logrus.WithError(err).Warn("[resolver] connect failed") + r.log().WithError(err).Warn("[resolver] connect failed") return nil } defer extConn.Close() - log := logrus.WithFields(logrus.Fields{ + log := r.log().WithFields(logrus.Fields{ "dns-server": extConn.RemoteAddr().Network() + ":" + extConn.RemoteAddr().String(), "client-addr": extConn.LocalAddr().Network() + ":" + extConn.LocalAddr().String(), "question": query.Question[0].String(), @@ -534,7 +542,7 @@ func (r *Resolver) exchange(proto string, extDNS extDNSEntry, query *dns.Msg) *d UDPSize: dns.MaxMsgSize, }).ExchangeWithConn(query, &dns.Conn{Conn: extConn}) if err != nil { - logrus.WithError(err).Errorf("[resolver] failed to query DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String()) + r.log().WithError(err).Errorf("[resolver] failed to query DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String()) return nil } diff --git a/libnetwork/resolver_test.go b/libnetwork/resolver_test.go index 4637a4298e..733e1992de 100644 --- a/libnetwork/resolver_test.go +++ b/libnetwork/resolver_test.go @@ -19,16 +19,22 @@ import ( // a simple/null address type that will be used to fake a local address for unit testing type tstaddr struct { + network string } -func (a *tstaddr) Network() string { return "tcp" } +func (a *tstaddr) Network() string { + if a.network != "" { + return a.network + } + return "tcp" +} -func (a *tstaddr) String() string { return "127.0.0.1" } +func (a *tstaddr) String() string { return "(fake)" } // a simple writer that implements dns.ResponseWriter for unit testing purposes type tstwriter struct { - localAddr net.Addr - msg *dns.Msg + network string + msg *dns.Msg } func (w *tstwriter) WriteMsg(m *dns.Msg) (err error) { @@ -39,13 +45,12 @@ func (w *tstwriter) WriteMsg(m *dns.Msg) (err error) { func (w *tstwriter) Write(m []byte) (int, error) { return 0, nil } func (w *tstwriter) LocalAddr() net.Addr { - if w.localAddr != nil { - return w.localAddr - } - return new(tstaddr) + return &tstaddr{network: w.network} } -func (w *tstwriter) RemoteAddr() net.Addr { return new(tstaddr) } +func (w *tstwriter) RemoteAddr() net.Addr { + return &tstaddr{network: w.network} +} func (w *tstwriter) TsigStatus() error { return nil } @@ -372,15 +377,14 @@ func TestOversizedDNSReply(t *testing.T) { srvAddr := srv.LocalAddr().(*net.UDPAddr) rsv := NewResolver("", true, noopDNSBackend{}) + // The resolver logs lots of valuable info at level debug. Redirect it + // to t.Log() so the log spew is emitted only if the test fails. + rsv.logger = testLogger(t) rsv.SetExtServers([]extDNSEntry{ {IPStr: srvAddr.IP.String(), port: uint16(srvAddr.Port), HostLoopback: true}, }) - // The resolver logs lots of valuable info at level debug. Redirect it - // to t.Log() so the log spew is emitted only if the test fails. - defer redirectLogrusTo(t)() - - w := &tstwriter{localAddr: srv.LocalAddr()} + w := &tstwriter{network: srvAddr.Network()} q := new(dns.Msg).SetQuestion("s3.amazonaws.com.", dns.TypeA) rsv.serveDNS(w, q) resp := w.GetResponse() @@ -391,14 +395,11 @@ func TestOversizedDNSReply(t *testing.T) { checkDNSRRType(t, resp.Answer[0].Header().Rrtype, dns.TypeA) } -func redirectLogrusTo(t *testing.T) func() { - oldLevel, oldOut := logrus.StandardLogger().Level, logrus.StandardLogger().Out - logrus.StandardLogger().SetLevel(logrus.DebugLevel) - logrus.SetOutput(tlogWriter{t}) - return func() { - logrus.StandardLogger().SetLevel(oldLevel) - logrus.StandardLogger().SetOutput(oldOut) - } +func testLogger(t *testing.T) *logrus.Logger { + logger := logrus.New() + logger.SetLevel(logrus.DebugLevel) + logger.SetOutput(tlogWriter{t}) + return logger } type tlogWriter struct{ t *testing.T } @@ -440,9 +441,8 @@ func TestReplySERVFAIL(t *testing.T) { } for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { - defer redirectLogrusTo(t) - rsv := NewResolver("", tt.proxyDNS, badSRVDNSBackend{}) + rsv.logger = testLogger(t) w := &tstwriter{} rsv.serveDNS(w, tt.q) resp := w.GetResponse() @@ -494,6 +494,13 @@ func TestProxyNXDOMAIN(t *testing.T) { <-serveDone }() + // This test, by virtue of running a server and client in different + // not-locked-to-thread goroutines, happens to be a good canary for + // whether we are leaking unlocked OS threads set to the wrong network + // namespace. Make a best-effort attempt to detect that situation so we + // are not left chasing ghosts next time. + testutils.AssertSocketSameNetNS(t, srv.PacketConn.(*net.UDPConn)) + srvAddr := srv.PacketConn.LocalAddr().(*net.UDPAddr) rsv := NewResolver("", true, noopDNSBackend{}) rsv.SetExtServers([]extDNSEntry{ @@ -502,9 +509,9 @@ func TestProxyNXDOMAIN(t *testing.T) { // The resolver logs lots of valuable info at level debug. Redirect it // to t.Log() so the log spew is emitted only if the test fails. - defer redirectLogrusTo(t)() + rsv.logger = testLogger(t) - w := &tstwriter{localAddr: srv.PacketConn.LocalAddr()} + w := &tstwriter{network: srvAddr.Network()} q := new(dns.Msg).SetQuestion("example.net.", dns.TypeA) rsv.serveDNS(w, q) resp := w.GetResponse() diff --git a/libnetwork/testutils/sanity_linux.go b/libnetwork/testutils/sanity_linux.go new file mode 100644 index 0000000000..8c85e1a896 --- /dev/null +++ b/libnetwork/testutils/sanity_linux.go @@ -0,0 +1,43 @@ +package testutils + +import ( + "errors" + "syscall" + "testing" + + "github.com/vishvananda/netns" + "golang.org/x/sys/unix" + "gotest.tools/v3/assert" +) + +// AssertSocketSameNetNS makes a best-effort attempt to assert that conn is in +// the same network namespace as the current goroutine's thread. +func AssertSocketSameNetNS(t testing.TB, conn syscall.Conn) { + t.Helper() + + sc, err := conn.SyscallConn() + assert.NilError(t, err) + sc.Control(func(fd uintptr) { + srvnsfd, err := unix.IoctlRetInt(int(fd), unix.SIOCGSKNS) + if err != nil { + if errors.Is(err, unix.EPERM) { + t.Log("Cannot determine socket's network namespace. Do we have CAP_NET_ADMIN?") + return + } + if errors.Is(err, unix.ENOSYS) { + t.Log("Cannot query socket's network namespace due to missing kernel support.") + return + } + t.Fatal(err) + } + srvns := netns.NsHandle(srvnsfd) + defer srvns.Close() + + curns, err := netns.Get() + assert.NilError(t, err) + defer curns.Close() + if !srvns.Equal(curns) { + t.Fatalf("Socket is in network namespace %s, but test goroutine is in %s", srvns, curns) + } + }) +} diff --git a/libnetwork/testutils/sanity_notlinux.go b/libnetwork/testutils/sanity_notlinux.go new file mode 100644 index 0000000000..ed58a6dbda --- /dev/null +++ b/libnetwork/testutils/sanity_notlinux.go @@ -0,0 +1,11 @@ +//go:build !linux + +package testutils + +import ( + "syscall" + "testing" +) + +// AssertSocketSameNetNS is a no-op on platforms other than Linux. +func AssertSocketSameNetNS(t testing.TB, conn syscall.Conn) {}