瀏覽代碼

Always return version and server headers

If a 400 error is returned due to an API version mismatch, no
version and server-identification headers were returned by the API.

All information in these headers is "static", so there is no
reason to omit the information in case of an error being
returned.

This patch updates the version middleware to always
return the headers.

Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
Sebastiaan van Stijn 7 年之前
父節點
當前提交
e9dac5ef5e
共有 2 個文件被更改,包括 34 次插入5 次删除
  1. 4 5
      api/server/middleware/version.go
  2. 30 0
      api/server/middleware/version_test.go

+ 4 - 5
api/server/middleware/version.go

@@ -43,6 +43,10 @@ func (e versionUnsupportedError) InvalidParameter() {}
 // WrapHandler returns a new handler function wrapping the previous one in the request chain.
 // WrapHandler returns a new handler function wrapping the previous one in the request chain.
 func (v VersionMiddleware) WrapHandler(handler func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error) func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
 func (v VersionMiddleware) WrapHandler(handler func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error) func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
 	return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
 	return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+		w.Header().Set("Server", fmt.Sprintf("Docker/%s (%s)", v.serverVersion, runtime.GOOS))
+		w.Header().Set("API-Version", v.defaultVersion)
+		w.Header().Set("OSType", runtime.GOOS)
+
 		apiVersion := vars["version"]
 		apiVersion := vars["version"]
 		if apiVersion == "" {
 		if apiVersion == "" {
 			apiVersion = v.defaultVersion
 			apiVersion = v.defaultVersion
@@ -53,11 +57,6 @@ func (v VersionMiddleware) WrapHandler(handler func(ctx context.Context, w http.
 		if versions.GreaterThan(apiVersion, v.defaultVersion) {
 		if versions.GreaterThan(apiVersion, v.defaultVersion) {
 			return versionUnsupportedError{version: apiVersion, maxVersion: v.defaultVersion}
 			return versionUnsupportedError{version: apiVersion, maxVersion: v.defaultVersion}
 		}
 		}
-
-		header := fmt.Sprintf("Docker/%s (%s)", v.serverVersion, runtime.GOOS)
-		w.Header().Set("Server", header)
-		w.Header().Set("API-Version", v.defaultVersion)
-		w.Header().Set("OSType", runtime.GOOS)
 		// nolint: golint
 		// nolint: golint
 		ctx = context.WithValue(ctx, "api-version", apiVersion)
 		ctx = context.WithValue(ctx, "api-version", apiVersion)
 		return handler(ctx, w, r, vars)
 		return handler(ctx, w, r, vars)

+ 30 - 0
api/server/middleware/version_test.go

@@ -3,10 +3,12 @@ package middleware
 import (
 import (
 	"net/http"
 	"net/http"
 	"net/http/httptest"
 	"net/http/httptest"
+	"runtime"
 	"strings"
 	"strings"
 	"testing"
 	"testing"
 
 
 	"github.com/docker/docker/api/server/httputils"
 	"github.com/docker/docker/api/server/httputils"
+	"github.com/stretchr/testify/assert"
 	"golang.org/x/net/context"
 	"golang.org/x/net/context"
 )
 )
 
 
@@ -80,3 +82,31 @@ func TestVersionMiddlewareVersionTooNew(t *testing.T) {
 		t.Fatalf("Expected too new client error, got %v", err)
 		t.Fatalf("Expected too new client error, got %v", err)
 	}
 	}
 }
 }
+
+func TestVersionMiddlewareWithErrorsReturnsHeaders(t *testing.T) {
+	handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
+		if httputils.VersionFromContext(ctx) == "" {
+			t.Fatal("Expected version, got empty string")
+		}
+		return nil
+	}
+
+	defaultVersion := "1.10.0"
+	minVersion := "1.2.0"
+	m := NewVersionMiddleware(defaultVersion, defaultVersion, minVersion)
+	h := m.WrapHandler(handler)
+
+	req, _ := http.NewRequest("GET", "/containers/json", nil)
+	resp := httptest.NewRecorder()
+	ctx := context.Background()
+
+	vars := map[string]string{"version": "0.1"}
+	err := h(ctx, resp, req, vars)
+
+	assert.Error(t, err)
+	hdr := resp.Result().Header
+	assert.Contains(t, hdr.Get("Server"), "Docker/"+defaultVersion)
+	assert.Contains(t, hdr.Get("Server"), runtime.GOOS)
+	assert.Equal(t, hdr.Get("API-Version"), defaultVersion)
+	assert.Equal(t, hdr.Get("OSType"), runtime.GOOS)
+}