Quellcode durchsuchen

api/server/middleware: NewVersionMiddleware: add validation

Make sure the middleware cannot be initialized with out of range versions.

Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
Sebastiaan van Stijn vor 1 Jahr
Ursprung
Commit
14503ccebd

+ 13 - 3
api/server/middleware/version.go

@@ -6,6 +6,7 @@ import (
 	"net/http"
 	"runtime"
 
+	"github.com/docker/docker/api"
 	"github.com/docker/docker/api/server/httputils"
 	"github.com/docker/docker/api/types/versions"
 )
@@ -32,12 +33,21 @@ type VersionMiddleware struct {
 }
 
 // NewVersionMiddleware creates a VersionMiddleware with the given versions.
-func NewVersionMiddleware(serverVersion, defaultAPIVersion, minAPIVersion string) VersionMiddleware {
-	return VersionMiddleware{
+func NewVersionMiddleware(serverVersion, defaultAPIVersion, minAPIVersion string) (*VersionMiddleware, error) {
+	if versions.LessThan(defaultAPIVersion, api.MinSupportedAPIVersion) || versions.GreaterThan(defaultAPIVersion, api.DefaultVersion) {
+		return nil, fmt.Errorf("invalid default API version (%s): must be between %s and %s", defaultAPIVersion, api.MinSupportedAPIVersion, api.DefaultVersion)
+	}
+	if versions.LessThan(minAPIVersion, api.MinSupportedAPIVersion) || versions.GreaterThan(minAPIVersion, api.DefaultVersion) {
+		return nil, fmt.Errorf("invalid minimum API version (%s): must be between %s and %s", minAPIVersion, api.MinSupportedAPIVersion, api.DefaultVersion)
+	}
+	if versions.GreaterThan(minAPIVersion, defaultAPIVersion) {
+		return nil, fmt.Errorf("invalid API version: the minimum API version (%s) is higher than the default version (%s)", minAPIVersion, defaultAPIVersion)
+	}
+	return &VersionMiddleware{
 		serverVersion:     serverVersion,
 		defaultAPIVersion: defaultAPIVersion,
 		minAPIVersion:     minAPIVersion,
-	}
+	}, nil
 }
 
 type versionUnsupportedError struct {

+ 59 - 3
api/server/middleware/version_test.go

@@ -14,6 +14,60 @@ import (
 	is "gotest.tools/v3/assert/cmp"
 )
 
+func TestNewVersionMiddlewareValidation(t *testing.T) {
+	tests := []struct {
+		doc, defaultVersion, minVersion, expectedErr string
+	}{
+		{
+			doc:            "defaults",
+			defaultVersion: api.DefaultVersion,
+			minVersion:     api.MinSupportedAPIVersion,
+		},
+		{
+			doc:            "invalid default lower than min",
+			defaultVersion: api.MinSupportedAPIVersion,
+			minVersion:     api.DefaultVersion,
+			expectedErr:    fmt.Sprintf("invalid API version: the minimum API version (%s) is higher than the default version (%s)", api.DefaultVersion, api.MinSupportedAPIVersion),
+		},
+		{
+			doc:            "invalid default too low",
+			defaultVersion: "0.1",
+			minVersion:     api.MinSupportedAPIVersion,
+			expectedErr:    fmt.Sprintf("invalid default API version (0.1): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion),
+		},
+		{
+			doc:            "invalid default too high",
+			defaultVersion: "9999.9999",
+			minVersion:     api.DefaultVersion,
+			expectedErr:    fmt.Sprintf("invalid default API version (9999.9999): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion),
+		},
+		{
+			doc:            "invalid minimum too low",
+			defaultVersion: api.MinSupportedAPIVersion,
+			minVersion:     "0.1",
+			expectedErr:    fmt.Sprintf("invalid minimum API version (0.1): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion),
+		},
+		{
+			doc:            "invalid minimum too high",
+			defaultVersion: api.DefaultVersion,
+			minVersion:     "9999.9999",
+			expectedErr:    fmt.Sprintf("invalid minimum API version (9999.9999): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion),
+		},
+	}
+
+	for _, tc := range tests {
+		tc := tc
+		t.Run(tc.doc, func(t *testing.T) {
+			_, err := NewVersionMiddleware("1.2.3", tc.defaultVersion, tc.minVersion)
+			if tc.expectedErr == "" {
+				assert.Check(t, err)
+			} else {
+				assert.Check(t, is.Error(err, tc.expectedErr))
+			}
+		})
+	}
+}
+
 func TestVersionMiddlewareVersion(t *testing.T) {
 	expectedVersion := "<not set>"
 	handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
@@ -22,7 +76,8 @@ func TestVersionMiddlewareVersion(t *testing.T) {
 		return nil
 	}
 
-	m := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion)
+	m, err := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion)
+	assert.NilError(t, err)
 	h := m.WrapHandler(handler)
 
 	req, _ := http.NewRequest(http.MethodGet, "/containers/json", nil)
@@ -71,7 +126,8 @@ func TestVersionMiddlewareWithErrorsReturnsHeaders(t *testing.T) {
 		return nil
 	}
 
-	m := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion)
+	m, err := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion)
+	assert.NilError(t, err)
 	h := m.WrapHandler(handler)
 
 	req, _ := http.NewRequest(http.MethodGet, "/containers/json", nil)
@@ -79,7 +135,7 @@ func TestVersionMiddlewareWithErrorsReturnsHeaders(t *testing.T) {
 	ctx := context.Background()
 
 	vars := map[string]string{"version": "0.1"}
-	err := h(ctx, resp, req, vars)
+	err = h(ctx, resp, req, vars)
 	assert.Check(t, is.ErrorContains(err, ""))
 
 	hdr := resp.Result().Header

+ 5 - 2
api/server/server_test.go

@@ -15,8 +15,11 @@ import (
 func TestMiddlewares(t *testing.T) {
 	srv := &Server{}
 
-	const apiMinVersion = "1.12"
-	srv.UseMiddleware(middleware.NewVersionMiddleware("0.1omega2", api.DefaultVersion, apiMinVersion))
+	m, err := middleware.NewVersionMiddleware("0.1omega2", api.DefaultVersion, api.MinSupportedAPIVersion)
+	if err != nil {
+		t.Fatal(err)
+	}
+	srv.UseMiddleware(*m)
 
 	req, _ := http.NewRequest(http.MethodGet, "/containers/json", nil)
 	resp := httptest.NewRecorder()

+ 11 - 7
cmd/dockerd/daemon.go

@@ -256,7 +256,10 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
 	pluginStore := plugin.NewStore()
 
 	var apiServer apiserver.Server
-	cli.authzMiddleware = initMiddlewares(&apiServer, cli.Config, pluginStore)
+	cli.authzMiddleware, err = initMiddlewares(&apiServer, cli.Config, pluginStore)
+	if err != nil {
+		return errors.Wrap(err, "failed to start API server")
+	}
 
 	d, err := daemon.NewDaemon(ctx, cli.Config, pluginStore, cli.authzMiddleware)
 	if err != nil {
@@ -708,14 +711,15 @@ func (opts routerOptions) Build() []router.Router {
 	return routers
 }
 
-func initMiddlewares(s *apiserver.Server, cfg *config.Config, pluginStore plugingetter.PluginGetter) *authorization.Middleware {
-	v := dockerversion.Version
-
+func initMiddlewares(s *apiserver.Server, cfg *config.Config, pluginStore plugingetter.PluginGetter) (*authorization.Middleware, error) {
 	exp := middleware.NewExperimentalMiddleware(cfg.Experimental)
 	s.UseMiddleware(exp)
 
-	vm := middleware.NewVersionMiddleware(v, api.DefaultVersion, cfg.MinAPIVersion)
-	s.UseMiddleware(vm)
+	vm, err := middleware.NewVersionMiddleware(dockerversion.Version, api.DefaultVersion, cfg.MinAPIVersion)
+	if err != nil {
+		return nil, err
+	}
+	s.UseMiddleware(*vm)
 
 	if cfg.CorsHeaders != "" {
 		c := middleware.NewCORSMiddleware(cfg.CorsHeaders)
@@ -724,7 +728,7 @@ func initMiddlewares(s *apiserver.Server, cfg *config.Config, pluginStore plugin
 
 	authzMiddleware := authorization.NewMiddleware(cfg.AuthorizationPlugins, pluginStore)
 	s.UseMiddleware(authzMiddleware)
-	return authzMiddleware
+	return authzMiddleware, nil
 }
 
 func (cli *DaemonCli) getContainerdDaemonOpts() ([]supervisor.DaemonOpt, error) {