Browse Source

Merge pull request #1062 from sanimej/fixes

Limit number of concurrent DNS queries
Alessandro Boch 9 năm trước cách đây
mục cha
commit
8be202014d
1 tập tin đã thay đổi với 41 bổ sung1 xóa
  1. 41 1
      libnetwork/resolver.go

+ 41 - 1
libnetwork/resolver.go

@@ -47,6 +47,8 @@ const (
 	maxExtDNS       = 3 //max number of external servers to try
 	extIOTimeout    = 3 * time.Second
 	defaultRespSize = 512
+	maxConcurrent   = 50
+	logInterval     = 2 * time.Second
 )
 
 type extDNSEntry struct {
@@ -64,6 +66,9 @@ type resolver struct {
 	tcpServer  *dns.Server
 	tcpListen  *net.TCPListener
 	err        error
+	count      int32
+	tStamp     time.Time
+	queryLock  sync.Mutex
 }
 
 func init() {
@@ -162,6 +167,9 @@ func (r *resolver) Stop() {
 	r.conn = nil
 	r.tcpServer = nil
 	r.err = fmt.Errorf("setup not done yet")
+	r.tStamp = time.Time{}
+	r.count = 0
+	r.queryLock = sync.Mutex{}
 }
 
 func (r *resolver) SetExtServers(dns []string) {
@@ -328,7 +336,8 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
 			if extDNS.ipStr == "" {
 				break
 			}
-			log.Debugf("Querying ext dns %s:%s for %s[%d]", proto, extDNS.ipStr, name, query.Question[0].Qtype)
+			log.Debugf("Query %s[%d] from %s, forwarding to %s:%s", name, query.Question[0].Qtype,
+				w.LocalAddr().String(), proto, extDNS.ipStr)
 
 			extConnect := func() {
 				addr := fmt.Sprintf("%s:%d", extDNS.ipStr, 53)
@@ -366,6 +375,15 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
 			extConn.SetDeadline(time.Now().Add(extIOTimeout))
 			co := &dns.Conn{Conn: extConn}
 
+			if r.concurrentQueryInc() == false {
+				old := r.tStamp
+				r.tStamp = time.Now()
+				if r.tStamp.Sub(old) > logInterval {
+					log.Errorf("More than %v concurrent queries from %s", maxConcurrent, w.LocalAddr().String())
+				}
+				continue
+			}
+
 			defer func() {
 				if proto == "tcp" {
 					co.Close()
@@ -373,11 +391,13 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
 			}()
 			err = co.WriteMsg(query)
 			if err != nil {
+				r.concurrentQueryDec()
 				log.Debugf("Send to DNS server failed, %s", err)
 				continue
 			}
 
 			resp, err = co.ReadMsg()
+			r.concurrentQueryDec()
 			if err != nil {
 				log.Debugf("Read from DNS server failed, %s", err)
 				continue
@@ -397,3 +417,23 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
 		log.Errorf("error writing resolver resp, %s", err)
 	}
 }
+
+func (r *resolver) concurrentQueryInc() bool {
+	r.queryLock.Lock()
+	defer r.queryLock.Unlock()
+	if r.count == maxConcurrent {
+		return false
+	}
+	r.count++
+	return true
+}
+
+func (r *resolver) concurrentQueryDec() bool {
+	r.queryLock.Lock()
+	defer r.queryLock.Unlock()
+	if r.count == 0 {
+		return false
+	}
+	r.count--
+	return true
+}