From 0fea04d27ee91d7b57e0a77b110db1c861768c74 Mon Sep 17 00:00:00 2001 From: David Calavera Date: Tue, 15 Sep 2015 19:01:49 -0400 Subject: [PATCH] Organize server pre-func logic in middlewares. It defines global middlewares for every request. This makes the server slightly more composable. Signed-off-by: David Calavera --- api/server/middleware.go | 130 ++++++++++++++++++++++++++++++++++ api/server/middleware_test.go | 74 +++++++++++++++++++ api/server/server.go | 67 +++--------------- api/server/server_test.go | 35 +++++++++ errors/server.go | 27 +++++++ 5 files changed, 274 insertions(+), 59 deletions(-) create mode 100644 api/server/middleware.go create mode 100644 api/server/middleware_test.go create mode 100644 api/server/server_test.go create mode 100644 errors/server.go diff --git a/api/server/middleware.go b/api/server/middleware.go new file mode 100644 index 0000000000..10683f39aa --- /dev/null +++ b/api/server/middleware.go @@ -0,0 +1,130 @@ +package server + +import ( + "net/http" + "runtime" + "strings" + + "github.com/Sirupsen/logrus" + "github.com/docker/docker/api" + "github.com/docker/docker/autogen/dockerversion" + "github.com/docker/docker/context" + "github.com/docker/docker/errors" + "github.com/docker/docker/pkg/stringid" + "github.com/docker/docker/pkg/version" +) + +// middleware is an adapter to allow the use of ordinary functions as Docker API filters. +// Any function that has the appropriate signature can be register as a middleware. +type middleware func(handler HTTPAPIFunc) HTTPAPIFunc + +// loggingMiddleware logs each request when logging is enabled. +func (s *Server) loggingMiddleware(handler HTTPAPIFunc) HTTPAPIFunc { + return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + if s.cfg.Logging { + logrus.Infof("%s %s", r.Method, r.RequestURI) + } + return handler(ctx, w, r, vars) + } +} + +// requestIDMiddleware generates a uniq ID for each request. +// This ID travels inside the context for tracing purposes. +func requestIDMiddleware(handler HTTPAPIFunc) HTTPAPIFunc { + return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + reqID := stringid.TruncateID(stringid.GenerateNonCryptoID()) + ctx = context.WithValue(ctx, context.RequestID, reqID) + return handler(ctx, w, r, vars) + } +} + +// userAgentMiddleware checks the User-Agent header looking for a valid docker client spec. +func (s *Server) userAgentMiddleware(handler HTTPAPIFunc) HTTPAPIFunc { + return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + if strings.Contains(r.Header.Get("User-Agent"), "Docker-Client/") { + dockerVersion := version.Version(s.cfg.Version) + + userAgent := strings.Split(r.Header.Get("User-Agent"), "/") + + // v1.20 onwards includes the GOOS of the client after the version + // such as Docker/1.7.0 (linux) + if len(userAgent) == 2 && strings.Contains(userAgent[1], " ") { + userAgent[1] = strings.Split(userAgent[1], " ")[0] + } + + if len(userAgent) == 2 && !dockerVersion.Equal(version.Version(userAgent[1])) { + logrus.Debugf("Warning: client and server don't have the same version (client: %s, server: %s)", userAgent[1], dockerVersion) + } + } + return handler(ctx, w, r, vars) + } +} + +// corsMiddleware sets the CORS header expectations in the server. +func (s *Server) corsMiddleware(handler HTTPAPIFunc) HTTPAPIFunc { + return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + // If "api-cors-header" is not given, but "api-enable-cors" is true, we set cors to "*" + // otherwise, all head values will be passed to HTTP handler + corsHeaders := s.cfg.CorsHeaders + if corsHeaders == "" && s.cfg.EnableCors { + corsHeaders = "*" + } + + if corsHeaders != "" { + writeCorsHeaders(w, r, corsHeaders) + } + return handler(ctx, w, r, vars) + } +} + +// versionMiddleware checks the api version requirements before passing the request to the server handler. +func versionMiddleware(handler HTTPAPIFunc) HTTPAPIFunc { + return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + apiVersion := version.Version(vars["version"]) + if apiVersion == "" { + apiVersion = api.Version + } + + if apiVersion.GreaterThan(api.Version) { + return errors.ErrorCodeNewerClientVersion.WithArgs(apiVersion, api.Version) + } + if apiVersion.LessThan(api.MinVersion) { + return errors.ErrorCodeOldClientVersion.WithArgs(apiVersion, api.Version) + } + + w.Header().Set("Server", "Docker/"+dockerversion.VERSION+" ("+runtime.GOOS+")") + ctx = context.WithValue(ctx, context.APIVersion, apiVersion) + return handler(ctx, w, r, vars) + } +} + +// handleWithGlobalMiddlwares wraps the handler function for a request with +// the server's global middlewares. The order of the middlewares is backwards, +// meaning that the first in the list will be evaludated last. +// +// Example: handleWithGlobalMiddlewares(s.getContainersName) +// +// requestIDMiddlware( +// s.loggingMiddleware( +// s.userAgentMiddleware( +// s.corsMiddleware( +// versionMiddleware(s.getContainersName) +// ) +// ) +// ) +// ) +func (s *Server) handleWithGlobalMiddlewares(handler HTTPAPIFunc) HTTPAPIFunc { + middlewares := []middleware{ + versionMiddleware, + s.corsMiddleware, + s.userAgentMiddleware, + s.loggingMiddleware, + requestIDMiddleware, + } + + h := handler + for _, m := range middlewares { + h = m(h) + } + return h +} diff --git a/api/server/middleware_test.go b/api/server/middleware_test.go new file mode 100644 index 0000000000..eda4588b9e --- /dev/null +++ b/api/server/middleware_test.go @@ -0,0 +1,74 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/docker/distribution/registry/api/errcode" + "github.com/docker/docker/context" + "github.com/docker/docker/errors" +) + +func TestVersionMiddleware(t *testing.T) { + handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + if ctx.Version() == "" { + t.Fatalf("Expected version, got empty string") + } + return nil + } + + h := versionMiddleware(handler) + + req, _ := http.NewRequest("GET", "/containers/json", nil) + resp := httptest.NewRecorder() + ctx := context.Background() + if err := h(ctx, resp, req, map[string]string{}); err != nil { + t.Fatal(err) + } +} + +func TestVersionMiddlewareWithErrors(t *testing.T) { + handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + if ctx.Version() == "" { + t.Fatalf("Expected version, got empty string") + } + return nil + } + + h := versionMiddleware(handler) + + req, _ := http.NewRequest("GET", "/containers/json", nil) + resp := httptest.NewRecorder() + ctx := context.Background() + + vars := map[string]string{"version": "0.1"} + err := h(ctx, resp, req, vars) + if derr, ok := err.(errcode.Error); !ok || derr.ErrorCode() != errors.ErrorCodeOldClientVersion { + t.Fatalf("Expected ErrorCodeOldClientVersion, got %v", err) + } + + vars["version"] = "100000" + err = h(ctx, resp, req, vars) + if derr, ok := err.(errcode.Error); !ok || derr.ErrorCode() != errors.ErrorCodeNewerClientVersion { + t.Fatalf("Expected ErrorCodeNewerClientVersion, got %v", err) + } +} + +func TestRequestIDMiddleware(t *testing.T) { + handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + if ctx.RequestID() == "" { + t.Fatalf("Expected request-id, got empty string") + } + return nil + } + + h := requestIDMiddleware(handler) + + req, _ := http.NewRequest("GET", "/containers/json", nil) + resp := httptest.NewRecorder() + ctx := context.Background() + if err := h(ctx, resp, req, map[string]string{}); err != nil { + t.Fatal(err) + } +} diff --git a/api/server/server.go b/api/server/server.go index b47bc580ed..8795838d0f 100644 --- a/api/server/server.go +++ b/api/server/server.go @@ -8,7 +8,6 @@ import ( "net" "net/http" "os" - "runtime" "strings" "github.com/gorilla/mux" @@ -16,12 +15,9 @@ import ( "github.com/Sirupsen/logrus" "github.com/docker/distribution/registry/api/errcode" "github.com/docker/docker/api" - "github.com/docker/docker/autogen/dockerversion" "github.com/docker/docker/context" "github.com/docker/docker/daemon" "github.com/docker/docker/pkg/sockets" - "github.com/docker/docker/pkg/stringid" - "github.com/docker/docker/pkg/version" ) // Config provides the configuration for the API server @@ -49,8 +45,7 @@ func New(cfg *Config) *Server { cfg: cfg, start: make(chan struct{}), } - r := createRouter(srv) - srv.router = r + srv.router = createRouter(srv) return srv } @@ -294,8 +289,11 @@ func (s *Server) initTCPSocket(addr string) (l net.Listener, err error) { return } -func makeHTTPHandler(logging bool, localMethod string, localRoute string, handlerFunc HTTPAPIFunc, corsHeaders string, dockerVersion version.Version) http.HandlerFunc { +func (s *Server) makeHTTPHandler(localMethod string, localRoute string, localHandler HTTPAPIFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + // log the handler generation + logrus.Debugf("Calling %s %s", localMethod, localRoute) + // Define the context that we'll pass around to share info // like the docker-request-id. // @@ -303,51 +301,8 @@ func makeHTTPHandler(logging bool, localMethod string, localRoute string, handle // apply to all requests. Data that is specific to the // immediate function being called should still be passed // as 'args' on the function call. - - reqID := stringid.TruncateID(stringid.GenerateNonCryptoID()) - apiVersion := version.Version(mux.Vars(r)["version"]) - if apiVersion == "" { - apiVersion = api.Version - } - ctx := context.Background() - ctx = context.WithValue(ctx, context.RequestID, reqID) - ctx = context.WithValue(ctx, context.APIVersion, apiVersion) - - // log the request - logrus.Debugf("Calling %s %s", localMethod, localRoute) - - if logging { - logrus.Infof("%s %s", r.Method, r.RequestURI) - } - - if strings.Contains(r.Header.Get("User-Agent"), "Docker-Client/") { - userAgent := strings.Split(r.Header.Get("User-Agent"), "/") - - // v1.20 onwards includes the GOOS of the client after the version - // such as Docker/1.7.0 (linux) - if len(userAgent) == 2 && strings.Contains(userAgent[1], " ") { - userAgent[1] = strings.Split(userAgent[1], " ")[0] - } - - if len(userAgent) == 2 && !dockerVersion.Equal(version.Version(userAgent[1])) { - logrus.Debugf("Warning: client and server don't have the same version (client: %s, server: %s)", userAgent[1], dockerVersion) - } - } - if corsHeaders != "" { - writeCorsHeaders(w, r, corsHeaders) - } - - if apiVersion.GreaterThan(api.Version) { - http.Error(w, fmt.Errorf("client is newer than server (client API version: %s, server API version: %s)", apiVersion, api.Version).Error(), http.StatusBadRequest) - return - } - if apiVersion.LessThan(api.MinVersion) { - http.Error(w, fmt.Errorf("client is too old, minimum supported API version is %s, please upgrade your client to a newer version", api.MinVersion).Error(), http.StatusBadRequest) - return - } - - w.Header().Set("Server", "Docker/"+dockerversion.VERSION+" ("+runtime.GOOS+")") + handlerFunc := s.handleWithGlobalMiddlewares(localHandler) if err := handlerFunc(ctx, w, r, mux.Vars(r)); err != nil { logrus.Errorf("Handler for %s %s returned error: %s", localMethod, localRoute, err) @@ -356,6 +311,7 @@ func makeHTTPHandler(logging bool, localMethod string, localRoute string, handle } } +// createRouter initializes the main router the server uses. // we keep enableCors just for legacy usage, need to be removed in the future func createRouter(s *Server) *mux.Router { r := mux.NewRouter() @@ -428,13 +384,6 @@ func createRouter(s *Server) *mux.Router { }, } - // If "api-cors-header" is not given, but "api-enable-cors" is true, we set cors to "*" - // otherwise, all head values will be passed to HTTP handler - corsHeaders := s.cfg.CorsHeaders - if corsHeaders == "" && s.cfg.EnableCors { - corsHeaders = "*" - } - for method, routes := range m { for route, fct := range routes { logrus.Debugf("Registering %s, %s", method, route) @@ -444,7 +393,7 @@ func createRouter(s *Server) *mux.Router { localMethod := method // build the handler function - f := makeHTTPHandler(s.cfg.Logging, localMethod, localRoute, localFct, corsHeaders, version.Version(s.cfg.Version)) + f := s.makeHTTPHandler(localMethod, localRoute, localFct) // add the new route if localRoute == "" { diff --git a/api/server/server_test.go b/api/server/server_test.go new file mode 100644 index 0000000000..3d1a613976 --- /dev/null +++ b/api/server/server_test.go @@ -0,0 +1,35 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/docker/docker/context" +) + +func TestMiddlewares(t *testing.T) { + cfg := &Config{} + srv := &Server{ + cfg: cfg, + } + + req, _ := http.NewRequest("GET", "/containers/json", nil) + resp := httptest.NewRecorder() + ctx := context.Background() + + localHandler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + if ctx.Version() == "" { + t.Fatalf("Expected version, got empty string") + } + if ctx.RequestID() == "" { + t.Fatalf("Expected request-id, got empty string") + } + return nil + } + + handlerFunc := srv.handleWithGlobalMiddlewares(localHandler) + if err := handlerFunc(ctx, resp, req, map[string]string{}); err != nil { + t.Fatal(err) + } +} diff --git a/errors/server.go b/errors/server.go new file mode 100644 index 0000000000..9dfcc02b50 --- /dev/null +++ b/errors/server.go @@ -0,0 +1,27 @@ +package errors + +import ( + "net/http" + + "github.com/docker/distribution/registry/api/errcode" +) + +var ( + // ErrorCodeNewerClientVersion is generated when a request from a client + // specifies a higher version than the server supports. + ErrorCodeNewerClientVersion = errcode.Register(errGroup, errcode.ErrorDescriptor{ + Value: "NEWERCLIENTVERSION", + Message: "client is newer than server (client API version: %s, server API version: %s)", + Description: "The client version is higher than the server version", + HTTPStatusCode: http.StatusBadRequest, + }) + + // ErrorCodeOldClientVersion is generated when a request from a client + // specifies a version lower than the minimum version supported by the server. + ErrorCodeOldClientVersion = errcode.Register(errGroup, errcode.ErrorDescriptor{ + Value: "OLDCLIENTVERSION", + Message: "client version %s is too old. Minimum supported API version is %s, please upgrade your client to a newer version", + Description: "The client version is too old for the server", + HTTPStatusCode: http.StatusBadRequest, + }) +)