Jelajahi Sumber

Merge pull request #21897 from calavera/remove_authorization_refs_from_api

Move middleware to interfaces.
Alexander Morozov 9 tahun lalu
induk
melakukan
9438572fc0

+ 2 - 18
api/server/middleware.go

@@ -2,10 +2,8 @@ package server
 
 import (
 	"github.com/Sirupsen/logrus"
-	"github.com/docker/docker/api"
 	"github.com/docker/docker/api/server/httputils"
 	"github.com/docker/docker/api/server/middleware"
-	"github.com/docker/docker/pkg/authorization"
 )
 
 // handleWithGlobalMiddlwares wraps the handler function for a request with
@@ -14,27 +12,13 @@ import (
 func (s *Server) handleWithGlobalMiddlewares(handler httputils.APIFunc) httputils.APIFunc {
 	next := handler
 
-	handleVersion := middleware.NewVersionMiddleware(s.cfg.Version, api.DefaultVersion, api.MinVersion)
-	next = handleVersion(next)
-
-	if s.cfg.EnableCors {
-		handleCORS := middleware.NewCORSMiddleware(s.cfg.CorsHeaders)
-		next = handleCORS(next)
+	for _, m := range s.middlewares {
+		next = m.WrapHandler(next)
 	}
 
-	handleUserAgent := middleware.NewUserAgentMiddleware(s.cfg.Version)
-	next = handleUserAgent(next)
-
-	// Only want this on debug level
 	if s.cfg.Logging && logrus.GetLevel() == logrus.DebugLevel {
 		next = middleware.DebugRequestMiddleware(next)
 	}
 
-	if len(s.cfg.AuthorizationPluginNames) > 0 {
-		s.authZPlugins = authorization.NewPlugins(s.cfg.AuthorizationPluginNames)
-		handleAuthorization := middleware.NewAuthorizationMiddleware(s.authZPlugins)
-		next = handleAuthorization(next)
-	}
-
 	return next
 }

+ 0 - 50
api/server/middleware/authorization.go

@@ -1,50 +0,0 @@
-package middleware
-
-import (
-	"net/http"
-
-	"github.com/Sirupsen/logrus"
-	"github.com/docker/docker/api/server/httputils"
-	"github.com/docker/docker/pkg/authorization"
-	"golang.org/x/net/context"
-)
-
-// NewAuthorizationMiddleware creates a new Authorization middleware.
-func NewAuthorizationMiddleware(plugins []authorization.Plugin) Middleware {
-	return func(handler httputils.APIFunc) httputils.APIFunc {
-		return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
-
-			user := ""
-			userAuthNMethod := ""
-
-			// Default authorization using existing TLS connection credentials
-			// FIXME: Non trivial authorization mechanisms (such as advanced certificate validations, kerberos support
-			// and ldap) will be extracted using AuthN feature, which is tracked under:
-			// https://github.com/docker/docker/pull/20883
-			if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
-				user = r.TLS.PeerCertificates[0].Subject.CommonName
-				userAuthNMethod = "TLS"
-			}
-
-			authCtx := authorization.NewCtx(plugins, user, userAuthNMethod, r.Method, r.RequestURI)
-
-			if err := authCtx.AuthZRequest(w, r); err != nil {
-				logrus.Errorf("AuthZRequest for %s %s returned error: %s", r.Method, r.RequestURI, err)
-				return err
-			}
-
-			rw := authorization.NewResponseModifier(w)
-
-			if err := handler(ctx, rw, r, vars); err != nil {
-				logrus.Errorf("Handler for %s %s returned error: %s", r.Method, r.RequestURI, err)
-				return err
-			}
-
-			if err := authCtx.AuthZResponse(rw, r); err != nil {
-				logrus.Errorf("AuthZResponse for %s %s returned error: %s", r.Method, r.RequestURI, err)
-				return err
-			}
-			return nil
-		}
-	}
-}

+ 24 - 20
api/server/middleware/cors.go

@@ -4,30 +4,34 @@ import (
 	"net/http"
 
 	"github.com/Sirupsen/logrus"
-	"github.com/docker/docker/api/server/httputils"
 	"golang.org/x/net/context"
 )
 
-// NewCORSMiddleware creates a new CORS middleware.
-func NewCORSMiddleware(defaultHeaders string) Middleware {
-	return func(handler httputils.APIFunc) httputils.APIFunc {
-		return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
-			// If "api-cors-header" is not given, but "api-enable-cors" is true, we set cors to "*"
-			// otherwise, all head values will be passed to HTTP handler
-			corsHeaders := defaultHeaders
-			if corsHeaders == "" {
-				corsHeaders = "*"
-			}
+// CORSMiddleware injects CORS headers to each request
+// when it's configured.
+type CORSMiddleware struct {
+	defaultHeaders string
+}
 
-			writeCorsHeaders(w, r, corsHeaders)
-			return handler(ctx, w, r, vars)
-		}
-	}
+// NewCORSMiddleware creates a new CORSMiddleware with default headers.
+func NewCORSMiddleware(d string) CORSMiddleware {
+	return CORSMiddleware{defaultHeaders: d}
 }
 
-func writeCorsHeaders(w http.ResponseWriter, r *http.Request, corsHeaders string) {
-	logrus.Debugf("CORS header is enabled and set to: %s", corsHeaders)
-	w.Header().Add("Access-Control-Allow-Origin", corsHeaders)
-	w.Header().Add("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept, X-Registry-Auth")
-	w.Header().Add("Access-Control-Allow-Methods", "HEAD, GET, POST, DELETE, PUT, OPTIONS")
+// WrapHandler returns a new handler function wrapping the previous one in the request chain.
+func (c CORSMiddleware) WrapHandler(handler func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error) func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+	return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+		// If "api-cors-header" is not given, but "api-enable-cors" is true, we set cors to "*"
+		// otherwise, all head values will be passed to HTTP handler
+		corsHeaders := c.defaultHeaders
+		if corsHeaders == "" {
+			corsHeaders = "*"
+		}
+
+		logrus.Debugf("CORS header is enabled and set to: %s", corsHeaders)
+		w.Header().Add("Access-Control-Allow-Origin", corsHeaders)
+		w.Header().Add("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept, X-Registry-Auth")
+		w.Header().Add("Access-Control-Allow-Methods", "HEAD, GET, POST, DELETE, PUT, OPTIONS")
+		return handler(ctx, w, r, vars)
+	}
 }

+ 1 - 1
api/server/middleware/debug.go

@@ -13,7 +13,7 @@ import (
 )
 
 // DebugRequestMiddleware dumps the request to logger
-func DebugRequestMiddleware(handler httputils.APIFunc) httputils.APIFunc {
+func DebugRequestMiddleware(handler func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error) func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
 	return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
 		logrus.Debugf("Calling %s %s", r.Method, r.RequestURI)
 

+ 10 - 4
api/server/middleware/middleware.go

@@ -1,7 +1,13 @@
 package middleware
 
-import "github.com/docker/docker/api/server/httputils"
+import (
+	"net/http"
 
-// Middleware is an adapter to allow the use of ordinary functions as Docker API filters.
-// Any function that has the appropriate signature can be registered as a middleware.
-type Middleware func(handler httputils.APIFunc) httputils.APIFunc
+	"golang.org/x/net/context"
+)
+
+// Middleware is an interface to allow the use of ordinary functions as Docker API filters.
+// Any struct that has the appropriate signature can be registered as a middleware.
+type Middleware interface {
+	WrapHandler(func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error) func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error
+}

+ 31 - 21
api/server/middleware/user_agent.go

@@ -10,28 +10,38 @@ import (
 	"golang.org/x/net/context"
 )
 
-// NewUserAgentMiddleware creates a new UserAgent middleware.
-func NewUserAgentMiddleware(versionCheck string) Middleware {
-	serverVersion := version.Version(versionCheck)
-
-	return func(handler httputils.APIFunc) httputils.APIFunc {
-		return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
-			ctx = context.WithValue(ctx, httputils.UAStringKey, r.Header.Get("User-Agent"))
-
-			if strings.Contains(r.Header.Get("User-Agent"), "Docker-Client/") {
-				userAgent := strings.Split(r.Header.Get("User-Agent"), "/")
-
-				// v1.20 onwards includes the GOOS of the client after the version
-				// such as Docker/1.7.0 (linux)
-				if len(userAgent) == 2 && strings.Contains(userAgent[1], " ") {
-					userAgent[1] = strings.Split(userAgent[1], " ")[0]
-				}
-
-				if len(userAgent) == 2 && !serverVersion.Equal(version.Version(userAgent[1])) {
-					logrus.Debugf("Client and server don't have the same version (client: %s, server: %s)", userAgent[1], serverVersion)
-				}
+// UserAgentMiddleware is a middleware that
+// validates the client user-agent.
+type UserAgentMiddleware struct {
+	serverVersion version.Version
+}
+
+// NewUserAgentMiddleware creates a new UserAgentMiddleware
+// with the server version.
+func NewUserAgentMiddleware(s version.Version) UserAgentMiddleware {
+	return UserAgentMiddleware{
+		serverVersion: s,
+	}
+}
+
+// WrapHandler returns a new handler function wrapping the previous one in the request chain.
+func (u UserAgentMiddleware) WrapHandler(handler func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error) func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+	return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+		ctx = context.WithValue(ctx, httputils.UAStringKey, r.Header.Get("User-Agent"))
+
+		if strings.Contains(r.Header.Get("User-Agent"), "Docker-Client/") {
+			userAgent := strings.Split(r.Header.Get("User-Agent"), "/")
+
+			// v1.20 onwards includes the GOOS of the client after the version
+			// such as Docker/1.7.0 (linux)
+			if len(userAgent) == 2 && strings.Contains(userAgent[1], " ") {
+				userAgent[1] = strings.Split(userAgent[1], " ")[0]
+			}
+
+			if len(userAgent) == 2 && !u.serverVersion.Equal(version.Version(userAgent[1])) {
+				logrus.Debugf("Client and server don't have the same version (client: %s, server: %s)", userAgent[1], u.serverVersion)
 			}
-			return handler(ctx, w, r, vars)
 		}
+		return handler(ctx, w, r, vars)
 	}
 }

+ 37 - 23
api/server/middleware/version.go

@@ -5,7 +5,6 @@ import (
 	"net/http"
 	"runtime"
 
-	"github.com/docker/docker/api/server/httputils"
 	"github.com/docker/docker/pkg/version"
 	"golang.org/x/net/context"
 )
@@ -18,28 +17,43 @@ func (badRequestError) HTTPErrorStatusCode() int {
 	return http.StatusBadRequest
 }
 
-// NewVersionMiddleware creates a new Version middleware.
-func NewVersionMiddleware(versionCheck string, defaultVersion, minVersion version.Version) Middleware {
-	serverVersion := version.Version(versionCheck)
-
-	return func(handler httputils.APIFunc) httputils.APIFunc {
-		return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
-			apiVersion := version.Version(vars["version"])
-			if apiVersion == "" {
-				apiVersion = defaultVersion
-			}
-
-			if apiVersion.GreaterThan(defaultVersion) {
-				return badRequestError{fmt.Errorf("client is newer than server (client API version: %s, server API version: %s)", apiVersion, defaultVersion)}
-			}
-			if apiVersion.LessThan(minVersion) {
-				return badRequestError{fmt.Errorf("client version %s is too old. Minimum supported API version is %s, please upgrade your client to a newer version", apiVersion, minVersion)}
-			}
-
-			header := fmt.Sprintf("Docker/%s (%s)", serverVersion, runtime.GOOS)
-			w.Header().Set("Server", header)
-			ctx = context.WithValue(ctx, httputils.APIVersionKey, apiVersion)
-			return handler(ctx, w, r, vars)
+// VersionMiddleware is a middleware that
+// validates the client and server versions.
+type VersionMiddleware struct {
+	serverVersion  version.Version
+	defaultVersion version.Version
+	minVersion     version.Version
+}
+
+// NewVersionMiddleware creates a new VersionMiddleware
+// with the default versions.
+func NewVersionMiddleware(s, d, m version.Version) VersionMiddleware {
+	return VersionMiddleware{
+		serverVersion:  s,
+		defaultVersion: d,
+		minVersion:     m,
+	}
+}
+
+// WrapHandler returns a new handler function wrapping the previous one in the request chain.
+func (v VersionMiddleware) WrapHandler(handler func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error) func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+	return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+		apiVersion := version.Version(vars["version"])
+		if apiVersion == "" {
+			apiVersion = v.defaultVersion
 		}
+
+		if apiVersion.GreaterThan(v.defaultVersion) {
+			return badRequestError{fmt.Errorf("client is newer than server (client API version: %s, server API version: %s)", apiVersion, v.defaultVersion)}
+		}
+		if apiVersion.LessThan(v.minVersion) {
+			return badRequestError{fmt.Errorf("client version %s is too old. Minimum supported API version is %s, please upgrade your client to a newer version", apiVersion, v.minVersion)}
+		}
+
+		header := fmt.Sprintf("Docker/%s (%s)", v.serverVersion, runtime.GOOS)
+		w.Header().Set("Server", header)
+		ctx = context.WithValue(ctx, "api-version", apiVersion)
+		return handler(ctx, w, r, vars)
 	}
+
 }

+ 4 - 4
api/server/middleware/version_test.go

@@ -21,8 +21,8 @@ func TestVersionMiddleware(t *testing.T) {
 
 	defaultVersion := version.Version("1.10.0")
 	minVersion := version.Version("1.2.0")
-	m := NewVersionMiddleware(defaultVersion.String(), defaultVersion, minVersion)
-	h := m(handler)
+	m := NewVersionMiddleware(defaultVersion, defaultVersion, minVersion)
+	h := m.WrapHandler(handler)
 
 	req, _ := http.NewRequest("GET", "/containers/json", nil)
 	resp := httptest.NewRecorder()
@@ -42,8 +42,8 @@ func TestVersionMiddlewareWithErrors(t *testing.T) {
 
 	defaultVersion := version.Version("1.10.0")
 	minVersion := version.Version("1.2.0")
-	m := NewVersionMiddleware(defaultVersion.String(), defaultVersion, minVersion)
-	h := m(handler)
+	m := NewVersionMiddleware(defaultVersion, defaultVersion, minVersion)
+	h := m.WrapHandler(handler)
 
 	req, _ := http.NewRequest("GET", "/containers/json", nil)
 	resp := httptest.NewRecorder()

+ 14 - 9
api/server/server.go

@@ -8,8 +8,8 @@ import (
 
 	"github.com/Sirupsen/logrus"
 	"github.com/docker/docker/api/server/httputils"
+	"github.com/docker/docker/api/server/middleware"
 	"github.com/docker/docker/api/server/router"
-	"github.com/docker/docker/pkg/authorization"
 	"github.com/gorilla/mux"
 	"golang.org/x/net/context"
 )
@@ -20,13 +20,12 @@ const versionMatcher = "/v{version:[0-9.]+}"
 
 // Config provides the configuration for the API server
 type Config struct {
-	Logging                  bool
-	EnableCors               bool
-	CorsHeaders              string
-	AuthorizationPluginNames []string
-	Version                  string
-	SocketGroup              string
-	TLSConfig                *tls.Config
+	Logging     bool
+	EnableCors  bool
+	CorsHeaders string
+	Version     string
+	SocketGroup string
+	TLSConfig   *tls.Config
 }
 
 // Server contains instance details for the server
@@ -34,8 +33,8 @@ type Server struct {
 	cfg           *Config
 	servers       []*HTTPServer
 	routers       []router.Router
-	authZPlugins  []authorization.Plugin
 	routerSwapper *routerSwapper
+	middlewares   []middleware.Middleware
 }
 
 // New returns a new instance of the server based on the specified configuration.
@@ -46,6 +45,12 @@ func New(cfg *Config) *Server {
 	}
 }
 
+// UseMiddleware appends a new middleware to the request chain.
+// This needs to be called before the API routes are configured.
+func (s *Server) UseMiddleware(m middleware.Middleware) {
+	s.middlewares = append(s.middlewares, m)
+}
+
 // Accept sets a listener the server accepts connections into.
 func (s *Server) Accept(addr string, listeners ...net.Listener) {
 	for _, listener := range listeners {

+ 5 - 0
api/server/server_test.go

@@ -6,7 +6,10 @@ import (
 	"strings"
 	"testing"
 
+	"github.com/docker/docker/api"
 	"github.com/docker/docker/api/server/httputils"
+	"github.com/docker/docker/api/server/middleware"
+	"github.com/docker/docker/pkg/version"
 
 	"golang.org/x/net/context"
 )
@@ -19,6 +22,8 @@ func TestMiddlewares(t *testing.T) {
 		cfg: cfg,
 	}
 
+	srv.UseMiddleware(middleware.NewVersionMiddleware(version.Version("0.1omega2"), api.DefaultVersion, api.MinVersion))
+
 	req, _ := http.NewRequest("GET", "/containers/json", nil)
 	resp := httptest.NewRecorder()
 	ctx := context.Background()

+ 29 - 4
docker/daemon.go

@@ -14,7 +14,9 @@ import (
 
 	"github.com/Sirupsen/logrus"
 	"github.com/docker/distribution/uuid"
+	"github.com/docker/docker/api"
 	apiserver "github.com/docker/docker/api/server"
+	"github.com/docker/docker/api/server/middleware"
 	"github.com/docker/docker/api/server/router"
 	"github.com/docker/docker/api/server/router/build"
 	"github.com/docker/docker/api/server/router/container"
@@ -29,12 +31,14 @@ import (
 	"github.com/docker/docker/dockerversion"
 	"github.com/docker/docker/libcontainerd"
 	"github.com/docker/docker/opts"
+	"github.com/docker/docker/pkg/authorization"
 	"github.com/docker/docker/pkg/jsonlog"
 	"github.com/docker/docker/pkg/listeners"
 	flag "github.com/docker/docker/pkg/mflag"
 	"github.com/docker/docker/pkg/pidfile"
 	"github.com/docker/docker/pkg/signal"
 	"github.com/docker/docker/pkg/system"
+	"github.com/docker/docker/pkg/version"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/runconfig"
 	"github.com/docker/docker/utils"
@@ -208,10 +212,9 @@ func (cli *DaemonCli) CmdDaemon(args ...string) error {
 	}
 
 	serverConfig := &apiserver.Config{
-		AuthorizationPluginNames: cli.Config.AuthorizationPlugins,
-		Logging:                  true,
-		SocketGroup:              cli.Config.SocketGroup,
-		Version:                  dockerversion.Version,
+		Logging:     true,
+		SocketGroup: cli.Config.SocketGroup,
+		Version:     dockerversion.Version,
 	}
 	serverConfig = setPlatformServerConfig(serverConfig, cli.Config)
 
@@ -301,6 +304,7 @@ func (cli *DaemonCli) CmdDaemon(args ...string) error {
 		"graphdriver": d.GraphDriverName(),
 	}).Info("Docker daemon")
 
+	cli.initMiddlewares(api, serverConfig)
 	initRouter(api, d)
 
 	reload := func(config *daemon.Config) {
@@ -433,3 +437,24 @@ func initRouter(s *apiserver.Server, d *daemon.Daemon) {
 
 	s.InitRouter(utils.IsDebugEnabled(), routers...)
 }
+
+func (cli *DaemonCli) initMiddlewares(s *apiserver.Server, cfg *apiserver.Config) {
+	v := version.Version(cfg.Version)
+
+	vm := middleware.NewVersionMiddleware(v, api.DefaultVersion, api.MinVersion)
+	s.UseMiddleware(vm)
+
+	if cfg.EnableCors {
+		c := middleware.NewCORSMiddleware(cfg.CorsHeaders)
+		s.UseMiddleware(c)
+	}
+
+	u := middleware.NewUserAgentMiddleware(v)
+	s.UseMiddleware(u)
+
+	if len(cli.Config.AuthorizationPlugins) > 0 {
+		authZPlugins := authorization.NewPlugins(cli.Config.AuthorizationPlugins)
+		handleAuthorization := authorization.NewMiddleware(authZPlugins)
+		s.UseMiddleware(handleAuthorization)
+	}
+}

+ 60 - 0
pkg/authorization/middleware.go

@@ -0,0 +1,60 @@
+package authorization
+
+import (
+	"net/http"
+
+	"github.com/Sirupsen/logrus"
+	"golang.org/x/net/context"
+)
+
+// Middleware uses a list of plugins to
+// handle authorization in the API requests.
+type Middleware struct {
+	plugins []Plugin
+}
+
+// NewMiddleware creates a new Middleware
+// with a slice of plugins.
+func NewMiddleware(p []Plugin) Middleware {
+	return Middleware{
+		plugins: p,
+	}
+}
+
+// WrapHandler returns a new handler function wrapping the previous one in the request chain.
+func (m Middleware) WrapHandler(handler func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error) func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+	return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+
+		user := ""
+		userAuthNMethod := ""
+
+		// Default authorization using existing TLS connection credentials
+		// FIXME: Non trivial authorization mechanisms (such as advanced certificate validations, kerberos support
+		// and ldap) will be extracted using AuthN feature, which is tracked under:
+		// https://github.com/docker/docker/pull/20883
+		if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
+			user = r.TLS.PeerCertificates[0].Subject.CommonName
+			userAuthNMethod = "TLS"
+		}
+
+		authCtx := NewCtx(m.plugins, user, userAuthNMethod, r.Method, r.RequestURI)
+
+		if err := authCtx.AuthZRequest(w, r); err != nil {
+			logrus.Errorf("AuthZRequest for %s %s returned error: %s", r.Method, r.RequestURI, err)
+			return err
+		}
+
+		rw := NewResponseModifier(w)
+
+		if err := handler(ctx, rw, r, vars); err != nil {
+			logrus.Errorf("Handler for %s %s returned error: %s", r.Method, r.RequestURI, err)
+			return err
+		}
+
+		if err := authCtx.AuthZResponse(rw, r); err != nil {
+			logrus.Errorf("AuthZResponse for %s %s returned error: %s", r.Method, r.RequestURI, err)
+			return err
+		}
+		return nil
+	}
+}