|
@@ -3,30 +3,40 @@
|
|
|
package dns
|
|
|
|
|
|
import (
|
|
|
- "bytes"
|
|
|
+ "context"
|
|
|
"crypto/tls"
|
|
|
"encoding/binary"
|
|
|
+ "errors"
|
|
|
"io"
|
|
|
"net"
|
|
|
+ "strings"
|
|
|
"sync"
|
|
|
- "sync/atomic"
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
// Default maximum number of TCP queries before we close the socket.
|
|
|
const maxTCPQueries = 128
|
|
|
|
|
|
-// Interval for stop worker if no load
|
|
|
-const idleWorkerTimeout = 10 * time.Second
|
|
|
-
|
|
|
-// Maximum number of workers
|
|
|
-const maxWorkersCount = 10000
|
|
|
+// aLongTimeAgo is a non-zero time, far in the past, used for
|
|
|
+// immediate cancelation of network operations.
|
|
|
+var aLongTimeAgo = time.Unix(1, 0)
|
|
|
|
|
|
// Handler is implemented by any value that implements ServeDNS.
|
|
|
type Handler interface {
|
|
|
ServeDNS(w ResponseWriter, r *Msg)
|
|
|
}
|
|
|
|
|
|
+// The HandlerFunc type is an adapter to allow the use of
|
|
|
+// ordinary functions as DNS handlers. If f is a function
|
|
|
+// with the appropriate signature, HandlerFunc(f) is a
|
|
|
+// Handler object that calls f.
|
|
|
+type HandlerFunc func(ResponseWriter, *Msg)
|
|
|
+
|
|
|
+// ServeDNS calls f(w, r).
|
|
|
+func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
|
|
|
+ f(w, r)
|
|
|
+}
|
|
|
+
|
|
|
// A ResponseWriter interface is used by an DNS handler to
|
|
|
// construct an DNS response.
|
|
|
type ResponseWriter interface {
|
|
@@ -49,11 +59,17 @@ type ResponseWriter interface {
|
|
|
Hijack()
|
|
|
}
|
|
|
|
|
|
+// A ConnectionStater interface is used by a DNS Handler to access TLS connection state
|
|
|
+// when available.
|
|
|
+type ConnectionStater interface {
|
|
|
+ ConnectionState() *tls.ConnectionState
|
|
|
+}
|
|
|
+
|
|
|
type response struct {
|
|
|
- msg []byte
|
|
|
+ closed bool // connection has been closed
|
|
|
hijacked bool // connection has been hijacked by handler
|
|
|
- tsigStatus error
|
|
|
tsigTimersOnly bool
|
|
|
+ tsigStatus error
|
|
|
tsigRequestMAC string
|
|
|
tsigSecret map[string]string // the tsig secrets
|
|
|
udp *net.UDPConn // i/o connection if UDP was used
|
|
@@ -62,35 +78,6 @@ type response struct {
|
|
|
writer Writer // writer to output the raw DNS bits
|
|
|
}
|
|
|
|
|
|
-// ServeMux is an DNS request multiplexer. It matches the
|
|
|
-// zone name of each incoming request against a list of
|
|
|
-// registered patterns add calls the handler for the pattern
|
|
|
-// that most closely matches the zone name. ServeMux is DNSSEC aware, meaning
|
|
|
-// that queries for the DS record are redirected to the parent zone (if that
|
|
|
-// is also registered), otherwise the child gets the query.
|
|
|
-// ServeMux is also safe for concurrent access from multiple goroutines.
|
|
|
-type ServeMux struct {
|
|
|
- z map[string]Handler
|
|
|
- m *sync.RWMutex
|
|
|
-}
|
|
|
-
|
|
|
-// NewServeMux allocates and returns a new ServeMux.
|
|
|
-func NewServeMux() *ServeMux { return &ServeMux{z: make(map[string]Handler), m: new(sync.RWMutex)} }
|
|
|
-
|
|
|
-// DefaultServeMux is the default ServeMux used by Serve.
|
|
|
-var DefaultServeMux = NewServeMux()
|
|
|
-
|
|
|
-// The HandlerFunc type is an adapter to allow the use of
|
|
|
-// ordinary functions as DNS handlers. If f is a function
|
|
|
-// with the appropriate signature, HandlerFunc(f) is a
|
|
|
-// Handler object that calls f.
|
|
|
-type HandlerFunc func(ResponseWriter, *Msg)
|
|
|
-
|
|
|
-// ServeDNS calls f(w, r).
|
|
|
-func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
|
|
|
- f(w, r)
|
|
|
-}
|
|
|
-
|
|
|
// HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets.
|
|
|
func HandleFailed(w ResponseWriter, r *Msg) {
|
|
|
m := new(Msg)
|
|
@@ -99,8 +86,6 @@ func HandleFailed(w ResponseWriter, r *Msg) {
|
|
|
w.WriteMsg(m)
|
|
|
}
|
|
|
|
|
|
-func failedHandler() Handler { return HandlerFunc(HandleFailed) }
|
|
|
-
|
|
|
// ListenAndServe Starts a server on address and network specified Invoke handler
|
|
|
// for incoming queries.
|
|
|
func ListenAndServe(addr string, network string, handler Handler) error {
|
|
@@ -139,99 +124,6 @@ func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error {
|
|
|
return server.ActivateAndServe()
|
|
|
}
|
|
|
|
|
|
-func (mux *ServeMux) match(q string, t uint16) Handler {
|
|
|
- mux.m.RLock()
|
|
|
- defer mux.m.RUnlock()
|
|
|
- var handler Handler
|
|
|
- b := make([]byte, len(q)) // worst case, one label of length q
|
|
|
- off := 0
|
|
|
- end := false
|
|
|
- for {
|
|
|
- l := len(q[off:])
|
|
|
- for i := 0; i < l; i++ {
|
|
|
- b[i] = q[off+i]
|
|
|
- if b[i] >= 'A' && b[i] <= 'Z' {
|
|
|
- b[i] |= ('a' - 'A')
|
|
|
- }
|
|
|
- }
|
|
|
- if h, ok := mux.z[string(b[:l])]; ok { // causes garbage, might want to change the map key
|
|
|
- if t != TypeDS {
|
|
|
- return h
|
|
|
- }
|
|
|
- // Continue for DS to see if we have a parent too, if so delegeate to the parent
|
|
|
- handler = h
|
|
|
- }
|
|
|
- off, end = NextLabel(q, off)
|
|
|
- if end {
|
|
|
- break
|
|
|
- }
|
|
|
- }
|
|
|
- // Wildcard match, if we have found nothing try the root zone as a last resort.
|
|
|
- if h, ok := mux.z["."]; ok {
|
|
|
- return h
|
|
|
- }
|
|
|
- return handler
|
|
|
-}
|
|
|
-
|
|
|
-// Handle adds a handler to the ServeMux for pattern.
|
|
|
-func (mux *ServeMux) Handle(pattern string, handler Handler) {
|
|
|
- if pattern == "" {
|
|
|
- panic("dns: invalid pattern " + pattern)
|
|
|
- }
|
|
|
- mux.m.Lock()
|
|
|
- mux.z[Fqdn(pattern)] = handler
|
|
|
- mux.m.Unlock()
|
|
|
-}
|
|
|
-
|
|
|
-// HandleFunc adds a handler function to the ServeMux for pattern.
|
|
|
-func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
|
|
|
- mux.Handle(pattern, HandlerFunc(handler))
|
|
|
-}
|
|
|
-
|
|
|
-// HandleRemove deregistrars the handler specific for pattern from the ServeMux.
|
|
|
-func (mux *ServeMux) HandleRemove(pattern string) {
|
|
|
- if pattern == "" {
|
|
|
- panic("dns: invalid pattern " + pattern)
|
|
|
- }
|
|
|
- mux.m.Lock()
|
|
|
- delete(mux.z, Fqdn(pattern))
|
|
|
- mux.m.Unlock()
|
|
|
-}
|
|
|
-
|
|
|
-// ServeDNS dispatches the request to the handler whose
|
|
|
-// pattern most closely matches the request message. If DefaultServeMux
|
|
|
-// is used the correct thing for DS queries is done: a possible parent
|
|
|
-// is sought.
|
|
|
-// If no handler is found a standard SERVFAIL message is returned
|
|
|
-// If the request message does not have exactly one question in the
|
|
|
-// question section a SERVFAIL is returned, unlesss Unsafe is true.
|
|
|
-func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) {
|
|
|
- var h Handler
|
|
|
- if len(request.Question) < 1 { // allow more than one question
|
|
|
- h = failedHandler()
|
|
|
- } else {
|
|
|
- if h = mux.match(request.Question[0].Name, request.Question[0].Qtype); h == nil {
|
|
|
- h = failedHandler()
|
|
|
- }
|
|
|
- }
|
|
|
- h.ServeDNS(w, request)
|
|
|
-}
|
|
|
-
|
|
|
-// Handle registers the handler with the given pattern
|
|
|
-// in the DefaultServeMux. The documentation for
|
|
|
-// ServeMux explains how patterns are matched.
|
|
|
-func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
|
|
|
-
|
|
|
-// HandleRemove deregisters the handle with the given pattern
|
|
|
-// in the DefaultServeMux.
|
|
|
-func HandleRemove(pattern string) { DefaultServeMux.HandleRemove(pattern) }
|
|
|
-
|
|
|
-// HandleFunc registers the handler function with the given pattern
|
|
|
-// in the DefaultServeMux.
|
|
|
-func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
|
|
|
- DefaultServeMux.HandleFunc(pattern, handler)
|
|
|
-}
|
|
|
-
|
|
|
// Writer writes raw DNS messages; each call to Write should send an entire message.
|
|
|
type Writer interface {
|
|
|
io.Writer
|
|
@@ -253,11 +145,11 @@ type defaultReader struct {
|
|
|
*Server
|
|
|
}
|
|
|
|
|
|
-func (dr *defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
|
|
|
+func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
|
|
|
return dr.readTCP(conn, timeout)
|
|
|
}
|
|
|
|
|
|
-func (dr *defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
|
|
|
+func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
|
|
|
return dr.readUDP(conn, timeout)
|
|
|
}
|
|
|
|
|
@@ -294,9 +186,6 @@ type Server struct {
|
|
|
IdleTimeout func() time.Duration
|
|
|
// Secret(s) for Tsig map[<zonename>]<base64 secret>. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2).
|
|
|
TsigSecret map[string]string
|
|
|
- // Unsafe instructs the server to disregard any sanity checks and directly hand the message to
|
|
|
- // the handler. It will specifically not check if the query has the QR bit not set.
|
|
|
- Unsafe bool
|
|
|
// If NotifyStartedFunc is set it is called once the server has started listening.
|
|
|
NotifyStartedFunc func()
|
|
|
// DecorateReader is optional, allows customization of the process that reads raw DNS messages.
|
|
@@ -305,65 +194,64 @@ type Server struct {
|
|
|
DecorateWriter DecorateWriter
|
|
|
// Maximum number of TCP queries before we close the socket. Default is maxTCPQueries (unlimited if -1).
|
|
|
MaxTCPQueries int
|
|
|
+ // Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address.
|
|
|
+ // It is only supported on go1.11+ and when using ListenAndServe.
|
|
|
+ ReusePort bool
|
|
|
+ // AcceptMsgFunc will check the incoming message and will reject it early in the process.
|
|
|
+ // By default DefaultMsgAcceptFunc will be used.
|
|
|
+ MsgAcceptFunc MsgAcceptFunc
|
|
|
|
|
|
- // UDP packet or TCP connection queue
|
|
|
- queue chan *response
|
|
|
- // Workers count
|
|
|
- workersCount int32
|
|
|
// Shutdown handling
|
|
|
- lock sync.RWMutex
|
|
|
- started bool
|
|
|
+ lock sync.RWMutex
|
|
|
+ started bool
|
|
|
+ shutdown chan struct{}
|
|
|
+ conns map[net.Conn]struct{}
|
|
|
+
|
|
|
+ // A pool for UDP message buffers.
|
|
|
+ udpPool sync.Pool
|
|
|
}
|
|
|
|
|
|
-func (srv *Server) worker(w *response) {
|
|
|
- srv.serve(w)
|
|
|
+func (srv *Server) isStarted() bool {
|
|
|
+ srv.lock.RLock()
|
|
|
+ started := srv.started
|
|
|
+ srv.lock.RUnlock()
|
|
|
+ return started
|
|
|
+}
|
|
|
|
|
|
- for {
|
|
|
- count := atomic.LoadInt32(&srv.workersCount)
|
|
|
- if count > maxWorkersCount {
|
|
|
- return
|
|
|
- }
|
|
|
- if atomic.CompareAndSwapInt32(&srv.workersCount, count, count+1) {
|
|
|
- break
|
|
|
- }
|
|
|
+func makeUDPBuffer(size int) func() interface{} {
|
|
|
+ return func() interface{} {
|
|
|
+ return make([]byte, size)
|
|
|
}
|
|
|
+}
|
|
|
|
|
|
- defer atomic.AddInt32(&srv.workersCount, -1)
|
|
|
+func (srv *Server) init() {
|
|
|
+ srv.shutdown = make(chan struct{})
|
|
|
+ srv.conns = make(map[net.Conn]struct{})
|
|
|
|
|
|
- inUse := false
|
|
|
- timeout := time.NewTimer(idleWorkerTimeout)
|
|
|
- defer timeout.Stop()
|
|
|
-LOOP:
|
|
|
- for {
|
|
|
- select {
|
|
|
- case w, ok := <-srv.queue:
|
|
|
- if !ok {
|
|
|
- break LOOP
|
|
|
- }
|
|
|
- inUse = true
|
|
|
- srv.serve(w)
|
|
|
- case <-timeout.C:
|
|
|
- if !inUse {
|
|
|
- break LOOP
|
|
|
- }
|
|
|
- inUse = false
|
|
|
- timeout.Reset(idleWorkerTimeout)
|
|
|
- }
|
|
|
+ if srv.UDPSize == 0 {
|
|
|
+ srv.UDPSize = MinMsgSize
|
|
|
+ }
|
|
|
+ if srv.MsgAcceptFunc == nil {
|
|
|
+ srv.MsgAcceptFunc = DefaultMsgAcceptFunc
|
|
|
}
|
|
|
+ if srv.Handler == nil {
|
|
|
+ srv.Handler = DefaultServeMux
|
|
|
+ }
|
|
|
+
|
|
|
+ srv.udpPool.New = makeUDPBuffer(srv.UDPSize)
|
|
|
}
|
|
|
|
|
|
-func (srv *Server) spawnWorker(w *response) {
|
|
|
- select {
|
|
|
- case srv.queue <- w:
|
|
|
- default:
|
|
|
- go srv.worker(w)
|
|
|
- }
|
|
|
+func unlockOnce(l sync.Locker) func() {
|
|
|
+ var once sync.Once
|
|
|
+ return func() { once.Do(l.Unlock) }
|
|
|
}
|
|
|
|
|
|
// ListenAndServe starts a nameserver on the configured address in *Server.
|
|
|
func (srv *Server) ListenAndServe() error {
|
|
|
+ unlock := unlockOnce(&srv.lock)
|
|
|
srv.lock.Lock()
|
|
|
- defer srv.lock.Unlock()
|
|
|
+ defer unlock()
|
|
|
+
|
|
|
if srv.started {
|
|
|
return &Error{err: "server already started"}
|
|
|
}
|
|
@@ -372,63 +260,46 @@ func (srv *Server) ListenAndServe() error {
|
|
|
if addr == "" {
|
|
|
addr = ":domain"
|
|
|
}
|
|
|
- if srv.UDPSize == 0 {
|
|
|
- srv.UDPSize = MinMsgSize
|
|
|
- }
|
|
|
- srv.queue = make(chan *response)
|
|
|
- defer close(srv.queue)
|
|
|
+
|
|
|
+ srv.init()
|
|
|
+
|
|
|
switch srv.Net {
|
|
|
case "tcp", "tcp4", "tcp6":
|
|
|
- a, err := net.ResolveTCPAddr(srv.Net, addr)
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- l, err := net.ListenTCP(srv.Net, a)
|
|
|
+ l, err := listenTCP(srv.Net, addr, srv.ReusePort)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
srv.Listener = l
|
|
|
srv.started = true
|
|
|
- srv.lock.Unlock()
|
|
|
- err = srv.serveTCP(l)
|
|
|
- srv.lock.Lock() // to satisfy the defer at the top
|
|
|
- return err
|
|
|
+ unlock()
|
|
|
+ return srv.serveTCP(l)
|
|
|
case "tcp-tls", "tcp4-tls", "tcp6-tls":
|
|
|
- network := "tcp"
|
|
|
- if srv.Net == "tcp4-tls" {
|
|
|
- network = "tcp4"
|
|
|
- } else if srv.Net == "tcp6-tls" {
|
|
|
- network = "tcp6"
|
|
|
+ if srv.TLSConfig == nil || (len(srv.TLSConfig.Certificates) == 0 && srv.TLSConfig.GetCertificate == nil) {
|
|
|
+ return errors.New("dns: neither Certificates nor GetCertificate set in Config")
|
|
|
}
|
|
|
-
|
|
|
- l, err := tls.Listen(network, addr, srv.TLSConfig)
|
|
|
+ network := strings.TrimSuffix(srv.Net, "-tls")
|
|
|
+ l, err := listenTCP(network, addr, srv.ReusePort)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
+ l = tls.NewListener(l, srv.TLSConfig)
|
|
|
srv.Listener = l
|
|
|
srv.started = true
|
|
|
- srv.lock.Unlock()
|
|
|
- err = srv.serveTCP(l)
|
|
|
- srv.lock.Lock() // to satisfy the defer at the top
|
|
|
- return err
|
|
|
+ unlock()
|
|
|
+ return srv.serveTCP(l)
|
|
|
case "udp", "udp4", "udp6":
|
|
|
- a, err := net.ResolveUDPAddr(srv.Net, addr)
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- l, err := net.ListenUDP(srv.Net, a)
|
|
|
+ l, err := listenUDP(srv.Net, addr, srv.ReusePort)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- if e := setUDPSocketOptions(l); e != nil {
|
|
|
+ u := l.(*net.UDPConn)
|
|
|
+ if e := setUDPSocketOptions(u); e != nil {
|
|
|
return e
|
|
|
}
|
|
|
srv.PacketConn = l
|
|
|
srv.started = true
|
|
|
- srv.lock.Unlock()
|
|
|
- err = srv.serveUDP(l)
|
|
|
- srv.lock.Lock() // to satisfy the defer at the top
|
|
|
- return err
|
|
|
+ unlock()
|
|
|
+ return srv.serveUDP(u)
|
|
|
}
|
|
|
return &Error{err: "bad network"}
|
|
|
}
|
|
@@ -436,20 +307,19 @@ func (srv *Server) ListenAndServe() error {
|
|
|
// ActivateAndServe starts a nameserver with the PacketConn or Listener
|
|
|
// configured in *Server. Its main use is to start a server from systemd.
|
|
|
func (srv *Server) ActivateAndServe() error {
|
|
|
+ unlock := unlockOnce(&srv.lock)
|
|
|
srv.lock.Lock()
|
|
|
- defer srv.lock.Unlock()
|
|
|
+ defer unlock()
|
|
|
+
|
|
|
if srv.started {
|
|
|
return &Error{err: "server already started"}
|
|
|
}
|
|
|
|
|
|
+ srv.init()
|
|
|
+
|
|
|
pConn := srv.PacketConn
|
|
|
l := srv.Listener
|
|
|
- srv.queue = make(chan *response)
|
|
|
- defer close(srv.queue)
|
|
|
if pConn != nil {
|
|
|
- if srv.UDPSize == 0 {
|
|
|
- srv.UDPSize = MinMsgSize
|
|
|
- }
|
|
|
// Check PacketConn interface's type is valid and value
|
|
|
// is not nil
|
|
|
if t, ok := pConn.(*net.UDPConn); ok && t != nil {
|
|
@@ -457,18 +327,14 @@ func (srv *Server) ActivateAndServe() error {
|
|
|
return e
|
|
|
}
|
|
|
srv.started = true
|
|
|
- srv.lock.Unlock()
|
|
|
- e := srv.serveUDP(t)
|
|
|
- srv.lock.Lock() // to satisfy the defer at the top
|
|
|
- return e
|
|
|
+ unlock()
|
|
|
+ return srv.serveUDP(t)
|
|
|
}
|
|
|
}
|
|
|
if l != nil {
|
|
|
srv.started = true
|
|
|
- srv.lock.Unlock()
|
|
|
- e := srv.serveTCP(l)
|
|
|
- srv.lock.Lock() // to satisfy the defer at the top
|
|
|
- return e
|
|
|
+ unlock()
|
|
|
+ return srv.serveTCP(l)
|
|
|
}
|
|
|
return &Error{err: "bad listeners"}
|
|
|
}
|
|
@@ -476,30 +342,63 @@ func (srv *Server) ActivateAndServe() error {
|
|
|
// Shutdown shuts down a server. After a call to Shutdown, ListenAndServe and
|
|
|
// ActivateAndServe will return.
|
|
|
func (srv *Server) Shutdown() error {
|
|
|
+ return srv.ShutdownContext(context.Background())
|
|
|
+}
|
|
|
+
|
|
|
+// ShutdownContext shuts down a server. After a call to ShutdownContext,
|
|
|
+// ListenAndServe and ActivateAndServe will return.
|
|
|
+//
|
|
|
+// A context.Context may be passed to limit how long to wait for connections
|
|
|
+// to terminate.
|
|
|
+func (srv *Server) ShutdownContext(ctx context.Context) error {
|
|
|
srv.lock.Lock()
|
|
|
if !srv.started {
|
|
|
srv.lock.Unlock()
|
|
|
return &Error{err: "server not started"}
|
|
|
}
|
|
|
+
|
|
|
srv.started = false
|
|
|
- srv.lock.Unlock()
|
|
|
|
|
|
if srv.PacketConn != nil {
|
|
|
- srv.PacketConn.Close()
|
|
|
+ srv.PacketConn.SetReadDeadline(aLongTimeAgo) // Unblock reads
|
|
|
}
|
|
|
+
|
|
|
if srv.Listener != nil {
|
|
|
srv.Listener.Close()
|
|
|
}
|
|
|
- return nil
|
|
|
+
|
|
|
+ for rw := range srv.conns {
|
|
|
+ rw.SetReadDeadline(aLongTimeAgo) // Unblock reads
|
|
|
+ }
|
|
|
+
|
|
|
+ srv.lock.Unlock()
|
|
|
+
|
|
|
+ if testShutdownNotify != nil {
|
|
|
+ testShutdownNotify.Broadcast()
|
|
|
+ }
|
|
|
+
|
|
|
+ var ctxErr error
|
|
|
+ select {
|
|
|
+ case <-srv.shutdown:
|
|
|
+ case <-ctx.Done():
|
|
|
+ ctxErr = ctx.Err()
|
|
|
+ }
|
|
|
+
|
|
|
+ if srv.PacketConn != nil {
|
|
|
+ srv.PacketConn.Close()
|
|
|
+ }
|
|
|
+
|
|
|
+ return ctxErr
|
|
|
}
|
|
|
|
|
|
+var testShutdownNotify *sync.Cond
|
|
|
+
|
|
|
// getReadTimeout is a helper func to use system timeout if server did not intend to change it.
|
|
|
func (srv *Server) getReadTimeout() time.Duration {
|
|
|
- rtimeout := dnsTimeout
|
|
|
if srv.ReadTimeout != 0 {
|
|
|
- rtimeout = srv.ReadTimeout
|
|
|
+ return srv.ReadTimeout
|
|
|
}
|
|
|
- return rtimeout
|
|
|
+ return dnsTimeout
|
|
|
}
|
|
|
|
|
|
// serveTCP starts a TCP listener for the server.
|
|
@@ -510,22 +409,32 @@ func (srv *Server) serveTCP(l net.Listener) error {
|
|
|
srv.NotifyStartedFunc()
|
|
|
}
|
|
|
|
|
|
- for {
|
|
|
+ var wg sync.WaitGroup
|
|
|
+ defer func() {
|
|
|
+ wg.Wait()
|
|
|
+ close(srv.shutdown)
|
|
|
+ }()
|
|
|
+
|
|
|
+ for srv.isStarted() {
|
|
|
rw, err := l.Accept()
|
|
|
- srv.lock.RLock()
|
|
|
- if !srv.started {
|
|
|
- srv.lock.RUnlock()
|
|
|
- return nil
|
|
|
- }
|
|
|
- srv.lock.RUnlock()
|
|
|
if err != nil {
|
|
|
+ if !srv.isStarted() {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
if neterr, ok := err.(net.Error); ok && neterr.Temporary() {
|
|
|
continue
|
|
|
}
|
|
|
return err
|
|
|
}
|
|
|
- srv.spawnWorker(&response{tsigSecret: srv.TsigSecret, tcp: rw})
|
|
|
+ srv.lock.Lock()
|
|
|
+ // Track the connection to allow unblocking reads on shutdown.
|
|
|
+ srv.conns[rw] = struct{}{}
|
|
|
+ srv.lock.Unlock()
|
|
|
+ wg.Add(1)
|
|
|
+ go srv.serveTCPConn(&wg, rw)
|
|
|
}
|
|
|
+
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
// serveUDP starts a UDP listener for the server.
|
|
@@ -536,58 +445,57 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
|
|
|
srv.NotifyStartedFunc()
|
|
|
}
|
|
|
|
|
|
- reader := Reader(&defaultReader{srv})
|
|
|
+ reader := Reader(defaultReader{srv})
|
|
|
if srv.DecorateReader != nil {
|
|
|
reader = srv.DecorateReader(reader)
|
|
|
}
|
|
|
|
|
|
+ var wg sync.WaitGroup
|
|
|
+ defer func() {
|
|
|
+ wg.Wait()
|
|
|
+ close(srv.shutdown)
|
|
|
+ }()
|
|
|
+
|
|
|
rtimeout := srv.getReadTimeout()
|
|
|
// deadline is not used here
|
|
|
- for {
|
|
|
+ for srv.isStarted() {
|
|
|
m, s, err := reader.ReadUDP(l, rtimeout)
|
|
|
- srv.lock.RLock()
|
|
|
- if !srv.started {
|
|
|
- srv.lock.RUnlock()
|
|
|
- return nil
|
|
|
- }
|
|
|
- srv.lock.RUnlock()
|
|
|
if err != nil {
|
|
|
+ if !srv.isStarted() {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
|
|
|
continue
|
|
|
}
|
|
|
return err
|
|
|
}
|
|
|
if len(m) < headerSize {
|
|
|
+ if cap(m) == srv.UDPSize {
|
|
|
+ srv.udpPool.Put(m[:srv.UDPSize])
|
|
|
+ }
|
|
|
continue
|
|
|
}
|
|
|
- srv.spawnWorker(&response{msg: m, tsigSecret: srv.TsigSecret, udp: l, udpSession: s})
|
|
|
+ wg.Add(1)
|
|
|
+ go srv.serveUDPPacket(&wg, m, l, s)
|
|
|
}
|
|
|
+
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
-func (srv *Server) serve(w *response) {
|
|
|
+// Serve a new TCP connection.
|
|
|
+func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
|
|
|
+ w := &response{tsigSecret: srv.TsigSecret, tcp: rw}
|
|
|
if srv.DecorateWriter != nil {
|
|
|
w.writer = srv.DecorateWriter(w)
|
|
|
} else {
|
|
|
w.writer = w
|
|
|
}
|
|
|
|
|
|
- if w.udp != nil {
|
|
|
- // serve UDP
|
|
|
- srv.serveDNS(w)
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- reader := Reader(&defaultReader{srv})
|
|
|
+ reader := Reader(defaultReader{srv})
|
|
|
if srv.DecorateReader != nil {
|
|
|
reader = srv.DecorateReader(reader)
|
|
|
}
|
|
|
|
|
|
- defer func() {
|
|
|
- if !w.hijacked {
|
|
|
- w.Close()
|
|
|
- }
|
|
|
- }()
|
|
|
-
|
|
|
idleTimeout := tcpIdleTimeout
|
|
|
if srv.IdleTimeout != nil {
|
|
|
idleTimeout = srv.IdleTimeout()
|
|
@@ -600,15 +508,14 @@ func (srv *Server) serve(w *response) {
|
|
|
limit = maxTCPQueries
|
|
|
}
|
|
|
|
|
|
- for q := 0; q < limit || limit == -1; q++ {
|
|
|
- var err error
|
|
|
- w.msg, err = reader.ReadTCP(w.tcp, timeout)
|
|
|
+ for q := 0; (q < limit || limit == -1) && srv.isStarted(); q++ {
|
|
|
+ m, err := reader.ReadTCP(w.tcp, timeout)
|
|
|
if err != nil {
|
|
|
// TODO(tmthrgd): handle error
|
|
|
break
|
|
|
}
|
|
|
- srv.serveDNS(w)
|
|
|
- if w.tcp == nil {
|
|
|
+ srv.serveDNS(m, w)
|
|
|
+ if w.closed {
|
|
|
break // Close() was called
|
|
|
}
|
|
|
if w.hijacked {
|
|
@@ -618,18 +525,67 @@ func (srv *Server) serve(w *response) {
|
|
|
// idle timeout.
|
|
|
timeout = idleTimeout
|
|
|
}
|
|
|
+
|
|
|
+ if !w.hijacked {
|
|
|
+ w.Close()
|
|
|
+ }
|
|
|
+
|
|
|
+ srv.lock.Lock()
|
|
|
+ delete(srv.conns, w.tcp)
|
|
|
+ srv.lock.Unlock()
|
|
|
+
|
|
|
+ wg.Done()
|
|
|
}
|
|
|
|
|
|
-func (srv *Server) serveDNS(w *response) {
|
|
|
- req := new(Msg)
|
|
|
- err := req.Unpack(w.msg)
|
|
|
- if err != nil { // Send a FormatError back
|
|
|
- x := new(Msg)
|
|
|
- x.SetRcodeFormatError(req)
|
|
|
- w.WriteMsg(x)
|
|
|
+// Serve a new UDP request.
|
|
|
+func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u *net.UDPConn, s *SessionUDP) {
|
|
|
+ w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: s}
|
|
|
+ if srv.DecorateWriter != nil {
|
|
|
+ w.writer = srv.DecorateWriter(w)
|
|
|
+ } else {
|
|
|
+ w.writer = w
|
|
|
+ }
|
|
|
+
|
|
|
+ srv.serveDNS(m, w)
|
|
|
+ wg.Done()
|
|
|
+}
|
|
|
+
|
|
|
+func (srv *Server) serveDNS(m []byte, w *response) {
|
|
|
+ dh, off, err := unpackMsgHdr(m, 0)
|
|
|
+ if err != nil {
|
|
|
+ // Let client hang, they are sending crap; any reply can be used to amplify.
|
|
|
return
|
|
|
}
|
|
|
- if !srv.Unsafe && req.Response {
|
|
|
+
|
|
|
+ req := new(Msg)
|
|
|
+ req.setHdr(dh)
|
|
|
+
|
|
|
+ switch action := srv.MsgAcceptFunc(dh); action {
|
|
|
+ case MsgAccept:
|
|
|
+ if req.unpack(dh, m, off) == nil {
|
|
|
+ break
|
|
|
+ }
|
|
|
+
|
|
|
+ fallthrough
|
|
|
+ case MsgReject, MsgRejectNotImplemented:
|
|
|
+ opcode := req.Opcode
|
|
|
+ req.SetRcodeFormatError(req)
|
|
|
+ req.Zero = false
|
|
|
+ if action == MsgRejectNotImplemented {
|
|
|
+ req.Opcode = opcode
|
|
|
+ req.Rcode = RcodeNotImplemented
|
|
|
+ }
|
|
|
+
|
|
|
+ // Are we allowed to delete any OPT records here?
|
|
|
+ req.Ns, req.Answer, req.Extra = nil, nil, nil
|
|
|
+
|
|
|
+ w.WriteMsg(req)
|
|
|
+ fallthrough
|
|
|
+ case MsgIgnore:
|
|
|
+ if w.udp != nil && cap(m) == srv.UDPSize {
|
|
|
+ srv.udpPool.Put(m[:srv.UDPSize])
|
|
|
+ }
|
|
|
+
|
|
|
return
|
|
|
}
|
|
|
|
|
@@ -637,7 +593,7 @@ func (srv *Server) serveDNS(w *response) {
|
|
|
if w.tsigSecret != nil {
|
|
|
if t := req.IsTsig(); t != nil {
|
|
|
if secret, ok := w.tsigSecret[t.Hdr.Name]; ok {
|
|
|
- w.tsigStatus = TsigVerify(w.msg, secret, "", false)
|
|
|
+ w.tsigStatus = TsigVerify(m, secret, "", false)
|
|
|
} else {
|
|
|
w.tsigStatus = ErrSecret
|
|
|
}
|
|
@@ -646,54 +602,49 @@ func (srv *Server) serveDNS(w *response) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- handler := srv.Handler
|
|
|
- if handler == nil {
|
|
|
- handler = DefaultServeMux
|
|
|
+ if w.udp != nil && cap(m) == srv.UDPSize {
|
|
|
+ srv.udpPool.Put(m[:srv.UDPSize])
|
|
|
}
|
|
|
|
|
|
- handler.ServeDNS(w, req) // Writes back to the client
|
|
|
+ srv.Handler.ServeDNS(w, req) // Writes back to the client
|
|
|
}
|
|
|
|
|
|
func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
|
|
|
- conn.SetReadDeadline(time.Now().Add(timeout))
|
|
|
- l := make([]byte, 2)
|
|
|
- n, err := conn.Read(l)
|
|
|
- if err != nil || n != 2 {
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- return nil, ErrShortRead
|
|
|
- }
|
|
|
- length := binary.BigEndian.Uint16(l)
|
|
|
- if length == 0 {
|
|
|
- return nil, ErrShortRead
|
|
|
+ // If we race with ShutdownContext, the read deadline may
|
|
|
+ // have been set in the distant past to unblock the read
|
|
|
+ // below. We must not override it, otherwise we may block
|
|
|
+ // ShutdownContext.
|
|
|
+ srv.lock.RLock()
|
|
|
+ if srv.started {
|
|
|
+ conn.SetReadDeadline(time.Now().Add(timeout))
|
|
|
}
|
|
|
- m := make([]byte, int(length))
|
|
|
- n, err = conn.Read(m[:int(length)])
|
|
|
- if err != nil || n == 0 {
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- return nil, ErrShortRead
|
|
|
+ srv.lock.RUnlock()
|
|
|
+
|
|
|
+ var length uint16
|
|
|
+ if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
|
|
|
+ return nil, err
|
|
|
}
|
|
|
- i := n
|
|
|
- for i < int(length) {
|
|
|
- j, err := conn.Read(m[i:int(length)])
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- i += j
|
|
|
+
|
|
|
+ m := make([]byte, length)
|
|
|
+ if _, err := io.ReadFull(conn, m); err != nil {
|
|
|
+ return nil, err
|
|
|
}
|
|
|
- n = i
|
|
|
- m = m[:n]
|
|
|
+
|
|
|
return m, nil
|
|
|
}
|
|
|
|
|
|
func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
|
|
|
- conn.SetReadDeadline(time.Now().Add(timeout))
|
|
|
- m := make([]byte, srv.UDPSize)
|
|
|
+ srv.lock.RLock()
|
|
|
+ if srv.started {
|
|
|
+ // See the comment in readTCP above.
|
|
|
+ conn.SetReadDeadline(time.Now().Add(timeout))
|
|
|
+ }
|
|
|
+ srv.lock.RUnlock()
|
|
|
+
|
|
|
+ m := srv.udpPool.Get().([]byte)
|
|
|
n, s, err := ReadFromSessionUDP(conn, m)
|
|
|
if err != nil {
|
|
|
+ srv.udpPool.Put(m)
|
|
|
return nil, nil, err
|
|
|
}
|
|
|
m = m[:n]
|
|
@@ -702,6 +653,10 @@ func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *S
|
|
|
|
|
|
// WriteMsg implements the ResponseWriter.WriteMsg method.
|
|
|
func (w *response) WriteMsg(m *Msg) (err error) {
|
|
|
+ if w.closed {
|
|
|
+ return &Error{err: "WriteMsg called after Close"}
|
|
|
+ }
|
|
|
+
|
|
|
var data []byte
|
|
|
if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check)
|
|
|
if t := m.IsTsig(); t != nil {
|
|
@@ -723,42 +678,50 @@ func (w *response) WriteMsg(m *Msg) (err error) {
|
|
|
|
|
|
// Write implements the ResponseWriter.Write method.
|
|
|
func (w *response) Write(m []byte) (int, error) {
|
|
|
+ if w.closed {
|
|
|
+ return 0, &Error{err: "Write called after Close"}
|
|
|
+ }
|
|
|
+
|
|
|
switch {
|
|
|
case w.udp != nil:
|
|
|
- n, err := WriteToSessionUDP(w.udp, m, w.udpSession)
|
|
|
- return n, err
|
|
|
+ return WriteToSessionUDP(w.udp, m, w.udpSession)
|
|
|
case w.tcp != nil:
|
|
|
- lm := len(m)
|
|
|
- if lm < 2 {
|
|
|
- return 0, io.ErrShortBuffer
|
|
|
- }
|
|
|
- if lm > MaxMsgSize {
|
|
|
+ if len(m) > MaxMsgSize {
|
|
|
return 0, &Error{err: "message too large"}
|
|
|
}
|
|
|
- l := make([]byte, 2, 2+lm)
|
|
|
- binary.BigEndian.PutUint16(l, uint16(lm))
|
|
|
- m = append(l, m...)
|
|
|
|
|
|
- n, err := io.Copy(w.tcp, bytes.NewReader(m))
|
|
|
+ l := make([]byte, 2)
|
|
|
+ binary.BigEndian.PutUint16(l, uint16(len(m)))
|
|
|
+
|
|
|
+ n, err := (&net.Buffers{l, m}).WriteTo(w.tcp)
|
|
|
return int(n), err
|
|
|
+ default:
|
|
|
+ panic("dns: internal error: udp and tcp both nil")
|
|
|
}
|
|
|
- panic("not reached")
|
|
|
}
|
|
|
|
|
|
// LocalAddr implements the ResponseWriter.LocalAddr method.
|
|
|
func (w *response) LocalAddr() net.Addr {
|
|
|
- if w.tcp != nil {
|
|
|
+ switch {
|
|
|
+ case w.udp != nil:
|
|
|
+ return w.udp.LocalAddr()
|
|
|
+ case w.tcp != nil:
|
|
|
return w.tcp.LocalAddr()
|
|
|
+ default:
|
|
|
+ panic("dns: internal error: udp and tcp both nil")
|
|
|
}
|
|
|
- return w.udp.LocalAddr()
|
|
|
}
|
|
|
|
|
|
// RemoteAddr implements the ResponseWriter.RemoteAddr method.
|
|
|
func (w *response) RemoteAddr() net.Addr {
|
|
|
- if w.tcp != nil {
|
|
|
+ switch {
|
|
|
+ case w.udpSession != nil:
|
|
|
+ return w.udpSession.RemoteAddr()
|
|
|
+ case w.tcp != nil:
|
|
|
return w.tcp.RemoteAddr()
|
|
|
+ default:
|
|
|
+ panic("dns: internal error: udpSession and tcp both nil")
|
|
|
}
|
|
|
- return w.udpSession.RemoteAddr()
|
|
|
}
|
|
|
|
|
|
// TsigStatus implements the ResponseWriter.TsigStatus method.
|
|
@@ -772,11 +735,30 @@ func (w *response) Hijack() { w.hijacked = true }
|
|
|
|
|
|
// Close implements the ResponseWriter.Close method
|
|
|
func (w *response) Close() error {
|
|
|
- // Can't close the udp conn, as that is actually the listener.
|
|
|
- if w.tcp != nil {
|
|
|
- e := w.tcp.Close()
|
|
|
- w.tcp = nil
|
|
|
- return e
|
|
|
+ if w.closed {
|
|
|
+ return &Error{err: "connection already closed"}
|
|
|
+ }
|
|
|
+ w.closed = true
|
|
|
+
|
|
|
+ switch {
|
|
|
+ case w.udp != nil:
|
|
|
+ // Can't close the udp conn, as that is actually the listener.
|
|
|
+ return nil
|
|
|
+ case w.tcp != nil:
|
|
|
+ return w.tcp.Close()
|
|
|
+ default:
|
|
|
+ panic("dns: internal error: udp and tcp both nil")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// ConnectionState() implements the ConnectionStater.ConnectionState() interface.
|
|
|
+func (w *response) ConnectionState() *tls.ConnectionState {
|
|
|
+ type tlsConnectionStater interface {
|
|
|
+ ConnectionState() tls.ConnectionState
|
|
|
+ }
|
|
|
+ if v, ok := w.tcp.(tlsConnectionStater); ok {
|
|
|
+ t := v.ConnectionState()
|
|
|
+ return &t
|
|
|
}
|
|
|
return nil
|
|
|
}
|