Prechádzať zdrojové kódy

libnetwork: reply SERVFAIL on resolve error

...instead of silently dropping the DNS query.

Signed-off-by: Cory Snider <csnider@mirantis.com>
Cory Snider 2 rokov pred
rodič
commit
0bd30e90bb
2 zmenil súbory, kde vykonal 35 pridanie a 8 odobranie
  1. 1 1
      libnetwork/resolver.go
  2. 34 7
      libnetwork/resolver_test.go

+ 1 - 1
libnetwork/resolver.go

@@ -399,7 +399,7 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
 
 
 	if err != nil {
 	if err != nil {
 		logrus.WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType])
 		logrus.WithError(err).Errorf("[resolver] failed to handle query: %s (%s)", queryName, dns.TypeToString[queryType])
-		return
+		resp = new(dns.Msg).SetRcode(query, dns.RcodeServerFailure)
 	}
 	}
 
 
 	if resp == nil {
 	if resp == nil {

+ 34 - 7
libnetwork/resolver_test.go

@@ -59,12 +59,14 @@ func (w *tstwriter) GetResponse() *dns.Msg { return w.msg }
 func (w *tstwriter) ClearResponse() { w.msg = nil }
 func (w *tstwriter) ClearResponse() { w.msg = nil }
 
 
 func checkNonNullResponse(t *testing.T, m *dns.Msg) {
 func checkNonNullResponse(t *testing.T, m *dns.Msg) {
+	t.Helper()
 	if m == nil {
 	if m == nil {
 		t.Fatal("Null DNS response found. Non Null response msg expected.")
 		t.Fatal("Null DNS response found. Non Null response msg expected.")
 	}
 	}
 }
 }
 
 
 func checkDNSAnswersCount(t *testing.T, m *dns.Msg, expected int) {
 func checkDNSAnswersCount(t *testing.T, m *dns.Msg, expected int) {
+	t.Helper()
 	answers := len(m.Answer)
 	answers := len(m.Answer)
 	if answers != expected {
 	if answers != expected {
 		t.Fatalf("Expected number of answers in response: %d. Found: %d", expected, answers)
 		t.Fatalf("Expected number of answers in response: %d. Found: %d", expected, answers)
@@ -72,12 +74,14 @@ func checkDNSAnswersCount(t *testing.T, m *dns.Msg, expected int) {
 }
 }
 
 
 func checkDNSResponseCode(t *testing.T, m *dns.Msg, expected int) {
 func checkDNSResponseCode(t *testing.T, m *dns.Msg, expected int) {
+	t.Helper()
 	if m.MsgHdr.Rcode != expected {
 	if m.MsgHdr.Rcode != expected {
 		t.Fatalf("Expected DNS response code: %d. Found: %d", expected, m.MsgHdr.Rcode)
 		t.Fatalf("Expected DNS response code: %d. Found: %d", expected, m.MsgHdr.Rcode)
 	}
 	}
 }
 }
 
 
 func checkDNSRRType(t *testing.T, actual, expected uint16) {
 func checkDNSRRType(t *testing.T, actual, expected uint16) {
+	t.Helper()
 	if actual != expected {
 	if actual != expected {
 		t.Fatalf("Expected DNS Rrtype: %d. Found: %d", expected, actual)
 		t.Fatalf("Expected DNS Rrtype: %d. Found: %d", expected, actual)
 	}
 	}
@@ -373,13 +377,7 @@ func TestOversizedDNSReply(t *testing.T) {
 
 
 	// The resolver logs lots of valuable info at level debug. Redirect it
 	// The resolver logs lots of valuable info at level debug. Redirect it
 	// to t.Log() so the log spew is emitted only if the test fails.
 	// to t.Log() so the log spew is emitted only if the test fails.
-	oldLevel, oldOut := logrus.StandardLogger().Level, logrus.StandardLogger().Out
-	defer func() {
-		logrus.StandardLogger().SetLevel(oldLevel)
-		logrus.StandardLogger().SetOutput(oldOut)
-	}()
-	logrus.StandardLogger().SetLevel(logrus.DebugLevel)
-	logrus.SetOutput(tlogWriter{t})
+	defer redirectLogrusTo(t)()
 
 
 	w := &tstwriter{localAddr: srv.LocalAddr()}
 	w := &tstwriter{localAddr: srv.LocalAddr()}
 	q := new(dns.Msg).SetQuestion("s3.amazonaws.com.", dns.TypeA)
 	q := new(dns.Msg).SetQuestion("s3.amazonaws.com.", dns.TypeA)
@@ -392,6 +390,16 @@ func TestOversizedDNSReply(t *testing.T) {
 	checkDNSRRType(t, resp.Answer[0].Header().Rrtype, dns.TypeA)
 	checkDNSRRType(t, resp.Answer[0].Header().Rrtype, dns.TypeA)
 }
 }
 
 
+func redirectLogrusTo(t *testing.T) func() {
+	oldLevel, oldOut := logrus.StandardLogger().Level, logrus.StandardLogger().Out
+	logrus.StandardLogger().SetLevel(logrus.DebugLevel)
+	logrus.SetOutput(tlogWriter{t})
+	return func() {
+		logrus.StandardLogger().SetLevel(oldLevel)
+		logrus.StandardLogger().SetOutput(oldOut)
+	}
+}
+
 type tlogWriter struct{ t *testing.T }
 type tlogWriter struct{ t *testing.T }
 
 
 func (w tlogWriter) Write(p []byte) (n int, err error) {
 func (w tlogWriter) Write(p []byte) (n int, err error) {
@@ -408,3 +416,22 @@ func (noopDNSBackend) ExecFunc(f func()) error { f(); return nil }
 func (noopDNSBackend) NdotsSet() bool { return false }
 func (noopDNSBackend) NdotsSet() bool { return false }
 
 
 func (noopDNSBackend) HandleQueryResp(name string, ip net.IP) {}
 func (noopDNSBackend) HandleQueryResp(name string, ip net.IP) {}
+
+func TestReplySERVFAILOnInternalError(t *testing.T) {
+	defer redirectLogrusTo(t)
+
+	rsv := NewResolver("", false, badSRVDNSBackend{}).(*resolver)
+	w := &tstwriter{}
+	q := new(dns.Msg).SetQuestion("_sip._tcp.example.com.", dns.TypeSRV)
+	rsv.ServeDNS(w, q)
+	resp := w.GetResponse()
+	checkNonNullResponse(t, resp)
+	t.Log("Response: ", resp.String())
+	checkDNSResponseCode(t, resp, dns.RcodeServerFailure)
+}
+
+type badSRVDNSBackend struct{ noopDNSBackend }
+
+func (badSRVDNSBackend) ResolveService(name string) ([]*net.SRV, []net.IP) {
+	return []*net.SRV{nil, nil, nil}, nil // Mismatched slice lengths
+}