|
@@ -30,9 +30,6 @@ type Resolver interface {
|
|
|
// SetExtServers configures the external nameservers the resolver
|
|
|
// should use to forward queries
|
|
|
SetExtServers([]string)
|
|
|
- // FlushExtServers clears the cached UDP connections to external
|
|
|
- // nameservers
|
|
|
- FlushExtServers()
|
|
|
// ResolverOptions returns resolv.conf options that should be set
|
|
|
ResolverOptions() []string
|
|
|
}
|
|
@@ -48,35 +45,12 @@ const (
|
|
|
defaultRespSize = 512
|
|
|
maxConcurrent = 100
|
|
|
logInterval = 2 * time.Second
|
|
|
- maxDNSID = 65536
|
|
|
)
|
|
|
|
|
|
-type clientConn struct {
|
|
|
- dnsID uint16
|
|
|
- respWriter dns.ResponseWriter
|
|
|
-}
|
|
|
-
|
|
|
type extDNSEntry struct {
|
|
|
- ipStr string
|
|
|
- extConn net.Conn
|
|
|
- extOnce sync.Once
|
|
|
-}
|
|
|
-
|
|
|
-type sboxQuery struct {
|
|
|
- sboxID string
|
|
|
- dnsID uint16
|
|
|
+ ipStr string
|
|
|
}
|
|
|
|
|
|
-type clientConnGC struct {
|
|
|
- toDelete bool
|
|
|
- client clientConn
|
|
|
-}
|
|
|
-
|
|
|
-var (
|
|
|
- queryGCMutex sync.Mutex
|
|
|
- queryGC map[sboxQuery]*clientConnGC
|
|
|
-)
|
|
|
-
|
|
|
// resolver implements the Resolver interface
|
|
|
type resolver struct {
|
|
|
sb *sandbox
|
|
@@ -89,34 +63,17 @@ type resolver struct {
|
|
|
count int32
|
|
|
tStamp time.Time
|
|
|
queryLock sync.Mutex
|
|
|
- client map[uint16]clientConn
|
|
|
}
|
|
|
|
|
|
func init() {
|
|
|
rand.Seed(time.Now().Unix())
|
|
|
- queryGC = make(map[sboxQuery]*clientConnGC)
|
|
|
- go func() {
|
|
|
- ticker := time.NewTicker(1 * time.Minute)
|
|
|
- for range ticker.C {
|
|
|
- queryGCMutex.Lock()
|
|
|
- for query, conn := range queryGC {
|
|
|
- if !conn.toDelete {
|
|
|
- conn.toDelete = true
|
|
|
- continue
|
|
|
- }
|
|
|
- delete(queryGC, query)
|
|
|
- }
|
|
|
- queryGCMutex.Unlock()
|
|
|
- }
|
|
|
- }()
|
|
|
}
|
|
|
|
|
|
// NewResolver creates a new instance of the Resolver
|
|
|
func NewResolver(sb *sandbox) Resolver {
|
|
|
return &resolver{
|
|
|
- sb: sb,
|
|
|
- err: fmt.Errorf("setup not done yet"),
|
|
|
- client: make(map[uint16]clientConn),
|
|
|
+ sb: sb,
|
|
|
+ err: fmt.Errorf("setup not done yet"),
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -173,20 +130,7 @@ func (r *resolver) Start() error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (r *resolver) FlushExtServers() {
|
|
|
- for i := 0; i < maxExtDNS; i++ {
|
|
|
- if r.extDNSList[i].extConn != nil {
|
|
|
- r.extDNSList[i].extConn.Close()
|
|
|
- }
|
|
|
-
|
|
|
- r.extDNSList[i].extConn = nil
|
|
|
- r.extDNSList[i].extOnce = sync.Once{}
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
func (r *resolver) Stop() {
|
|
|
- r.FlushExtServers()
|
|
|
-
|
|
|
if r.server != nil {
|
|
|
r.server.Shutdown()
|
|
|
}
|
|
@@ -355,7 +299,6 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
|
|
|
extConn net.Conn
|
|
|
resp *dns.Msg
|
|
|
err error
|
|
|
- writer dns.ResponseWriter
|
|
|
)
|
|
|
|
|
|
if query == nil || len(query.Question) == 0 {
|
|
@@ -397,10 +340,7 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
|
|
|
if resp.Len() > maxSize {
|
|
|
truncateResp(resp, maxSize, proto == "tcp")
|
|
|
}
|
|
|
- writer = w
|
|
|
} else {
|
|
|
- queryID := query.Id
|
|
|
- extQueryLoop:
|
|
|
for i := 0; i < maxExtDNS; i++ {
|
|
|
extDNS := &r.extDNSList[i]
|
|
|
if extDNS.ipStr == "" {
|
|
@@ -411,30 +351,9 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
|
|
|
extConn, err = net.DialTimeout(proto, addr, extIOTimeout)
|
|
|
}
|
|
|
|
|
|
- // For udp clients connection is persisted to reuse for further queries.
|
|
|
- // Accessing extDNS.extConn be a race here between go rouines. Hence the
|
|
|
- // connection setup is done in a Once block and fetch the extConn again
|
|
|
- extConn = extDNS.extConn
|
|
|
- if extConn == nil || proto == "tcp" {
|
|
|
- if proto == "udp" {
|
|
|
- extDNS.extOnce.Do(func() {
|
|
|
- r.sb.execFunc(extConnect)
|
|
|
- extDNS.extConn = extConn
|
|
|
- })
|
|
|
- extConn = extDNS.extConn
|
|
|
- } else {
|
|
|
- r.sb.execFunc(extConnect)
|
|
|
- }
|
|
|
- if err != nil {
|
|
|
- log.Debugf("Connect failed, %s", err)
|
|
|
- continue
|
|
|
- }
|
|
|
- }
|
|
|
- // If two go routines are executing in parralel one will
|
|
|
- // block on the Once.Do and in case of error connecting
|
|
|
- // to the external server it will end up with a nil err
|
|
|
- // but extConn also being nil.
|
|
|
- if extConn == nil {
|
|
|
+ r.sb.execFunc(extConnect)
|
|
|
+ if err != nil {
|
|
|
+ log.Debugf("Connect failed, %s", err)
|
|
|
continue
|
|
|
}
|
|
|
log.Debugf("Query %s[%d] from %s, forwarding to %s:%s", name, query.Question[0].Qtype,
|
|
@@ -443,10 +362,10 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
|
|
|
// Timeout has to be set for every IO operation.
|
|
|
extConn.SetDeadline(time.Now().Add(extIOTimeout))
|
|
|
co := &dns.Conn{Conn: extConn}
|
|
|
+ defer co.Close()
|
|
|
|
|
|
- // forwardQueryStart stores required context to mux multiple client queries over
|
|
|
- // one connection; and limits the number of outstanding concurrent queries.
|
|
|
- if r.forwardQueryStart(w, query, queryID) == false {
|
|
|
+ // limits the number of outstanding concurrent queries.
|
|
|
+ if r.forwardQueryStart() == false {
|
|
|
old := r.tStamp
|
|
|
r.tStamp = time.Now()
|
|
|
if r.tStamp.Sub(old) > logInterval {
|
|
@@ -455,69 +374,38 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- defer func() {
|
|
|
- if proto == "tcp" {
|
|
|
- co.Close()
|
|
|
- }
|
|
|
- }()
|
|
|
err = co.WriteMsg(query)
|
|
|
if err != nil {
|
|
|
- r.forwardQueryEnd(w, query)
|
|
|
+ r.forwardQueryEnd()
|
|
|
log.Debugf("Send to DNS server failed, %s", err)
|
|
|
continue
|
|
|
}
|
|
|
- for {
|
|
|
- // If a reply comes after a read timeout it will remain in the socket buffer
|
|
|
- // and will be read after sending next query. To ignore such stale replies
|
|
|
- // save the query context in a GC queue when read timesout. On the next reply
|
|
|
- // if the context is present in the GC queue its a old reply. Ignore it and
|
|
|
- // read again
|
|
|
- resp, err = co.ReadMsg()
|
|
|
- if err != nil {
|
|
|
- if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
|
|
- r.addQueryToGC(w, query)
|
|
|
- }
|
|
|
- r.forwardQueryEnd(w, query)
|
|
|
- log.Debugf("Read from DNS server failed, %s", err)
|
|
|
- continue extQueryLoop
|
|
|
- }
|
|
|
|
|
|
- if !r.checkRespInGC(w, resp) {
|
|
|
- break
|
|
|
- }
|
|
|
- }
|
|
|
- // Retrieves the context for the forwarded query and returns the client connection
|
|
|
- // to send the reply to
|
|
|
- writer = r.forwardQueryEnd(w, resp)
|
|
|
- if writer == nil {
|
|
|
+ resp, err = co.ReadMsg()
|
|
|
+ // Truncated DNS replies should be sent to the client so that the
|
|
|
+ // client can retry over TCP
|
|
|
+ if err != nil && err != dns.ErrTruncated {
|
|
|
+ r.forwardQueryEnd()
|
|
|
+ log.Debugf("Read from DNS server failed, %s", err)
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
+ r.forwardQueryEnd()
|
|
|
+
|
|
|
resp.Compress = true
|
|
|
break
|
|
|
}
|
|
|
- if resp == nil || writer == nil {
|
|
|
+ if resp == nil {
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if writer == nil {
|
|
|
- return
|
|
|
- }
|
|
|
- if err = writer.WriteMsg(resp); err != nil {
|
|
|
+ if err = w.WriteMsg(resp); err != nil {
|
|
|
log.Errorf("error writing resolver resp, %s", err)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (r *resolver) forwardQueryStart(w dns.ResponseWriter, msg *dns.Msg, queryID uint16) bool {
|
|
|
- proto := w.LocalAddr().Network()
|
|
|
- dnsID := uint16(rand.Intn(maxDNSID))
|
|
|
-
|
|
|
- cc := clientConn{
|
|
|
- dnsID: queryID,
|
|
|
- respWriter: w,
|
|
|
- }
|
|
|
-
|
|
|
+func (r *resolver) forwardQueryStart() bool {
|
|
|
r.queryLock.Lock()
|
|
|
defer r.queryLock.Unlock()
|
|
|
|
|
@@ -526,74 +414,10 @@ func (r *resolver) forwardQueryStart(w dns.ResponseWriter, msg *dns.Msg, queryID
|
|
|
}
|
|
|
r.count++
|
|
|
|
|
|
- switch proto {
|
|
|
- case "tcp":
|
|
|
- break
|
|
|
- case "udp":
|
|
|
- for ok := true; ok == true; dnsID = uint16(rand.Intn(maxDNSID)) {
|
|
|
- _, ok = r.client[dnsID]
|
|
|
- }
|
|
|
- log.Debugf("client dns id %v, changed id %v", queryID, dnsID)
|
|
|
- r.client[dnsID] = cc
|
|
|
- msg.Id = dnsID
|
|
|
- default:
|
|
|
- log.Errorf("Invalid protocol..")
|
|
|
- return false
|
|
|
- }
|
|
|
-
|
|
|
return true
|
|
|
}
|
|
|
|
|
|
-func (r *resolver) addQueryToGC(w dns.ResponseWriter, msg *dns.Msg) {
|
|
|
- if w.LocalAddr().Network() != "udp" {
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- r.queryLock.Lock()
|
|
|
- cc, ok := r.client[msg.Id]
|
|
|
- r.queryLock.Unlock()
|
|
|
- if !ok {
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- query := sboxQuery{
|
|
|
- sboxID: r.sb.ID(),
|
|
|
- dnsID: msg.Id,
|
|
|
- }
|
|
|
- clientGC := &clientConnGC{
|
|
|
- client: cc,
|
|
|
- }
|
|
|
- queryGCMutex.Lock()
|
|
|
- queryGC[query] = clientGC
|
|
|
- queryGCMutex.Unlock()
|
|
|
-}
|
|
|
-
|
|
|
-func (r *resolver) checkRespInGC(w dns.ResponseWriter, msg *dns.Msg) bool {
|
|
|
- if w.LocalAddr().Network() != "udp" {
|
|
|
- return false
|
|
|
- }
|
|
|
-
|
|
|
- query := sboxQuery{
|
|
|
- sboxID: r.sb.ID(),
|
|
|
- dnsID: msg.Id,
|
|
|
- }
|
|
|
-
|
|
|
- queryGCMutex.Lock()
|
|
|
- defer queryGCMutex.Unlock()
|
|
|
- if _, ok := queryGC[query]; ok {
|
|
|
- delete(queryGC, query)
|
|
|
- return true
|
|
|
- }
|
|
|
- return false
|
|
|
-}
|
|
|
-
|
|
|
-func (r *resolver) forwardQueryEnd(w dns.ResponseWriter, msg *dns.Msg) dns.ResponseWriter {
|
|
|
- var (
|
|
|
- cc clientConn
|
|
|
- ok bool
|
|
|
- )
|
|
|
- proto := w.LocalAddr().Network()
|
|
|
-
|
|
|
+func (r *resolver) forwardQueryEnd() {
|
|
|
r.queryLock.Lock()
|
|
|
defer r.queryLock.Unlock()
|
|
|
|
|
@@ -602,22 +426,4 @@ func (r *resolver) forwardQueryEnd(w dns.ResponseWriter, msg *dns.Msg) dns.Respo
|
|
|
} else {
|
|
|
r.count--
|
|
|
}
|
|
|
-
|
|
|
- switch proto {
|
|
|
- case "tcp":
|
|
|
- break
|
|
|
- case "udp":
|
|
|
- if cc, ok = r.client[msg.Id]; ok == false {
|
|
|
- log.Debugf("Can't retrieve client context for dns id %v", msg.Id)
|
|
|
- return nil
|
|
|
- }
|
|
|
- log.Debugf("dns msg id %v, client id %v", msg.Id, cc.dnsID)
|
|
|
- delete(r.client, msg.Id)
|
|
|
- msg.Id = cc.dnsID
|
|
|
- w = cc.respWriter
|
|
|
- default:
|
|
|
- log.Errorf("Invalid protocol")
|
|
|
- return nil
|
|
|
- }
|
|
|
- return w
|
|
|
}
|