浏览代码

Merge pull request #1125 from sanimej/bugs

Fix a panic in handling forwarded queries
Jana Radhakrishnan 9 年之前
父节点
当前提交
4d59574cb3
共有 1 个文件被更改,包括 13 次插入9 次删除
  1. 13 9
      libnetwork/resolver.go

+ 13 - 9
libnetwork/resolver.go

@@ -292,6 +292,7 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
 		extConn net.Conn
 		resp    *dns.Msg
 		err     error
+		writer  dns.ResponseWriter
 	)
 
 	if query == nil || len(query.Question) == 0 {
@@ -329,7 +330,9 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
 		if resp.Len() > maxSize {
 			truncateResp(resp, maxSize, proto == "tcp")
 		}
+		writer = w
 	} else {
+		queryID := query.Id
 		for i := 0; i < maxExtDNS; i++ {
 			extDNS := &r.extDNSList[i]
 			if extDNS.ipStr == "" {
@@ -375,7 +378,7 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
 
 			// forwardQueryStart stores required context to mux multiple client queries over
 			// one connection; and limits the number of outstanding concurrent queries.
-			if r.forwardQueryStart(w, query) == false {
+			if r.forwardQueryStart(w, query, queryID) == false {
 				old := r.tStamp
 				r.tStamp = time.Now()
 				if r.tStamp.Sub(old) > logInterval {
@@ -405,32 +408,33 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
 
 			// Retrieves the context for the forwarded query and returns the client connection
 			// to send the reply to
-			w = r.forwardQueryEnd(w, resp)
-			if w == nil {
+			writer = r.forwardQueryEnd(w, resp)
+			if writer == nil {
 				continue
 			}
 
 			resp.Compress = true
 			break
 		}
-
-		if resp == nil || w == nil {
+		if resp == nil || writer == nil {
 			return
 		}
 	}
 
-	err = w.WriteMsg(resp)
-	if err != nil {
+	if writer == nil {
+		return
+	}
+	if err = writer.WriteMsg(resp); err != nil {
 		log.Errorf("error writing resolver resp, %s", err)
 	}
 }
 
-func (r *resolver) forwardQueryStart(w dns.ResponseWriter, msg *dns.Msg) bool {
+func (r *resolver) forwardQueryStart(w dns.ResponseWriter, msg *dns.Msg, queryID uint16) bool {
 	proto := w.LocalAddr().Network()
 	dnsID := uint16(rand.Intn(maxDNSID))
 
 	cc := clientConn{
-		dnsID:      msg.Id,
+		dnsID:      queryID,
 		respWriter: w,
 	}