api/server/middleware: NewVersionMiddleware: add validation

Make sure the middleware cannot be initialized with out of range versions.

Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
This commit is contained in:
Sebastiaan van Stijn 2024-01-22 18:38:59 +01:00
parent e1897cbde4
commit 14503ccebd
No known key found for this signature in database
GPG key ID: 76698F39D527CE8C
4 changed files with 88 additions and 15 deletions

View file

@ -6,6 +6,7 @@ import (
"net/http"
"runtime"
"github.com/docker/docker/api"
"github.com/docker/docker/api/server/httputils"
"github.com/docker/docker/api/types/versions"
)
@ -32,12 +33,21 @@ type VersionMiddleware struct {
}
// NewVersionMiddleware creates a VersionMiddleware with the given versions.
func NewVersionMiddleware(serverVersion, defaultAPIVersion, minAPIVersion string) VersionMiddleware {
return VersionMiddleware{
func NewVersionMiddleware(serverVersion, defaultAPIVersion, minAPIVersion string) (*VersionMiddleware, error) {
if versions.LessThan(defaultAPIVersion, api.MinSupportedAPIVersion) || versions.GreaterThan(defaultAPIVersion, api.DefaultVersion) {
return nil, fmt.Errorf("invalid default API version (%s): must be between %s and %s", defaultAPIVersion, api.MinSupportedAPIVersion, api.DefaultVersion)
}
if versions.LessThan(minAPIVersion, api.MinSupportedAPIVersion) || versions.GreaterThan(minAPIVersion, api.DefaultVersion) {
return nil, fmt.Errorf("invalid minimum API version (%s): must be between %s and %s", minAPIVersion, api.MinSupportedAPIVersion, api.DefaultVersion)
}
if versions.GreaterThan(minAPIVersion, defaultAPIVersion) {
return nil, fmt.Errorf("invalid API version: the minimum API version (%s) is higher than the default version (%s)", minAPIVersion, defaultAPIVersion)
}
return &VersionMiddleware{
serverVersion: serverVersion,
defaultAPIVersion: defaultAPIVersion,
minAPIVersion: minAPIVersion,
}
}, nil
}
type versionUnsupportedError struct {

View file

@ -14,6 +14,60 @@ import (
is "gotest.tools/v3/assert/cmp"
)
func TestNewVersionMiddlewareValidation(t *testing.T) {
tests := []struct {
doc, defaultVersion, minVersion, expectedErr string
}{
{
doc: "defaults",
defaultVersion: api.DefaultVersion,
minVersion: api.MinSupportedAPIVersion,
},
{
doc: "invalid default lower than min",
defaultVersion: api.MinSupportedAPIVersion,
minVersion: api.DefaultVersion,
expectedErr: fmt.Sprintf("invalid API version: the minimum API version (%s) is higher than the default version (%s)", api.DefaultVersion, api.MinSupportedAPIVersion),
},
{
doc: "invalid default too low",
defaultVersion: "0.1",
minVersion: api.MinSupportedAPIVersion,
expectedErr: fmt.Sprintf("invalid default API version (0.1): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion),
},
{
doc: "invalid default too high",
defaultVersion: "9999.9999",
minVersion: api.DefaultVersion,
expectedErr: fmt.Sprintf("invalid default API version (9999.9999): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion),
},
{
doc: "invalid minimum too low",
defaultVersion: api.MinSupportedAPIVersion,
minVersion: "0.1",
expectedErr: fmt.Sprintf("invalid minimum API version (0.1): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion),
},
{
doc: "invalid minimum too high",
defaultVersion: api.DefaultVersion,
minVersion: "9999.9999",
expectedErr: fmt.Sprintf("invalid minimum API version (9999.9999): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion),
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.doc, func(t *testing.T) {
_, err := NewVersionMiddleware("1.2.3", tc.defaultVersion, tc.minVersion)
if tc.expectedErr == "" {
assert.Check(t, err)
} else {
assert.Check(t, is.Error(err, tc.expectedErr))
}
})
}
}
func TestVersionMiddlewareVersion(t *testing.T) {
expectedVersion := "<not set>"
handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
@ -22,7 +76,8 @@ func TestVersionMiddlewareVersion(t *testing.T) {
return nil
}
m := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion)
m, err := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion)
assert.NilError(t, err)
h := m.WrapHandler(handler)
req, _ := http.NewRequest(http.MethodGet, "/containers/json", nil)
@ -71,7 +126,8 @@ func TestVersionMiddlewareWithErrorsReturnsHeaders(t *testing.T) {
return nil
}
m := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion)
m, err := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion)
assert.NilError(t, err)
h := m.WrapHandler(handler)
req, _ := http.NewRequest(http.MethodGet, "/containers/json", nil)
@ -79,7 +135,7 @@ func TestVersionMiddlewareWithErrorsReturnsHeaders(t *testing.T) {
ctx := context.Background()
vars := map[string]string{"version": "0.1"}
err := h(ctx, resp, req, vars)
err = h(ctx, resp, req, vars)
assert.Check(t, is.ErrorContains(err, ""))
hdr := resp.Result().Header

View file

@ -15,8 +15,11 @@ import (
func TestMiddlewares(t *testing.T) {
srv := &Server{}
const apiMinVersion = "1.12"
srv.UseMiddleware(middleware.NewVersionMiddleware("0.1omega2", api.DefaultVersion, apiMinVersion))
m, err := middleware.NewVersionMiddleware("0.1omega2", api.DefaultVersion, api.MinSupportedAPIVersion)
if err != nil {
t.Fatal(err)
}
srv.UseMiddleware(*m)
req, _ := http.NewRequest(http.MethodGet, "/containers/json", nil)
resp := httptest.NewRecorder()

View file

@ -256,7 +256,10 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
pluginStore := plugin.NewStore()
var apiServer apiserver.Server
cli.authzMiddleware = initMiddlewares(&apiServer, cli.Config, pluginStore)
cli.authzMiddleware, err = initMiddlewares(&apiServer, cli.Config, pluginStore)
if err != nil {
return errors.Wrap(err, "failed to start API server")
}
d, err := daemon.NewDaemon(ctx, cli.Config, pluginStore, cli.authzMiddleware)
if err != nil {
@ -708,14 +711,15 @@ func (opts routerOptions) Build() []router.Router {
return routers
}
func initMiddlewares(s *apiserver.Server, cfg *config.Config, pluginStore plugingetter.PluginGetter) *authorization.Middleware {
v := dockerversion.Version
func initMiddlewares(s *apiserver.Server, cfg *config.Config, pluginStore plugingetter.PluginGetter) (*authorization.Middleware, error) {
exp := middleware.NewExperimentalMiddleware(cfg.Experimental)
s.UseMiddleware(exp)
vm := middleware.NewVersionMiddleware(v, api.DefaultVersion, cfg.MinAPIVersion)
s.UseMiddleware(vm)
vm, err := middleware.NewVersionMiddleware(dockerversion.Version, api.DefaultVersion, cfg.MinAPIVersion)
if err != nil {
return nil, err
}
s.UseMiddleware(*vm)
if cfg.CorsHeaders != "" {
c := middleware.NewCORSMiddleware(cfg.CorsHeaders)
@ -724,7 +728,7 @@ func initMiddlewares(s *apiserver.Server, cfg *config.Config, pluginStore plugin
authzMiddleware := authorization.NewMiddleware(cfg.AuthorizationPlugins, pluginStore)
s.UseMiddleware(authzMiddleware)
return authzMiddleware
return authzMiddleware, nil
}
func (cli *DaemonCli) getContainerdDaemonOpts() ([]supervisor.DaemonOpt, error) {