diff --git a/api/server/middleware/version.go b/api/server/middleware/version.go index 18689fdece..6bd181ffeb 100644 --- a/api/server/middleware/version.go +++ b/api/server/middleware/version.go @@ -6,6 +6,7 @@ import ( "net/http" "runtime" + "github.com/docker/docker/api" "github.com/docker/docker/api/server/httputils" "github.com/docker/docker/api/types/versions" ) @@ -32,12 +33,21 @@ type VersionMiddleware struct { } // NewVersionMiddleware creates a VersionMiddleware with the given versions. -func NewVersionMiddleware(serverVersion, defaultAPIVersion, minAPIVersion string) VersionMiddleware { - return VersionMiddleware{ +func NewVersionMiddleware(serverVersion, defaultAPIVersion, minAPIVersion string) (*VersionMiddleware, error) { + if versions.LessThan(defaultAPIVersion, api.MinSupportedAPIVersion) || versions.GreaterThan(defaultAPIVersion, api.DefaultVersion) { + return nil, fmt.Errorf("invalid default API version (%s): must be between %s and %s", defaultAPIVersion, api.MinSupportedAPIVersion, api.DefaultVersion) + } + if versions.LessThan(minAPIVersion, api.MinSupportedAPIVersion) || versions.GreaterThan(minAPIVersion, api.DefaultVersion) { + return nil, fmt.Errorf("invalid minimum API version (%s): must be between %s and %s", minAPIVersion, api.MinSupportedAPIVersion, api.DefaultVersion) + } + if versions.GreaterThan(minAPIVersion, defaultAPIVersion) { + return nil, fmt.Errorf("invalid API version: the minimum API version (%s) is higher than the default version (%s)", minAPIVersion, defaultAPIVersion) + } + return &VersionMiddleware{ serverVersion: serverVersion, defaultAPIVersion: defaultAPIVersion, minAPIVersion: minAPIVersion, - } + }, nil } type versionUnsupportedError struct { diff --git a/api/server/middleware/version_test.go b/api/server/middleware/version_test.go index 74c358b359..1c7888bd95 100644 --- a/api/server/middleware/version_test.go +++ b/api/server/middleware/version_test.go @@ -14,6 +14,60 @@ import ( is "gotest.tools/v3/assert/cmp" ) +func TestNewVersionMiddlewareValidation(t *testing.T) { + tests := []struct { + doc, defaultVersion, minVersion, expectedErr string + }{ + { + doc: "defaults", + defaultVersion: api.DefaultVersion, + minVersion: api.MinSupportedAPIVersion, + }, + { + doc: "invalid default lower than min", + defaultVersion: api.MinSupportedAPIVersion, + minVersion: api.DefaultVersion, + expectedErr: fmt.Sprintf("invalid API version: the minimum API version (%s) is higher than the default version (%s)", api.DefaultVersion, api.MinSupportedAPIVersion), + }, + { + doc: "invalid default too low", + defaultVersion: "0.1", + minVersion: api.MinSupportedAPIVersion, + expectedErr: fmt.Sprintf("invalid default API version (0.1): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion), + }, + { + doc: "invalid default too high", + defaultVersion: "9999.9999", + minVersion: api.DefaultVersion, + expectedErr: fmt.Sprintf("invalid default API version (9999.9999): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion), + }, + { + doc: "invalid minimum too low", + defaultVersion: api.MinSupportedAPIVersion, + minVersion: "0.1", + expectedErr: fmt.Sprintf("invalid minimum API version (0.1): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion), + }, + { + doc: "invalid minimum too high", + defaultVersion: api.DefaultVersion, + minVersion: "9999.9999", + expectedErr: fmt.Sprintf("invalid minimum API version (9999.9999): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion), + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.doc, func(t *testing.T) { + _, err := NewVersionMiddleware("1.2.3", tc.defaultVersion, tc.minVersion) + if tc.expectedErr == "" { + assert.Check(t, err) + } else { + assert.Check(t, is.Error(err, tc.expectedErr)) + } + }) + } +} + func TestVersionMiddlewareVersion(t *testing.T) { expectedVersion := "" handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { @@ -22,7 +76,8 @@ func TestVersionMiddlewareVersion(t *testing.T) { return nil } - m := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion) + m, err := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion) + assert.NilError(t, err) h := m.WrapHandler(handler) req, _ := http.NewRequest(http.MethodGet, "/containers/json", nil) @@ -71,7 +126,8 @@ func TestVersionMiddlewareWithErrorsReturnsHeaders(t *testing.T) { return nil } - m := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion) + m, err := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion) + assert.NilError(t, err) h := m.WrapHandler(handler) req, _ := http.NewRequest(http.MethodGet, "/containers/json", nil) @@ -79,7 +135,7 @@ func TestVersionMiddlewareWithErrorsReturnsHeaders(t *testing.T) { ctx := context.Background() vars := map[string]string{"version": "0.1"} - err := h(ctx, resp, req, vars) + err = h(ctx, resp, req, vars) assert.Check(t, is.ErrorContains(err, "")) hdr := resp.Result().Header diff --git a/api/server/server_test.go b/api/server/server_test.go index 65dda95c9c..cf8b0568aa 100644 --- a/api/server/server_test.go +++ b/api/server/server_test.go @@ -15,8 +15,11 @@ import ( func TestMiddlewares(t *testing.T) { srv := &Server{} - const apiMinVersion = "1.12" - srv.UseMiddleware(middleware.NewVersionMiddleware("0.1omega2", api.DefaultVersion, apiMinVersion)) + m, err := middleware.NewVersionMiddleware("0.1omega2", api.DefaultVersion, api.MinSupportedAPIVersion) + if err != nil { + t.Fatal(err) + } + srv.UseMiddleware(*m) req, _ := http.NewRequest(http.MethodGet, "/containers/json", nil) resp := httptest.NewRecorder() diff --git a/cmd/dockerd/daemon.go b/cmd/dockerd/daemon.go index 9a8ffe0b0c..99c5aabbf8 100644 --- a/cmd/dockerd/daemon.go +++ b/cmd/dockerd/daemon.go @@ -256,7 +256,10 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) { pluginStore := plugin.NewStore() var apiServer apiserver.Server - cli.authzMiddleware = initMiddlewares(&apiServer, cli.Config, pluginStore) + cli.authzMiddleware, err = initMiddlewares(&apiServer, cli.Config, pluginStore) + if err != nil { + return errors.Wrap(err, "failed to start API server") + } d, err := daemon.NewDaemon(ctx, cli.Config, pluginStore, cli.authzMiddleware) if err != nil { @@ -708,14 +711,15 @@ func (opts routerOptions) Build() []router.Router { return routers } -func initMiddlewares(s *apiserver.Server, cfg *config.Config, pluginStore plugingetter.PluginGetter) *authorization.Middleware { - v := dockerversion.Version - +func initMiddlewares(s *apiserver.Server, cfg *config.Config, pluginStore plugingetter.PluginGetter) (*authorization.Middleware, error) { exp := middleware.NewExperimentalMiddleware(cfg.Experimental) s.UseMiddleware(exp) - vm := middleware.NewVersionMiddleware(v, api.DefaultVersion, cfg.MinAPIVersion) - s.UseMiddleware(vm) + vm, err := middleware.NewVersionMiddleware(dockerversion.Version, api.DefaultVersion, cfg.MinAPIVersion) + if err != nil { + return nil, err + } + s.UseMiddleware(*vm) if cfg.CorsHeaders != "" { c := middleware.NewCORSMiddleware(cfg.CorsHeaders) @@ -724,7 +728,7 @@ func initMiddlewares(s *apiserver.Server, cfg *config.Config, pluginStore plugin authzMiddleware := authorization.NewMiddleware(cfg.AuthorizationPlugins, pluginStore) s.UseMiddleware(authzMiddleware) - return authzMiddleware + return authzMiddleware, nil } func (cli *DaemonCli) getContainerdDaemonOpts() ([]supervisor.DaemonOpt, error) {