Merge pull request #16331 from calavera/server_middlewares
Organize server pre-func logic in middlewares.
This commit is contained in:
commit
9dc0973655
5 changed files with 274 additions and 59 deletions
130
api/server/middleware.go
Normal file
130
api/server/middleware.go
Normal file
|
@ -0,0 +1,130 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
"github.com/docker/docker/api"
|
||||
"github.com/docker/docker/autogen/dockerversion"
|
||||
"github.com/docker/docker/context"
|
||||
"github.com/docker/docker/errors"
|
||||
"github.com/docker/docker/pkg/stringid"
|
||||
"github.com/docker/docker/pkg/version"
|
||||
)
|
||||
|
||||
// middleware is an adapter to allow the use of ordinary functions as Docker API filters.
|
||||
// Any function that has the appropriate signature can be register as a middleware.
|
||||
type middleware func(handler HTTPAPIFunc) HTTPAPIFunc
|
||||
|
||||
// loggingMiddleware logs each request when logging is enabled.
|
||||
func (s *Server) loggingMiddleware(handler HTTPAPIFunc) HTTPAPIFunc {
|
||||
return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
|
||||
if s.cfg.Logging {
|
||||
logrus.Infof("%s %s", r.Method, r.RequestURI)
|
||||
}
|
||||
return handler(ctx, w, r, vars)
|
||||
}
|
||||
}
|
||||
|
||||
// requestIDMiddleware generates a uniq ID for each request.
|
||||
// This ID travels inside the context for tracing purposes.
|
||||
func requestIDMiddleware(handler HTTPAPIFunc) HTTPAPIFunc {
|
||||
return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
|
||||
reqID := stringid.TruncateID(stringid.GenerateNonCryptoID())
|
||||
ctx = context.WithValue(ctx, context.RequestID, reqID)
|
||||
return handler(ctx, w, r, vars)
|
||||
}
|
||||
}
|
||||
|
||||
// userAgentMiddleware checks the User-Agent header looking for a valid docker client spec.
|
||||
func (s *Server) userAgentMiddleware(handler HTTPAPIFunc) HTTPAPIFunc {
|
||||
return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
|
||||
if strings.Contains(r.Header.Get("User-Agent"), "Docker-Client/") {
|
||||
dockerVersion := version.Version(s.cfg.Version)
|
||||
|
||||
userAgent := strings.Split(r.Header.Get("User-Agent"), "/")
|
||||
|
||||
// v1.20 onwards includes the GOOS of the client after the version
|
||||
// such as Docker/1.7.0 (linux)
|
||||
if len(userAgent) == 2 && strings.Contains(userAgent[1], " ") {
|
||||
userAgent[1] = strings.Split(userAgent[1], " ")[0]
|
||||
}
|
||||
|
||||
if len(userAgent) == 2 && !dockerVersion.Equal(version.Version(userAgent[1])) {
|
||||
logrus.Debugf("Warning: client and server don't have the same version (client: %s, server: %s)", userAgent[1], dockerVersion)
|
||||
}
|
||||
}
|
||||
return handler(ctx, w, r, vars)
|
||||
}
|
||||
}
|
||||
|
||||
// corsMiddleware sets the CORS header expectations in the server.
|
||||
func (s *Server) corsMiddleware(handler HTTPAPIFunc) HTTPAPIFunc {
|
||||
return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
|
||||
// If "api-cors-header" is not given, but "api-enable-cors" is true, we set cors to "*"
|
||||
// otherwise, all head values will be passed to HTTP handler
|
||||
corsHeaders := s.cfg.CorsHeaders
|
||||
if corsHeaders == "" && s.cfg.EnableCors {
|
||||
corsHeaders = "*"
|
||||
}
|
||||
|
||||
if corsHeaders != "" {
|
||||
writeCorsHeaders(w, r, corsHeaders)
|
||||
}
|
||||
return handler(ctx, w, r, vars)
|
||||
}
|
||||
}
|
||||
|
||||
// versionMiddleware checks the api version requirements before passing the request to the server handler.
|
||||
func versionMiddleware(handler HTTPAPIFunc) HTTPAPIFunc {
|
||||
return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
|
||||
apiVersion := version.Version(vars["version"])
|
||||
if apiVersion == "" {
|
||||
apiVersion = api.Version
|
||||
}
|
||||
|
||||
if apiVersion.GreaterThan(api.Version) {
|
||||
return errors.ErrorCodeNewerClientVersion.WithArgs(apiVersion, api.Version)
|
||||
}
|
||||
if apiVersion.LessThan(api.MinVersion) {
|
||||
return errors.ErrorCodeOldClientVersion.WithArgs(apiVersion, api.Version)
|
||||
}
|
||||
|
||||
w.Header().Set("Server", "Docker/"+dockerversion.VERSION+" ("+runtime.GOOS+")")
|
||||
ctx = context.WithValue(ctx, context.APIVersion, apiVersion)
|
||||
return handler(ctx, w, r, vars)
|
||||
}
|
||||
}
|
||||
|
||||
// handleWithGlobalMiddlwares wraps the handler function for a request with
|
||||
// the server's global middlewares. The order of the middlewares is backwards,
|
||||
// meaning that the first in the list will be evaludated last.
|
||||
//
|
||||
// Example: handleWithGlobalMiddlewares(s.getContainersName)
|
||||
//
|
||||
// requestIDMiddlware(
|
||||
// s.loggingMiddleware(
|
||||
// s.userAgentMiddleware(
|
||||
// s.corsMiddleware(
|
||||
// versionMiddleware(s.getContainersName)
|
||||
// )
|
||||
// )
|
||||
// )
|
||||
// )
|
||||
func (s *Server) handleWithGlobalMiddlewares(handler HTTPAPIFunc) HTTPAPIFunc {
|
||||
middlewares := []middleware{
|
||||
versionMiddleware,
|
||||
s.corsMiddleware,
|
||||
s.userAgentMiddleware,
|
||||
s.loggingMiddleware,
|
||||
requestIDMiddleware,
|
||||
}
|
||||
|
||||
h := handler
|
||||
for _, m := range middlewares {
|
||||
h = m(h)
|
||||
}
|
||||
return h
|
||||
}
|
74
api/server/middleware_test.go
Normal file
74
api/server/middleware_test.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/docker/distribution/registry/api/errcode"
|
||||
"github.com/docker/docker/context"
|
||||
"github.com/docker/docker/errors"
|
||||
)
|
||||
|
||||
func TestVersionMiddleware(t *testing.T) {
|
||||
handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
|
||||
if ctx.Version() == "" {
|
||||
t.Fatalf("Expected version, got empty string")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
h := versionMiddleware(handler)
|
||||
|
||||
req, _ := http.NewRequest("GET", "/containers/json", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
ctx := context.Background()
|
||||
if err := h(ctx, resp, req, map[string]string{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersionMiddlewareWithErrors(t *testing.T) {
|
||||
handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
|
||||
if ctx.Version() == "" {
|
||||
t.Fatalf("Expected version, got empty string")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
h := versionMiddleware(handler)
|
||||
|
||||
req, _ := http.NewRequest("GET", "/containers/json", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
ctx := context.Background()
|
||||
|
||||
vars := map[string]string{"version": "0.1"}
|
||||
err := h(ctx, resp, req, vars)
|
||||
if derr, ok := err.(errcode.Error); !ok || derr.ErrorCode() != errors.ErrorCodeOldClientVersion {
|
||||
t.Fatalf("Expected ErrorCodeOldClientVersion, got %v", err)
|
||||
}
|
||||
|
||||
vars["version"] = "100000"
|
||||
err = h(ctx, resp, req, vars)
|
||||
if derr, ok := err.(errcode.Error); !ok || derr.ErrorCode() != errors.ErrorCodeNewerClientVersion {
|
||||
t.Fatalf("Expected ErrorCodeNewerClientVersion, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestIDMiddleware(t *testing.T) {
|
||||
handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
|
||||
if ctx.RequestID() == "" {
|
||||
t.Fatalf("Expected request-id, got empty string")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
h := requestIDMiddleware(handler)
|
||||
|
||||
req, _ := http.NewRequest("GET", "/containers/json", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
ctx := context.Background()
|
||||
if err := h(ctx, resp, req, map[string]string{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
|
@ -8,7 +8,6 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
@ -16,12 +15,9 @@ import (
|
|||
"github.com/Sirupsen/logrus"
|
||||
"github.com/docker/distribution/registry/api/errcode"
|
||||
"github.com/docker/docker/api"
|
||||
"github.com/docker/docker/autogen/dockerversion"
|
||||
"github.com/docker/docker/context"
|
||||
"github.com/docker/docker/daemon"
|
||||
"github.com/docker/docker/pkg/sockets"
|
||||
"github.com/docker/docker/pkg/stringid"
|
||||
"github.com/docker/docker/pkg/version"
|
||||
)
|
||||
|
||||
// Config provides the configuration for the API server
|
||||
|
@ -49,8 +45,7 @@ func New(cfg *Config) *Server {
|
|||
cfg: cfg,
|
||||
start: make(chan struct{}),
|
||||
}
|
||||
r := createRouter(srv)
|
||||
srv.router = r
|
||||
srv.router = createRouter(srv)
|
||||
return srv
|
||||
}
|
||||
|
||||
|
@ -294,8 +289,11 @@ func (s *Server) initTCPSocket(addr string) (l net.Listener, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func makeHTTPHandler(logging bool, localMethod string, localRoute string, handlerFunc HTTPAPIFunc, corsHeaders string, dockerVersion version.Version) http.HandlerFunc {
|
||||
func (s *Server) makeHTTPHandler(localMethod string, localRoute string, localHandler HTTPAPIFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// log the handler generation
|
||||
logrus.Debugf("Calling %s %s", localMethod, localRoute)
|
||||
|
||||
// Define the context that we'll pass around to share info
|
||||
// like the docker-request-id.
|
||||
//
|
||||
|
@ -303,51 +301,8 @@ func makeHTTPHandler(logging bool, localMethod string, localRoute string, handle
|
|||
// apply to all requests. Data that is specific to the
|
||||
// immediate function being called should still be passed
|
||||
// as 'args' on the function call.
|
||||
|
||||
reqID := stringid.TruncateID(stringid.GenerateNonCryptoID())
|
||||
apiVersion := version.Version(mux.Vars(r)["version"])
|
||||
if apiVersion == "" {
|
||||
apiVersion = api.Version
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, context.RequestID, reqID)
|
||||
ctx = context.WithValue(ctx, context.APIVersion, apiVersion)
|
||||
|
||||
// log the request
|
||||
logrus.Debugf("Calling %s %s", localMethod, localRoute)
|
||||
|
||||
if logging {
|
||||
logrus.Infof("%s %s", r.Method, r.RequestURI)
|
||||
}
|
||||
|
||||
if strings.Contains(r.Header.Get("User-Agent"), "Docker-Client/") {
|
||||
userAgent := strings.Split(r.Header.Get("User-Agent"), "/")
|
||||
|
||||
// v1.20 onwards includes the GOOS of the client after the version
|
||||
// such as Docker/1.7.0 (linux)
|
||||
if len(userAgent) == 2 && strings.Contains(userAgent[1], " ") {
|
||||
userAgent[1] = strings.Split(userAgent[1], " ")[0]
|
||||
}
|
||||
|
||||
if len(userAgent) == 2 && !dockerVersion.Equal(version.Version(userAgent[1])) {
|
||||
logrus.Debugf("Warning: client and server don't have the same version (client: %s, server: %s)", userAgent[1], dockerVersion)
|
||||
}
|
||||
}
|
||||
if corsHeaders != "" {
|
||||
writeCorsHeaders(w, r, corsHeaders)
|
||||
}
|
||||
|
||||
if apiVersion.GreaterThan(api.Version) {
|
||||
http.Error(w, fmt.Errorf("client is newer than server (client API version: %s, server API version: %s)", apiVersion, api.Version).Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if apiVersion.LessThan(api.MinVersion) {
|
||||
http.Error(w, fmt.Errorf("client is too old, minimum supported API version is %s, please upgrade your client to a newer version", api.MinVersion).Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Server", "Docker/"+dockerversion.VERSION+" ("+runtime.GOOS+")")
|
||||
handlerFunc := s.handleWithGlobalMiddlewares(localHandler)
|
||||
|
||||
if err := handlerFunc(ctx, w, r, mux.Vars(r)); err != nil {
|
||||
logrus.Errorf("Handler for %s %s returned error: %s", localMethod, localRoute, err)
|
||||
|
@ -356,6 +311,7 @@ func makeHTTPHandler(logging bool, localMethod string, localRoute string, handle
|
|||
}
|
||||
}
|
||||
|
||||
// createRouter initializes the main router the server uses.
|
||||
// we keep enableCors just for legacy usage, need to be removed in the future
|
||||
func createRouter(s *Server) *mux.Router {
|
||||
r := mux.NewRouter()
|
||||
|
@ -428,13 +384,6 @@ func createRouter(s *Server) *mux.Router {
|
|||
},
|
||||
}
|
||||
|
||||
// If "api-cors-header" is not given, but "api-enable-cors" is true, we set cors to "*"
|
||||
// otherwise, all head values will be passed to HTTP handler
|
||||
corsHeaders := s.cfg.CorsHeaders
|
||||
if corsHeaders == "" && s.cfg.EnableCors {
|
||||
corsHeaders = "*"
|
||||
}
|
||||
|
||||
for method, routes := range m {
|
||||
for route, fct := range routes {
|
||||
logrus.Debugf("Registering %s, %s", method, route)
|
||||
|
@ -444,7 +393,7 @@ func createRouter(s *Server) *mux.Router {
|
|||
localMethod := method
|
||||
|
||||
// build the handler function
|
||||
f := makeHTTPHandler(s.cfg.Logging, localMethod, localRoute, localFct, corsHeaders, version.Version(s.cfg.Version))
|
||||
f := s.makeHTTPHandler(localMethod, localRoute, localFct)
|
||||
|
||||
// add the new route
|
||||
if localRoute == "" {
|
||||
|
|
35
api/server/server_test.go
Normal file
35
api/server/server_test.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/docker/docker/context"
|
||||
)
|
||||
|
||||
func TestMiddlewares(t *testing.T) {
|
||||
cfg := &Config{}
|
||||
srv := &Server{
|
||||
cfg: cfg,
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", "/containers/json", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
ctx := context.Background()
|
||||
|
||||
localHandler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
|
||||
if ctx.Version() == "" {
|
||||
t.Fatalf("Expected version, got empty string")
|
||||
}
|
||||
if ctx.RequestID() == "" {
|
||||
t.Fatalf("Expected request-id, got empty string")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
handlerFunc := srv.handleWithGlobalMiddlewares(localHandler)
|
||||
if err := handlerFunc(ctx, resp, req, map[string]string{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
27
errors/server.go
Normal file
27
errors/server.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package errors
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/docker/distribution/registry/api/errcode"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrorCodeNewerClientVersion is generated when a request from a client
|
||||
// specifies a higher version than the server supports.
|
||||
ErrorCodeNewerClientVersion = errcode.Register(errGroup, errcode.ErrorDescriptor{
|
||||
Value: "NEWERCLIENTVERSION",
|
||||
Message: "client is newer than server (client API version: %s, server API version: %s)",
|
||||
Description: "The client version is higher than the server version",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
})
|
||||
|
||||
// ErrorCodeOldClientVersion is generated when a request from a client
|
||||
// specifies a version lower than the minimum version supported by the server.
|
||||
ErrorCodeOldClientVersion = errcode.Register(errGroup, errcode.ErrorDescriptor{
|
||||
Value: "OLDCLIENTVERSION",
|
||||
Message: "client version %s is too old. Minimum supported API version is %s, please upgrade your client to a newer version",
|
||||
Description: "The client version is too old for the server",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
})
|
||||
)
|
Loading…
Reference in a new issue