Browse code

api/server/middleware: NewVersionMiddleware: add validation

Make sure the middleware cannot be initialized with out of range versions.

Signed-off-by: Sebastiaan van Stijn <github@gone.nl>

Sebastiaan van Stijn authored on 2024/01/23 02:38:59
Showing 4 changed files
... ...
@@ -6,6 +6,7 @@ import (
6 6
 	"net/http"
7 7
 	"runtime"
8 8
 
9
+	"github.com/docker/docker/api"
9 10
 	"github.com/docker/docker/api/server/httputils"
10 11
 	"github.com/docker/docker/api/types/versions"
11 12
 )
... ...
@@ -32,12 +33,21 @@ type VersionMiddleware struct {
32 32
 }
33 33
 
34 34
 // NewVersionMiddleware creates a VersionMiddleware with the given versions.
35
-func NewVersionMiddleware(serverVersion, defaultAPIVersion, minAPIVersion string) VersionMiddleware {
36
-	return VersionMiddleware{
35
+func NewVersionMiddleware(serverVersion, defaultAPIVersion, minAPIVersion string) (*VersionMiddleware, error) {
36
+	if versions.LessThan(defaultAPIVersion, api.MinSupportedAPIVersion) || versions.GreaterThan(defaultAPIVersion, api.DefaultVersion) {
37
+		return nil, fmt.Errorf("invalid default API version (%s): must be between %s and %s", defaultAPIVersion, api.MinSupportedAPIVersion, api.DefaultVersion)
38
+	}
39
+	if versions.LessThan(minAPIVersion, api.MinSupportedAPIVersion) || versions.GreaterThan(minAPIVersion, api.DefaultVersion) {
40
+		return nil, fmt.Errorf("invalid minimum API version (%s): must be between %s and %s", minAPIVersion, api.MinSupportedAPIVersion, api.DefaultVersion)
41
+	}
42
+	if versions.GreaterThan(minAPIVersion, defaultAPIVersion) {
43
+		return nil, fmt.Errorf("invalid API version: the minimum API version (%s) is higher than the default version (%s)", minAPIVersion, defaultAPIVersion)
44
+	}
45
+	return &VersionMiddleware{
37 46
 		serverVersion:     serverVersion,
38 47
 		defaultAPIVersion: defaultAPIVersion,
39 48
 		minAPIVersion:     minAPIVersion,
40
-	}
49
+	}, nil
41 50
 }
42 51
 
43 52
 type versionUnsupportedError struct {
... ...
@@ -14,6 +14,60 @@ import (
14 14
 	is "gotest.tools/v3/assert/cmp"
15 15
 )
16 16
 
17
+func TestNewVersionMiddlewareValidation(t *testing.T) {
18
+	tests := []struct {
19
+		doc, defaultVersion, minVersion, expectedErr string
20
+	}{
21
+		{
22
+			doc:            "defaults",
23
+			defaultVersion: api.DefaultVersion,
24
+			minVersion:     api.MinSupportedAPIVersion,
25
+		},
26
+		{
27
+			doc:            "invalid default lower than min",
28
+			defaultVersion: api.MinSupportedAPIVersion,
29
+			minVersion:     api.DefaultVersion,
30
+			expectedErr:    fmt.Sprintf("invalid API version: the minimum API version (%s) is higher than the default version (%s)", api.DefaultVersion, api.MinSupportedAPIVersion),
31
+		},
32
+		{
33
+			doc:            "invalid default too low",
34
+			defaultVersion: "0.1",
35
+			minVersion:     api.MinSupportedAPIVersion,
36
+			expectedErr:    fmt.Sprintf("invalid default API version (0.1): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion),
37
+		},
38
+		{
39
+			doc:            "invalid default too high",
40
+			defaultVersion: "9999.9999",
41
+			minVersion:     api.DefaultVersion,
42
+			expectedErr:    fmt.Sprintf("invalid default API version (9999.9999): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion),
43
+		},
44
+		{
45
+			doc:            "invalid minimum too low",
46
+			defaultVersion: api.MinSupportedAPIVersion,
47
+			minVersion:     "0.1",
48
+			expectedErr:    fmt.Sprintf("invalid minimum API version (0.1): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion),
49
+		},
50
+		{
51
+			doc:            "invalid minimum too high",
52
+			defaultVersion: api.DefaultVersion,
53
+			minVersion:     "9999.9999",
54
+			expectedErr:    fmt.Sprintf("invalid minimum API version (9999.9999): must be between %s and %s", api.MinSupportedAPIVersion, api.DefaultVersion),
55
+		},
56
+	}
57
+
58
+	for _, tc := range tests {
59
+		tc := tc
60
+		t.Run(tc.doc, func(t *testing.T) {
61
+			_, err := NewVersionMiddleware("1.2.3", tc.defaultVersion, tc.minVersion)
62
+			if tc.expectedErr == "" {
63
+				assert.Check(t, err)
64
+			} else {
65
+				assert.Check(t, is.Error(err, tc.expectedErr))
66
+			}
67
+		})
68
+	}
69
+}
70
+
17 71
 func TestVersionMiddlewareVersion(t *testing.T) {
18 72
 	expectedVersion := "<not set>"
19 73
 	handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
... ...
@@ -22,7 +76,8 @@ func TestVersionMiddlewareVersion(t *testing.T) {
22 22
 		return nil
23 23
 	}
24 24
 
25
-	m := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion)
25
+	m, err := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion)
26
+	assert.NilError(t, err)
26 27
 	h := m.WrapHandler(handler)
27 28
 
28 29
 	req, _ := http.NewRequest(http.MethodGet, "/containers/json", nil)
... ...
@@ -71,7 +126,8 @@ func TestVersionMiddlewareWithErrorsReturnsHeaders(t *testing.T) {
71 71
 		return nil
72 72
 	}
73 73
 
74
-	m := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion)
74
+	m, err := NewVersionMiddleware("1.2.3", api.DefaultVersion, api.MinSupportedAPIVersion)
75
+	assert.NilError(t, err)
75 76
 	h := m.WrapHandler(handler)
76 77
 
77 78
 	req, _ := http.NewRequest(http.MethodGet, "/containers/json", nil)
... ...
@@ -79,7 +135,7 @@ func TestVersionMiddlewareWithErrorsReturnsHeaders(t *testing.T) {
79 79
 	ctx := context.Background()
80 80
 
81 81
 	vars := map[string]string{"version": "0.1"}
82
-	err := h(ctx, resp, req, vars)
82
+	err = h(ctx, resp, req, vars)
83 83
 	assert.Check(t, is.ErrorContains(err, ""))
84 84
 
85 85
 	hdr := resp.Result().Header
... ...
@@ -15,8 +15,11 @@ import (
15 15
 func TestMiddlewares(t *testing.T) {
16 16
 	srv := &Server{}
17 17
 
18
-	const apiMinVersion = "1.12"
19
-	srv.UseMiddleware(middleware.NewVersionMiddleware("0.1omega2", api.DefaultVersion, apiMinVersion))
18
+	m, err := middleware.NewVersionMiddleware("0.1omega2", api.DefaultVersion, api.MinSupportedAPIVersion)
19
+	if err != nil {
20
+		t.Fatal(err)
21
+	}
22
+	srv.UseMiddleware(*m)
20 23
 
21 24
 	req, _ := http.NewRequest(http.MethodGet, "/containers/json", nil)
22 25
 	resp := httptest.NewRecorder()
... ...
@@ -256,7 +256,10 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
256 256
 	pluginStore := plugin.NewStore()
257 257
 
258 258
 	var apiServer apiserver.Server
259
-	cli.authzMiddleware = initMiddlewares(&apiServer, cli.Config, pluginStore)
259
+	cli.authzMiddleware, err = initMiddlewares(&apiServer, cli.Config, pluginStore)
260
+	if err != nil {
261
+		return errors.Wrap(err, "failed to start API server")
262
+	}
260 263
 
261 264
 	d, err := daemon.NewDaemon(ctx, cli.Config, pluginStore, cli.authzMiddleware)
262 265
 	if err != nil {
... ...
@@ -708,14 +711,15 @@ func (opts routerOptions) Build() []router.Router {
708 708
 	return routers
709 709
 }
710 710
 
711
-func initMiddlewares(s *apiserver.Server, cfg *config.Config, pluginStore plugingetter.PluginGetter) *authorization.Middleware {
712
-	v := dockerversion.Version
713
-
711
+func initMiddlewares(s *apiserver.Server, cfg *config.Config, pluginStore plugingetter.PluginGetter) (*authorization.Middleware, error) {
714 712
 	exp := middleware.NewExperimentalMiddleware(cfg.Experimental)
715 713
 	s.UseMiddleware(exp)
716 714
 
717
-	vm := middleware.NewVersionMiddleware(v, api.DefaultVersion, cfg.MinAPIVersion)
718
-	s.UseMiddleware(vm)
715
+	vm, err := middleware.NewVersionMiddleware(dockerversion.Version, api.DefaultVersion, cfg.MinAPIVersion)
716
+	if err != nil {
717
+		return nil, err
718
+	}
719
+	s.UseMiddleware(*vm)
719 720
 
720 721
 	if cfg.CorsHeaders != "" {
721 722
 		c := middleware.NewCORSMiddleware(cfg.CorsHeaders)
... ...
@@ -724,7 +728,7 @@ func initMiddlewares(s *apiserver.Server, cfg *config.Config, pluginStore plugin
724 724
 
725 725
 	authzMiddleware := authorization.NewMiddleware(cfg.AuthorizationPlugins, pluginStore)
726 726
 	s.UseMiddleware(authzMiddleware)
727
-	return authzMiddleware
727
+	return authzMiddleware, nil
728 728
 }
729 729
 
730 730
 func (cli *DaemonCli) getContainerdDaemonOpts() ([]supervisor.DaemonOpt, error) {