瀏覽代碼

Merge pull request #1073 from sanimej/udp

Fix the handling for concurrent queries over UDP
Alessandro Boch 9 年之前
父節點
當前提交
90a1eb68e4
共有 1 個文件被更改,包括 76 次插入10 次删除
  1. 76 10
      libnetwork/resolver.go

+ 76 - 10
libnetwork/resolver.go

@@ -49,8 +49,14 @@ const (
 	defaultRespSize = 512
 	maxConcurrent   = 50
 	logInterval     = 2 * time.Second
+	maxDNSID        = 65536
 )
 
+type clientConn struct {
+	dnsID      uint16
+	respWriter dns.ResponseWriter
+}
+
 type extDNSEntry struct {
 	ipStr   string
 	extConn net.Conn
@@ -69,6 +75,7 @@ type resolver struct {
 	count      int32
 	tStamp     time.Time
 	queryLock  sync.Mutex
+	client     map[uint16]clientConn
 }
 
 func init() {
@@ -78,8 +85,9 @@ func init() {
 // NewResolver creates a new instance of the Resolver
 func NewResolver(sb *sandbox) Resolver {
 	return &resolver{
-		sb:  sb,
-		err: fmt.Errorf("setup not done yet"),
+		sb:     sb,
+		err:    fmt.Errorf("setup not done yet"),
+		client: make(map[uint16]clientConn),
 	}
 }
 
@@ -375,7 +383,9 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
 			extConn.SetDeadline(time.Now().Add(extIOTimeout))
 			co := &dns.Conn{Conn: extConn}
 
-			if r.concurrentQueryInc() == false {
+			// forwardQueryStart stores required context to mux multiple client queries over
+			// one connection; and limits the number of outstanding concurrent queries.
+			if r.forwardQueryStart(w, query) == false {
 				old := r.tStamp
 				r.tStamp = time.Now()
 				if r.tStamp.Sub(old) > logInterval {
@@ -391,18 +401,25 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
 			}()
 			err = co.WriteMsg(query)
 			if err != nil {
-				r.concurrentQueryDec()
+				r.forwardQueryEnd(w, query)
 				log.Debugf("Send to DNS server failed, %s", err)
 				continue
 			}
 
 			resp, err = co.ReadMsg()
-			r.concurrentQueryDec()
 			if err != nil {
+				r.forwardQueryEnd(w, query)
 				log.Debugf("Read from DNS server failed, %s", err)
 				continue
 			}
 
+			// Retrieves the context for the forwarded query and returns the client connection
+			// to send the reply to
+			w = r.forwardQueryEnd(w, resp)
+			if w == nil {
+				continue
+			}
+
 			resp.Compress = true
 			break
 		}
@@ -418,22 +435,71 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
 	}
 }
 
-func (r *resolver) concurrentQueryInc() bool {
+func (r *resolver) forwardQueryStart(w dns.ResponseWriter, msg *dns.Msg) bool {
+	proto := w.LocalAddr().Network()
+	dnsID := uint16(rand.Intn(maxDNSID))
+
+	cc := clientConn{
+		dnsID:      msg.Id,
+		respWriter: w,
+	}
+
 	r.queryLock.Lock()
 	defer r.queryLock.Unlock()
+
 	if r.count == maxConcurrent {
 		return false
 	}
 	r.count++
+
+	switch proto {
+	case "tcp":
+		break
+	case "udp":
+		for ok := true; ok == true; dnsID = uint16(rand.Intn(maxDNSID)) {
+			_, ok = r.client[dnsID]
+		}
+		log.Debugf("client dns id %v, changed id %v", msg.Id, dnsID)
+		r.client[dnsID] = cc
+		msg.Id = dnsID
+	default:
+		log.Errorf("Invalid protocol..")
+		return false
+	}
+
 	return true
 }
 
-func (r *resolver) concurrentQueryDec() bool {
+func (r *resolver) forwardQueryEnd(w dns.ResponseWriter, msg *dns.Msg) dns.ResponseWriter {
+	var (
+		cc clientConn
+		ok bool
+	)
+	proto := w.LocalAddr().Network()
+
 	r.queryLock.Lock()
 	defer r.queryLock.Unlock()
+
 	if r.count == 0 {
-		return false
+		log.Errorf("Invalid concurrent query count")
+	} else {
+		r.count--
 	}
-	r.count--
-	return true
+
+	switch proto {
+	case "tcp":
+		break
+	case "udp":
+		if cc, ok = r.client[msg.Id]; ok == false {
+			log.Debugf("Can't retrieve client context for dns id %v", msg.Id)
+			return nil
+		}
+		delete(r.client, msg.Id)
+		msg.Id = cc.dnsID
+		w = cc.respWriter
+	default:
+		log.Errorf("Invalid protocol")
+		return nil
+	}
+	return w
 }