|
@@ -47,6 +47,8 @@ const (
|
|
maxExtDNS = 3 //max number of external servers to try
|
|
maxExtDNS = 3 //max number of external servers to try
|
|
extIOTimeout = 3 * time.Second
|
|
extIOTimeout = 3 * time.Second
|
|
defaultRespSize = 512
|
|
defaultRespSize = 512
|
|
|
|
+ maxConcurrent = 50
|
|
|
|
+ logInterval = 2 * time.Second
|
|
)
|
|
)
|
|
|
|
|
|
type extDNSEntry struct {
|
|
type extDNSEntry struct {
|
|
@@ -64,6 +66,9 @@ type resolver struct {
|
|
tcpServer *dns.Server
|
|
tcpServer *dns.Server
|
|
tcpListen *net.TCPListener
|
|
tcpListen *net.TCPListener
|
|
err error
|
|
err error
|
|
|
|
+ count int32
|
|
|
|
+ tStamp time.Time
|
|
|
|
+ queryLock sync.Mutex
|
|
}
|
|
}
|
|
|
|
|
|
func init() {
|
|
func init() {
|
|
@@ -162,6 +167,9 @@ func (r *resolver) Stop() {
|
|
r.conn = nil
|
|
r.conn = nil
|
|
r.tcpServer = nil
|
|
r.tcpServer = nil
|
|
r.err = fmt.Errorf("setup not done yet")
|
|
r.err = fmt.Errorf("setup not done yet")
|
|
|
|
+ r.tStamp = time.Time{}
|
|
|
|
+ r.count = 0
|
|
|
|
+ r.queryLock = sync.Mutex{}
|
|
}
|
|
}
|
|
|
|
|
|
func (r *resolver) SetExtServers(dns []string) {
|
|
func (r *resolver) SetExtServers(dns []string) {
|
|
@@ -328,7 +336,8 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
|
|
if extDNS.ipStr == "" {
|
|
if extDNS.ipStr == "" {
|
|
break
|
|
break
|
|
}
|
|
}
|
|
- log.Debugf("Querying ext dns %s:%s for %s[%d]", proto, extDNS.ipStr, name, query.Question[0].Qtype)
|
|
|
|
|
|
+ log.Debugf("Query %s[%d] from %s, forwarding to %s:%s", name, query.Question[0].Qtype,
|
|
|
|
+ w.LocalAddr().String(), proto, extDNS.ipStr)
|
|
|
|
|
|
extConnect := func() {
|
|
extConnect := func() {
|
|
addr := fmt.Sprintf("%s:%d", extDNS.ipStr, 53)
|
|
addr := fmt.Sprintf("%s:%d", extDNS.ipStr, 53)
|
|
@@ -366,6 +375,15 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
|
|
extConn.SetDeadline(time.Now().Add(extIOTimeout))
|
|
extConn.SetDeadline(time.Now().Add(extIOTimeout))
|
|
co := &dns.Conn{Conn: extConn}
|
|
co := &dns.Conn{Conn: extConn}
|
|
|
|
|
|
|
|
+ if r.concurrentQueryInc() == false {
|
|
|
|
+ old := r.tStamp
|
|
|
|
+ r.tStamp = time.Now()
|
|
|
|
+ if r.tStamp.Sub(old) > logInterval {
|
|
|
|
+ log.Errorf("More than %v concurrent queries from %s", maxConcurrent, w.LocalAddr().String())
|
|
|
|
+ }
|
|
|
|
+ continue
|
|
|
|
+ }
|
|
|
|
+
|
|
defer func() {
|
|
defer func() {
|
|
if proto == "tcp" {
|
|
if proto == "tcp" {
|
|
co.Close()
|
|
co.Close()
|
|
@@ -373,11 +391,13 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
|
|
}()
|
|
}()
|
|
err = co.WriteMsg(query)
|
|
err = co.WriteMsg(query)
|
|
if err != nil {
|
|
if err != nil {
|
|
|
|
+ r.concurrentQueryDec()
|
|
log.Debugf("Send to DNS server failed, %s", err)
|
|
log.Debugf("Send to DNS server failed, %s", err)
|
|
continue
|
|
continue
|
|
}
|
|
}
|
|
|
|
|
|
resp, err = co.ReadMsg()
|
|
resp, err = co.ReadMsg()
|
|
|
|
+ r.concurrentQueryDec()
|
|
if err != nil {
|
|
if err != nil {
|
|
log.Debugf("Read from DNS server failed, %s", err)
|
|
log.Debugf("Read from DNS server failed, %s", err)
|
|
continue
|
|
continue
|
|
@@ -397,3 +417,23 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
|
|
log.Errorf("error writing resolver resp, %s", err)
|
|
log.Errorf("error writing resolver resp, %s", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+func (r *resolver) concurrentQueryInc() bool {
|
|
|
|
+ r.queryLock.Lock()
|
|
|
|
+ defer r.queryLock.Unlock()
|
|
|
|
+ if r.count == maxConcurrent {
|
|
|
|
+ return false
|
|
|
|
+ }
|
|
|
|
+ r.count++
|
|
|
|
+ return true
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (r *resolver) concurrentQueryDec() bool {
|
|
|
|
+ r.queryLock.Lock()
|
|
|
|
+ defer r.queryLock.Unlock()
|
|
|
|
+ if r.count == 0 {
|
|
|
|
+ return false
|
|
|
|
+ }
|
|
|
|
+ r.count--
|
|
|
|
+ return true
|
|
|
|
+}
|