Browse Source

Merge pull request #16331 from calavera/server_middlewares

Organize server pre-func logic in middlewares.
Phil Estes 9 years ago
parent
commit
9dc0973655
5 changed files with 274 additions and 59 deletions
  1. 130 0
      api/server/middleware.go
  2. 74 0
      api/server/middleware_test.go
  3. 8 59
      api/server/server.go
  4. 35 0
      api/server/server_test.go
  5. 27 0
      errors/server.go

+ 130 - 0
api/server/middleware.go

@@ -0,0 +1,130 @@
+package server
+
+import (
+	"net/http"
+	"runtime"
+	"strings"
+
+	"github.com/Sirupsen/logrus"
+	"github.com/docker/docker/api"
+	"github.com/docker/docker/autogen/dockerversion"
+	"github.com/docker/docker/context"
+	"github.com/docker/docker/errors"
+	"github.com/docker/docker/pkg/stringid"
+	"github.com/docker/docker/pkg/version"
+)
+
+// middleware is an adapter to allow the use of ordinary functions as Docker API filters.
+// Any function that has the appropriate signature can be register as a middleware.
+type middleware func(handler HTTPAPIFunc) HTTPAPIFunc
+
+// loggingMiddleware logs each request when logging is enabled.
+func (s *Server) loggingMiddleware(handler HTTPAPIFunc) HTTPAPIFunc {
+	return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+		if s.cfg.Logging {
+			logrus.Infof("%s %s", r.Method, r.RequestURI)
+		}
+		return handler(ctx, w, r, vars)
+	}
+}
+
+// requestIDMiddleware generates a uniq ID for each request.
+// This ID travels inside the context for tracing purposes.
+func requestIDMiddleware(handler HTTPAPIFunc) HTTPAPIFunc {
+	return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+		reqID := stringid.TruncateID(stringid.GenerateNonCryptoID())
+		ctx = context.WithValue(ctx, context.RequestID, reqID)
+		return handler(ctx, w, r, vars)
+	}
+}
+
+// userAgentMiddleware checks the User-Agent header looking for a valid docker client spec.
+func (s *Server) userAgentMiddleware(handler HTTPAPIFunc) HTTPAPIFunc {
+	return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+		if strings.Contains(r.Header.Get("User-Agent"), "Docker-Client/") {
+			dockerVersion := version.Version(s.cfg.Version)
+
+			userAgent := strings.Split(r.Header.Get("User-Agent"), "/")
+
+			// v1.20 onwards includes the GOOS of the client after the version
+			// such as Docker/1.7.0 (linux)
+			if len(userAgent) == 2 && strings.Contains(userAgent[1], " ") {
+				userAgent[1] = strings.Split(userAgent[1], " ")[0]
+			}
+
+			if len(userAgent) == 2 && !dockerVersion.Equal(version.Version(userAgent[1])) {
+				logrus.Debugf("Warning: client and server don't have the same version (client: %s, server: %s)", userAgent[1], dockerVersion)
+			}
+		}
+		return handler(ctx, w, r, vars)
+	}
+}
+
+// corsMiddleware sets the CORS header expectations in the server.
+func (s *Server) corsMiddleware(handler HTTPAPIFunc) HTTPAPIFunc {
+	return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+		// If "api-cors-header" is not given, but "api-enable-cors" is true, we set cors to "*"
+		// otherwise, all head values will be passed to HTTP handler
+		corsHeaders := s.cfg.CorsHeaders
+		if corsHeaders == "" && s.cfg.EnableCors {
+			corsHeaders = "*"
+		}
+
+		if corsHeaders != "" {
+			writeCorsHeaders(w, r, corsHeaders)
+		}
+		return handler(ctx, w, r, vars)
+	}
+}
+
+// versionMiddleware checks the api version requirements before passing the request to the server handler.
+func versionMiddleware(handler HTTPAPIFunc) HTTPAPIFunc {
+	return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+		apiVersion := version.Version(vars["version"])
+		if apiVersion == "" {
+			apiVersion = api.Version
+		}
+
+		if apiVersion.GreaterThan(api.Version) {
+			return errors.ErrorCodeNewerClientVersion.WithArgs(apiVersion, api.Version)
+		}
+		if apiVersion.LessThan(api.MinVersion) {
+			return errors.ErrorCodeOldClientVersion.WithArgs(apiVersion, api.Version)
+		}
+
+		w.Header().Set("Server", "Docker/"+dockerversion.VERSION+" ("+runtime.GOOS+")")
+		ctx = context.WithValue(ctx, context.APIVersion, apiVersion)
+		return handler(ctx, w, r, vars)
+	}
+}
+
+// handleWithGlobalMiddlwares wraps the handler function for a request with
+// the server's global middlewares. The order of the middlewares is backwards,
+// meaning that the first in the list will be evaludated last.
+//
+// Example: handleWithGlobalMiddlewares(s.getContainersName)
+//
+// requestIDMiddlware(
+//	s.loggingMiddleware(
+//		s.userAgentMiddleware(
+//			s.corsMiddleware(
+//				versionMiddleware(s.getContainersName)
+//			)
+//		)
+//	)
+// )
+func (s *Server) handleWithGlobalMiddlewares(handler HTTPAPIFunc) HTTPAPIFunc {
+	middlewares := []middleware{
+		versionMiddleware,
+		s.corsMiddleware,
+		s.userAgentMiddleware,
+		s.loggingMiddleware,
+		requestIDMiddleware,
+	}
+
+	h := handler
+	for _, m := range middlewares {
+		h = m(h)
+	}
+	return h
+}

+ 74 - 0
api/server/middleware_test.go

@@ -0,0 +1,74 @@
+package server
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	"github.com/docker/distribution/registry/api/errcode"
+	"github.com/docker/docker/context"
+	"github.com/docker/docker/errors"
+)
+
+func TestVersionMiddleware(t *testing.T) {
+	handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+		if ctx.Version() == "" {
+			t.Fatalf("Expected version, got empty string")
+		}
+		return nil
+	}
+
+	h := versionMiddleware(handler)
+
+	req, _ := http.NewRequest("GET", "/containers/json", nil)
+	resp := httptest.NewRecorder()
+	ctx := context.Background()
+	if err := h(ctx, resp, req, map[string]string{}); err != nil {
+		t.Fatal(err)
+	}
+}
+
+func TestVersionMiddlewareWithErrors(t *testing.T) {
+	handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+		if ctx.Version() == "" {
+			t.Fatalf("Expected version, got empty string")
+		}
+		return nil
+	}
+
+	h := versionMiddleware(handler)
+
+	req, _ := http.NewRequest("GET", "/containers/json", nil)
+	resp := httptest.NewRecorder()
+	ctx := context.Background()
+
+	vars := map[string]string{"version": "0.1"}
+	err := h(ctx, resp, req, vars)
+	if derr, ok := err.(errcode.Error); !ok || derr.ErrorCode() != errors.ErrorCodeOldClientVersion {
+		t.Fatalf("Expected ErrorCodeOldClientVersion, got %v", err)
+	}
+
+	vars["version"] = "100000"
+	err = h(ctx, resp, req, vars)
+	if derr, ok := err.(errcode.Error); !ok || derr.ErrorCode() != errors.ErrorCodeNewerClientVersion {
+		t.Fatalf("Expected ErrorCodeNewerClientVersion, got %v", err)
+	}
+}
+
+func TestRequestIDMiddleware(t *testing.T) {
+	handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+		if ctx.RequestID() == "" {
+			t.Fatalf("Expected request-id, got empty string")
+		}
+		return nil
+	}
+
+	h := requestIDMiddleware(handler)
+
+	req, _ := http.NewRequest("GET", "/containers/json", nil)
+	resp := httptest.NewRecorder()
+	ctx := context.Background()
+	if err := h(ctx, resp, req, map[string]string{}); err != nil {
+		t.Fatal(err)
+	}
+}

+ 8 - 59
api/server/server.go

@@ -8,7 +8,6 @@ import (
 	"net"
 	"net"
 	"net/http"
 	"net/http"
 	"os"
 	"os"
-	"runtime"
 	"strings"
 	"strings"
 
 
 	"github.com/gorilla/mux"
 	"github.com/gorilla/mux"
@@ -16,12 +15,9 @@ import (
 	"github.com/Sirupsen/logrus"
 	"github.com/Sirupsen/logrus"
 	"github.com/docker/distribution/registry/api/errcode"
 	"github.com/docker/distribution/registry/api/errcode"
 	"github.com/docker/docker/api"
 	"github.com/docker/docker/api"
-	"github.com/docker/docker/autogen/dockerversion"
 	"github.com/docker/docker/context"
 	"github.com/docker/docker/context"
 	"github.com/docker/docker/daemon"
 	"github.com/docker/docker/daemon"
 	"github.com/docker/docker/pkg/sockets"
 	"github.com/docker/docker/pkg/sockets"
-	"github.com/docker/docker/pkg/stringid"
-	"github.com/docker/docker/pkg/version"
 )
 )
 
 
 // Config provides the configuration for the API server
 // Config provides the configuration for the API server
@@ -49,8 +45,7 @@ func New(cfg *Config) *Server {
 		cfg:   cfg,
 		cfg:   cfg,
 		start: make(chan struct{}),
 		start: make(chan struct{}),
 	}
 	}
-	r := createRouter(srv)
-	srv.router = r
+	srv.router = createRouter(srv)
 	return srv
 	return srv
 }
 }
 
 
@@ -294,8 +289,11 @@ func (s *Server) initTCPSocket(addr string) (l net.Listener, err error) {
 	return
 	return
 }
 }
 
 
-func makeHTTPHandler(logging bool, localMethod string, localRoute string, handlerFunc HTTPAPIFunc, corsHeaders string, dockerVersion version.Version) http.HandlerFunc {
+func (s *Server) makeHTTPHandler(localMethod string, localRoute string, localHandler HTTPAPIFunc) http.HandlerFunc {
 	return func(w http.ResponseWriter, r *http.Request) {
 	return func(w http.ResponseWriter, r *http.Request) {
+		// log the handler generation
+		logrus.Debugf("Calling %s %s", localMethod, localRoute)
+
 		// Define the context that we'll pass around to share info
 		// Define the context that we'll pass around to share info
 		// like the docker-request-id.
 		// like the docker-request-id.
 		//
 		//
@@ -303,51 +301,8 @@ func makeHTTPHandler(logging bool, localMethod string, localRoute string, handle
 		// apply to all requests. Data that is specific to the
 		// apply to all requests. Data that is specific to the
 		// immediate function being called should still be passed
 		// immediate function being called should still be passed
 		// as 'args' on the function call.
 		// as 'args' on the function call.
-
-		reqID := stringid.TruncateID(stringid.GenerateNonCryptoID())
-		apiVersion := version.Version(mux.Vars(r)["version"])
-		if apiVersion == "" {
-			apiVersion = api.Version
-		}
-
 		ctx := context.Background()
 		ctx := context.Background()
-		ctx = context.WithValue(ctx, context.RequestID, reqID)
-		ctx = context.WithValue(ctx, context.APIVersion, apiVersion)
-
-		// log the request
-		logrus.Debugf("Calling %s %s", localMethod, localRoute)
-
-		if logging {
-			logrus.Infof("%s %s", r.Method, r.RequestURI)
-		}
-
-		if strings.Contains(r.Header.Get("User-Agent"), "Docker-Client/") {
-			userAgent := strings.Split(r.Header.Get("User-Agent"), "/")
-
-			// v1.20 onwards includes the GOOS of the client after the version
-			// such as Docker/1.7.0 (linux)
-			if len(userAgent) == 2 && strings.Contains(userAgent[1], " ") {
-				userAgent[1] = strings.Split(userAgent[1], " ")[0]
-			}
-
-			if len(userAgent) == 2 && !dockerVersion.Equal(version.Version(userAgent[1])) {
-				logrus.Debugf("Warning: client and server don't have the same version (client: %s, server: %s)", userAgent[1], dockerVersion)
-			}
-		}
-		if corsHeaders != "" {
-			writeCorsHeaders(w, r, corsHeaders)
-		}
-
-		if apiVersion.GreaterThan(api.Version) {
-			http.Error(w, fmt.Errorf("client is newer than server (client API version: %s, server API version: %s)", apiVersion, api.Version).Error(), http.StatusBadRequest)
-			return
-		}
-		if apiVersion.LessThan(api.MinVersion) {
-			http.Error(w, fmt.Errorf("client is too old, minimum supported API version is %s, please upgrade your client to a newer version", api.MinVersion).Error(), http.StatusBadRequest)
-			return
-		}
-
-		w.Header().Set("Server", "Docker/"+dockerversion.VERSION+" ("+runtime.GOOS+")")
+		handlerFunc := s.handleWithGlobalMiddlewares(localHandler)
 
 
 		if err := handlerFunc(ctx, w, r, mux.Vars(r)); err != nil {
 		if err := handlerFunc(ctx, w, r, mux.Vars(r)); err != nil {
 			logrus.Errorf("Handler for %s %s returned error: %s", localMethod, localRoute, err)
 			logrus.Errorf("Handler for %s %s returned error: %s", localMethod, localRoute, err)
@@ -356,6 +311,7 @@ func makeHTTPHandler(logging bool, localMethod string, localRoute string, handle
 	}
 	}
 }
 }
 
 
+// createRouter initializes the main router the server uses.
 // we keep enableCors just for legacy usage, need to be removed in the future
 // we keep enableCors just for legacy usage, need to be removed in the future
 func createRouter(s *Server) *mux.Router {
 func createRouter(s *Server) *mux.Router {
 	r := mux.NewRouter()
 	r := mux.NewRouter()
@@ -428,13 +384,6 @@ func createRouter(s *Server) *mux.Router {
 		},
 		},
 	}
 	}
 
 
-	// If "api-cors-header" is not given, but "api-enable-cors" is true, we set cors to "*"
-	// otherwise, all head values will be passed to HTTP handler
-	corsHeaders := s.cfg.CorsHeaders
-	if corsHeaders == "" && s.cfg.EnableCors {
-		corsHeaders = "*"
-	}
-
 	for method, routes := range m {
 	for method, routes := range m {
 		for route, fct := range routes {
 		for route, fct := range routes {
 			logrus.Debugf("Registering %s, %s", method, route)
 			logrus.Debugf("Registering %s, %s", method, route)
@@ -444,7 +393,7 @@ func createRouter(s *Server) *mux.Router {
 			localMethod := method
 			localMethod := method
 
 
 			// build the handler function
 			// build the handler function
-			f := makeHTTPHandler(s.cfg.Logging, localMethod, localRoute, localFct, corsHeaders, version.Version(s.cfg.Version))
+			f := s.makeHTTPHandler(localMethod, localRoute, localFct)
 
 
 			// add the new route
 			// add the new route
 			if localRoute == "" {
 			if localRoute == "" {

+ 35 - 0
api/server/server_test.go

@@ -0,0 +1,35 @@
+package server
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	"github.com/docker/docker/context"
+)
+
+func TestMiddlewares(t *testing.T) {
+	cfg := &Config{}
+	srv := &Server{
+		cfg: cfg,
+	}
+
+	req, _ := http.NewRequest("GET", "/containers/json", nil)
+	resp := httptest.NewRecorder()
+	ctx := context.Background()
+
+	localHandler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+		if ctx.Version() == "" {
+			t.Fatalf("Expected version, got empty string")
+		}
+		if ctx.RequestID() == "" {
+			t.Fatalf("Expected request-id, got empty string")
+		}
+		return nil
+	}
+
+	handlerFunc := srv.handleWithGlobalMiddlewares(localHandler)
+	if err := handlerFunc(ctx, resp, req, map[string]string{}); err != nil {
+		t.Fatal(err)
+	}
+}

+ 27 - 0
errors/server.go

@@ -0,0 +1,27 @@
+package errors
+
+import (
+	"net/http"
+
+	"github.com/docker/distribution/registry/api/errcode"
+)
+
+var (
+	// ErrorCodeNewerClientVersion is generated when a request from a client
+	// specifies a higher version than the server supports.
+	ErrorCodeNewerClientVersion = errcode.Register(errGroup, errcode.ErrorDescriptor{
+		Value:          "NEWERCLIENTVERSION",
+		Message:        "client is newer than server (client API version: %s, server API version: %s)",
+		Description:    "The client version is higher than the server version",
+		HTTPStatusCode: http.StatusBadRequest,
+	})
+
+	// ErrorCodeOldClientVersion is generated when a request from a client
+	// specifies a version lower than the minimum version supported by the server.
+	ErrorCodeOldClientVersion = errcode.Register(errGroup, errcode.ErrorDescriptor{
+		Value:          "OLDCLIENTVERSION",
+		Message:        "client version %s is too old. Minimum supported API version is %s, please upgrade your client to a newer version",
+		Description:    "The client version is too old for the server",
+		HTTPStatusCode: http.StatusBadRequest,
+	})
+)