|
@@ -59,12 +59,14 @@ func (w *tstwriter) GetResponse() *dns.Msg { return w.msg }
|
|
func (w *tstwriter) ClearResponse() { w.msg = nil }
|
|
func (w *tstwriter) ClearResponse() { w.msg = nil }
|
|
|
|
|
|
func checkNonNullResponse(t *testing.T, m *dns.Msg) {
|
|
func checkNonNullResponse(t *testing.T, m *dns.Msg) {
|
|
|
|
+ t.Helper()
|
|
if m == nil {
|
|
if m == nil {
|
|
t.Fatal("Null DNS response found. Non Null response msg expected.")
|
|
t.Fatal("Null DNS response found. Non Null response msg expected.")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
func checkDNSAnswersCount(t *testing.T, m *dns.Msg, expected int) {
|
|
func checkDNSAnswersCount(t *testing.T, m *dns.Msg, expected int) {
|
|
|
|
+ t.Helper()
|
|
answers := len(m.Answer)
|
|
answers := len(m.Answer)
|
|
if answers != expected {
|
|
if answers != expected {
|
|
t.Fatalf("Expected number of answers in response: %d. Found: %d", expected, answers)
|
|
t.Fatalf("Expected number of answers in response: %d. Found: %d", expected, answers)
|
|
@@ -72,12 +74,14 @@ func checkDNSAnswersCount(t *testing.T, m *dns.Msg, expected int) {
|
|
}
|
|
}
|
|
|
|
|
|
func checkDNSResponseCode(t *testing.T, m *dns.Msg, expected int) {
|
|
func checkDNSResponseCode(t *testing.T, m *dns.Msg, expected int) {
|
|
|
|
+ t.Helper()
|
|
if m.MsgHdr.Rcode != expected {
|
|
if m.MsgHdr.Rcode != expected {
|
|
t.Fatalf("Expected DNS response code: %d. Found: %d", expected, m.MsgHdr.Rcode)
|
|
t.Fatalf("Expected DNS response code: %d. Found: %d", expected, m.MsgHdr.Rcode)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
func checkDNSRRType(t *testing.T, actual, expected uint16) {
|
|
func checkDNSRRType(t *testing.T, actual, expected uint16) {
|
|
|
|
+ t.Helper()
|
|
if actual != expected {
|
|
if actual != expected {
|
|
t.Fatalf("Expected DNS Rrtype: %d. Found: %d", expected, actual)
|
|
t.Fatalf("Expected DNS Rrtype: %d. Found: %d", expected, actual)
|
|
}
|
|
}
|
|
@@ -373,13 +377,7 @@ func TestOversizedDNSReply(t *testing.T) {
|
|
|
|
|
|
// The resolver logs lots of valuable info at level debug. Redirect it
|
|
// The resolver logs lots of valuable info at level debug. Redirect it
|
|
// to t.Log() so the log spew is emitted only if the test fails.
|
|
// to t.Log() so the log spew is emitted only if the test fails.
|
|
- oldLevel, oldOut := logrus.StandardLogger().Level, logrus.StandardLogger().Out
|
|
|
|
- defer func() {
|
|
|
|
- logrus.StandardLogger().SetLevel(oldLevel)
|
|
|
|
- logrus.StandardLogger().SetOutput(oldOut)
|
|
|
|
- }()
|
|
|
|
- logrus.StandardLogger().SetLevel(logrus.DebugLevel)
|
|
|
|
- logrus.SetOutput(tlogWriter{t})
|
|
|
|
|
|
+ defer redirectLogrusTo(t)()
|
|
|
|
|
|
w := &tstwriter{localAddr: srv.LocalAddr()}
|
|
w := &tstwriter{localAddr: srv.LocalAddr()}
|
|
q := new(dns.Msg).SetQuestion("s3.amazonaws.com.", dns.TypeA)
|
|
q := new(dns.Msg).SetQuestion("s3.amazonaws.com.", dns.TypeA)
|
|
@@ -392,6 +390,16 @@ func TestOversizedDNSReply(t *testing.T) {
|
|
checkDNSRRType(t, resp.Answer[0].Header().Rrtype, dns.TypeA)
|
|
checkDNSRRType(t, resp.Answer[0].Header().Rrtype, dns.TypeA)
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+func redirectLogrusTo(t *testing.T) func() {
|
|
|
|
+ oldLevel, oldOut := logrus.StandardLogger().Level, logrus.StandardLogger().Out
|
|
|
|
+ logrus.StandardLogger().SetLevel(logrus.DebugLevel)
|
|
|
|
+ logrus.SetOutput(tlogWriter{t})
|
|
|
|
+ return func() {
|
|
|
|
+ logrus.StandardLogger().SetLevel(oldLevel)
|
|
|
|
+ logrus.StandardLogger().SetOutput(oldOut)
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
type tlogWriter struct{ t *testing.T }
|
|
type tlogWriter struct{ t *testing.T }
|
|
|
|
|
|
func (w tlogWriter) Write(p []byte) (n int, err error) {
|
|
func (w tlogWriter) Write(p []byte) (n int, err error) {
|
|
@@ -408,3 +416,22 @@ func (noopDNSBackend) ExecFunc(f func()) error { f(); return nil }
|
|
func (noopDNSBackend) NdotsSet() bool { return false }
|
|
func (noopDNSBackend) NdotsSet() bool { return false }
|
|
|
|
|
|
func (noopDNSBackend) HandleQueryResp(name string, ip net.IP) {}
|
|
func (noopDNSBackend) HandleQueryResp(name string, ip net.IP) {}
|
|
|
|
+
|
|
|
|
+func TestReplySERVFAILOnInternalError(t *testing.T) {
|
|
|
|
+ defer redirectLogrusTo(t)
|
|
|
|
+
|
|
|
|
+ rsv := NewResolver("", false, badSRVDNSBackend{}).(*resolver)
|
|
|
|
+ w := &tstwriter{}
|
|
|
|
+ q := new(dns.Msg).SetQuestion("_sip._tcp.example.com.", dns.TypeSRV)
|
|
|
|
+ rsv.ServeDNS(w, q)
|
|
|
|
+ resp := w.GetResponse()
|
|
|
|
+ checkNonNullResponse(t, resp)
|
|
|
|
+ t.Log("Response: ", resp.String())
|
|
|
|
+ checkDNSResponseCode(t, resp, dns.RcodeServerFailure)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+type badSRVDNSBackend struct{ noopDNSBackend }
|
|
|
|
+
|
|
|
|
+func (badSRVDNSBackend) ResolveService(name string) ([]*net.SRV, []net.IP) {
|
|
|
|
+ return []*net.SRV{nil, nil, nil}, nil // Mismatched slice lengths
|
|
|
|
+}
|