Merge pull request #44664 from corhere/embedded-resolver-fixes

libnetwork: improve embedded DNS resolver
This commit is contained in:
Bjorn Neergaard 2023-02-23 12:25:58 -07:00 committed by GitHub
commit 855c684708
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 524 additions and 293 deletions

View file

@ -223,7 +223,7 @@ type network struct {
persist bool
drvOnce *sync.Once
resolverOnce sync.Once //nolint:nolintlint,unused // only used on windows
resolver []Resolver
resolver []*Resolver
internal bool
attachable bool
inDelete bool

View file

@ -1,6 +1,7 @@
package libnetwork
import (
"context"
"fmt"
"math/rand"
"net"
@ -11,29 +12,10 @@ import (
"github.com/docker/docker/libnetwork/types"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"golang.org/x/sync/semaphore"
"golang.org/x/time/rate"
)
// Resolver represents the embedded DNS server in Docker. It operates
// by listening on container's loopback interface for DNS queries.
type Resolver interface {
// Start starts the name server for the container
Start() error
// Stop stops the name server for the container. Stopped resolver
// can be reused after running the SetupFunc again.
Stop()
// SetupFunc provides the setup function that should be run
// in the container's network namespace.
SetupFunc(int) func()
// NameServer returns the IP of the DNS resolver for the
// containers.
NameServer() string
// SetExtServers configures the external nameservers the resolver
// should use to forward queries
SetExtServers([]extDNSEntry)
// ResolverOptions returns resolv.conf options that should be set
ResolverOptions() []string
}
// DNSBackend represents a backend DNS resolver used for DNS name
// resolution. All the queries to the resolver are forwarded to the
// backend resolver.
@ -60,24 +42,25 @@ type DNSBackend interface {
}
const (
dnsPort = "53"
ptrIPv4domain = ".in-addr.arpa."
ptrIPv6domain = ".ip6.arpa."
respTTL = 600
maxExtDNS = 3 // max number of external servers to try
extIOTimeout = 4 * time.Second
defaultRespSize = 512
maxConcurrent = 1024
logInterval = 2 * time.Second
dnsPort = "53"
ptrIPv4domain = ".in-addr.arpa."
ptrIPv6domain = ".ip6.arpa."
respTTL = 600
maxExtDNS = 3 // max number of external servers to try
extIOTimeout = 4 * time.Second
maxConcurrent = 1024
logInterval = 2 * time.Second
)
type extDNSEntry struct {
IPStr string
port uint16 // for testing
HostLoopback bool
}
// resolver implements the Resolver interface
type resolver struct {
// Resolver is the embedded DNS server in Docker. It operates by listening on
// the container's loopback interface for DNS queries.
type Resolver struct {
backend DNSBackend
extDNSList [maxExtDNS]extDNSEntry
server *dns.Server
@ -85,26 +68,30 @@ type resolver struct {
tcpServer *dns.Server
tcpListen *net.TCPListener
err error
count int32
tStamp time.Time
queryLock sync.Mutex
listenAddress string
proxyDNS bool
startCh chan struct{}
fwdSem *semaphore.Weighted // Limit the number of concurrent external DNS requests in-flight
logInverval rate.Sometimes // Rate-limit logging about hitting the fwdSem limit
}
// NewResolver creates a new instance of the Resolver
func NewResolver(address string, proxyDNS bool, backend DNSBackend) Resolver {
return &resolver{
func NewResolver(address string, proxyDNS bool, backend DNSBackend) *Resolver {
return &Resolver{
backend: backend,
proxyDNS: proxyDNS,
listenAddress: address,
err: fmt.Errorf("setup not done yet"),
startCh: make(chan struct{}, 1),
fwdSem: semaphore.NewWeighted(maxConcurrent),
logInverval: rate.Sometimes{Interval: logInterval},
}
}
func (r *resolver) SetupFunc(port int) func() {
// SetupFunc returns the setup function that should be run in the container's
// network namespace.
func (r *Resolver) SetupFunc(port int) func() {
return func() {
var err error
@ -135,7 +122,8 @@ func (r *resolver) SetupFunc(port int) func() {
}
}
func (r *resolver) Start() error {
// Start starts the name server for the container.
func (r *Resolver) Start() error {
r.startCh <- struct{}{}
defer func() { <-r.startCh }()
@ -148,7 +136,7 @@ func (r *resolver) Start() error {
return fmt.Errorf("setting up IP table rules failed: %v", err)
}
s := &dns.Server{Handler: r, PacketConn: r.conn}
s := &dns.Server{Handler: dns.HandlerFunc(r.serveDNS), PacketConn: r.conn}
r.server = s
go func() {
if err := s.ActivateAndServe(); err != nil {
@ -156,7 +144,7 @@ func (r *resolver) Start() error {
}
}()
tcpServer := &dns.Server{Handler: r, Listener: r.tcpListen}
tcpServer := &dns.Server{Handler: dns.HandlerFunc(r.serveDNS), Listener: r.tcpListen}
r.tcpServer = tcpServer
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
@ -166,7 +154,9 @@ func (r *resolver) Start() error {
return nil
}
func (r *resolver) Stop() {
// Stop stops the name server for the container. A stopped resolver can be
// reused after running the SetupFunc again.
func (r *Resolver) Stop() {
r.startCh <- struct{}{}
defer func() { <-r.startCh }()
@ -179,12 +169,12 @@ func (r *resolver) Stop() {
r.conn = nil
r.tcpServer = nil
r.err = fmt.Errorf("setup not done yet")
r.tStamp = time.Time{}
r.count = 0
r.queryLock = sync.Mutex{}
r.fwdSem = semaphore.NewWeighted(maxConcurrent)
}
func (r *resolver) SetExtServers(extDNS []extDNSEntry) {
// SetExtServers configures the external nameservers the resolver should use
// when forwarding queries.
func (r *Resolver) SetExtServers(extDNS []extDNSEntry) {
l := len(extDNS)
if l > maxExtDNS {
l = maxExtDNS
@ -194,11 +184,13 @@ func (r *resolver) SetExtServers(extDNS []extDNSEntry) {
}
}
func (r *resolver) NameServer() string {
// NameServer returns the IP of the DNS resolver for the containers.
func (r *Resolver) NameServer() string {
return r.listenAddress
}
func (r *resolver) ResolverOptions() []string {
// ResolverOptions returns resolv.conf options that should be set.
func (r *Resolver) ResolverOptions() []string {
return []string{"ndots:0"}
}
@ -230,7 +222,7 @@ func createRespMsg(query *dns.Msg) *dns.Msg {
return resp
}
func (r *resolver) handleMXQuery(query *dns.Msg) (*dns.Msg, error) {
func (r *Resolver) handleMXQuery(query *dns.Msg) (*dns.Msg, error) {
name := query.Question[0].Name
addrv4, _ := r.backend.ResolveName(name, types.IPv4)
addrv6, _ := r.backend.ResolveName(name, types.IPv6)
@ -247,7 +239,7 @@ func (r *resolver) handleMXQuery(query *dns.Msg) (*dns.Msg, error) {
return resp, nil
}
func (r *resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) {
func (r *Resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) {
var (
addr []net.IP
ipv6Miss bool
@ -289,27 +281,24 @@ func (r *resolver) handleIPQuery(query *dns.Msg, ipType int) (*dns.Msg, error) {
return resp, nil
}
func (r *resolver) handlePTRQuery(query *dns.Msg) (*dns.Msg, error) {
var (
parts []string
ptr = query.Question[0].Name
)
if strings.HasSuffix(ptr, ptrIPv4domain) {
parts = strings.Split(ptr, ptrIPv4domain)
} else if strings.HasSuffix(ptr, ptrIPv6domain) {
parts = strings.Split(ptr, ptrIPv6domain)
} else {
return nil, fmt.Errorf("invalid PTR query, %v", ptr)
func (r *Resolver) handlePTRQuery(query *dns.Msg) (*dns.Msg, error) {
ptr := query.Question[0].Name
name, after, found := strings.Cut(ptr, ptrIPv4domain)
if !found || after != "" {
name, after, found = strings.Cut(ptr, ptrIPv6domain)
}
host := r.backend.ResolveIP(parts[0])
if len(host) == 0 {
if !found || after != "" {
// Not a known IPv4 or IPv6 PTR domain.
// Maybe the external DNS servers know what to do with the query?
return nil, nil
}
logrus.Debugf("[resolver] lookup for IP %s: name %s", parts[0], host)
host := r.backend.ResolveIP(name)
if host == "" {
return nil, nil
}
logrus.Debugf("[resolver] lookup for IP %s: name %s", name, host)
fqdn := dns.Fqdn(host)
resp := new(dns.Msg)
@ -323,7 +312,7 @@ func (r *resolver) handlePTRQuery(query *dns.Msg) (*dns.Msg, error) {
return resp, nil
}
func (r *resolver) handleSRVQuery(query *dns.Msg) (*dns.Msg, error) {
func (r *Resolver) handleSRVQuery(query *dns.Msg) (*dns.Msg, error) {
svc := query.Question[0].Name
srv, ip := r.backend.ResolveService(svc)
@ -351,28 +340,10 @@ func (r *resolver) handleSRVQuery(query *dns.Msg) (*dns.Msg, error) {
return resp, nil
}
func truncateResp(resp *dns.Msg, maxSize int, isTCP bool) {
if !isTCP {
resp.Truncated = true
}
srv := resp.Question[0].Qtype == dns.TypeSRV
// trim the Answer RRs one by one till the whole message fits
// within the reply size
for resp.Len() > maxSize {
resp.Answer = resp.Answer[:len(resp.Answer)-1]
if srv && len(resp.Extra) > 0 {
resp.Extra = resp.Extra[:len(resp.Extra)-1]
}
}
}
func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
func (r *Resolver) serveDNS(w dns.ResponseWriter, query *dns.Msg) {
var (
extConn net.Conn
resp *dns.Msg
err error
resp *dns.Msg
err error
)
if query == nil || len(query.Question) == 0 {
@ -397,171 +368,195 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
logrus.Debugf("[resolver] query type %s is not supported by the embedded DNS and will be forwarded to external DNS", dns.TypeToString[queryType])
}
reply := func(msg *dns.Msg) {
if err = w.WriteMsg(msg); err != nil {
logrus.WithError(err).Errorf("[resolver] failed to write response")
}
}
if err != nil {
logrus.WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType])
reply(new(dns.Msg).SetRcode(query, dns.RcodeServerFailure))
return
}
if resp == nil {
// If the backend doesn't support proxying dns request
// fail the response
if !r.proxyDNS {
resp = new(dns.Msg)
resp.SetRcode(query, dns.RcodeServerFailure)
if err := w.WriteMsg(resp); err != nil {
logrus.WithError(err).Error("[resolver] error writing dns response")
if resp != nil {
// We are the authoritative DNS server for this request so it's
// on us to truncate the response message to the size limit
// negotiated by the client.
maxSize := dns.MinMsgSize
if w.LocalAddr().Network() == "tcp" {
maxSize = dns.MaxMsgSize
} else {
if optRR := query.IsEdns0(); optRR != nil {
if udpsize := int(optRR.UDPSize()); udpsize > maxSize {
maxSize = udpsize
}
}
return
}
resp.Truncate(maxSize)
reply(resp)
return
}
if r.proxyDNS {
// If the user sets ndots > 0 explicitly and the query is
// in the root domain don't forward it out. We will return
// failure and let the client retry with the search domain
// attached
switch queryType {
case dns.TypeA, dns.TypeAAAA:
if r.backend.NdotsSet() && !strings.Contains(strings.TrimSuffix(queryName, "."), ".") {
resp = createRespMsg(query)
}
// attached.
if (queryType == dns.TypeA || queryType == dns.TypeAAAA) && r.backend.NdotsSet() &&
!strings.Contains(strings.TrimSuffix(queryName, "."), ".") {
resp = createRespMsg(query)
} else {
resp = r.forwardExtDNS(w.LocalAddr().Network(), query)
}
}
proto := w.LocalAddr().Network()
maxSize := 0
if proto == "tcp" {
maxSize = dns.MaxMsgSize - 1
} else if proto == "udp" {
optRR := query.IsEdns0()
if optRR != nil {
maxSize = int(optRR.UDPSize())
}
if maxSize < defaultRespSize {
maxSize = defaultRespSize
if resp == nil {
// We were unable to get an answer from any of the upstream DNS
// servers or the backend doesn't support proxying DNS requests.
resp = new(dns.Msg).SetRcode(query, dns.RcodeServerFailure)
}
reply(resp)
}
func (r *Resolver) dialExtDNS(proto string, server extDNSEntry) (net.Conn, error) {
var (
extConn net.Conn
dialErr error
)
extConnect := func() {
if server.port == 0 {
server.port = 53
}
addr := fmt.Sprintf("%s:%d", server.IPStr, server.port)
extConn, dialErr = net.DialTimeout(proto, addr, extIOTimeout)
}
if resp != nil {
if resp.Len() > maxSize {
truncateResp(resp, maxSize, proto == "tcp")
}
if server.HostLoopback {
extConnect()
} else {
for i := 0; i < maxExtDNS; i++ {
extDNS := &r.extDNSList[i]
if extDNS.IPStr == "" {
break
}
extConnect := func() {
addr := fmt.Sprintf("%s:%d", extDNS.IPStr, 53)
extConn, err = net.DialTimeout(proto, addr, extIOTimeout)
}
execErr := r.backend.ExecFunc(extConnect)
if execErr != nil {
return nil, execErr
}
}
if dialErr != nil {
return nil, dialErr
}
if extDNS.HostLoopback {
extConnect()
} else {
execErr := r.backend.ExecFunc(extConnect)
if execErr != nil {
logrus.Warn(execErr)
continue
}
}
if err != nil {
logrus.WithField("retries", i).Warnf("[resolver] connect failed: %s", err)
continue
}
logrus.Debugf("[resolver] query %s (%s) from %s, forwarding to %s:%s", queryName, dns.TypeToString[queryType],
extConn.LocalAddr().String(), proto, extDNS.IPStr)
return extConn, nil
}
// Timeout has to be set for every IO operation.
if err := extConn.SetDeadline(time.Now().Add(extIOTimeout)); err != nil {
logrus.WithError(err).Error("[resolver] error setting conn deadline")
}
co := &dns.Conn{
Conn: extConn,
UDPSize: uint16(maxSize),
}
defer co.Close()
// limits the number of outstanding concurrent queries.
if !r.forwardQueryStart() {
old := r.tStamp
r.tStamp = time.Now()
if r.tStamp.Sub(old) > logInterval {
logrus.Errorf("[resolver] more than %v concurrent queries from %s", maxConcurrent, extConn.LocalAddr().String())
}
continue
}
err = co.WriteMsg(query)
if err != nil {
r.forwardQueryEnd()
logrus.Debugf("[resolver] send to DNS server failed, %s", err)
continue
}
resp, err = co.ReadMsg()
// Truncated DNS replies should be sent to the client so that the
// client can retry over TCP
if err != nil && (resp == nil || !resp.Truncated) {
r.forwardQueryEnd()
logrus.WithError(err).Warnf("[resolver] failed to read from DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String())
continue
}
r.forwardQueryEnd()
if resp == nil {
logrus.Debugf("[resolver] external DNS %s:%s returned empty response for %q", proto, extDNS.IPStr, queryName)
break
}
switch resp.Rcode {
case dns.RcodeServerFailure, dns.RcodeRefused:
// Server returned FAILURE: continue with the next external DNS server
// Server returned REFUSED: this can be a transitional status, so continue with the next external DNS server
logrus.Debugf("[resolver] external DNS %s:%s responded with %s for %q", proto, extDNS.IPStr, statusString(resp.Rcode), queryName)
continue
case dns.RcodeNameError:
// Server returned NXDOMAIN. Stop resolution if it's an authoritative answer (see RFC 8020: https://tools.ietf.org/html/rfc8020#section-2)
logrus.Debugf("[resolver] external DNS %s:%s responded with %s for %q", proto, extDNS.IPStr, statusString(resp.Rcode), queryName)
if resp.Authoritative {
break
}
continue
case dns.RcodeSuccess:
// All is well
default:
// Server gave some error. Log the error, and continue with the next external DNS server
logrus.Debugf("[resolver] external DNS %s:%s responded with %s (code %d) for %q", proto, extDNS.IPStr, statusString(resp.Rcode), resp.Rcode, queryName)
continue
}
answers := 0
for _, rr := range resp.Answer {
h := rr.Header()
switch h.Rrtype {
case dns.TypeA:
answers++
ip := rr.(*dns.A).A
logrus.Debugf("[resolver] received A record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
r.backend.HandleQueryResp(h.Name, ip)
case dns.TypeAAAA:
answers++
ip := rr.(*dns.AAAA).AAAA
logrus.Debugf("[resolver] received AAAA record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
r.backend.HandleQueryResp(h.Name, ip)
}
}
if resp.Answer == nil || answers == 0 {
logrus.Debugf("[resolver] external DNS %s:%s did not return any %s records for %q", proto, extDNS.IPStr, dns.TypeToString[queryType], queryName)
}
resp.Compress = true
func (r *Resolver) forwardExtDNS(proto string, query *dns.Msg) *dns.Msg {
queryName, queryType := query.Question[0].Name, query.Question[0].Qtype
for _, extDNS := range r.extDNSList {
if extDNS.IPStr == "" {
break
}
if resp == nil {
return
// limits the number of outstanding concurrent queries.
ctx, cancel := context.WithTimeout(context.Background(), extIOTimeout)
err := r.fwdSem.Acquire(ctx, 1)
cancel()
if err != nil {
r.logInverval.Do(func() {
logrus.Errorf("[resolver] more than %v concurrent queries", maxConcurrent)
})
return new(dns.Msg).SetRcode(query, dns.RcodeRefused)
}
resp := func() *dns.Msg {
defer r.fwdSem.Release(1)
return r.exchange(proto, extDNS, query)
}()
if resp == nil {
continue
}
switch resp.Rcode {
case dns.RcodeServerFailure, dns.RcodeRefused:
// Server returned FAILURE: continue with the next external DNS server
// Server returned REFUSED: this can be a transitional status, so continue with the next external DNS server
logrus.Debugf("[resolver] external DNS %s:%s responded with %s for %q", proto, extDNS.IPStr, statusString(resp.Rcode), queryName)
continue
case dns.RcodeNameError:
// Server returned NXDOMAIN. Stop resolution if it's an authoritative answer (see RFC 8020: https://tools.ietf.org/html/rfc8020#section-2)
logrus.Debugf("[resolver] external DNS %s:%s responded with %s for %q", proto, extDNS.IPStr, statusString(resp.Rcode), queryName)
if resp.Authoritative {
break
}
continue
case dns.RcodeSuccess:
// All is well
default:
// Server gave some error. Log the error, and continue with the next external DNS server
logrus.Debugf("[resolver] external DNS %s:%s responded with %s (code %d) for %q", proto, extDNS.IPStr, statusString(resp.Rcode), resp.Rcode, queryName)
continue
}
answers := 0
for _, rr := range resp.Answer {
h := rr.Header()
switch h.Rrtype {
case dns.TypeA:
answers++
ip := rr.(*dns.A).A
logrus.Debugf("[resolver] received A record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
r.backend.HandleQueryResp(h.Name, ip)
case dns.TypeAAAA:
answers++
ip := rr.(*dns.AAAA).AAAA
logrus.Debugf("[resolver] received AAAA record %q for %q from %s:%s", ip, h.Name, proto, extDNS.IPStr)
r.backend.HandleQueryResp(h.Name, ip)
}
}
if resp.Answer == nil || answers == 0 {
logrus.Debugf("[resolver] external DNS %s:%s did not return any %s records for %q", proto, extDNS.IPStr, dns.TypeToString[queryType], queryName)
}
resp.Compress = true
return resp
}
if err = w.WriteMsg(resp); err != nil {
logrus.WithError(err).Errorf("[resolver] failed to write response")
return nil
}
func (r *Resolver) exchange(proto string, extDNS extDNSEntry, query *dns.Msg) *dns.Msg {
extConn, err := r.dialExtDNS(proto, extDNS)
if err != nil {
logrus.WithError(err).Warn("[resolver] connect failed")
return nil
}
defer extConn.Close()
log := logrus.WithFields(logrus.Fields{
"dns-server": extConn.RemoteAddr().Network() + ":" + extConn.RemoteAddr().String(),
"client-addr": extConn.LocalAddr().Network() + ":" + extConn.LocalAddr().String(),
"question": query.Question[0].String(),
})
log.Debug("[resolver] forwarding query")
resp, _, err := (&dns.Client{
Timeout: extIOTimeout,
// Following the robustness principle, make a best-effort
// attempt to receive oversized response messages without
// truncating them on our end to forward verbatim to the client.
// Some DNS servers (e.g. Mikrotik RouterOS) don't support
// EDNS(0) and may send replies over UDP longer than 512 bytes
// regardless of what size limit, if any, was advertized in the
// query message. Note that ExchangeWithConn will override this
// value if it detects an EDNS OPT record in query so only
// oversized replies to non-EDNS queries will benefit.
UDPSize: dns.MaxMsgSize,
}).ExchangeWithConn(query, &dns.Conn{Conn: extConn})
if err != nil {
logrus.WithError(err).Errorf("[resolver] failed to query DNS server: %s, query: %s", extConn.RemoteAddr().String(), query.Question[0].String())
return nil
}
if resp == nil {
// Should be impossible, so make noise if it happens anyway.
log.Error("[resolver] external DNS returned empty response")
}
return resp
}
func statusString(responseCode int) string {
@ -570,26 +565,3 @@ func statusString(responseCode int) string {
}
return "UNKNOWN"
}
func (r *resolver) forwardQueryStart() bool {
r.queryLock.Lock()
defer r.queryLock.Unlock()
if r.count == maxConcurrent {
return false
}
r.count++
return true
}
func (r *resolver) forwardQueryEnd() {
r.queryLock.Lock()
defer r.queryLock.Unlock()
if r.count == 0 {
logrus.Error("[resolver] invalid concurrent query count")
} else {
r.count--
}
}

View file

@ -1,6 +1,8 @@
package libnetwork
import (
"encoding/hex"
"errors"
"net"
"runtime"
"syscall"
@ -10,6 +12,7 @@ import (
"github.com/docker/docker/libnetwork/testutils"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"gotest.tools/v3/assert"
"gotest.tools/v3/skip"
)
@ -23,7 +26,8 @@ func (a *tstaddr) String() string { return "127.0.0.1" }
// a simple writer that implements dns.ResponseWriter for unit testing purposes
type tstwriter struct {
msg *dns.Msg
localAddr net.Addr
msg *dns.Msg
}
func (w *tstwriter) WriteMsg(m *dns.Msg) (err error) {
@ -33,7 +37,12 @@ func (w *tstwriter) WriteMsg(m *dns.Msg) (err error) {
func (w *tstwriter) Write(m []byte) (int, error) { return 0, nil }
func (w *tstwriter) LocalAddr() net.Addr { return new(tstaddr) }
func (w *tstwriter) LocalAddr() net.Addr {
if w.localAddr != nil {
return w.localAddr
}
return new(tstaddr)
}
func (w *tstwriter) RemoteAddr() net.Addr { return new(tstaddr) }
@ -50,12 +59,14 @@ func (w *tstwriter) GetResponse() *dns.Msg { return w.msg }
func (w *tstwriter) ClearResponse() { w.msg = nil }
func checkNonNullResponse(t *testing.T, m *dns.Msg) {
t.Helper()
if m == nil {
t.Fatal("Null DNS response found. Non Null response msg expected.")
}
}
func checkDNSAnswersCount(t *testing.T, m *dns.Msg, expected int) {
t.Helper()
answers := len(m.Answer)
if answers != expected {
t.Fatalf("Expected number of answers in response: %d. Found: %d", expected, answers)
@ -63,12 +74,14 @@ func checkDNSAnswersCount(t *testing.T, m *dns.Msg, expected int) {
}
func checkDNSResponseCode(t *testing.T, m *dns.Msg, expected int) {
t.Helper()
if m.MsgHdr.Rcode != expected {
t.Fatalf("Expected DNS response code: %d. Found: %d", expected, m.MsgHdr.Rcode)
}
}
func checkDNSRRType(t *testing.T, actual, expected uint16) {
t.Helper()
if actual != expected {
t.Fatalf("Expected DNS Rrtype: %d. Found: %d", expected, actual)
}
@ -130,7 +143,7 @@ func TestDNSIPQuery(t *testing.T) {
for _, name := range names {
q := new(dns.Msg)
q.SetQuestion(name, dns.TypeA)
r.(*resolver).ServeDNS(w, q)
r.serveDNS(w, q)
resp := w.GetResponse()
checkNonNullResponse(t, resp)
t.Log("Response: ", resp.String())
@ -150,7 +163,7 @@ func TestDNSIPQuery(t *testing.T) {
// test MX query with name1 results in Success response with 0 answer records
q := new(dns.Msg)
q.SetQuestion("name1", dns.TypeMX)
r.(*resolver).ServeDNS(w, q)
r.serveDNS(w, q)
resp := w.GetResponse()
checkNonNullResponse(t, resp)
t.Log("Response: ", resp.String())
@ -162,7 +175,7 @@ func TestDNSIPQuery(t *testing.T) {
// since this is a unit test env, we disable proxying DNS above which results in ServFail rather than NXDOMAIN
q = new(dns.Msg)
q.SetQuestion("nonexistent", dns.TypeMX)
r.(*resolver).ServeDNS(w, q)
r.serveDNS(w, q)
resp = w.GetResponse()
checkNonNullResponse(t, resp)
t.Log("Response: ", resp.String())
@ -278,10 +291,169 @@ func TestDNSProxyServFail(t *testing.T) {
localDNSEntries = append(localDNSEntries, extTestDNSEntry)
// this should generate two requests: the first will fail leading to a retry
r.(*resolver).SetExtServers(localDNSEntries)
r.(*resolver).ServeDNS(w, q)
r.SetExtServers(localDNSEntries)
r.serveDNS(w, q)
if nRequests != 2 {
t.Fatalf("Expected 2 DNS querries. Found: %d", nRequests)
}
t.Logf("Expected number of DNS requests generated")
}
// Packet 24 extracted from
// https://gist.github.com/vojtad/3bac63b8c91b1ec50e8d8b36047317fa/raw/7d75eb3d3448381bf252ae55ea5123a132c46658/host.pcap
// (https://github.com/moby/moby/issues/44575)
// which is a non-compliant DNS reply > 512B (w/o EDNS(0)) to the query
//
// s3.amazonaws.com. IN A
const oversizedDNSReplyMsg = "\xf5\x11\x81\x80\x00\x01\x00\x20\x00\x00\x00\x00\x02\x73\x33\x09" +
"\x61\x6d\x61\x7a\x6f\x6e\x61\x77\x73\x03\x63\x6f\x6d\x00\x00\x01" +
"\x00\x01\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" +
"\x11\x9e\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" +
"\x4c\x66\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" +
"\xda\x10\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" +
"\x01\x3e\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" +
"\x88\x68\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" +
"\x66\x9e\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" +
"\x5f\x28\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" +
"\x8e\x4e\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x36\xe7" +
"\x84\xf0\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x34\xd8" +
"\x92\x45\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" +
"\x8f\xa6\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x36\xe7" +
"\xc0\xd0\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" +
"\xfe\x28\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" +
"\xaa\x3d\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" +
"\x4e\x56\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd9" +
"\xea\xb0\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" +
"\x6d\xed\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x04\x00\x04\x34\xd8" +
"\x28\x00\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x34\xd9" +
"\xe9\x78\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x34\xd9" +
"\x6e\x9e\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x34\xd9" +
"\x45\x86\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x34\xd8" +
"\x30\x38\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x36\xe7" +
"\xc6\xa8\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x00\x00\x04\x03\x05" +
"\x01\x9d\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd9" +
"\xa8\xe8\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd9" +
"\x64\xa6\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd8" +
"\x3c\x48\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd8" +
"\x35\x20\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd9" +
"\x54\xf6\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd9" +
"\x5d\x36\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x34\xd9" +
"\x30\x36\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\x05\x00\x04\x36\xe7" +
"\x83\x90"
// Regression test for https://github.com/moby/moby/issues/44575
func TestOversizedDNSReply(t *testing.T) {
srv, err := net.ListenPacket("udp", "127.0.0.1:0")
assert.NilError(t, err)
defer srv.Close()
go func() {
buf := make([]byte, 65536)
for {
n, src, err := srv.ReadFrom(buf)
if errors.Is(err, net.ErrClosed) {
return
}
t.Logf("[<-%v]\n%s", src, hex.Dump(buf[:n]))
if n < 2 {
continue
}
resp := []byte(oversizedDNSReplyMsg)
resp[0], resp[1] = buf[0], buf[1] // Copy query ID into response.
_, err = srv.WriteTo(resp, src)
if errors.Is(err, net.ErrClosed) {
return
}
if err != nil {
t.Log(err)
}
}
}()
srvAddr := srv.LocalAddr().(*net.UDPAddr)
rsv := NewResolver("", true, noopDNSBackend{})
rsv.SetExtServers([]extDNSEntry{
{IPStr: srvAddr.IP.String(), port: uint16(srvAddr.Port), HostLoopback: true},
})
// The resolver logs lots of valuable info at level debug. Redirect it
// to t.Log() so the log spew is emitted only if the test fails.
defer redirectLogrusTo(t)()
w := &tstwriter{localAddr: srv.LocalAddr()}
q := new(dns.Msg).SetQuestion("s3.amazonaws.com.", dns.TypeA)
rsv.serveDNS(w, q)
resp := w.GetResponse()
checkNonNullResponse(t, resp)
t.Log("Response: ", resp.String())
checkDNSResponseCode(t, resp, dns.RcodeSuccess)
assert.Assert(t, len(resp.Answer) >= 1)
checkDNSRRType(t, resp.Answer[0].Header().Rrtype, dns.TypeA)
}
func redirectLogrusTo(t *testing.T) func() {
oldLevel, oldOut := logrus.StandardLogger().Level, logrus.StandardLogger().Out
logrus.StandardLogger().SetLevel(logrus.DebugLevel)
logrus.SetOutput(tlogWriter{t})
return func() {
logrus.StandardLogger().SetLevel(oldLevel)
logrus.StandardLogger().SetOutput(oldOut)
}
}
type tlogWriter struct{ t *testing.T }
func (w tlogWriter) Write(p []byte) (n int, err error) {
w.t.Logf("%s", p)
return len(p), nil
}
type noopDNSBackend struct{ DNSBackend }
func (noopDNSBackend) ResolveName(name string, iplen int) ([]net.IP, bool) { return nil, false }
func (noopDNSBackend) ExecFunc(f func()) error { f(); return nil }
func (noopDNSBackend) NdotsSet() bool { return false }
func (noopDNSBackend) HandleQueryResp(name string, ip net.IP) {}
func TestReplySERVFAIL(t *testing.T) {
cases := []struct {
name string
q *dns.Msg
proxyDNS bool
}{
{
name: "InternalError",
q: new(dns.Msg).SetQuestion("_sip._tcp.example.com.", dns.TypeSRV),
},
{
name: "ProxyDNS=false",
q: new(dns.Msg).SetQuestion("example.com.", dns.TypeA),
},
{
name: "ProxyDNS=true", // No extDNS servers configured -> no answer from any upstream
q: new(dns.Msg).SetQuestion("example.com.", dns.TypeA),
proxyDNS: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
defer redirectLogrusTo(t)
rsv := NewResolver("", tt.proxyDNS, badSRVDNSBackend{})
w := &tstwriter{}
rsv.serveDNS(w, tt.q)
resp := w.GetResponse()
checkNonNullResponse(t, resp)
t.Log("Response: ", resp.String())
checkDNSResponseCode(t, resp, dns.RcodeServerFailure)
})
}
}
type badSRVDNSBackend struct{ noopDNSBackend }
func (badSRVDNSBackend) ResolveService(name string) ([]*net.SRV, []net.IP) {
return []*net.SRV{nil, nil, nil}, nil // Mismatched slice lengths
}

View file

@ -4,20 +4,20 @@
package libnetwork
import (
"fmt"
"net"
"github.com/docker/docker/libnetwork/iptables"
"github.com/sirupsen/logrus"
)
const (
// outputChain used for docker embed dns
// output chain used for docker embedded DNS resolver
outputChain = "DOCKER_OUTPUT"
//postroutingchain used for docker embed dns
postroutingchain = "DOCKER_POSTROUTING"
// postrouting chain used for docker embedded DNS resolver
postroutingChain = "DOCKER_POSTROUTING"
)
func (r *resolver) setupIPTable() error {
func (r *Resolver) setupIPTable() error {
if r.err != nil {
return r.err
}
@ -27,36 +27,60 @@ func (r *resolver) setupIPTable() error {
_, tcpPort, _ := net.SplitHostPort(ltcpaddr)
rules := [][]string{
{"-t", "nat", "-I", outputChain, "-d", resolverIP, "-p", "udp", "--dport", dnsPort, "-j", "DNAT", "--to-destination", laddr},
{"-t", "nat", "-I", postroutingchain, "-s", resolverIP, "-p", "udp", "--sport", ipPort, "-j", "SNAT", "--to-source", ":" + dnsPort},
{"-t", "nat", "-I", postroutingChain, "-s", resolverIP, "-p", "udp", "--sport", ipPort, "-j", "SNAT", "--to-source", ":" + dnsPort},
{"-t", "nat", "-I", outputChain, "-d", resolverIP, "-p", "tcp", "--dport", dnsPort, "-j", "DNAT", "--to-destination", ltcpaddr},
{"-t", "nat", "-I", postroutingchain, "-s", resolverIP, "-p", "tcp", "--sport", tcpPort, "-j", "SNAT", "--to-source", ":" + dnsPort},
{"-t", "nat", "-I", postroutingChain, "-s", resolverIP, "-p", "tcp", "--sport", tcpPort, "-j", "SNAT", "--to-source", ":" + dnsPort},
}
return r.backend.ExecFunc(func() {
var setupErr error
err := r.backend.ExecFunc(func() {
// TODO IPv6 support
iptable := iptables.GetIptable(iptables.IPv4)
// insert outputChain and postroutingchain
err := iptable.RawCombinedOutputNative("-t", "nat", "-C", "OUTPUT", "-d", resolverIP, "-j", outputChain)
if err == nil {
iptable.RawCombinedOutputNative("-t", "nat", "-F", outputChain)
if err := iptable.RawCombinedOutputNative("-t", "nat", "-F", outputChain); err != nil {
setupErr = err
return
}
} else {
iptable.RawCombinedOutputNative("-t", "nat", "-N", outputChain)
iptable.RawCombinedOutputNative("-t", "nat", "-I", "OUTPUT", "-d", resolverIP, "-j", outputChain)
if err := iptable.RawCombinedOutputNative("-t", "nat", "-N", outputChain); err != nil {
setupErr = err
return
}
if err := iptable.RawCombinedOutputNative("-t", "nat", "-I", "OUTPUT", "-d", resolverIP, "-j", outputChain); err != nil {
setupErr = err
return
}
}
err = iptable.RawCombinedOutputNative("-t", "nat", "-C", "POSTROUTING", "-d", resolverIP, "-j", postroutingchain)
err = iptable.RawCombinedOutputNative("-t", "nat", "-C", "POSTROUTING", "-d", resolverIP, "-j", postroutingChain)
if err == nil {
iptable.RawCombinedOutputNative("-t", "nat", "-F", postroutingchain)
if err := iptable.RawCombinedOutputNative("-t", "nat", "-F", postroutingChain); err != nil {
setupErr = err
return
}
} else {
iptable.RawCombinedOutputNative("-t", "nat", "-N", postroutingchain)
iptable.RawCombinedOutputNative("-t", "nat", "-I", "POSTROUTING", "-d", resolverIP, "-j", postroutingchain)
if err := iptable.RawCombinedOutputNative("-t", "nat", "-N", postroutingChain); err != nil {
setupErr = err
return
}
if err := iptable.RawCombinedOutputNative("-t", "nat", "-I", "POSTROUTING", "-d", resolverIP, "-j", postroutingChain); err != nil {
setupErr = err
return
}
}
for _, rule := range rules {
if iptable.RawCombinedOutputNative(rule...) != nil {
logrus.Errorf("set up rule failed, %v", rule)
setupErr = fmt.Errorf("set up rule failed, %v", rule)
return
}
}
})
if err != nil {
return err
}
return setupErr
}

View file

@ -3,6 +3,6 @@
package libnetwork
func (r *resolver) setupIPTable() error {
func (r *Resolver) setupIPTable() error {
return nil
}

View file

@ -38,7 +38,7 @@ type Sandbox struct {
extDNS []extDNSEntry
osSbox osl.Sandbox
controller *Controller
resolver Resolver
resolver *Resolver
resolverOnce sync.Once
endpoints []*Endpoint
epPriority map[string]int

View file

@ -91,7 +91,7 @@ require (
golang.org/x/sync v0.1.0
golang.org/x/sys v0.5.0
golang.org/x/text v0.7.0
golang.org/x/time v0.1.0
golang.org/x/time v0.3.0
google.golang.org/genproto v0.0.0-20220706185917-7780775163c4
google.golang.org/grpc v1.50.1
gotest.tools/v3 v3.4.0

View file

@ -1424,8 +1424,8 @@ golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxb
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.1.0 h1:xYY+Bajn2a7VBmTM5GikTmnK8ZuX8YgnQCqZpbBNtmA=
golang.org/x/time v0.1.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181011042414-1f849cf54d09/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View file

@ -83,7 +83,7 @@ func (lim *Limiter) Burst() int {
// TokensAt returns the number of tokens available at time t.
func (lim *Limiter) TokensAt(t time.Time) float64 {
lim.mu.Lock()
_, _, tokens := lim.advance(t) // does not mutute lim
_, tokens := lim.advance(t) // does not mutate lim
lim.mu.Unlock()
return tokens
}
@ -183,7 +183,7 @@ func (r *Reservation) CancelAt(t time.Time) {
return
}
// advance time to now
t, _, tokens := r.lim.advance(t)
t, tokens := r.lim.advance(t)
// calculate new number of tokens
tokens += restoreTokens
if burst := float64(r.lim.burst); tokens > burst {
@ -304,7 +304,7 @@ func (lim *Limiter) SetLimitAt(t time.Time, newLimit Limit) {
lim.mu.Lock()
defer lim.mu.Unlock()
t, _, tokens := lim.advance(t)
t, tokens := lim.advance(t)
lim.last = t
lim.tokens = tokens
@ -321,7 +321,7 @@ func (lim *Limiter) SetBurstAt(t time.Time, newBurst int) {
lim.mu.Lock()
defer lim.mu.Unlock()
t, _, tokens := lim.advance(t)
t, tokens := lim.advance(t)
lim.last = t
lim.tokens = tokens
@ -356,7 +356,7 @@ func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration)
}
}
t, last, tokens := lim.advance(t)
t, tokens := lim.advance(t)
// Calculate the remaining number of tokens resulting from the request.
tokens -= float64(n)
@ -379,15 +379,11 @@ func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration)
if ok {
r.tokens = n
r.timeToAct = t.Add(waitDuration)
}
// Update state
if ok {
// Update state
lim.last = t
lim.tokens = tokens
lim.lastEvent = r.timeToAct
} else {
lim.last = last
}
return r
@ -396,7 +392,7 @@ func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration)
// advance calculates and returns an updated state for lim resulting from the passage of time.
// lim is not changed.
// advance requires that lim.mu is held.
func (lim *Limiter) advance(t time.Time) (newT time.Time, newLast time.Time, newTokens float64) {
func (lim *Limiter) advance(t time.Time) (newT time.Time, newTokens float64) {
last := lim.last
if t.Before(last) {
last = t
@ -409,7 +405,7 @@ func (lim *Limiter) advance(t time.Time) (newT time.Time, newLast time.Time, new
if burst := float64(lim.burst); tokens > burst {
tokens = burst
}
return t, last, tokens
return t, tokens
}
// durationFromTokens is a unit conversion function from the number of tokens to the duration

67
vendor/golang.org/x/time/rate/sometimes.go generated vendored Normal file
View file

@ -0,0 +1,67 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rate
import (
"sync"
"time"
)
// Sometimes will perform an action occasionally. The First, Every, and
// Interval fields govern the behavior of Do, which performs the action.
// A zero Sometimes value will perform an action exactly once.
//
// # Example: logging with rate limiting
//
// var sometimes = rate.Sometimes{First: 3, Interval: 10*time.Second}
// func Spammy() {
// sometimes.Do(func() { log.Info("here I am!") })
// }
type Sometimes struct {
First int // if non-zero, the first N calls to Do will run f.
Every int // if non-zero, every Nth call to Do will run f.
Interval time.Duration // if non-zero and Interval has elapsed since f's last run, Do will run f.
mu sync.Mutex
count int // number of Do calls
last time.Time // last time f was run
}
// Do runs the function f as allowed by First, Every, and Interval.
//
// The model is a union (not intersection) of filters. The first call to Do
// always runs f. Subsequent calls to Do run f if allowed by First or Every or
// Interval.
//
// A non-zero First:N causes the first N Do(f) calls to run f.
//
// A non-zero Every:M causes every Mth Do(f) call, starting with the first, to
// run f.
//
// A non-zero Interval causes Do(f) to run f if Interval has elapsed since
// Do last ran f.
//
// Specifying multiple filters produces the union of these execution streams.
// For example, specifying both First:N and Every:M causes the first N Do(f)
// calls and every Mth Do(f) call, starting with the first, to run f. See
// Examples for more.
//
// If Do is called multiple times simultaneously, the calls will block and run
// serially. Therefore, Do is intended for lightweight operations.
//
// Because a call to Do may block until f returns, if f causes Do to be called,
// it will deadlock.
func (s *Sometimes) Do(f func()) {
s.mu.Lock()
defer s.mu.Unlock()
if s.count == 0 ||
(s.First > 0 && s.count < s.First) ||
(s.Every > 0 && s.count%s.Every == 0) ||
(s.Interval > 0 && time.Since(s.last) >= s.Interval) {
f()
s.last = time.Now()
}
s.count++
}

2
vendor/modules.txt vendored
View file

@ -1087,7 +1087,7 @@ golang.org/x/text/secure/bidirule
golang.org/x/text/transform
golang.org/x/text/unicode/bidi
golang.org/x/text/unicode/norm
# golang.org/x/time v0.1.0
# golang.org/x/time v0.3.0
## explicit
golang.org/x/time/rate
# google.golang.org/api v0.93.0