Browse Source

cmd/dockerd: gracefully shut down the API server

As of Go 1.8, "net/http".Server provides facilities to close all
listeners, making the same facilities in server.Server redundant.
http.Server also improves upon server.Server by additionally providing a
facility to also wait for outstanding requests to complete after closing
all listeners. Leverage those facilities to give in-flight requests up
to five seconds to finish up after all containers have been shut down.

Signed-off-by: Cory Snider <csnider@mirantis.com>
Cory Snider 2 năm trước cách đây
mục cha
commit
12bf850c84
3 tập tin đã thay đổi với 108 bổ sung122 xóa
  1. 3 80
      api/server/server.go
  2. 101 36
      cmd/dockerd/daemon.go
  3. 4 6
      cmd/dockerd/daemon_linux_test.go

+ 3 - 80
api/server/server.go

@@ -2,10 +2,7 @@ package server // import "github.com/docker/docker/api/server"
 
 import (
 	"context"
-	"net"
 	"net/http"
-	"strings"
-	"time"
 
 	"github.com/docker/docker/api/server/httpstatus"
 	"github.com/docker/docker/api/server/httputils"
@@ -23,8 +20,6 @@ const versionMatcher = "/v{version:[0-9.]+}"
 
 // Server contains instance details for the server
 type Server struct {
-	servers     []*HTTPServer
-	routers     []router.Router
 	middlewares []middleware.Middleware
 }
 
@@ -34,71 +29,6 @@ func (s *Server) UseMiddleware(m middleware.Middleware) {
 	s.middlewares = append(s.middlewares, m)
 }
 
-// Accept sets a listener the server accepts connections into.
-func (s *Server) Accept(addr string, listeners ...net.Listener) {
-	for _, listener := range listeners {
-		httpServer := &HTTPServer{
-			srv: &http.Server{
-				Addr:              addr,
-				ReadHeaderTimeout: 5 * time.Minute, // "G112: Potential Slowloris Attack (gosec)"; not a real concern for our use, so setting a long timeout.
-			},
-			l: listener,
-		}
-		s.servers = append(s.servers, httpServer)
-	}
-}
-
-// Close closes servers and thus stop receiving requests
-func (s *Server) Close() {
-	for _, srv := range s.servers {
-		if err := srv.Close(); err != nil {
-			logrus.Error(err)
-		}
-	}
-}
-
-// Serve starts listening for inbound requests.
-func (s *Server) Serve() error {
-	var chErrors = make(chan error, len(s.servers))
-	for _, srv := range s.servers {
-		srv.srv.Handler = s.createMux()
-		go func(srv *HTTPServer) {
-			var err error
-			logrus.Infof("API listen on %s", srv.l.Addr())
-			if err = srv.Serve(); err != nil && strings.Contains(err.Error(), "use of closed network connection") {
-				err = nil
-			}
-			chErrors <- err
-		}(srv)
-	}
-
-	for range s.servers {
-		err := <-chErrors
-		if err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
-// HTTPServer contains an instance of http server and the listener.
-// srv *http.Server, contains configuration to create an http server and a mux router with all api end points.
-// l   net.Listener, is a TCP or Socket listener that dispatches incoming request to the router.
-type HTTPServer struct {
-	srv *http.Server
-	l   net.Listener
-}
-
-// Serve starts listening for inbound requests.
-func (s *HTTPServer) Serve() error {
-	return s.srv.Serve(s.l)
-}
-
-// Close closes the HTTPServer from listening for the inbound requests.
-func (s *HTTPServer) Close() error {
-	return s.l.Close()
-}
-
 func (s *Server) makeHTTPHandler(handler httputils.APIFunc) http.HandlerFunc {
 	return func(w http.ResponseWriter, r *http.Request) {
 		// Define the context that we'll pass around to share info
@@ -130,12 +60,6 @@ func (s *Server) makeHTTPHandler(handler httputils.APIFunc) http.HandlerFunc {
 	}
 }
 
-// InitRouter initializes the list of routers for the server.
-// This method also enables the Go profiler.
-func (s *Server) InitRouter(routers ...router.Router) {
-	s.routers = append(s.routers, routers...)
-}
-
 type pageNotFoundError struct{}
 
 func (pageNotFoundError) Error() string {
@@ -144,12 +68,12 @@ func (pageNotFoundError) Error() string {
 
 func (pageNotFoundError) NotFound() {}
 
-// createMux initializes the main router the server uses.
-func (s *Server) createMux() *mux.Router {
+// CreateMux returns a new mux with all the routers registered.
+func (s *Server) CreateMux(routers ...router.Router) *mux.Router {
 	m := mux.NewRouter()
 
 	logrus.Debug("Registering routers")
-	for _, apiRouter := range s.routers {
+	for _, apiRouter := range routers {
 		for _, r := range apiRouter.Routes() {
 			f := s.makeHTTPHandler(r.Handler())
 
@@ -160,7 +84,6 @@ func (s *Server) createMux() *mux.Router {
 	}
 
 	debugRouter := debug.NewRouter()
-	s.routers = append(s.routers, debugRouter)
 	for _, r := range debugRouter.Routes() {
 		f := s.makeHTTPHandler(r.Handler())
 		m.Path("/debug" + r.Path()).Handler(f)

+ 101 - 36
cmd/dockerd/daemon.go

@@ -5,11 +5,13 @@ import (
 	"crypto/tls"
 	"fmt"
 	"net"
+	"net/http"
 	"os"
 	"path/filepath"
 	"runtime"
 	"sort"
 	"strings"
+	"sync"
 	"time"
 
 	containerddefaults "github.com/containerd/containerd/defaults"
@@ -65,14 +67,18 @@ type DaemonCli struct {
 	configFile *string
 	flags      *pflag.FlagSet
 
-	api             apiserver.Server
 	d               *daemon.Daemon
 	authzMiddleware *authorization.Middleware // authzMiddleware enables to dynamically reload the authorization plugins
+
+	stopOnce    sync.Once
+	apiShutdown chan struct{}
 }
 
 // NewDaemonCli returns a daemon CLI
 func NewDaemonCli() *DaemonCli {
-	return &DaemonCli{}
+	return &DaemonCli{
+		apiShutdown: make(chan struct{}),
+	}
 }
 
 func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
@@ -161,7 +167,7 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
 		}
 	}
 
-	hosts, err := loadListeners(cli, tlsConfig)
+	lss, hosts, err := loadListeners(cli.Config, tlsConfig)
 	if err != nil {
 		return errors.Wrap(err, "failed to load listeners")
 	}
@@ -177,20 +183,51 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
 	}
 	defer cancel()
 
-	stopc := make(chan bool)
-	defer close(stopc)
-
-	trap.Trap(func() {
-		cli.stop()
-		<-stopc // wait for daemonCli.start() to return
-	})
+	httpServer := &http.Server{
+		ReadHeaderTimeout: 5 * time.Minute, // "G112: Potential Slowloris Attack (gosec)"; not a real concern for our use, so setting a long timeout.
+	}
+	apiShutdownCtx, apiShutdownCancel := context.WithCancel(context.Background())
+	apiShutdownDone := make(chan struct{})
+	trap.Trap(cli.stop)
+	go func() {
+		// Block until cli.stop() has been called.
+		// It may have already been called, and that's okay.
+		// Any httpServer.Serve() calls made after
+		// httpServer.Shutdown() will return immediately,
+		// which is what we want.
+		<-cli.apiShutdown
+		err := httpServer.Shutdown(apiShutdownCtx)
+		if err != nil {
+			logrus.WithError(err).Error("Error shutting down http server")
+		}
+		close(apiShutdownDone)
+	}()
+	defer func() {
+		select {
+		case <-cli.apiShutdown:
+			// cli.stop() has been called and the daemon has completed
+			// shutting down. Give the HTTP server a little more time to
+			// finish handling any outstanding requests if needed.
+			tmr := time.AfterFunc(5*time.Second, apiShutdownCancel)
+			defer tmr.Stop()
+			<-apiShutdownDone
+		default:
+			// cli.start() has returned without cli.stop() being called,
+			// e.g. because the daemon failed to start.
+			// Stop the HTTP server with no grace period.
+			if closeErr := httpServer.Close(); closeErr != nil {
+				logrus.WithError(closeErr).Error("Error closing http server")
+			}
+		}
+	}()
 
 	// Notify that the API is active, but before daemon is set up.
 	preNotifyReady()
 
 	pluginStore := plugin.NewStore()
 
-	cli.authzMiddleware = initMiddlewares(&cli.api, cli.Config, pluginStore)
+	var apiServer apiserver.Server
+	cli.authzMiddleware = initMiddlewares(&apiServer, cli.Config, pluginStore)
 
 	d, err := daemon.NewDaemon(ctx, cli.Config, pluginStore, cli.authzMiddleware)
 	if err != nil {
@@ -229,10 +266,9 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
 	if err != nil {
 		return err
 	}
-	routerOptions.api = &cli.api
 	routerOptions.cluster = c
 
-	initRouter(routerOptions)
+	httpServer.Handler = apiServer.CreateMux(routerOptions.Build()...)
 
 	go d.ProcessClusterNotifications(ctx, c.GetWatchStream())
 
@@ -243,10 +279,30 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
 
 	// Daemon is fully initialized. Start handling API traffic
 	// and wait for serve API to complete.
-	errAPI := cli.api.Serve()
-	if errAPI != nil {
-		logrus.WithError(errAPI).Error("ServeAPI error")
+	var (
+		apiWG  sync.WaitGroup
+		errAPI = make(chan error, 1)
+	)
+	for _, ls := range lss {
+		apiWG.Add(1)
+		go func(ls net.Listener) {
+			defer apiWG.Done()
+			logrus.Infof("API listen on %s", ls.Addr())
+			if err := httpServer.Serve(ls); err != http.ErrServerClosed {
+				logrus.WithFields(logrus.Fields{
+					logrus.ErrorKey: err,
+					"listener":      ls.Addr(),
+				}).Error("ServeAPI error")
+
+				select {
+				case errAPI <- err:
+				default:
+				}
+			}
+		}(ls)
 	}
+	apiWG.Wait()
+	close(errAPI)
 
 	c.Cleanup()
 
@@ -257,8 +313,8 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
 	// Stop notification processing and any background processes
 	cancel()
 
-	if errAPI != nil {
-		return errors.Wrap(errAPI, "shutting down due to ServeAPI error")
+	if err, ok := <-errAPI; ok {
+		return errors.Wrap(err, "shutting down due to ServeAPI error")
 	}
 
 	logrus.Info("Daemon shutdown complete")
@@ -271,7 +327,6 @@ type routerOptions struct {
 	features       *map[string]bool
 	buildkit       *buildkit.Builder
 	daemon         *daemon.Daemon
-	api            *apiserver.Server
 	cluster        *cluster.Cluster
 }
 
@@ -356,7 +411,14 @@ func (cli *DaemonCli) reloadConfig() {
 }
 
 func (cli *DaemonCli) stop() {
-	cli.api.Close()
+	// Signal that the API server should shut down as soon as possible.
+	// This construct is used rather than directly shutting down the HTTP
+	// server to avoid any issues if this method is called before the server
+	// has been instantiated in cli.start(). If this method is called first,
+	// the HTTP server will be shut down immediately upon instantiation.
+	cli.stopOnce.Do(func() {
+		close(cli.apiShutdown)
+	})
 }
 
 // shutdownDaemon just wraps daemon.Shutdown() to handle a timeout in case
@@ -498,7 +560,7 @@ func normalizeHosts(config *config.Config) error {
 	return nil
 }
 
-func initRouter(opts routerOptions) {
+func (opts routerOptions) Build() []router.Router {
 	decoder := runconfig.ContainerDecoder{
 		GetSysInfo: func() *sysinfo.SysInfo {
 			return opts.daemon.RawSysInfo()
@@ -543,7 +605,7 @@ func initRouter(opts routerOptions) {
 		}
 	}
 
-	opts.api.InitRouter(routers...)
+	return routers
 }
 
 func initMiddlewares(s *apiserver.Server, cfg *config.Config, pluginStore plugingetter.PluginGetter) *authorization.Middleware {
@@ -647,17 +709,20 @@ func checkTLSAuthOK(c *config.Config) bool {
 	return true
 }
 
-func loadListeners(cli *DaemonCli, tlsConfig *tls.Config) ([]string, error) {
-	if len(cli.Config.Hosts) == 0 {
-		return nil, errors.New("no hosts configured")
+func loadListeners(cfg *config.Config, tlsConfig *tls.Config) ([]net.Listener, []string, error) {
+	if len(cfg.Hosts) == 0 {
+		return nil, nil, errors.New("no hosts configured")
 	}
-	var hosts []string
+	var (
+		hosts []string
+		lss   []net.Listener
+	)
 
-	for i := 0; i < len(cli.Config.Hosts); i++ {
-		protoAddr := cli.Config.Hosts[i]
+	for i := 0; i < len(cfg.Hosts); i++ {
+		protoAddr := cfg.Hosts[i]
 		proto, addr, ok := strings.Cut(protoAddr, "://")
 		if !ok {
-			return nil, fmt.Errorf("bad format %s, expected PROTO://ADDR", protoAddr)
+			return nil, nil, fmt.Errorf("bad format %s, expected PROTO://ADDR", protoAddr)
 		}
 
 		// It's a bad idea to bind to TCP without tlsverify.
@@ -669,10 +734,10 @@ func loadListeners(cli *DaemonCli, tlsConfig *tls.Config) ([]string, error) {
 
 			// If TLSVerify is explicitly set to false we'll take that as "Please let me shoot myself in the foot"
 			// We do not want to continue to support a default mode where tls verification is disabled, so we do some extra warnings here and eventually remove support
-			if !checkTLSAuthOK(cli.Config) {
+			if !checkTLSAuthOK(cfg) {
 				ipAddr, _, err := net.SplitHostPort(addr)
 				if err != nil {
-					return nil, errors.Wrap(err, "error parsing tcp address")
+					return nil, nil, errors.Wrap(err, "error parsing tcp address")
 				}
 
 				// shortcut all this extra stuff for literal "localhost"
@@ -702,19 +767,19 @@ func loadListeners(cli *DaemonCli, tlsConfig *tls.Config) ([]string, error) {
 		// If we're binding to a TCP port, make sure that a container doesn't try to use it.
 		if proto == "tcp" {
 			if err := allocateDaemonPort(addr); err != nil {
-				return nil, err
+				return nil, nil, err
 			}
 		}
-		ls, err := listeners.Init(proto, addr, cli.Config.SocketGroup, tlsConfig)
+		ls, err := listeners.Init(proto, addr, cfg.SocketGroup, tlsConfig)
 		if err != nil {
-			return nil, err
+			return nil, nil, err
 		}
 		logrus.Debugf("Listener created for HTTP on %s (%s)", proto, addr)
 		hosts = append(hosts, addr)
-		cli.api.Accept(addr, ls...)
+		lss = append(lss, ls...)
 	}
 
-	return hosts, nil
+	return lss, hosts, nil
 }
 
 func createAndStartCluster(cli *DaemonCli, d *daemon.Daemon) (*cluster.Cluster, error) {

+ 4 - 6
cmd/dockerd/daemon_linux_test.go

@@ -42,14 +42,12 @@ func initListenerTestPhase1() {
 }
 
 func initListenerTestPhase2() {
-	cli := &DaemonCli{
-		Config: &config.Config{
-			CommonConfig: config.CommonConfig{
-				Hosts: []string{"fd://"},
-			},
+	cfg := &config.Config{
+		CommonConfig: config.CommonConfig{
+			Hosts: []string{"fd://"},
 		},
 	}
-	_, err := loadListeners(cli, nil)
+	_, _, err := loadListeners(cfg, nil)
 	var resp listenerTestResponse
 	if err != nil {
 		resp.Err = err.Error()