Browse code

cmd/dockerd: gracefully shut down the API server

As of Go 1.8, "net/http".Server provides facilities to close all
listeners, making the same facilities in server.Server redundant.
http.Server also improves upon server.Server by additionally providing a
facility to also wait for outstanding requests to complete after closing
all listeners. Leverage those facilities to give in-flight requests up
to five seconds to finish up after all containers have been shut down.

Signed-off-by: Cory Snider <csnider@mirantis.com>

Cory Snider authored on 2023/04/13 01:26:38
Showing 3 changed files
... ...
@@ -2,10 +2,7 @@ package server // import "github.com/docker/docker/api/server"
2 2
 
3 3
 import (
4 4
 	"context"
5
-	"net"
6 5
 	"net/http"
7
-	"strings"
8
-	"time"
9 6
 
10 7
 	"github.com/docker/docker/api/server/httpstatus"
11 8
 	"github.com/docker/docker/api/server/httputils"
... ...
@@ -23,8 +20,6 @@ const versionMatcher = "/v{version:[0-9.]+}"
23 23
 
24 24
 // Server contains instance details for the server
25 25
 type Server struct {
26
-	servers     []*HTTPServer
27
-	routers     []router.Router
28 26
 	middlewares []middleware.Middleware
29 27
 }
30 28
 
... ...
@@ -34,71 +29,6 @@ func (s *Server) UseMiddleware(m middleware.Middleware) {
34 34
 	s.middlewares = append(s.middlewares, m)
35 35
 }
36 36
 
37
-// Accept sets a listener the server accepts connections into.
38
-func (s *Server) Accept(addr string, listeners ...net.Listener) {
39
-	for _, listener := range listeners {
40
-		httpServer := &HTTPServer{
41
-			srv: &http.Server{
42
-				Addr:              addr,
43
-				ReadHeaderTimeout: 5 * time.Minute, // "G112: Potential Slowloris Attack (gosec)"; not a real concern for our use, so setting a long timeout.
44
-			},
45
-			l: listener,
46
-		}
47
-		s.servers = append(s.servers, httpServer)
48
-	}
49
-}
50
-
51
-// Close closes servers and thus stop receiving requests
52
-func (s *Server) Close() {
53
-	for _, srv := range s.servers {
54
-		if err := srv.Close(); err != nil {
55
-			logrus.Error(err)
56
-		}
57
-	}
58
-}
59
-
60
-// Serve starts listening for inbound requests.
61
-func (s *Server) Serve() error {
62
-	var chErrors = make(chan error, len(s.servers))
63
-	for _, srv := range s.servers {
64
-		srv.srv.Handler = s.createMux()
65
-		go func(srv *HTTPServer) {
66
-			var err error
67
-			logrus.Infof("API listen on %s", srv.l.Addr())
68
-			if err = srv.Serve(); err != nil && strings.Contains(err.Error(), "use of closed network connection") {
69
-				err = nil
70
-			}
71
-			chErrors <- err
72
-		}(srv)
73
-	}
74
-
75
-	for range s.servers {
76
-		err := <-chErrors
77
-		if err != nil {
78
-			return err
79
-		}
80
-	}
81
-	return nil
82
-}
83
-
84
-// HTTPServer contains an instance of http server and the listener.
85
-// srv *http.Server, contains configuration to create an http server and a mux router with all api end points.
86
-// l   net.Listener, is a TCP or Socket listener that dispatches incoming request to the router.
87
-type HTTPServer struct {
88
-	srv *http.Server
89
-	l   net.Listener
90
-}
91
-
92
-// Serve starts listening for inbound requests.
93
-func (s *HTTPServer) Serve() error {
94
-	return s.srv.Serve(s.l)
95
-}
96
-
97
-// Close closes the HTTPServer from listening for the inbound requests.
98
-func (s *HTTPServer) Close() error {
99
-	return s.l.Close()
100
-}
101
-
102 37
 func (s *Server) makeHTTPHandler(handler httputils.APIFunc) http.HandlerFunc {
103 38
 	return func(w http.ResponseWriter, r *http.Request) {
104 39
 		// Define the context that we'll pass around to share info
... ...
@@ -130,12 +60,6 @@ func (s *Server) makeHTTPHandler(handler httputils.APIFunc) http.HandlerFunc {
130 130
 	}
131 131
 }
132 132
 
133
-// InitRouter initializes the list of routers for the server.
134
-// This method also enables the Go profiler.
135
-func (s *Server) InitRouter(routers ...router.Router) {
136
-	s.routers = append(s.routers, routers...)
137
-}
138
-
139 133
 type pageNotFoundError struct{}
140 134
 
141 135
 func (pageNotFoundError) Error() string {
... ...
@@ -144,12 +68,12 @@ func (pageNotFoundError) Error() string {
144 144
 
145 145
 func (pageNotFoundError) NotFound() {}
146 146
 
147
-// createMux initializes the main router the server uses.
148
-func (s *Server) createMux() *mux.Router {
147
+// CreateMux returns a new mux with all the routers registered.
148
+func (s *Server) CreateMux(routers ...router.Router) *mux.Router {
149 149
 	m := mux.NewRouter()
150 150
 
151 151
 	logrus.Debug("Registering routers")
152
-	for _, apiRouter := range s.routers {
152
+	for _, apiRouter := range routers {
153 153
 		for _, r := range apiRouter.Routes() {
154 154
 			f := s.makeHTTPHandler(r.Handler())
155 155
 
... ...
@@ -160,7 +84,6 @@ func (s *Server) createMux() *mux.Router {
160 160
 	}
161 161
 
162 162
 	debugRouter := debug.NewRouter()
163
-	s.routers = append(s.routers, debugRouter)
164 163
 	for _, r := range debugRouter.Routes() {
165 164
 		f := s.makeHTTPHandler(r.Handler())
166 165
 		m.Path("/debug" + r.Path()).Handler(f)
... ...
@@ -5,11 +5,13 @@ import (
5 5
 	"crypto/tls"
6 6
 	"fmt"
7 7
 	"net"
8
+	"net/http"
8 9
 	"os"
9 10
 	"path/filepath"
10 11
 	"runtime"
11 12
 	"sort"
12 13
 	"strings"
14
+	"sync"
13 15
 	"time"
14 16
 
15 17
 	containerddefaults "github.com/containerd/containerd/defaults"
... ...
@@ -65,14 +67,18 @@ type DaemonCli struct {
65 65
 	configFile *string
66 66
 	flags      *pflag.FlagSet
67 67
 
68
-	api             apiserver.Server
69 68
 	d               *daemon.Daemon
70 69
 	authzMiddleware *authorization.Middleware // authzMiddleware enables to dynamically reload the authorization plugins
70
+
71
+	stopOnce    sync.Once
72
+	apiShutdown chan struct{}
71 73
 }
72 74
 
73 75
 // NewDaemonCli returns a daemon CLI
74 76
 func NewDaemonCli() *DaemonCli {
75
-	return &DaemonCli{}
77
+	return &DaemonCli{
78
+		apiShutdown: make(chan struct{}),
79
+	}
76 80
 }
77 81
 
78 82
 func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
... ...
@@ -161,7 +167,7 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
161 161
 		}
162 162
 	}
163 163
 
164
-	hosts, err := loadListeners(cli, tlsConfig)
164
+	lss, hosts, err := loadListeners(cli.Config, tlsConfig)
165 165
 	if err != nil {
166 166
 		return errors.Wrap(err, "failed to load listeners")
167 167
 	}
... ...
@@ -177,20 +183,51 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
177 177
 	}
178 178
 	defer cancel()
179 179
 
180
-	stopc := make(chan bool)
181
-	defer close(stopc)
182
-
183
-	trap.Trap(func() {
184
-		cli.stop()
185
-		<-stopc // wait for daemonCli.start() to return
186
-	})
180
+	httpServer := &http.Server{
181
+		ReadHeaderTimeout: 5 * time.Minute, // "G112: Potential Slowloris Attack (gosec)"; not a real concern for our use, so setting a long timeout.
182
+	}
183
+	apiShutdownCtx, apiShutdownCancel := context.WithCancel(context.Background())
184
+	apiShutdownDone := make(chan struct{})
185
+	trap.Trap(cli.stop)
186
+	go func() {
187
+		// Block until cli.stop() has been called.
188
+		// It may have already been called, and that's okay.
189
+		// Any httpServer.Serve() calls made after
190
+		// httpServer.Shutdown() will return immediately,
191
+		// which is what we want.
192
+		<-cli.apiShutdown
193
+		err := httpServer.Shutdown(apiShutdownCtx)
194
+		if err != nil {
195
+			logrus.WithError(err).Error("Error shutting down http server")
196
+		}
197
+		close(apiShutdownDone)
198
+	}()
199
+	defer func() {
200
+		select {
201
+		case <-cli.apiShutdown:
202
+			// cli.stop() has been called and the daemon has completed
203
+			// shutting down. Give the HTTP server a little more time to
204
+			// finish handling any outstanding requests if needed.
205
+			tmr := time.AfterFunc(5*time.Second, apiShutdownCancel)
206
+			defer tmr.Stop()
207
+			<-apiShutdownDone
208
+		default:
209
+			// cli.start() has returned without cli.stop() being called,
210
+			// e.g. because the daemon failed to start.
211
+			// Stop the HTTP server with no grace period.
212
+			if closeErr := httpServer.Close(); closeErr != nil {
213
+				logrus.WithError(closeErr).Error("Error closing http server")
214
+			}
215
+		}
216
+	}()
187 217
 
188 218
 	// Notify that the API is active, but before daemon is set up.
189 219
 	preNotifyReady()
190 220
 
191 221
 	pluginStore := plugin.NewStore()
192 222
 
193
-	cli.authzMiddleware = initMiddlewares(&cli.api, cli.Config, pluginStore)
223
+	var apiServer apiserver.Server
224
+	cli.authzMiddleware = initMiddlewares(&apiServer, cli.Config, pluginStore)
194 225
 
195 226
 	d, err := daemon.NewDaemon(ctx, cli.Config, pluginStore, cli.authzMiddleware)
196 227
 	if err != nil {
... ...
@@ -229,10 +266,9 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
229 229
 	if err != nil {
230 230
 		return err
231 231
 	}
232
-	routerOptions.api = &cli.api
233 232
 	routerOptions.cluster = c
234 233
 
235
-	initRouter(routerOptions)
234
+	httpServer.Handler = apiServer.CreateMux(routerOptions.Build()...)
236 235
 
237 236
 	go d.ProcessClusterNotifications(ctx, c.GetWatchStream())
238 237
 
... ...
@@ -243,10 +279,30 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
243 243
 
244 244
 	// Daemon is fully initialized. Start handling API traffic
245 245
 	// and wait for serve API to complete.
246
-	errAPI := cli.api.Serve()
247
-	if errAPI != nil {
248
-		logrus.WithError(errAPI).Error("ServeAPI error")
246
+	var (
247
+		apiWG  sync.WaitGroup
248
+		errAPI = make(chan error, 1)
249
+	)
250
+	for _, ls := range lss {
251
+		apiWG.Add(1)
252
+		go func(ls net.Listener) {
253
+			defer apiWG.Done()
254
+			logrus.Infof("API listen on %s", ls.Addr())
255
+			if err := httpServer.Serve(ls); err != http.ErrServerClosed {
256
+				logrus.WithFields(logrus.Fields{
257
+					logrus.ErrorKey: err,
258
+					"listener":      ls.Addr(),
259
+				}).Error("ServeAPI error")
260
+
261
+				select {
262
+				case errAPI <- err:
263
+				default:
264
+				}
265
+			}
266
+		}(ls)
249 267
 	}
268
+	apiWG.Wait()
269
+	close(errAPI)
250 270
 
251 271
 	c.Cleanup()
252 272
 
... ...
@@ -257,8 +313,8 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
257 257
 	// Stop notification processing and any background processes
258 258
 	cancel()
259 259
 
260
-	if errAPI != nil {
261
-		return errors.Wrap(errAPI, "shutting down due to ServeAPI error")
260
+	if err, ok := <-errAPI; ok {
261
+		return errors.Wrap(err, "shutting down due to ServeAPI error")
262 262
 	}
263 263
 
264 264
 	logrus.Info("Daemon shutdown complete")
... ...
@@ -271,7 +327,6 @@ type routerOptions struct {
271 271
 	features       *map[string]bool
272 272
 	buildkit       *buildkit.Builder
273 273
 	daemon         *daemon.Daemon
274
-	api            *apiserver.Server
275 274
 	cluster        *cluster.Cluster
276 275
 }
277 276
 
... ...
@@ -356,7 +411,14 @@ func (cli *DaemonCli) reloadConfig() {
356 356
 }
357 357
 
358 358
 func (cli *DaemonCli) stop() {
359
-	cli.api.Close()
359
+	// Signal that the API server should shut down as soon as possible.
360
+	// This construct is used rather than directly shutting down the HTTP
361
+	// server to avoid any issues if this method is called before the server
362
+	// has been instantiated in cli.start(). If this method is called first,
363
+	// the HTTP server will be shut down immediately upon instantiation.
364
+	cli.stopOnce.Do(func() {
365
+		close(cli.apiShutdown)
366
+	})
360 367
 }
361 368
 
362 369
 // shutdownDaemon just wraps daemon.Shutdown() to handle a timeout in case
... ...
@@ -498,7 +560,7 @@ func normalizeHosts(config *config.Config) error {
498 498
 	return nil
499 499
 }
500 500
 
501
-func initRouter(opts routerOptions) {
501
+func (opts routerOptions) Build() []router.Router {
502 502
 	decoder := runconfig.ContainerDecoder{
503 503
 		GetSysInfo: func() *sysinfo.SysInfo {
504 504
 			return opts.daemon.RawSysInfo()
... ...
@@ -543,7 +605,7 @@ func initRouter(opts routerOptions) {
543 543
 		}
544 544
 	}
545 545
 
546
-	opts.api.InitRouter(routers...)
546
+	return routers
547 547
 }
548 548
 
549 549
 func initMiddlewares(s *apiserver.Server, cfg *config.Config, pluginStore plugingetter.PluginGetter) *authorization.Middleware {
... ...
@@ -647,17 +709,20 @@ func checkTLSAuthOK(c *config.Config) bool {
647 647
 	return true
648 648
 }
649 649
 
650
-func loadListeners(cli *DaemonCli, tlsConfig *tls.Config) ([]string, error) {
651
-	if len(cli.Config.Hosts) == 0 {
652
-		return nil, errors.New("no hosts configured")
650
+func loadListeners(cfg *config.Config, tlsConfig *tls.Config) ([]net.Listener, []string, error) {
651
+	if len(cfg.Hosts) == 0 {
652
+		return nil, nil, errors.New("no hosts configured")
653 653
 	}
654
-	var hosts []string
654
+	var (
655
+		hosts []string
656
+		lss   []net.Listener
657
+	)
655 658
 
656
-	for i := 0; i < len(cli.Config.Hosts); i++ {
657
-		protoAddr := cli.Config.Hosts[i]
659
+	for i := 0; i < len(cfg.Hosts); i++ {
660
+		protoAddr := cfg.Hosts[i]
658 661
 		proto, addr, ok := strings.Cut(protoAddr, "://")
659 662
 		if !ok {
660
-			return nil, fmt.Errorf("bad format %s, expected PROTO://ADDR", protoAddr)
663
+			return nil, nil, fmt.Errorf("bad format %s, expected PROTO://ADDR", protoAddr)
661 664
 		}
662 665
 
663 666
 		// It's a bad idea to bind to TCP without tlsverify.
... ...
@@ -669,10 +734,10 @@ func loadListeners(cli *DaemonCli, tlsConfig *tls.Config) ([]string, error) {
669 669
 
670 670
 			// If TLSVerify is explicitly set to false we'll take that as "Please let me shoot myself in the foot"
671 671
 			// We do not want to continue to support a default mode where tls verification is disabled, so we do some extra warnings here and eventually remove support
672
-			if !checkTLSAuthOK(cli.Config) {
672
+			if !checkTLSAuthOK(cfg) {
673 673
 				ipAddr, _, err := net.SplitHostPort(addr)
674 674
 				if err != nil {
675
-					return nil, errors.Wrap(err, "error parsing tcp address")
675
+					return nil, nil, errors.Wrap(err, "error parsing tcp address")
676 676
 				}
677 677
 
678 678
 				// shortcut all this extra stuff for literal "localhost"
... ...
@@ -702,19 +767,19 @@ func loadListeners(cli *DaemonCli, tlsConfig *tls.Config) ([]string, error) {
702 702
 		// If we're binding to a TCP port, make sure that a container doesn't try to use it.
703 703
 		if proto == "tcp" {
704 704
 			if err := allocateDaemonPort(addr); err != nil {
705
-				return nil, err
705
+				return nil, nil, err
706 706
 			}
707 707
 		}
708
-		ls, err := listeners.Init(proto, addr, cli.Config.SocketGroup, tlsConfig)
708
+		ls, err := listeners.Init(proto, addr, cfg.SocketGroup, tlsConfig)
709 709
 		if err != nil {
710
-			return nil, err
710
+			return nil, nil, err
711 711
 		}
712 712
 		logrus.Debugf("Listener created for HTTP on %s (%s)", proto, addr)
713 713
 		hosts = append(hosts, addr)
714
-		cli.api.Accept(addr, ls...)
714
+		lss = append(lss, ls...)
715 715
 	}
716 716
 
717
-	return hosts, nil
717
+	return lss, hosts, nil
718 718
 }
719 719
 
720 720
 func createAndStartCluster(cli *DaemonCli, d *daemon.Daemon) (*cluster.Cluster, error) {
... ...
@@ -42,14 +42,12 @@ func initListenerTestPhase1() {
42 42
 }
43 43
 
44 44
 func initListenerTestPhase2() {
45
-	cli := &DaemonCli{
46
-		Config: &config.Config{
47
-			CommonConfig: config.CommonConfig{
48
-				Hosts: []string{"fd://"},
49
-			},
45
+	cfg := &config.Config{
46
+		CommonConfig: config.CommonConfig{
47
+			Hosts: []string{"fd://"},
50 48
 		},
51 49
 	}
52
-	_, err := loadListeners(cli, nil)
50
+	_, _, err := loadListeners(cfg, nil)
53 51
 	var resp listenerTestResponse
54 52
 	if err != nil {
55 53
 		resp.Err = err.Error()