|
@@ -49,8 +49,14 @@ const (
|
|
|
defaultRespSize = 512
|
|
|
maxConcurrent = 50
|
|
|
logInterval = 2 * time.Second
|
|
|
+ maxDNSID = 65536
|
|
|
)
|
|
|
|
|
|
+type clientConn struct {
|
|
|
+ dnsID uint16
|
|
|
+ respWriter dns.ResponseWriter
|
|
|
+}
|
|
|
+
|
|
|
type extDNSEntry struct {
|
|
|
ipStr string
|
|
|
extConn net.Conn
|
|
@@ -69,6 +75,7 @@ type resolver struct {
|
|
|
count int32
|
|
|
tStamp time.Time
|
|
|
queryLock sync.Mutex
|
|
|
+ client map[uint16]clientConn
|
|
|
}
|
|
|
|
|
|
func init() {
|
|
@@ -78,8 +85,9 @@ func init() {
|
|
|
// NewResolver creates a new instance of the Resolver
|
|
|
func NewResolver(sb *sandbox) Resolver {
|
|
|
return &resolver{
|
|
|
- sb: sb,
|
|
|
- err: fmt.Errorf("setup not done yet"),
|
|
|
+ sb: sb,
|
|
|
+ err: fmt.Errorf("setup not done yet"),
|
|
|
+ client: make(map[uint16]clientConn),
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -375,7 +383,9 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
|
|
|
extConn.SetDeadline(time.Now().Add(extIOTimeout))
|
|
|
co := &dns.Conn{Conn: extConn}
|
|
|
|
|
|
- if r.concurrentQueryInc() == false {
|
|
|
+ // forwardQueryStart stores required context to mux multiple client queries over
|
|
|
+ // one connection; and limits the number of outstanding concurrent queries.
|
|
|
+ if r.forwardQueryStart(w, query) == false {
|
|
|
old := r.tStamp
|
|
|
r.tStamp = time.Now()
|
|
|
if r.tStamp.Sub(old) > logInterval {
|
|
@@ -391,18 +401,25 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
|
|
|
}()
|
|
|
err = co.WriteMsg(query)
|
|
|
if err != nil {
|
|
|
- r.concurrentQueryDec()
|
|
|
+ r.forwardQueryEnd(w, query)
|
|
|
log.Debugf("Send to DNS server failed, %s", err)
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
resp, err = co.ReadMsg()
|
|
|
- r.concurrentQueryDec()
|
|
|
if err != nil {
|
|
|
+ r.forwardQueryEnd(w, query)
|
|
|
log.Debugf("Read from DNS server failed, %s", err)
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
+ // Retrieves the context for the forwarded query and returns the client connection
|
|
|
+ // to send the reply to
|
|
|
+ w = r.forwardQueryEnd(w, resp)
|
|
|
+ if w == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
resp.Compress = true
|
|
|
break
|
|
|
}
|
|
@@ -418,22 +435,71 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (r *resolver) concurrentQueryInc() bool {
|
|
|
+func (r *resolver) forwardQueryStart(w dns.ResponseWriter, msg *dns.Msg) bool {
|
|
|
+ proto := w.LocalAddr().Network()
|
|
|
+ dnsID := uint16(rand.Intn(maxDNSID))
|
|
|
+
|
|
|
+ cc := clientConn{
|
|
|
+ dnsID: msg.Id,
|
|
|
+ respWriter: w,
|
|
|
+ }
|
|
|
+
|
|
|
r.queryLock.Lock()
|
|
|
defer r.queryLock.Unlock()
|
|
|
+
|
|
|
if r.count == maxConcurrent {
|
|
|
return false
|
|
|
}
|
|
|
r.count++
|
|
|
+
|
|
|
+ switch proto {
|
|
|
+ case "tcp":
|
|
|
+ break
|
|
|
+ case "udp":
|
|
|
+ for ok := true; ok == true; dnsID = uint16(rand.Intn(maxDNSID)) {
|
|
|
+ _, ok = r.client[dnsID]
|
|
|
+ }
|
|
|
+ log.Debugf("client dns id %v, changed id %v", msg.Id, dnsID)
|
|
|
+ r.client[dnsID] = cc
|
|
|
+ msg.Id = dnsID
|
|
|
+ default:
|
|
|
+ log.Errorf("Invalid protocol..")
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
return true
|
|
|
}
|
|
|
|
|
|
-func (r *resolver) concurrentQueryDec() bool {
|
|
|
+func (r *resolver) forwardQueryEnd(w dns.ResponseWriter, msg *dns.Msg) dns.ResponseWriter {
|
|
|
+ var (
|
|
|
+ cc clientConn
|
|
|
+ ok bool
|
|
|
+ )
|
|
|
+ proto := w.LocalAddr().Network()
|
|
|
+
|
|
|
r.queryLock.Lock()
|
|
|
defer r.queryLock.Unlock()
|
|
|
+
|
|
|
if r.count == 0 {
|
|
|
- return false
|
|
|
+ log.Errorf("Invalid concurrent query count")
|
|
|
+ } else {
|
|
|
+ r.count--
|
|
|
}
|
|
|
- r.count--
|
|
|
- return true
|
|
|
+
|
|
|
+ switch proto {
|
|
|
+ case "tcp":
|
|
|
+ break
|
|
|
+ case "udp":
|
|
|
+ if cc, ok = r.client[msg.Id]; ok == false {
|
|
|
+ log.Debugf("Can't retrieve client context for dns id %v", msg.Id)
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ delete(r.client, msg.Id)
|
|
|
+ msg.Id = cc.dnsID
|
|
|
+ w = cc.respWriter
|
|
|
+ default:
|
|
|
+ log.Errorf("Invalid protocol")
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ return w
|
|
|
}
|