Browse code

Organize server pre-func logic in middlewares.

It defines global middlewares for every request.
This makes the server slightly more composable.

Signed-off-by: David Calavera <david.calavera@gmail.com>

David Calavera authored on 2015/09/16 08:01:49
Showing 5 changed files
1 1
new file mode 100644
... ...
@@ -0,0 +1,130 @@
0
+package server
1
+
2
+import (
3
+	"net/http"
4
+	"runtime"
5
+	"strings"
6
+
7
+	"github.com/Sirupsen/logrus"
8
+	"github.com/docker/docker/api"
9
+	"github.com/docker/docker/autogen/dockerversion"
10
+	"github.com/docker/docker/context"
11
+	"github.com/docker/docker/errors"
12
+	"github.com/docker/docker/pkg/stringid"
13
+	"github.com/docker/docker/pkg/version"
14
+)
15
+
16
+// middleware is an adapter to allow the use of ordinary functions as Docker API filters.
17
+// Any function that has the appropriate signature can be register as a middleware.
18
+type middleware func(handler HTTPAPIFunc) HTTPAPIFunc
19
+
20
+// loggingMiddleware logs each request when logging is enabled.
21
+func (s *Server) loggingMiddleware(handler HTTPAPIFunc) HTTPAPIFunc {
22
+	return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
23
+		if s.cfg.Logging {
24
+			logrus.Infof("%s %s", r.Method, r.RequestURI)
25
+		}
26
+		return handler(ctx, w, r, vars)
27
+	}
28
+}
29
+
30
+// requestIDMiddleware generates a uniq ID for each request.
31
+// This ID travels inside the context for tracing purposes.
32
+func requestIDMiddleware(handler HTTPAPIFunc) HTTPAPIFunc {
33
+	return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
34
+		reqID := stringid.TruncateID(stringid.GenerateNonCryptoID())
35
+		ctx = context.WithValue(ctx, context.RequestID, reqID)
36
+		return handler(ctx, w, r, vars)
37
+	}
38
+}
39
+
40
+// userAgentMiddleware checks the User-Agent header looking for a valid docker client spec.
41
+func (s *Server) userAgentMiddleware(handler HTTPAPIFunc) HTTPAPIFunc {
42
+	return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
43
+		if strings.Contains(r.Header.Get("User-Agent"), "Docker-Client/") {
44
+			dockerVersion := version.Version(s.cfg.Version)
45
+
46
+			userAgent := strings.Split(r.Header.Get("User-Agent"), "/")
47
+
48
+			// v1.20 onwards includes the GOOS of the client after the version
49
+			// such as Docker/1.7.0 (linux)
50
+			if len(userAgent) == 2 && strings.Contains(userAgent[1], " ") {
51
+				userAgent[1] = strings.Split(userAgent[1], " ")[0]
52
+			}
53
+
54
+			if len(userAgent) == 2 && !dockerVersion.Equal(version.Version(userAgent[1])) {
55
+				logrus.Debugf("Warning: client and server don't have the same version (client: %s, server: %s)", userAgent[1], dockerVersion)
56
+			}
57
+		}
58
+		return handler(ctx, w, r, vars)
59
+	}
60
+}
61
+
62
+// corsMiddleware sets the CORS header expectations in the server.
63
+func (s *Server) corsMiddleware(handler HTTPAPIFunc) HTTPAPIFunc {
64
+	return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
65
+		// If "api-cors-header" is not given, but "api-enable-cors" is true, we set cors to "*"
66
+		// otherwise, all head values will be passed to HTTP handler
67
+		corsHeaders := s.cfg.CorsHeaders
68
+		if corsHeaders == "" && s.cfg.EnableCors {
69
+			corsHeaders = "*"
70
+		}
71
+
72
+		if corsHeaders != "" {
73
+			writeCorsHeaders(w, r, corsHeaders)
74
+		}
75
+		return handler(ctx, w, r, vars)
76
+	}
77
+}
78
+
79
+// versionMiddleware checks the api version requirements before passing the request to the server handler.
80
+func versionMiddleware(handler HTTPAPIFunc) HTTPAPIFunc {
81
+	return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
82
+		apiVersion := version.Version(vars["version"])
83
+		if apiVersion == "" {
84
+			apiVersion = api.Version
85
+		}
86
+
87
+		if apiVersion.GreaterThan(api.Version) {
88
+			return errors.ErrorCodeNewerClientVersion.WithArgs(apiVersion, api.Version)
89
+		}
90
+		if apiVersion.LessThan(api.MinVersion) {
91
+			return errors.ErrorCodeOldClientVersion.WithArgs(apiVersion, api.Version)
92
+		}
93
+
94
+		w.Header().Set("Server", "Docker/"+dockerversion.VERSION+" ("+runtime.GOOS+")")
95
+		ctx = context.WithValue(ctx, context.APIVersion, apiVersion)
96
+		return handler(ctx, w, r, vars)
97
+	}
98
+}
99
+
100
+// handleWithGlobalMiddlwares wraps the handler function for a request with
101
+// the server's global middlewares. The order of the middlewares is backwards,
102
+// meaning that the first in the list will be evaludated last.
103
+//
104
+// Example: handleWithGlobalMiddlewares(s.getContainersName)
105
+//
106
+// requestIDMiddlware(
107
+//	s.loggingMiddleware(
108
+//		s.userAgentMiddleware(
109
+//			s.corsMiddleware(
110
+//				versionMiddleware(s.getContainersName)
111
+//			)
112
+//		)
113
+//	)
114
+// )
115
+func (s *Server) handleWithGlobalMiddlewares(handler HTTPAPIFunc) HTTPAPIFunc {
116
+	middlewares := []middleware{
117
+		versionMiddleware,
118
+		s.corsMiddleware,
119
+		s.userAgentMiddleware,
120
+		s.loggingMiddleware,
121
+		requestIDMiddleware,
122
+	}
123
+
124
+	h := handler
125
+	for _, m := range middlewares {
126
+		h = m(h)
127
+	}
128
+	return h
129
+}
0 130
new file mode 100644
... ...
@@ -0,0 +1,74 @@
0
+package server
1
+
2
+import (
3
+	"net/http"
4
+	"net/http/httptest"
5
+	"testing"
6
+
7
+	"github.com/docker/distribution/registry/api/errcode"
8
+	"github.com/docker/docker/context"
9
+	"github.com/docker/docker/errors"
10
+)
11
+
12
+func TestVersionMiddleware(t *testing.T) {
13
+	handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
14
+		if ctx.Version() == "" {
15
+			t.Fatalf("Expected version, got empty string")
16
+		}
17
+		return nil
18
+	}
19
+
20
+	h := versionMiddleware(handler)
21
+
22
+	req, _ := http.NewRequest("GET", "/containers/json", nil)
23
+	resp := httptest.NewRecorder()
24
+	ctx := context.Background()
25
+	if err := h(ctx, resp, req, map[string]string{}); err != nil {
26
+		t.Fatal(err)
27
+	}
28
+}
29
+
30
+func TestVersionMiddlewareWithErrors(t *testing.T) {
31
+	handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
32
+		if ctx.Version() == "" {
33
+			t.Fatalf("Expected version, got empty string")
34
+		}
35
+		return nil
36
+	}
37
+
38
+	h := versionMiddleware(handler)
39
+
40
+	req, _ := http.NewRequest("GET", "/containers/json", nil)
41
+	resp := httptest.NewRecorder()
42
+	ctx := context.Background()
43
+
44
+	vars := map[string]string{"version": "0.1"}
45
+	err := h(ctx, resp, req, vars)
46
+	if derr, ok := err.(errcode.Error); !ok || derr.ErrorCode() != errors.ErrorCodeOldClientVersion {
47
+		t.Fatalf("Expected ErrorCodeOldClientVersion, got %v", err)
48
+	}
49
+
50
+	vars["version"] = "100000"
51
+	err = h(ctx, resp, req, vars)
52
+	if derr, ok := err.(errcode.Error); !ok || derr.ErrorCode() != errors.ErrorCodeNewerClientVersion {
53
+		t.Fatalf("Expected ErrorCodeNewerClientVersion, got %v", err)
54
+	}
55
+}
56
+
57
+func TestRequestIDMiddleware(t *testing.T) {
58
+	handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
59
+		if ctx.RequestID() == "" {
60
+			t.Fatalf("Expected request-id, got empty string")
61
+		}
62
+		return nil
63
+	}
64
+
65
+	h := requestIDMiddleware(handler)
66
+
67
+	req, _ := http.NewRequest("GET", "/containers/json", nil)
68
+	resp := httptest.NewRecorder()
69
+	ctx := context.Background()
70
+	if err := h(ctx, resp, req, map[string]string{}); err != nil {
71
+		t.Fatal(err)
72
+	}
73
+}
... ...
@@ -8,7 +8,6 @@ import (
8 8
 	"net"
9 9
 	"net/http"
10 10
 	"os"
11
-	"runtime"
12 11
 	"strings"
13 12
 
14 13
 	"github.com/gorilla/mux"
... ...
@@ -16,12 +15,9 @@ import (
16 16
 	"github.com/Sirupsen/logrus"
17 17
 	"github.com/docker/distribution/registry/api/errcode"
18 18
 	"github.com/docker/docker/api"
19
-	"github.com/docker/docker/autogen/dockerversion"
20 19
 	"github.com/docker/docker/context"
21 20
 	"github.com/docker/docker/daemon"
22 21
 	"github.com/docker/docker/pkg/sockets"
23
-	"github.com/docker/docker/pkg/stringid"
24
-	"github.com/docker/docker/pkg/version"
25 22
 )
26 23
 
27 24
 // Config provides the configuration for the API server
... ...
@@ -49,8 +45,7 @@ func New(cfg *Config) *Server {
49 49
 		cfg:   cfg,
50 50
 		start: make(chan struct{}),
51 51
 	}
52
-	r := createRouter(srv)
53
-	srv.router = r
52
+	srv.router = createRouter(srv)
54 53
 	return srv
55 54
 }
56 55
 
... ...
@@ -294,8 +289,11 @@ func (s *Server) initTCPSocket(addr string) (l net.Listener, err error) {
294 294
 	return
295 295
 }
296 296
 
297
-func makeHTTPHandler(logging bool, localMethod string, localRoute string, handlerFunc HTTPAPIFunc, corsHeaders string, dockerVersion version.Version) http.HandlerFunc {
297
+func (s *Server) makeHTTPHandler(localMethod string, localRoute string, localHandler HTTPAPIFunc) http.HandlerFunc {
298 298
 	return func(w http.ResponseWriter, r *http.Request) {
299
+		// log the handler generation
300
+		logrus.Debugf("Calling %s %s", localMethod, localRoute)
301
+
299 302
 		// Define the context that we'll pass around to share info
300 303
 		// like the docker-request-id.
301 304
 		//
... ...
@@ -303,51 +301,8 @@ func makeHTTPHandler(logging bool, localMethod string, localRoute string, handle
303 303
 		// apply to all requests. Data that is specific to the
304 304
 		// immediate function being called should still be passed
305 305
 		// as 'args' on the function call.
306
-
307
-		reqID := stringid.TruncateID(stringid.GenerateNonCryptoID())
308
-		apiVersion := version.Version(mux.Vars(r)["version"])
309
-		if apiVersion == "" {
310
-			apiVersion = api.Version
311
-		}
312
-
313 306
 		ctx := context.Background()
314
-		ctx = context.WithValue(ctx, context.RequestID, reqID)
315
-		ctx = context.WithValue(ctx, context.APIVersion, apiVersion)
316
-
317
-		// log the request
318
-		logrus.Debugf("Calling %s %s", localMethod, localRoute)
319
-
320
-		if logging {
321
-			logrus.Infof("%s %s", r.Method, r.RequestURI)
322
-		}
323
-
324
-		if strings.Contains(r.Header.Get("User-Agent"), "Docker-Client/") {
325
-			userAgent := strings.Split(r.Header.Get("User-Agent"), "/")
326
-
327
-			// v1.20 onwards includes the GOOS of the client after the version
328
-			// such as Docker/1.7.0 (linux)
329
-			if len(userAgent) == 2 && strings.Contains(userAgent[1], " ") {
330
-				userAgent[1] = strings.Split(userAgent[1], " ")[0]
331
-			}
332
-
333
-			if len(userAgent) == 2 && !dockerVersion.Equal(version.Version(userAgent[1])) {
334
-				logrus.Debugf("Warning: client and server don't have the same version (client: %s, server: %s)", userAgent[1], dockerVersion)
335
-			}
336
-		}
337
-		if corsHeaders != "" {
338
-			writeCorsHeaders(w, r, corsHeaders)
339
-		}
340
-
341
-		if apiVersion.GreaterThan(api.Version) {
342
-			http.Error(w, fmt.Errorf("client is newer than server (client API version: %s, server API version: %s)", apiVersion, api.Version).Error(), http.StatusBadRequest)
343
-			return
344
-		}
345
-		if apiVersion.LessThan(api.MinVersion) {
346
-			http.Error(w, fmt.Errorf("client is too old, minimum supported API version is %s, please upgrade your client to a newer version", api.MinVersion).Error(), http.StatusBadRequest)
347
-			return
348
-		}
349
-
350
-		w.Header().Set("Server", "Docker/"+dockerversion.VERSION+" ("+runtime.GOOS+")")
307
+		handlerFunc := s.handleWithGlobalMiddlewares(localHandler)
351 308
 
352 309
 		if err := handlerFunc(ctx, w, r, mux.Vars(r)); err != nil {
353 310
 			logrus.Errorf("Handler for %s %s returned error: %s", localMethod, localRoute, err)
... ...
@@ -356,6 +311,7 @@ func makeHTTPHandler(logging bool, localMethod string, localRoute string, handle
356 356
 	}
357 357
 }
358 358
 
359
+// createRouter initializes the main router the server uses.
359 360
 // we keep enableCors just for legacy usage, need to be removed in the future
360 361
 func createRouter(s *Server) *mux.Router {
361 362
 	r := mux.NewRouter()
... ...
@@ -428,13 +384,6 @@ func createRouter(s *Server) *mux.Router {
428 428
 		},
429 429
 	}
430 430
 
431
-	// If "api-cors-header" is not given, but "api-enable-cors" is true, we set cors to "*"
432
-	// otherwise, all head values will be passed to HTTP handler
433
-	corsHeaders := s.cfg.CorsHeaders
434
-	if corsHeaders == "" && s.cfg.EnableCors {
435
-		corsHeaders = "*"
436
-	}
437
-
438 431
 	for method, routes := range m {
439 432
 		for route, fct := range routes {
440 433
 			logrus.Debugf("Registering %s, %s", method, route)
... ...
@@ -444,7 +393,7 @@ func createRouter(s *Server) *mux.Router {
444 444
 			localMethod := method
445 445
 
446 446
 			// build the handler function
447
-			f := makeHTTPHandler(s.cfg.Logging, localMethod, localRoute, localFct, corsHeaders, version.Version(s.cfg.Version))
447
+			f := s.makeHTTPHandler(localMethod, localRoute, localFct)
448 448
 
449 449
 			// add the new route
450 450
 			if localRoute == "" {
451 451
new file mode 100644
... ...
@@ -0,0 +1,35 @@
0
+package server
1
+
2
+import (
3
+	"net/http"
4
+	"net/http/httptest"
5
+	"testing"
6
+
7
+	"github.com/docker/docker/context"
8
+)
9
+
10
+func TestMiddlewares(t *testing.T) {
11
+	cfg := &Config{}
12
+	srv := &Server{
13
+		cfg: cfg,
14
+	}
15
+
16
+	req, _ := http.NewRequest("GET", "/containers/json", nil)
17
+	resp := httptest.NewRecorder()
18
+	ctx := context.Background()
19
+
20
+	localHandler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error {
21
+		if ctx.Version() == "" {
22
+			t.Fatalf("Expected version, got empty string")
23
+		}
24
+		if ctx.RequestID() == "" {
25
+			t.Fatalf("Expected request-id, got empty string")
26
+		}
27
+		return nil
28
+	}
29
+
30
+	handlerFunc := srv.handleWithGlobalMiddlewares(localHandler)
31
+	if err := handlerFunc(ctx, resp, req, map[string]string{}); err != nil {
32
+		t.Fatal(err)
33
+	}
34
+}
0 35
new file mode 100644
... ...
@@ -0,0 +1,27 @@
0
+package errors
1
+
2
+import (
3
+	"net/http"
4
+
5
+	"github.com/docker/distribution/registry/api/errcode"
6
+)
7
+
8
+var (
9
+	// ErrorCodeNewerClientVersion is generated when a request from a client
10
+	// specifies a higher version than the server supports.
11
+	ErrorCodeNewerClientVersion = errcode.Register(errGroup, errcode.ErrorDescriptor{
12
+		Value:          "NEWERCLIENTVERSION",
13
+		Message:        "client is newer than server (client API version: %s, server API version: %s)",
14
+		Description:    "The client version is higher than the server version",
15
+		HTTPStatusCode: http.StatusBadRequest,
16
+	})
17
+
18
+	// ErrorCodeOldClientVersion is generated when a request from a client
19
+	// specifies a version lower than the minimum version supported by the server.
20
+	ErrorCodeOldClientVersion = errcode.Register(errGroup, errcode.ErrorDescriptor{
21
+		Value:          "OLDCLIENTVERSION",
22
+		Message:        "client version %s is too old. Minimum supported API version is %s, please upgrade your client to a newer version",
23
+		Description:    "The client version is too old for the server",
24
+		HTTPStatusCode: http.StatusBadRequest,
25
+	})
26
+)