From 78792eae6842f021c38cebcc2348eb06cd08ec0e Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Fri, 6 Jan 2023 18:25:00 -0500 Subject: [PATCH 01/14] libnetwork: add regression test for issue 44575 Signed-off-by: Cory Snider --- libnetwork/resolver.go | 7 +- libnetwork/resolver_test.go | 127 +++++++++++++++++++++++++++++++++++- 2 files changed, 131 insertions(+), 3 deletions(-) diff --git a/libnetwork/resolver.go b/libnetwork/resolver.go index 41546ac541..6684d46f4b 100644 --- a/libnetwork/resolver.go +++ b/libnetwork/resolver.go @@ -73,6 +73,7 @@ const ( type extDNSEntry struct { IPStr string + port uint16 // for testing HostLoopback bool } @@ -451,7 +452,11 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { break } extConnect := func() { - addr := fmt.Sprintf("%s:%d", extDNS.IPStr, 53) + port := extDNS.port + if port == 0 { + port = 53 + } + addr := fmt.Sprintf("%s:%d", extDNS.IPStr, port) extConn, err = net.DialTimeout(proto, addr, extIOTimeout) } diff --git a/libnetwork/resolver_test.go b/libnetwork/resolver_test.go index e5bd477f9c..05bb9b1017 100644 --- a/libnetwork/resolver_test.go +++ b/libnetwork/resolver_test.go @@ -1,6 +1,8 @@ package libnetwork import ( + "encoding/hex" + "errors" "net" "runtime" "syscall" @@ -10,6 +12,7 @@ import ( "github.com/docker/docker/libnetwork/testutils" "github.com/miekg/dns" "github.com/sirupsen/logrus" + "gotest.tools/v3/assert" "gotest.tools/v3/skip" ) @@ -23,7 +26,8 @@ func (a *tstaddr) String() string { return "127.0.0.1" } // a simple writer that implements dns.ResponseWriter for unit testing purposes type tstwriter struct { - msg *dns.Msg + localAddr net.Addr + msg *dns.Msg } func (w *tstwriter) WriteMsg(m *dns.Msg) (err error) { @@ -33,7 +37,12 @@ func (w *tstwriter) WriteMsg(m *dns.Msg) (err error) { func (w *tstwriter) Write(m []byte) (int, error) { return 0, nil } -func (w *tstwriter) LocalAddr() net.Addr { return new(tstaddr) } +func (w *tstwriter) LocalAddr() net.Addr { + if w.localAddr != nil { + return w.localAddr + } + return new(tstaddr) +} func (w *tstwriter) RemoteAddr() net.Addr { return new(tstaddr) } @@ -285,3 +294,117 @@ func TestDNSProxyServFail(t *testing.T) { } t.Logf("Expected number of DNS requests generated") } + +// Packet 24 extracted from +// https://gist.github.com/vojtad/3bac63b8c91b1ec50e8d8b36047317fa/raw/7d75eb3d3448381bf252ae55ea5123a132c46658/host.pcap +// (https://github.com/moby/moby/issues/44575) +// which is a non-compliant DNS reply > 512B (w/o EDNS(0)) to the query +// +// s3.amazonaws.com. IN A +const oversizedDNSReplyMsg = "\xf5\x11\x81\x80\x00\x01\x00\x20\x00\x00\x00\x00\x02\x73\x33\x09" + + "\x61\x6d\x61\x7a\x6f\x6e\x61\x77\x73\x03\x63\x6f\x6d\x00\x00\x01" + + "\x00\x01\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" + + "\x11\x9e\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" + + "\x4c\x66\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" + + "\xda\x10\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" + + "\x01\x3e\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" + + "\x88\x68\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" + + "\x66\x9e\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" + + "\x5f\x28\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" + + "\x8e\x4e\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x36\xe7" + + "\x84\xf0\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x34\xd8" + + "\x92\x45\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" + + "\x8f\xa6\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x36\xe7" + + "\xc0\xd0\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" + + "\xfe\x28\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" + + "\xaa\x3d\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" + + "\x4e\x56\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" + + "\xea\xb0\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" + + "\x6d\xed\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" + + "\x28\x00\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x34\xd9" + + "\xe9\x78\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x34\xd9" + + "\x6e\x9e\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x34\xd9" + + "\x45\x86\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x34\xd8" + + "\x30\x38\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x36\xe7" + + "\xc6\xa8\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x03\x05" + + "\x01\x9d\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd9" + + "\xa8\xe8\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd9" + + "\x64\xa6\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd8" + + "\x3c\x48\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd8" + + "\x35\x20\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd9" + + "\x54\xf6\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd9" + + "\x5d\x36\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd9" + + "\x30\x36\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x36\xe7" + + "\x83\x90" + +// Regression test for https://github.com/moby/moby/issues/44575 +func TestOversizedDNSReply(t *testing.T) { + srv, err := net.ListenPacket("udp", "127.0.0.1:0") + assert.NilError(t, err) + defer srv.Close() + go func() { + buf := make([]byte, 65536) + for { + n, src, err := srv.ReadFrom(buf) + if errors.Is(err, net.ErrClosed) { + return + } + t.Logf("[<-%v]\n%s", src, hex.Dump(buf[:n])) + if n < 2 { + continue + } + resp := []byte(oversizedDNSReplyMsg) + resp[0], resp[1] = buf[0], buf[1] // Copy query ID into response. + _, err = srv.WriteTo(resp, src) + if errors.Is(err, net.ErrClosed) { + return + } + if err != nil { + t.Log(err) + } + } + }() + + srvAddr := srv.LocalAddr().(*net.UDPAddr) + rsv := NewResolver("", true, noopDNSBackend{}).(*resolver) + rsv.SetExtServers([]extDNSEntry{ + {IPStr: srvAddr.IP.String(), port: uint16(srvAddr.Port), HostLoopback: true}, + }) + + // The resolver logs lots of valuable info at level debug. Redirect it + // to t.Log() so the log spew is emitted only if the test fails. + oldLevel, oldOut := logrus.StandardLogger().Level, logrus.StandardLogger().Out + defer func() { + logrus.StandardLogger().SetLevel(oldLevel) + logrus.StandardLogger().SetOutput(oldOut) + }() + logrus.StandardLogger().SetLevel(logrus.DebugLevel) + logrus.SetOutput(tlogWriter{t}) + + w := &tstwriter{localAddr: srv.LocalAddr()} + q := new(dns.Msg).SetQuestion("s3.amazonaws.com.", dns.TypeA) + rsv.ServeDNS(w, q) + resp := w.GetResponse() + checkNonNullResponse(t, resp) + t.Log("Response: ", resp.String()) + checkDNSResponseCode(t, resp, dns.RcodeSuccess) + assert.Assert(t, len(resp.Answer) >= 1) + checkDNSRRType(t, resp.Answer[0].Header().Rrtype, dns.TypeA) +} + +type tlogWriter struct{ t *testing.T } + +func (w tlogWriter) Write(p []byte) (n int, err error) { + w.t.Logf("%s", p) + return len(p), nil +} + +type noopDNSBackend struct{ DNSBackend } + +func (noopDNSBackend) ResolveName(name string, iplen int) ([]net.IP, bool) { return nil, false } + +func (noopDNSBackend) ExecFunc(f func()) error { f(); return nil } + +func (noopDNSBackend) NdotsSet() bool { return false } + +func (noopDNSBackend) HandleQueryResp(name string, ip net.IP) {} From 92aa6e6282dd855d027597ba85ac30d05ff87fd9 Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Thu, 8 Dec 2022 16:33:45 -0500 Subject: [PATCH 02/14] libnetwork: extract fn for external DNS forwarding Signed-off-by: Cory Snider --- libnetwork/resolver.go | 242 +++++++++++++++++++++-------------------- 1 file changed, 126 insertions(+), 116 deletions(-) diff --git a/libnetwork/resolver.go b/libnetwork/resolver.go index 6684d46f4b..774fbfdc51 100644 --- a/libnetwork/resolver.go +++ b/libnetwork/resolver.go @@ -371,9 +371,8 @@ func truncateResp(resp *dns.Msg, maxSize int, isTCP bool) { func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { var ( - extConn net.Conn - resp *dns.Msg - err error + resp *dns.Msg + err error ) if query == nil || len(query.Question) == 0 { @@ -446,119 +445,7 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { truncateResp(resp, maxSize, proto == "tcp") } } else { - for i := 0; i < maxExtDNS; i++ { - extDNS := &r.extDNSList[i] - if extDNS.IPStr == "" { - break - } - extConnect := func() { - port := extDNS.port - if port == 0 { - port = 53 - } - addr := fmt.Sprintf("%s:%d", extDNS.IPStr, port) - extConn, err = net.DialTimeout(proto, addr, extIOTimeout) - } - - if extDNS.HostLoopback { - extConnect() - } else { - execErr := r.backend.ExecFunc(extConnect) - if execErr != nil { - logrus.Warn(execErr) - continue - } - } - if err != nil { - logrus.WithField("retries", i).Warnf("[resolver] connect failed: %s", err) - continue - } - logrus.Debugf("[resolver] query %s (%s) from %s, forwarding to %s:%s", queryName, dns.TypeToString[queryType], - extConn.LocalAddr().String(), proto, extDNS.IPStr) - - // Timeout has to be set for every IO operation. - if err := extConn.SetDeadline(time.Now().Add(extIOTimeout)); err != nil { - logrus.WithError(err).Error("[resolver] error setting conn deadline") - } - co := &dns.Conn{ - Conn: extConn, - UDPSize: uint16(maxSize), - } - defer co.Close() - - // limits the number of outstanding concurrent queries. - if !r.forwardQueryStart() { - old := r.tStamp - r.tStamp = time.Now() - if r.tStamp.Sub(old) > logInterval { - logrus.Errorf("[resolver] more than %v concurrent queries from %s", maxConcurrent, extConn.LocalAddr().String()) - } - continue - } - - err = co.WriteMsg(query) - if err != nil { - r.forwardQueryEnd() - logrus.Debugf("[resolver] send to DNS server failed, %s", err) - continue - } - - resp, err = co.ReadMsg() - // Truncated DNS replies should be sent to the client so that the - // client can retry over TCP - if err != nil && (resp == nil || !resp.Truncated) { - r.forwardQueryEnd() - logrus.WithError(err).Warnf("[resolver] failed to read from DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String()) - continue - } - r.forwardQueryEnd() - - if resp == nil { - logrus.Debugf("[resolver] external DNS %s:%s returned empty response for %q", proto, extDNS.IPStr, queryName) - break - } - switch resp.Rcode { - case dns.RcodeServerFailure, dns.RcodeRefused: - // Server returned FAILURE: continue with the next external DNS server - // Server returned REFUSED: this can be a transitional status, so continue with the next external DNS server - logrus.Debugf("[resolver] external DNS %s:%s responded with %s for %q", proto, extDNS.IPStr, statusString(resp.Rcode), queryName) - continue - case dns.RcodeNameError: - // Server returned NXDOMAIN. Stop resolution if it's an authoritative answer (see RFC 8020: https://tools.ietf.org/html/rfc8020#section-2) - logrus.Debugf("[resolver] external DNS %s:%s responded with %s for %q", proto, extDNS.IPStr, statusString(resp.Rcode), queryName) - if resp.Authoritative { - break - } - continue - case dns.RcodeSuccess: - // All is well - default: - // Server gave some error. Log the error, and continue with the next external DNS server - logrus.Debugf("[resolver] external DNS %s:%s responded with %s (code %d) for %q", proto, extDNS.IPStr, statusString(resp.Rcode), resp.Rcode, queryName) - continue - } - answers := 0 - for _, rr := range resp.Answer { - h := rr.Header() - switch h.Rrtype { - case dns.TypeA: - answers++ - ip := rr.(*dns.A).A - logrus.Debugf("[resolver] received A record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr) - r.backend.HandleQueryResp(h.Name, ip) - case dns.TypeAAAA: - answers++ - ip := rr.(*dns.AAAA).AAAA - logrus.Debugf("[resolver] received AAAA record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr) - r.backend.HandleQueryResp(h.Name, ip) - } - } - if resp.Answer == nil || answers == 0 { - logrus.Debugf("[resolver] external DNS %s:%s did not return any %s records for %q", proto, extDNS.IPStr, dns.TypeToString[queryType], queryName) - } - resp.Compress = true - break - } + resp = r.forwardExtDNS(proto, maxSize, query) if resp == nil { return } @@ -569,6 +456,129 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { } } +func (r *resolver) forwardExtDNS(proto string, maxSize int, query *dns.Msg) *dns.Msg { + queryName, queryType := query.Question[0].Name, query.Question[0].Qtype + var resp *dns.Msg + for i := 0; i < maxExtDNS; i++ { + extDNS := &r.extDNSList[i] + if extDNS.IPStr == "" { + break + } + var ( + extConn net.Conn + err error + ) + extConnect := func() { + port := extDNS.port + if port == 0 { + port = 53 + } + addr := fmt.Sprintf("%s:%d", extDNS.IPStr, port) + extConn, err = net.DialTimeout(proto, addr, extIOTimeout) + } + + if extDNS.HostLoopback { + extConnect() + } else { + execErr := r.backend.ExecFunc(extConnect) + if execErr != nil { + logrus.Warn(execErr) + continue + } + } + if err != nil { + logrus.WithField("retries", i).Warnf("[resolver] connect failed: %s", err) + continue + } + logrus.Debugf("[resolver] query %s (%s) from %s, forwarding to %s:%s", queryName, dns.TypeToString[queryType], + extConn.LocalAddr().String(), proto, extDNS.IPStr) + + // Timeout has to be set for every IO operation. + if err := extConn.SetDeadline(time.Now().Add(extIOTimeout)); err != nil { + logrus.WithError(err).Error("[resolver] error setting conn deadline") + } + co := &dns.Conn{ + Conn: extConn, + UDPSize: uint16(maxSize), + } + defer co.Close() + + // limits the number of outstanding concurrent queries. + if !r.forwardQueryStart() { + old := r.tStamp + r.tStamp = time.Now() + if r.tStamp.Sub(old) > logInterval { + logrus.Errorf("[resolver] more than %v concurrent queries from %s", maxConcurrent, extConn.LocalAddr().String()) + } + continue + } + + err = co.WriteMsg(query) + if err != nil { + r.forwardQueryEnd() + logrus.Debugf("[resolver] send to DNS server failed, %s", err) + continue + } + + resp, err = co.ReadMsg() + // Truncated DNS replies should be sent to the client so that the + // client can retry over TCP + if err != nil && (resp == nil || !resp.Truncated) { + r.forwardQueryEnd() + logrus.WithError(err).Warnf("[resolver] failed to read from DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String()) + continue + } + r.forwardQueryEnd() + + if resp == nil { + logrus.Debugf("[resolver] external DNS %s:%s returned empty response for %q", proto, extDNS.IPStr, queryName) + break + } + switch resp.Rcode { + case dns.RcodeServerFailure, dns.RcodeRefused: + // Server returned FAILURE: continue with the next external DNS server + // Server returned REFUSED: this can be a transitional status, so continue with the next external DNS server + logrus.Debugf("[resolver] external DNS %s:%s responded with %s for %q", proto, extDNS.IPStr, statusString(resp.Rcode), queryName) + continue + case dns.RcodeNameError: + // Server returned NXDOMAIN. Stop resolution if it's an authoritative answer (see RFC 8020: https://tools.ietf.org/html/rfc8020#section-2) + logrus.Debugf("[resolver] external DNS %s:%s responded with %s for %q", proto, extDNS.IPStr, statusString(resp.Rcode), queryName) + if resp.Authoritative { + break + } + continue + case dns.RcodeSuccess: + // All is well + default: + // Server gave some error. Log the error, and continue with the next external DNS server + logrus.Debugf("[resolver] external DNS %s:%s responded with %s (code %d) for %q", proto, extDNS.IPStr, statusString(resp.Rcode), resp.Rcode, queryName) + continue + } + answers := 0 + for _, rr := range resp.Answer { + h := rr.Header() + switch h.Rrtype { + case dns.TypeA: + answers++ + ip := rr.(*dns.A).A + logrus.Debugf("[resolver] received A record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr) + r.backend.HandleQueryResp(h.Name, ip) + case dns.TypeAAAA: + answers++ + ip := rr.(*dns.AAAA).AAAA + logrus.Debugf("[resolver] received AAAA record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr) + r.backend.HandleQueryResp(h.Name, ip) + } + } + if resp.Answer == nil || answers == 0 { + logrus.Debugf("[resolver] external DNS %s:%s did not return any %s records for %q", proto, extDNS.IPStr, dns.TypeToString[queryType], queryName) + } + resp.Compress = true + break + } + return resp +} + func statusString(responseCode int) string { if s, ok := dns.RcodeToString[responseCode]; ok { return s From 0bd30e90bb3ade9f1262237d3ebda331c1be8a23 Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Tue, 13 Dec 2022 18:54:43 -0500 Subject: [PATCH 03/14] libnetwork: reply SERVFAIL on resolve error ...instead of silently dropping the DNS query. Signed-off-by: Cory Snider --- libnetwork/resolver.go | 2 +- libnetwork/resolver_test.go | 41 ++++++++++++++++++++++++++++++------- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/libnetwork/resolver.go b/libnetwork/resolver.go index 774fbfdc51..8106c506f2 100644 --- a/libnetwork/resolver.go +++ b/libnetwork/resolver.go @@ -399,7 +399,7 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { if err != nil { logrus.WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType]) - return + resp = new(dns.Msg).SetRcode(query, dns.RcodeServerFailure) } if resp == nil { diff --git a/libnetwork/resolver_test.go b/libnetwork/resolver_test.go index 05bb9b1017..59471082e5 100644 --- a/libnetwork/resolver_test.go +++ b/libnetwork/resolver_test.go @@ -59,12 +59,14 @@ func (w *tstwriter) GetResponse() *dns.Msg { return w.msg } func (w *tstwriter) ClearResponse() { w.msg = nil } func checkNonNullResponse(t *testing.T, m *dns.Msg) { + t.Helper() if m == nil { t.Fatal("Null DNS response found. Non Null response msg expected.") } } func checkDNSAnswersCount(t *testing.T, m *dns.Msg, expected int) { + t.Helper() answers := len(m.Answer) if answers != expected { t.Fatalf("Expected number of answers in response: %d. Found: %d", expected, answers) @@ -72,12 +74,14 @@ func checkDNSAnswersCount(t *testing.T, m *dns.Msg, expected int) { } func checkDNSResponseCode(t *testing.T, m *dns.Msg, expected int) { + t.Helper() if m.MsgHdr.Rcode != expected { t.Fatalf("Expected DNS response code: %d. Found: %d", expected, m.MsgHdr.Rcode) } } func checkDNSRRType(t *testing.T, actual, expected uint16) { + t.Helper() if actual != expected { t.Fatalf("Expected DNS Rrtype: %d. Found: %d", expected, actual) } @@ -373,13 +377,7 @@ func TestOversizedDNSReply(t *testing.T) { // The resolver logs lots of valuable info at level debug. Redirect it // to t.Log() so the log spew is emitted only if the test fails. - oldLevel, oldOut := logrus.StandardLogger().Level, logrus.StandardLogger().Out - defer func() { - logrus.StandardLogger().SetLevel(oldLevel) - logrus.StandardLogger().SetOutput(oldOut) - }() - logrus.StandardLogger().SetLevel(logrus.DebugLevel) - logrus.SetOutput(tlogWriter{t}) + defer redirectLogrusTo(t)() w := &tstwriter{localAddr: srv.LocalAddr()} q := new(dns.Msg).SetQuestion("s3.amazonaws.com.", dns.TypeA) @@ -392,6 +390,16 @@ func TestOversizedDNSReply(t *testing.T) { checkDNSRRType(t, resp.Answer[0].Header().Rrtype, dns.TypeA) } +func redirectLogrusTo(t *testing.T) func() { + oldLevel, oldOut := logrus.StandardLogger().Level, logrus.StandardLogger().Out + logrus.StandardLogger().SetLevel(logrus.DebugLevel) + logrus.SetOutput(tlogWriter{t}) + return func() { + logrus.StandardLogger().SetLevel(oldLevel) + logrus.StandardLogger().SetOutput(oldOut) + } +} + type tlogWriter struct{ t *testing.T } func (w tlogWriter) Write(p []byte) (n int, err error) { @@ -408,3 +416,22 @@ func (noopDNSBackend) ExecFunc(f func()) error { f(); return nil } func (noopDNSBackend) NdotsSet() bool { return false } func (noopDNSBackend) HandleQueryResp(name string, ip net.IP) {} + +func TestReplySERVFAILOnInternalError(t *testing.T) { + defer redirectLogrusTo(t) + + rsv := NewResolver("", false, badSRVDNSBackend{}).(*resolver) + w := &tstwriter{} + q := new(dns.Msg).SetQuestion("_sip._tcp.example.com.", dns.TypeSRV) + rsv.ServeDNS(w, q) + resp := w.GetResponse() + checkNonNullResponse(t, resp) + t.Log("Response: ", resp.String()) + checkDNSResponseCode(t, resp, dns.RcodeServerFailure) +} + +type badSRVDNSBackend struct{ noopDNSBackend } + +func (badSRVDNSBackend) ResolveService(name string) ([]*net.SRV, []net.IP) { + return []*net.SRV{nil, nil, nil}, nil // Mismatched slice lengths +} From 8a35fb0d1c67468e3a4275500dcc420966d68d95 Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Tue, 13 Dec 2022 19:14:26 -0500 Subject: [PATCH 04/14] libnetwork: refactor ServeDNS for readability Signed-off-by: Cory Snider --- libnetwork/resolver.go | 44 +++++++++++++++--------------------------- 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/libnetwork/resolver.go b/libnetwork/resolver.go index 8106c506f2..a3db5fad64 100644 --- a/libnetwork/resolver.go +++ b/libnetwork/resolver.go @@ -402,30 +402,6 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { resp = new(dns.Msg).SetRcode(query, dns.RcodeServerFailure) } - if resp == nil { - // If the backend doesn't support proxying dns request - // fail the response - if !r.proxyDNS { - resp = new(dns.Msg) - resp.SetRcode(query, dns.RcodeServerFailure) - if err := w.WriteMsg(resp); err != nil { - logrus.WithError(err).Error("[resolver] error writing dns response") - } - return - } - - // If the user sets ndots > 0 explicitly and the query is - // in the root domain don't forward it out. We will return - // failure and let the client retry with the search domain - // attached - switch queryType { - case dns.TypeA, dns.TypeAAAA: - if r.backend.NdotsSet() && !strings.Contains(strings.TrimSuffix(queryName, "."), ".") { - resp = createRespMsg(query) - } - } - } - proto := w.LocalAddr().Network() maxSize := 0 if proto == "tcp" { @@ -444,11 +420,23 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { if resp.Len() > maxSize { truncateResp(resp, maxSize, proto == "tcp") } - } else { - resp = r.forwardExtDNS(proto, maxSize, query) - if resp == nil { - return + } else if r.proxyDNS { + // If the user sets ndots > 0 explicitly and the query is + // in the root domain don't forward it out. We will return + // failure and let the client retry with the search domain + // attached. + if (queryType == dns.TypeA || queryType == dns.TypeAAAA) && r.backend.NdotsSet() && + !strings.Contains(strings.TrimSuffix(queryName, "."), ".") { + resp = createRespMsg(query) + } else { + resp = r.forwardExtDNS(proto, maxSize, query) + if resp == nil { + return + } } + } else { + // The backend doesn't support proxying DNS requests. + resp = new(dns.Msg).SetRcode(query, dns.RcodeServerFailure) } if err = w.WriteMsg(resp); err != nil { From 860e83e52f0854b5c742b00995e13dc7a23ff325 Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Tue, 13 Dec 2022 19:36:41 -0500 Subject: [PATCH 05/14] libnetwork: get rid of truncation red herring The TC flag in a DNS message indicates that the sender had to truncate it to fit within the length limit of the transmission channel. It does NOT indicate that part of the message was lost before reaching the recipient. Older versions of github.com/miekg/dns conflated the two cases by returning ErrTruncated from ReadMsg() if the message was parsed without error but had the TC flag set. The version of miekg/dns currently vendored no longer returns an error when a well-formed DNS message is received which has its TC flag set, but there was some confusion on how to update libnetwork to deal with this behaviour change. Truncated DNS replies are no longer different from any other reply message: they are normal replies which do not need any special- case handling to proxy back to the client. Signed-off-by: Cory Snider --- libnetwork/resolver.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/libnetwork/resolver.go b/libnetwork/resolver.go index a3db5fad64..95d3cd3127 100644 --- a/libnetwork/resolver.go +++ b/libnetwork/resolver.go @@ -509,9 +509,7 @@ func (r *resolver) forwardExtDNS(proto string, maxSize int, query *dns.Msg) *dns } resp, err = co.ReadMsg() - // Truncated DNS replies should be sent to the client so that the - // client can retry over TCP - if err != nil && (resp == nil || !resp.Truncated) { + if err != nil { r.forwardQueryEnd() logrus.WithError(err).Warnf("[resolver] failed to read from DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String()) continue From 51cdd7ceac267974d4c2204655dbcd68b5e87d2c Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Thu, 15 Dec 2022 19:21:11 -0500 Subject: [PATCH 06/14] libnetwork: truncate DNS msgs using library method (*dns.Msg).Truncate() is more intelligent and standards-compliant about truncating DNS response messages than our hand-rolled version. Fix a silly fencepost error the max TCP message size: the limit is dns.MaxMsgSize (65535), full stop. Signed-off-by: Cory Snider --- libnetwork/resolver.go | 54 +++++++++++++----------------------------- 1 file changed, 16 insertions(+), 38 deletions(-) diff --git a/libnetwork/resolver.go b/libnetwork/resolver.go index 95d3cd3127..77f4095d69 100644 --- a/libnetwork/resolver.go +++ b/libnetwork/resolver.go @@ -60,15 +60,14 @@ type DNSBackend interface { } const ( - dnsPort = "53" - ptrIPv4domain = ".in-addr.arpa." - ptrIPv6domain = ".ip6.arpa." - respTTL = 600 - maxExtDNS = 3 // max number of external servers to try - extIOTimeout = 4 * time.Second - defaultRespSize = 512 - maxConcurrent = 1024 - logInterval = 2 * time.Second + dnsPort = "53" + ptrIPv4domain = ".in-addr.arpa." + ptrIPv6domain = ".ip6.arpa." + respTTL = 600 + maxExtDNS = 3 // max number of external servers to try + extIOTimeout = 4 * time.Second + maxConcurrent = 1024 + logInterval = 2 * time.Second ) type extDNSEntry struct { @@ -352,23 +351,6 @@ func (r *resolver) handleSRVQuery(query *dns.Msg) (*dns.Msg, error) { return resp, nil } -func truncateResp(resp *dns.Msg, maxSize int, isTCP bool) { - if !isTCP { - resp.Truncated = true - } - - srv := resp.Question[0].Qtype == dns.TypeSRV - // trim the Answer RRs one by one till the whole message fits - // within the reply size - for resp.Len() > maxSize { - resp.Answer = resp.Answer[:len(resp.Answer)-1] - - if srv && len(resp.Extra) > 0 { - resp.Extra = resp.Extra[:len(resp.Extra)-1] - } - } -} - func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { var ( resp *dns.Msg @@ -403,23 +385,19 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { } proto := w.LocalAddr().Network() - maxSize := 0 + maxSize := dns.MinMsgSize if proto == "tcp" { - maxSize = dns.MaxMsgSize - 1 - } else if proto == "udp" { - optRR := query.IsEdns0() - if optRR != nil { - maxSize = int(optRR.UDPSize()) - } - if maxSize < defaultRespSize { - maxSize = defaultRespSize + maxSize = dns.MaxMsgSize + } else { + if optRR := query.IsEdns0(); optRR != nil { + if udpsize := int(optRR.UDPSize()); udpsize > maxSize { + maxSize = udpsize + } } } if resp != nil { - if resp.Len() > maxSize { - truncateResp(resp, maxSize, proto == "tcp") - } + resp.Truncate(maxSize) } else if r.proxyDNS { // If the user sets ndots > 0 explicitly and the query is // in the root domain don't forward it out. We will return From 854ec3ffb34ffc4b53c18f6a4fe62107e0b9b53a Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Fri, 16 Dec 2022 13:28:07 -0500 Subject: [PATCH 07/14] libnetwork: extract dialExtDNS to method Signed-off-by: Cory Snider --- libnetwork/resolver.go | 56 +++++++++++++++++++++++------------------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/libnetwork/resolver.go b/libnetwork/resolver.go index 77f4095d69..1f99e92023 100644 --- a/libnetwork/resolver.go +++ b/libnetwork/resolver.go @@ -422,38 +422,44 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { } } +func (r *resolver) dialExtDNS(proto string, server extDNSEntry) (net.Conn, error) { + var ( + extConn net.Conn + dialErr error + ) + extConnect := func() { + if server.port == 0 { + server.port = 53 + } + addr := fmt.Sprintf("%s:%d", server.IPStr, server.port) + extConn, dialErr = net.DialTimeout(proto, addr, extIOTimeout) + } + + if server.HostLoopback { + extConnect() + } else { + execErr := r.backend.ExecFunc(extConnect) + if execErr != nil { + return nil, execErr + } + } + if dialErr != nil { + return nil, dialErr + } + + return extConn, nil +} + func (r *resolver) forwardExtDNS(proto string, maxSize int, query *dns.Msg) *dns.Msg { queryName, queryType := query.Question[0].Name, query.Question[0].Qtype var resp *dns.Msg - for i := 0; i < maxExtDNS; i++ { - extDNS := &r.extDNSList[i] + for i, extDNS := range r.extDNSList { if extDNS.IPStr == "" { break } - var ( - extConn net.Conn - err error - ) - extConnect := func() { - port := extDNS.port - if port == 0 { - port = 53 - } - addr := fmt.Sprintf("%s:%d", extDNS.IPStr, port) - extConn, err = net.DialTimeout(proto, addr, extIOTimeout) - } - - if extDNS.HostLoopback { - extConnect() - } else { - execErr := r.backend.ExecFunc(extConnect) - if execErr != nil { - logrus.Warn(execErr) - continue - } - } + extConn, err := r.dialExtDNS(proto, extDNS) if err != nil { - logrus.WithField("retries", i).Warnf("[resolver] connect failed: %s", err) + logrus.WithField("retries", i).WithError(err).Warn("[resolver] connect failed") continue } logrus.Debugf("[resolver] query %s (%s) from %s, forwarding to %s:%s", queryName, dns.TypeToString[queryType], From 9cf8c4f68955457f15288701eb59da045349b214 Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Fri, 16 Dec 2022 14:32:45 -0500 Subject: [PATCH 08/14] libnetwork: extract DNS client exchange to method forwardExtDNS() will now continue with the next external DNS sever if co.ReadMsg() returns (nil, nil). Previously it would abort resolving the query and not reply to the container client. The implementation of ReadMsg() in the currently- vendored version of miekg/dns cannot return (nil, nil) so the difference is immaterial in practice. Signed-off-by: Cory Snider --- libnetwork/resolver.go | 83 +++++++++++++++++++++++------------------- 1 file changed, 45 insertions(+), 38 deletions(-) diff --git a/libnetwork/resolver.go b/libnetwork/resolver.go index 1f99e92023..e5ba96ca7e 100644 --- a/libnetwork/resolver.go +++ b/libnetwork/resolver.go @@ -452,58 +452,26 @@ func (r *resolver) dialExtDNS(proto string, server extDNSEntry) (net.Conn, error func (r *resolver) forwardExtDNS(proto string, maxSize int, query *dns.Msg) *dns.Msg { queryName, queryType := query.Question[0].Name, query.Question[0].Qtype - var resp *dns.Msg - for i, extDNS := range r.extDNSList { + for _, extDNS := range r.extDNSList { if extDNS.IPStr == "" { break } - extConn, err := r.dialExtDNS(proto, extDNS) - if err != nil { - logrus.WithField("retries", i).WithError(err).Warn("[resolver] connect failed") - continue - } - logrus.Debugf("[resolver] query %s (%s) from %s, forwarding to %s:%s", queryName, dns.TypeToString[queryType], - extConn.LocalAddr().String(), proto, extDNS.IPStr) - - // Timeout has to be set for every IO operation. - if err := extConn.SetDeadline(time.Now().Add(extIOTimeout)); err != nil { - logrus.WithError(err).Error("[resolver] error setting conn deadline") - } - co := &dns.Conn{ - Conn: extConn, - UDPSize: uint16(maxSize), - } - defer co.Close() // limits the number of outstanding concurrent queries. if !r.forwardQueryStart() { old := r.tStamp r.tStamp = time.Now() if r.tStamp.Sub(old) > logInterval { - logrus.Errorf("[resolver] more than %v concurrent queries from %s", maxConcurrent, extConn.LocalAddr().String()) + logrus.Errorf("[resolver] more than %v concurrent queries", maxConcurrent) } continue } - - err = co.WriteMsg(query) - if err != nil { - r.forwardQueryEnd() - logrus.Debugf("[resolver] send to DNS server failed, %s", err) - continue - } - - resp, err = co.ReadMsg() - if err != nil { - r.forwardQueryEnd() - logrus.WithError(err).Warnf("[resolver] failed to read from DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String()) - continue - } + resp := r.exchange(proto, extDNS, maxSize, query) r.forwardQueryEnd() - if resp == nil { - logrus.Debugf("[resolver] external DNS %s:%s returned empty response for %q", proto, extDNS.IPStr, queryName) - break + continue } + switch resp.Rcode { case dns.RcodeServerFailure, dns.RcodeRefused: // Server returned FAILURE: continue with the next external DNS server @@ -544,7 +512,46 @@ func (r *resolver) forwardExtDNS(proto string, maxSize int, query *dns.Msg) *dns logrus.Debugf("[resolver] external DNS %s:%s did not return any %s records for %q", proto, extDNS.IPStr, dns.TypeToString[queryType], queryName) } resp.Compress = true - break + return resp + } + + return nil +} + +func (r *resolver) exchange(proto string, extDNS extDNSEntry, maxSize int, query *dns.Msg) *dns.Msg { + queryName, queryType := query.Question[0].Name, query.Question[0].Qtype + extConn, err := r.dialExtDNS(proto, extDNS) + if err != nil { + logrus.WithError(err).Warn("[resolver] connect failed") + return nil + } + logrus.Debugf("[resolver] query %s (%s) from %s, forwarding to %s:%s", queryName, dns.TypeToString[queryType], + extConn.LocalAddr().String(), proto, extDNS.IPStr) + + // Timeout has to be set for every IO operation. + if err := extConn.SetDeadline(time.Now().Add(extIOTimeout)); err != nil { + logrus.WithError(err).Error("[resolver] error setting conn deadline") + } + co := &dns.Conn{ + Conn: extConn, + UDPSize: uint16(maxSize), + } + defer co.Close() + + err = co.WriteMsg(query) + if err != nil { + logrus.Debugf("[resolver] send to DNS server failed, %s", err) + return nil + } + + resp, err := co.ReadMsg() + if err != nil { + logrus.WithError(err).Warnf("[resolver] failed to read from DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String()) + return nil + } + + if resp == nil { + logrus.Debugf("[resolver] external DNS %s:%s returned empty response for %q", proto, extDNS.IPStr, queryName) } return resp } From e6258e65906434569e9d4a54ca0a80fcf40a5bda Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Fri, 16 Dec 2022 15:03:55 -0500 Subject: [PATCH 09/14] libnetwork: reply SERVFAIL if DNS forwarding fails Fixes moby/moby issue 44575 Signed-off-by: Cory Snider --- libnetwork/resolver.go | 11 +++++----- libnetwork/resolver_test.go | 42 ++++++++++++++++++++++++++++--------- 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/libnetwork/resolver.go b/libnetwork/resolver.go index e5ba96ca7e..39861d9b94 100644 --- a/libnetwork/resolver.go +++ b/libnetwork/resolver.go @@ -408,15 +408,14 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { resp = createRespMsg(query) } else { resp = r.forwardExtDNS(proto, maxSize, query) - if resp == nil { - return - } } - } else { - // The backend doesn't support proxying DNS requests. - resp = new(dns.Msg).SetRcode(query, dns.RcodeServerFailure) } + if resp == nil { + // We were unable to get an answer from any of the upstream DNS + // servers or the backend doesn't support proxying DNS requests. + resp = new(dns.Msg).SetRcode(query, dns.RcodeServerFailure) + } if err = w.WriteMsg(resp); err != nil { logrus.WithError(err).Errorf("[resolver] failed to write response") } diff --git a/libnetwork/resolver_test.go b/libnetwork/resolver_test.go index 59471082e5..8cd28a0c8a 100644 --- a/libnetwork/resolver_test.go +++ b/libnetwork/resolver_test.go @@ -417,17 +417,39 @@ func (noopDNSBackend) NdotsSet() bool { return false } func (noopDNSBackend) HandleQueryResp(name string, ip net.IP) {} -func TestReplySERVFAILOnInternalError(t *testing.T) { - defer redirectLogrusTo(t) +func TestReplySERVFAIL(t *testing.T) { + cases := []struct { + name string + q *dns.Msg + proxyDNS bool + }{ + { + name: "InternalError", + q: new(dns.Msg).SetQuestion("_sip._tcp.example.com.", dns.TypeSRV), + }, + { + name: "ProxyDNS=false", + q: new(dns.Msg).SetQuestion("example.com.", dns.TypeA), + }, + { + name: "ProxyDNS=true", // No extDNS servers configured -> no answer from any upstream + q: new(dns.Msg).SetQuestion("example.com.", dns.TypeA), + proxyDNS: true, + }, + } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + defer redirectLogrusTo(t) - rsv := NewResolver("", false, badSRVDNSBackend{}).(*resolver) - w := &tstwriter{} - q := new(dns.Msg).SetQuestion("_sip._tcp.example.com.", dns.TypeSRV) - rsv.ServeDNS(w, q) - resp := w.GetResponse() - checkNonNullResponse(t, resp) - t.Log("Response: ", resp.String()) - checkDNSResponseCode(t, resp, dns.RcodeServerFailure) + rsv := NewResolver("", tt.proxyDNS, badSRVDNSBackend{}).(*resolver) + w := &tstwriter{} + rsv.ServeDNS(w, tt.q) + resp := w.GetResponse() + checkNonNullResponse(t, resp) + t.Log("Response: ", resp.String()) + checkDNSResponseCode(t, resp, dns.RcodeServerFailure) + }) + } } type badSRVDNSBackend struct{ noopDNSBackend } From a1f7c644be5661627227b029f779ddb25d00eee6 Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Fri, 16 Dec 2022 15:35:33 -0500 Subject: [PATCH 10/14] libnetwork: use dns.Client for forwarded requests It handles figuring out the UDP receive buffer size and setting IO timeouts, which simplifies our code. It is also more robust to receiving UDP replies to earlier queries which timed out. Log failures to perform a client exchange at level error so they are more visible to operators and administrators. Signed-off-by: Cory Snider --- libnetwork/resolver.go | 98 ++++++++++++++++++++++++------------------ 1 file changed, 55 insertions(+), 43 deletions(-) diff --git a/libnetwork/resolver.go b/libnetwork/resolver.go index 39861d9b94..5b05160f15 100644 --- a/libnetwork/resolver.go +++ b/libnetwork/resolver.go @@ -379,26 +379,38 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { logrus.Debugf("[resolver] query type %s is not supported by the embedded DNS and will be forwarded to external DNS", dns.TypeToString[queryType]) } - if err != nil { - logrus.WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType]) - resp = new(dns.Msg).SetRcode(query, dns.RcodeServerFailure) - } - - proto := w.LocalAddr().Network() - maxSize := dns.MinMsgSize - if proto == "tcp" { - maxSize = dns.MaxMsgSize - } else { - if optRR := query.IsEdns0(); optRR != nil { - if udpsize := int(optRR.UDPSize()); udpsize > maxSize { - maxSize = udpsize - } + reply := func(msg *dns.Msg) { + if err = w.WriteMsg(msg); err != nil { + logrus.WithError(err).Errorf("[resolver] failed to write response") } } + if err != nil { + logrus.WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType]) + reply(new(dns.Msg).SetRcode(query, dns.RcodeServerFailure)) + return + } + if resp != nil { + // We are the authoritative DNS server for this request so it's + // on us to truncate the response message to the size limit + // negotiated by the client. + maxSize := dns.MinMsgSize + if w.LocalAddr().Network() == "tcp" { + maxSize = dns.MaxMsgSize + } else { + if optRR := query.IsEdns0(); optRR != nil { + if udpsize := int(optRR.UDPSize()); udpsize > maxSize { + maxSize = udpsize + } + } + } resp.Truncate(maxSize) - } else if r.proxyDNS { + reply(resp) + return + } + + if r.proxyDNS { // If the user sets ndots > 0 explicitly and the query is // in the root domain don't forward it out. We will return // failure and let the client retry with the search domain @@ -407,7 +419,7 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { !strings.Contains(strings.TrimSuffix(queryName, "."), ".") { resp = createRespMsg(query) } else { - resp = r.forwardExtDNS(proto, maxSize, query) + resp = r.forwardExtDNS(w.LocalAddr().Network(), query) } } @@ -416,9 +428,7 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { // servers or the backend doesn't support proxying DNS requests. resp = new(dns.Msg).SetRcode(query, dns.RcodeServerFailure) } - if err = w.WriteMsg(resp); err != nil { - logrus.WithError(err).Errorf("[resolver] failed to write response") - } + reply(resp) } func (r *resolver) dialExtDNS(proto string, server extDNSEntry) (net.Conn, error) { @@ -449,7 +459,7 @@ func (r *resolver) dialExtDNS(proto string, server extDNSEntry) (net.Conn, error return extConn, nil } -func (r *resolver) forwardExtDNS(proto string, maxSize int, query *dns.Msg) *dns.Msg { +func (r *resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg { queryName, queryType := query.Question[0].Name, query.Question[0].Qtype for _, extDNS := range r.extDNSList { if extDNS.IPStr == "" { @@ -465,7 +475,7 @@ func (r *resolver) forwardExtDNS(proto string, maxSize int, query *dns.Msg) *dns } continue } - resp := r.exchange(proto, extDNS, maxSize, query) + resp := r.exchange(proto, extDNS, query) r.forwardQueryEnd() if resp == nil { continue @@ -517,40 +527,42 @@ func (r *resolver) forwardExtDNS(proto string, maxSize int, query *dns.Msg) *dns return nil } -func (r *resolver) exchange(proto string, extDNS extDNSEntry, maxSize int, query *dns.Msg) *dns.Msg { - queryName, queryType := query.Question[0].Name, query.Question[0].Qtype +func (r *resolver) exchange(proto string, extDNS extDNSEntry, query *dns.Msg) *dns.Msg { extConn, err := r.dialExtDNS(proto, extDNS) if err != nil { logrus.WithError(err).Warn("[resolver] connect failed") return nil } - logrus.Debugf("[resolver] query %s (%s) from %s, forwarding to %s:%s", queryName, dns.TypeToString[queryType], - extConn.LocalAddr().String(), proto, extDNS.IPStr) + defer extConn.Close() - // Timeout has to be set for every IO operation. - if err := extConn.SetDeadline(time.Now().Add(extIOTimeout)); err != nil { - logrus.WithError(err).Error("[resolver] error setting conn deadline") - } - co := &dns.Conn{ - Conn: extConn, - UDPSize: uint16(maxSize), - } - defer co.Close() + log := logrus.WithFields(logrus.Fields{ + "dns-server": extConn.RemoteAddr().Network() + ":" + extConn.RemoteAddr().String(), + "client-addr": extConn.LocalAddr().Network() + ":" + extConn.LocalAddr().String(), + "question": query.Question[0].String(), + }) + log.Debug("[resolver] forwarding query") - err = co.WriteMsg(query) + resp, _, err := (&dns.Client{ + Timeout: extIOTimeout, + // Following the robustness principle, make a best-effort + // attempt to receive oversized response messages without + // truncating them on our end to forward verbatim to the client. + // Some DNS servers (e.g. Mikrotik RouterOS) don't support + // EDNS(0) and may send replies over UDP longer than 512 bytes + // regardless of what size limit, if any, was advertized in the + // query message. Note that ExchangeWithConn will override this + // value if it detects an EDNS OPT record in query so only + // oversized replies to non-EDNS queries will benefit. + UDPSize: dns.MaxMsgSize, + }).ExchangeWithConn(query, &dns.Conn{Conn: extConn}) if err != nil { - logrus.Debugf("[resolver] send to DNS server failed, %s", err) - return nil - } - - resp, err := co.ReadMsg() - if err != nil { - logrus.WithError(err).Warnf("[resolver] failed to read from DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String()) + logrus.WithError(err).Errorf("[resolver] failed to query DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String()) return nil } if resp == nil { - logrus.Debugf("[resolver] external DNS %s:%s returned empty response for %q", proto, extDNS.IPStr, queryName) + // Should be impossible, so make noise if it happens anyway. + log.Error("[resolver] external DNS returned empty response") } return resp } From 25b51cad3d4af7c2e24bcb038117649f2536b698 Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Fri, 16 Dec 2022 16:26:06 -0500 Subject: [PATCH 11/14] libnetwork: replace ad-hoc semaphore implementation ...for limiting concurrent external DNS requests with "golang.org/x/sync/semaphore".Weighted. Replace the ad-hoc rate limiter for when the concurrency limit is hit (which contains a data-race bug) with "golang.org/x/time/rate".Sometimes. Immediately retrying with the next server if the concurrency limit has been hit just further compounds the problem. Wait on the semaphore and refuse the query if it could not be acquired in a reasonable amount of time. Signed-off-by: Cory Snider --- libnetwork/resolver.go | 57 +++++++----------- vendor.mod | 2 +- vendor.sum | 4 +- vendor/golang.org/x/time/rate/rate.go | 20 +++---- vendor/golang.org/x/time/rate/sometimes.go | 67 ++++++++++++++++++++++ vendor/modules.txt | 2 +- 6 files changed, 99 insertions(+), 53 deletions(-) create mode 100644 vendor/golang.org/x/time/rate/sometimes.go diff --git a/libnetwork/resolver.go b/libnetwork/resolver.go index 5b05160f15..ecd1d198d9 100644 --- a/libnetwork/resolver.go +++ b/libnetwork/resolver.go @@ -1,6 +1,7 @@ package libnetwork import ( + "context" "fmt" "math/rand" "net" @@ -11,6 +12,8 @@ import ( "github.com/docker/docker/libnetwork/types" "github.com/miekg/dns" "github.com/sirupsen/logrus" + "golang.org/x/sync/semaphore" + "golang.org/x/time/rate" ) // Resolver represents the embedded DNS server in Docker. It operates @@ -85,12 +88,12 @@ type resolver struct { tcpServer *dns.Server tcpListen *net.TCPListener err error - count int32 - tStamp time.Time - queryLock sync.Mutex listenAddress string proxyDNS bool startCh chan struct{} + + fwdSem *semaphore.Weighted // Limit the number of concurrent external DNS requests in-flight + logInverval rate.Sometimes // Rate-limit logging about hitting the fwdSem limit } // NewResolver creates a new instance of the Resolver @@ -101,6 +104,8 @@ func NewResolver(address string, proxyDNS bool, backend DNSBackend) Resolver { listenAddress: address, err: fmt.Errorf("setup not done yet"), startCh: make(chan struct{}, 1), + fwdSem: semaphore.NewWeighted(maxConcurrent), + logInverval: rate.Sometimes{Interval: logInterval}, } } @@ -179,9 +184,7 @@ func (r *resolver) Stop() { r.conn = nil r.tcpServer = nil r.err = fmt.Errorf("setup not done yet") - r.tStamp = time.Time{} - r.count = 0 - r.queryLock = sync.Mutex{} + r.fwdSem = semaphore.NewWeighted(maxConcurrent) } func (r *resolver) SetExtServers(extDNS []extDNSEntry) { @@ -467,16 +470,19 @@ func (r *resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg { } // limits the number of outstanding concurrent queries. - if !r.forwardQueryStart() { - old := r.tStamp - r.tStamp = time.Now() - if r.tStamp.Sub(old) > logInterval { + ctx, cancel := context.WithTimeout(context.Background(), extIOTimeout) + err := r.fwdSem.Acquire(ctx, 1) + cancel() + if err != nil { + r.logInverval.Do(func() { logrus.Errorf("[resolver] more than %v concurrent queries", maxConcurrent) - } - continue + }) + return new(dns.Msg).SetRcode(query, dns.RcodeRefused) } - resp := r.exchange(proto, extDNS, query) - r.forwardQueryEnd() + resp := func() *dns.Msg { + defer r.fwdSem.Release(1) + return r.exchange(proto, extDNS, query) + }() if resp == nil { continue } @@ -573,26 +579,3 @@ func statusString(responseCode int) string { } return "UNKNOWN" } - -func (r *resolver) forwardQueryStart() bool { - r.queryLock.Lock() - defer r.queryLock.Unlock() - - if r.count == maxConcurrent { - return false - } - r.count++ - - return true -} - -func (r *resolver) forwardQueryEnd() { - r.queryLock.Lock() - defer r.queryLock.Unlock() - - if r.count == 0 { - logrus.Error("[resolver] invalid concurrent query count") - } else { - r.count-- - } -} diff --git a/vendor.mod b/vendor.mod index 522f7121c0..9bd28d3acb 100644 --- a/vendor.mod +++ b/vendor.mod @@ -91,7 +91,7 @@ require ( golang.org/x/sync v0.1.0 golang.org/x/sys v0.5.0 golang.org/x/text v0.7.0 - golang.org/x/time v0.1.0 + golang.org/x/time v0.3.0 google.golang.org/genproto v0.0.0-20220706185917-7780775163c4 google.golang.org/grpc v1.50.1 gotest.tools/v3 v3.4.0 diff --git a/vendor.sum b/vendor.sum index abbf46f2e3..8ba921d1ec 100644 --- a/vendor.sum +++ b/vendor.sum @@ -1422,8 +1422,8 @@ golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.1.0 h1:xYY+Bajn2a7VBmTM5GikTmnK8ZuX8YgnQCqZpbBNtmA= -golang.org/x/time v0.1.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181011042414-1f849cf54d09/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/vendor/golang.org/x/time/rate/rate.go b/vendor/golang.org/x/time/rate/rate.go index 8f7c29f156..f0e0cf3cb1 100644 --- a/vendor/golang.org/x/time/rate/rate.go +++ b/vendor/golang.org/x/time/rate/rate.go @@ -83,7 +83,7 @@ func (lim *Limiter) Burst() int { // TokensAt returns the number of tokens available at time t. func (lim *Limiter) TokensAt(t time.Time) float64 { lim.mu.Lock() - _, _, tokens := lim.advance(t) // does not mutute lim + _, tokens := lim.advance(t) // does not mutate lim lim.mu.Unlock() return tokens } @@ -183,7 +183,7 @@ func (r *Reservation) CancelAt(t time.Time) { return } // advance time to now - t, _, tokens := r.lim.advance(t) + t, tokens := r.lim.advance(t) // calculate new number of tokens tokens += restoreTokens if burst := float64(r.lim.burst); tokens > burst { @@ -304,7 +304,7 @@ func (lim *Limiter) SetLimitAt(t time.Time, newLimit Limit) { lim.mu.Lock() defer lim.mu.Unlock() - t, _, tokens := lim.advance(t) + t, tokens := lim.advance(t) lim.last = t lim.tokens = tokens @@ -321,7 +321,7 @@ func (lim *Limiter) SetBurstAt(t time.Time, newBurst int) { lim.mu.Lock() defer lim.mu.Unlock() - t, _, tokens := lim.advance(t) + t, tokens := lim.advance(t) lim.last = t lim.tokens = tokens @@ -356,7 +356,7 @@ func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration) } } - t, last, tokens := lim.advance(t) + t, tokens := lim.advance(t) // Calculate the remaining number of tokens resulting from the request. tokens -= float64(n) @@ -379,15 +379,11 @@ func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration) if ok { r.tokens = n r.timeToAct = t.Add(waitDuration) - } - // Update state - if ok { + // Update state lim.last = t lim.tokens = tokens lim.lastEvent = r.timeToAct - } else { - lim.last = last } return r @@ -396,7 +392,7 @@ func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration) // advance calculates and returns an updated state for lim resulting from the passage of time. // lim is not changed. // advance requires that lim.mu is held. -func (lim *Limiter) advance(t time.Time) (newT time.Time, newLast time.Time, newTokens float64) { +func (lim *Limiter) advance(t time.Time) (newT time.Time, newTokens float64) { last := lim.last if t.Before(last) { last = t @@ -409,7 +405,7 @@ func (lim *Limiter) advance(t time.Time) (newT time.Time, newLast time.Time, new if burst := float64(lim.burst); tokens > burst { tokens = burst } - return t, last, tokens + return t, tokens } // durationFromTokens is a unit conversion function from the number of tokens to the duration diff --git a/vendor/golang.org/x/time/rate/sometimes.go b/vendor/golang.org/x/time/rate/sometimes.go new file mode 100644 index 0000000000..6ba99ddb67 --- /dev/null +++ b/vendor/golang.org/x/time/rate/sometimes.go @@ -0,0 +1,67 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rate + +import ( + "sync" + "time" +) + +// Sometimes will perform an action occasionally. The First, Every, and +// Interval fields govern the behavior of Do, which performs the action. +// A zero Sometimes value will perform an action exactly once. +// +// # Example: logging with rate limiting +// +// var sometimes = rate.Sometimes{First: 3, Interval: 10*time.Second} +// func Spammy() { +// sometimes.Do(func() { log.Info("here I am!") }) +// } +type Sometimes struct { + First int // if non-zero, the first N calls to Do will run f. + Every int // if non-zero, every Nth call to Do will run f. + Interval time.Duration // if non-zero and Interval has elapsed since f's last run, Do will run f. + + mu sync.Mutex + count int // number of Do calls + last time.Time // last time f was run +} + +// Do runs the function f as allowed by First, Every, and Interval. +// +// The model is a union (not intersection) of filters. The first call to Do +// always runs f. Subsequent calls to Do run f if allowed by First or Every or +// Interval. +// +// A non-zero First:N causes the first N Do(f) calls to run f. +// +// A non-zero Every:M causes every Mth Do(f) call, starting with the first, to +// run f. +// +// A non-zero Interval causes Do(f) to run f if Interval has elapsed since +// Do last ran f. +// +// Specifying multiple filters produces the union of these execution streams. +// For example, specifying both First:N and Every:M causes the first N Do(f) +// calls and every Mth Do(f) call, starting with the first, to run f. See +// Examples for more. +// +// If Do is called multiple times simultaneously, the calls will block and run +// serially. Therefore, Do is intended for lightweight operations. +// +// Because a call to Do may block until f returns, if f causes Do to be called, +// it will deadlock. +func (s *Sometimes) Do(f func()) { + s.mu.Lock() + defer s.mu.Unlock() + if s.count == 0 || + (s.First > 0 && s.count < s.First) || + (s.Every > 0 && s.count%s.Every == 0) || + (s.Interval > 0 && time.Since(s.last) >= s.Interval) { + f() + s.last = time.Now() + } + s.count++ +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 62fffe33b2..6846c904c8 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -1086,7 +1086,7 @@ golang.org/x/text/secure/bidirule golang.org/x/text/transform golang.org/x/text/unicode/bidi golang.org/x/text/unicode/norm -# golang.org/x/time v0.1.0 +# golang.org/x/time v0.3.0 ## explicit golang.org/x/time/rate # google.golang.org/api v0.93.0 From 8f5a9a741b70852bc6c9d675c6e0c4944022b467 Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Fri, 6 Jan 2023 19:50:18 -0500 Subject: [PATCH 12/14] libnetwork: fail loudly on resolver iptables setup Signed-off-by: Cory Snider --- libnetwork/resolver_unix.go | 54 ++++++++++++++++++++++++++----------- 1 file changed, 39 insertions(+), 15 deletions(-) diff --git a/libnetwork/resolver_unix.go b/libnetwork/resolver_unix.go index d308437d66..e16251112e 100644 --- a/libnetwork/resolver_unix.go +++ b/libnetwork/resolver_unix.go @@ -4,17 +4,17 @@ package libnetwork import ( + "fmt" "net" "github.com/docker/docker/libnetwork/iptables" - "github.com/sirupsen/logrus" ) const ( - // outputChain used for docker embed dns + // output chain used for docker embedded DNS resolver outputChain = "DOCKER_OUTPUT" - //postroutingchain used for docker embed dns - postroutingchain = "DOCKER_POSTROUTING" + // postrouting chain used for docker embedded DNS resolver + postroutingChain = "DOCKER_POSTROUTING" ) func (r *resolver) setupIPTable() error { @@ -27,36 +27,60 @@ func (r *resolver) setupIPTable() error { _, tcpPort, _ := net.SplitHostPort(ltcpaddr) rules := [][]string{ {"-t", "nat", "-I", outputChain, "-d", resolverIP, "-p", "udp", "--dport", dnsPort, "-j", "DNAT", "--to-destination", laddr}, - {"-t", "nat", "-I", postroutingchain, "-s", resolverIP, "-p", "udp", "--sport", ipPort, "-j", "SNAT", "--to-source", ":" + dnsPort}, + {"-t", "nat", "-I", postroutingChain, "-s", resolverIP, "-p", "udp", "--sport", ipPort, "-j", "SNAT", "--to-source", ":" + dnsPort}, {"-t", "nat", "-I", outputChain, "-d", resolverIP, "-p", "tcp", "--dport", dnsPort, "-j", "DNAT", "--to-destination", ltcpaddr}, - {"-t", "nat", "-I", postroutingchain, "-s", resolverIP, "-p", "tcp", "--sport", tcpPort, "-j", "SNAT", "--to-source", ":" + dnsPort}, + {"-t", "nat", "-I", postroutingChain, "-s", resolverIP, "-p", "tcp", "--sport", tcpPort, "-j", "SNAT", "--to-source", ":" + dnsPort}, } - return r.backend.ExecFunc(func() { + var setupErr error + err := r.backend.ExecFunc(func() { // TODO IPv6 support iptable := iptables.GetIptable(iptables.IPv4) // insert outputChain and postroutingchain err := iptable.RawCombinedOutputNative("-t", "nat", "-C", "OUTPUT", "-d", resolverIP, "-j", outputChain) if err == nil { - iptable.RawCombinedOutputNative("-t", "nat", "-F", outputChain) + if err := iptable.RawCombinedOutputNative("-t", "nat", "-F", outputChain); err != nil { + setupErr = err + return + } } else { - iptable.RawCombinedOutputNative("-t", "nat", "-N", outputChain) - iptable.RawCombinedOutputNative("-t", "nat", "-I", "OUTPUT", "-d", resolverIP, "-j", outputChain) + if err := iptable.RawCombinedOutputNative("-t", "nat", "-N", outputChain); err != nil { + setupErr = err + return + } + if err := iptable.RawCombinedOutputNative("-t", "nat", "-I", "OUTPUT", "-d", resolverIP, "-j", outputChain); err != nil { + setupErr = err + return + } } - err = iptable.RawCombinedOutputNative("-t", "nat", "-C", "POSTROUTING", "-d", resolverIP, "-j", postroutingchain) + err = iptable.RawCombinedOutputNative("-t", "nat", "-C", "POSTROUTING", "-d", resolverIP, "-j", postroutingChain) if err == nil { - iptable.RawCombinedOutputNative("-t", "nat", "-F", postroutingchain) + if err := iptable.RawCombinedOutputNative("-t", "nat", "-F", postroutingChain); err != nil { + setupErr = err + return + } } else { - iptable.RawCombinedOutputNative("-t", "nat", "-N", postroutingchain) - iptable.RawCombinedOutputNative("-t", "nat", "-I", "POSTROUTING", "-d", resolverIP, "-j", postroutingchain) + if err := iptable.RawCombinedOutputNative("-t", "nat", "-N", postroutingChain); err != nil { + setupErr = err + return + } + if err := iptable.RawCombinedOutputNative("-t", "nat", "-I", "POSTROUTING", "-d", resolverIP, "-j", postroutingChain); err != nil { + setupErr = err + return + } } for _, rule := range rules { if iptable.RawCombinedOutputNative(rule...) != nil { - logrus.Errorf("set up rule failed, %v", rule) + setupErr = fmt.Errorf("set up rule failed, %v", rule) + return } } }) + if err != nil { + return err + } + return setupErr } From faaa4fdf1897085bf8d65f5a1be67551daf9811c Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Fri, 6 Jan 2023 20:11:15 -0500 Subject: [PATCH 13/14] libnetwork: forward unknown PTR queries externally PTR queries with domain names unknown to us are not necessarily invalid. Act like a well-behaved middlebox and fall back to forwarding externally, same as we do with the other query types. Signed-off-by: Cory Snider --- libnetwork/resolver.go | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/libnetwork/resolver.go b/libnetwork/resolver.go index ecd1d198d9..eafc2204d6 100644 --- a/libnetwork/resolver.go +++ b/libnetwork/resolver.go @@ -293,26 +293,23 @@ func (r *resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) { } func (r *resolver) handlePTRQuery(query *dns.Msg) (*dns.Msg, error) { - var ( - parts []string - ptr = query.Question[0].Name - ) - - if strings.HasSuffix(ptr, ptrIPv4domain) { - parts = strings.Split(ptr, ptrIPv4domain) - } else if strings.HasSuffix(ptr, ptrIPv6domain) { - parts = strings.Split(ptr, ptrIPv6domain) - } else { - return nil, fmt.Errorf("invalid PTR query, %v", ptr) + ptr := query.Question[0].Name + name, after, found := strings.Cut(ptr, ptrIPv4domain) + if !found || after != "" { + name, after, found = strings.Cut(ptr, ptrIPv6domain) } - - host := r.backend.ResolveIP(parts[0]) - - if len(host) == 0 { + if !found || after != "" { + // Not a known IPv4 or IPv6 PTR domain. + // Maybe the external DNS servers know what to do with the query? return nil, nil } - logrus.Debugf("[resolver] lookup for IP %s: name %s", parts[0], host) + host := r.backend.ResolveIP(name) + if host == "" { + return nil, nil + } + + logrus.Debugf("[resolver] lookup for IP %s: name %s", name, host) fqdn := dns.Fqdn(host) resp := new(dns.Msg) From f8cfd3a61fd1824f8b6a72b97fcf13ec3f5306bb Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Fri, 20 Jan 2023 16:58:23 -0500 Subject: [PATCH 14/14] libnetwork: devirtualize Resolver type https://github.com/golang/go/wiki/CodeReviewComments#interfaces Signed-off-by: Cory Snider --- libnetwork/network.go | 2 +- libnetwork/resolver.go | 71 ++++++++++++++-------------------- libnetwork/resolver_test.go | 18 ++++----- libnetwork/resolver_unix.go | 2 +- libnetwork/resolver_windows.go | 2 +- libnetwork/sandbox.go | 2 +- 6 files changed, 43 insertions(+), 54 deletions(-) diff --git a/libnetwork/network.go b/libnetwork/network.go index 417caeff46..bbd321584d 100644 --- a/libnetwork/network.go +++ b/libnetwork/network.go @@ -223,7 +223,7 @@ type network struct { persist bool drvOnce *sync.Once resolverOnce sync.Once //nolint:nolintlint,unused // only used on windows - resolver []Resolver + resolver []*Resolver internal bool attachable bool inDelete bool diff --git a/libnetwork/resolver.go b/libnetwork/resolver.go index eafc2204d6..e8e6944039 100644 --- a/libnetwork/resolver.go +++ b/libnetwork/resolver.go @@ -16,27 +16,6 @@ import ( "golang.org/x/time/rate" ) -// Resolver represents the embedded DNS server in Docker. It operates -// by listening on container's loopback interface for DNS queries. -type Resolver interface { - // Start starts the name server for the container - Start() error - // Stop stops the name server for the container. Stopped resolver - // can be reused after running the SetupFunc again. - Stop() - // SetupFunc provides the setup function that should be run - // in the container's network namespace. - SetupFunc(int) func() - // NameServer returns the IP of the DNS resolver for the - // containers. - NameServer() string - // SetExtServers configures the external nameservers the resolver - // should use to forward queries - SetExtServers([]extDNSEntry) - // ResolverOptions returns resolv.conf options that should be set - ResolverOptions() []string -} - // DNSBackend represents a backend DNS resolver used for DNS name // resolution. All the queries to the resolver are forwarded to the // backend resolver. @@ -79,8 +58,9 @@ type extDNSEntry struct { HostLoopback bool } -// resolver implements the Resolver interface -type resolver struct { +// Resolver is the embedded DNS server in Docker. It operates by listening on +// the container's loopback interface for DNS queries. +type Resolver struct { backend DNSBackend extDNSList [maxExtDNS]extDNSEntry server *dns.Server @@ -97,8 +77,8 @@ type resolver struct { } // NewResolver creates a new instance of the Resolver -func NewResolver(address string, proxyDNS bool, backend DNSBackend) Resolver { - return &resolver{ +func NewResolver(address string, proxyDNS bool, backend DNSBackend) *Resolver { + return &Resolver{ backend: backend, proxyDNS: proxyDNS, listenAddress: address, @@ -109,7 +89,9 @@ func NewResolver(address string, proxyDNS bool, backend DNSBackend) Resolver { } } -func (r *resolver) SetupFunc(port int) func() { +// SetupFunc returns the setup function that should be run in the container's +// network namespace. +func (r *Resolver) SetupFunc(port int) func() { return func() { var err error @@ -140,7 +122,8 @@ func (r *resolver) SetupFunc(port int) func() { } } -func (r *resolver) Start() error { +// Start starts the name server for the container. +func (r *Resolver) Start() error { r.startCh <- struct{}{} defer func() { <-r.startCh }() @@ -153,7 +136,7 @@ func (r *resolver) Start() error { return fmt.Errorf("setting up IP table rules failed: %v", err) } - s := &dns.Server{Handler: r, PacketConn: r.conn} + s := &dns.Server{Handler: dns.HandlerFunc(r.serveDNS), PacketConn: r.conn} r.server = s go func() { if err := s.ActivateAndServe(); err != nil { @@ -161,7 +144,7 @@ func (r *resolver) Start() error { } }() - tcpServer := &dns.Server{Handler: r, Listener: r.tcpListen} + tcpServer := &dns.Server{Handler: dns.HandlerFunc(r.serveDNS), Listener: r.tcpListen} r.tcpServer = tcpServer go func() { if err := tcpServer.ActivateAndServe(); err != nil { @@ -171,7 +154,9 @@ func (r *resolver) Start() error { return nil } -func (r *resolver) Stop() { +// Stop stops the name server for the container. A stopped resolver can be +// reused after running the SetupFunc again. +func (r *Resolver) Stop() { r.startCh <- struct{}{} defer func() { <-r.startCh }() @@ -187,7 +172,9 @@ func (r *resolver) Stop() { r.fwdSem = semaphore.NewWeighted(maxConcurrent) } -func (r *resolver) SetExtServers(extDNS []extDNSEntry) { +// SetExtServers configures the external nameservers the resolver should use +// when forwarding queries. +func (r *Resolver) SetExtServers(extDNS []extDNSEntry) { l := len(extDNS) if l > maxExtDNS { l = maxExtDNS @@ -197,11 +184,13 @@ func (r *resolver) SetExtServers(extDNS []extDNSEntry) { } } -func (r *resolver) NameServer() string { +// NameServer returns the IP of the DNS resolver for the containers. +func (r *Resolver) NameServer() string { return r.listenAddress } -func (r *resolver) ResolverOptions() []string { +// ResolverOptions returns resolv.conf options that should be set. +func (r *Resolver) ResolverOptions() []string { return []string{"ndots:0"} } @@ -233,7 +222,7 @@ func createRespMsg(query *dns.Msg) *dns.Msg { return resp } -func (r *resolver) handleMXQuery(query *dns.Msg) (*dns.Msg, error) { +func (r *Resolver) handleMXQuery(query *dns.Msg) (*dns.Msg, error) { name := query.Question[0].Name addrv4, _ := r.backend.ResolveName(name, types.IPv4) addrv6, _ := r.backend.ResolveName(name, types.IPv6) @@ -250,7 +239,7 @@ func (r *resolver) handleMXQuery(query *dns.Msg) (*dns.Msg, error) { return resp, nil } -func (r *resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) { +func (r *Resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) { var ( addr []net.IP ipv6Miss bool @@ -292,7 +281,7 @@ func (r *resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) { return resp, nil } -func (r *resolver) handlePTRQuery(query *dns.Msg) (*dns.Msg, error) { +func (r *Resolver) handlePTRQuery(query *dns.Msg) (*dns.Msg, error) { ptr := query.Question[0].Name name, after, found := strings.Cut(ptr, ptrIPv4domain) if !found || after != "" { @@ -323,7 +312,7 @@ func (r *resolver) handlePTRQuery(query *dns.Msg) (*dns.Msg, error) { return resp, nil } -func (r *resolver) handleSRVQuery(query *dns.Msg) (*dns.Msg, error) { +func (r *Resolver) handleSRVQuery(query *dns.Msg) (*dns.Msg, error) { svc := query.Question[0].Name srv, ip := r.backend.ResolveService(svc) @@ -351,7 +340,7 @@ func (r *resolver) handleSRVQuery(query *dns.Msg) (*dns.Msg, error) { return resp, nil } -func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { +func (r *Resolver) serveDNS(w dns.ResponseWriter, query *dns.Msg) { var ( resp *dns.Msg err error @@ -431,7 +420,7 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { reply(resp) } -func (r *resolver) dialExtDNS(proto string, server extDNSEntry) (net.Conn, error) { +func (r *Resolver) dialExtDNS(proto string, server extDNSEntry) (net.Conn, error) { var ( extConn net.Conn dialErr error @@ -459,7 +448,7 @@ func (r *resolver) dialExtDNS(proto string, server extDNSEntry) (net.Conn, error return extConn, nil } -func (r *resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg { +func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg { queryName, queryType := query.Question[0].Name, query.Question[0].Qtype for _, extDNS := range r.extDNSList { if extDNS.IPStr == "" { @@ -530,7 +519,7 @@ func (r *resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg { return nil } -func (r *resolver) exchange(proto string, extDNS extDNSEntry, query *dns.Msg) *dns.Msg { +func (r *Resolver) exchange(proto string, extDNS extDNSEntry, query *dns.Msg) *dns.Msg { extConn, err := r.dialExtDNS(proto, extDNS) if err != nil { logrus.WithError(err).Warn("[resolver] connect failed") diff --git a/libnetwork/resolver_test.go b/libnetwork/resolver_test.go index 8cd28a0c8a..e782de6e2a 100644 --- a/libnetwork/resolver_test.go +++ b/libnetwork/resolver_test.go @@ -143,7 +143,7 @@ func TestDNSIPQuery(t *testing.T) { for _, name := range names { q := new(dns.Msg) q.SetQuestion(name, dns.TypeA) - r.(*resolver).ServeDNS(w, q) + r.serveDNS(w, q) resp := w.GetResponse() checkNonNullResponse(t, resp) t.Log("Response: ", resp.String()) @@ -163,7 +163,7 @@ func TestDNSIPQuery(t *testing.T) { // test MX query with name1 results in Success response with 0 answer records q := new(dns.Msg) q.SetQuestion("name1", dns.TypeMX) - r.(*resolver).ServeDNS(w, q) + r.serveDNS(w, q) resp := w.GetResponse() checkNonNullResponse(t, resp) t.Log("Response: ", resp.String()) @@ -175,7 +175,7 @@ func TestDNSIPQuery(t *testing.T) { // since this is a unit test env, we disable proxying DNS above which results in ServFail rather than NXDOMAIN q = new(dns.Msg) q.SetQuestion("nonexistent", dns.TypeMX) - r.(*resolver).ServeDNS(w, q) + r.serveDNS(w, q) resp = w.GetResponse() checkNonNullResponse(t, resp) t.Log("Response: ", resp.String()) @@ -291,8 +291,8 @@ func TestDNSProxyServFail(t *testing.T) { localDNSEntries = append(localDNSEntries, extTestDNSEntry) // this should generate two requests: the first will fail leading to a retry - r.(*resolver).SetExtServers(localDNSEntries) - r.(*resolver).ServeDNS(w, q) + r.SetExtServers(localDNSEntries) + r.serveDNS(w, q) if nRequests != 2 { t.Fatalf("Expected 2 DNS querries. Found: %d", nRequests) } @@ -370,7 +370,7 @@ func TestOversizedDNSReply(t *testing.T) { }() srvAddr := srv.LocalAddr().(*net.UDPAddr) - rsv := NewResolver("", true, noopDNSBackend{}).(*resolver) + rsv := NewResolver("", true, noopDNSBackend{}) rsv.SetExtServers([]extDNSEntry{ {IPStr: srvAddr.IP.String(), port: uint16(srvAddr.Port), HostLoopback: true}, }) @@ -381,7 +381,7 @@ func TestOversizedDNSReply(t *testing.T) { w := &tstwriter{localAddr: srv.LocalAddr()} q := new(dns.Msg).SetQuestion("s3.amazonaws.com.", dns.TypeA) - rsv.ServeDNS(w, q) + rsv.serveDNS(w, q) resp := w.GetResponse() checkNonNullResponse(t, resp) t.Log("Response: ", resp.String()) @@ -441,9 +441,9 @@ func TestReplySERVFAIL(t *testing.T) { t.Run(tt.name, func(t *testing.T) { defer redirectLogrusTo(t) - rsv := NewResolver("", tt.proxyDNS, badSRVDNSBackend{}).(*resolver) + rsv := NewResolver("", tt.proxyDNS, badSRVDNSBackend{}) w := &tstwriter{} - rsv.ServeDNS(w, tt.q) + rsv.serveDNS(w, tt.q) resp := w.GetResponse() checkNonNullResponse(t, resp) t.Log("Response: ", resp.String()) diff --git a/libnetwork/resolver_unix.go b/libnetwork/resolver_unix.go index e16251112e..7b0511bcff 100644 --- a/libnetwork/resolver_unix.go +++ b/libnetwork/resolver_unix.go @@ -17,7 +17,7 @@ const ( postroutingChain = "DOCKER_POSTROUTING" ) -func (r *resolver) setupIPTable() error { +func (r *Resolver) setupIPTable() error { if r.err != nil { return r.err } diff --git a/libnetwork/resolver_windows.go b/libnetwork/resolver_windows.go index 3d422fcd06..a3b17fcb4d 100644 --- a/libnetwork/resolver_windows.go +++ b/libnetwork/resolver_windows.go @@ -3,6 +3,6 @@ package libnetwork -func (r *resolver) setupIPTable() error { +func (r *Resolver) setupIPTable() error { return nil } diff --git a/libnetwork/sandbox.go b/libnetwork/sandbox.go index 80698c2905..194844ca7b 100644 --- a/libnetwork/sandbox.go +++ b/libnetwork/sandbox.go @@ -38,7 +38,7 @@ type Sandbox struct { extDNS []extDNSEntry osSbox osl.Sandbox controller *Controller - resolver Resolver + resolver *Resolver resolverOnce sync.Once endpoints []*Endpoint epPriority map[string]int