Browse code

Cleanup api server creation

Current implementation is hard to reason about because of trying to mix
unix/tcp server implementations, even though they are quite different.
This cleans that up.

Also makes it possible to create and manage a new API server easily,
e.g. for adding an introspection socket to a container.

Built in such a way as to allow a non-HTTP server to work as well, such
as libchan.

Signed-off-by: Brian Goff <cpuguy83@gmail.com>

Brian Goff authored on 2014/11/08 05:21:19
Showing 1 changed files
... ...
@@ -3,8 +3,7 @@ package server
3 3
 import (
4 4
 	"bufio"
5 5
 	"bytes"
6
-	"crypto/tls"
7
-	"crypto/x509"
6
+
8 7
 	"encoding/base64"
9 8
 	"encoding/json"
10 9
 	"expvar"
... ...
@@ -19,6 +18,9 @@ import (
19 19
 	"strings"
20 20
 	"syscall"
21 21
 
22
+	"crypto/tls"
23
+	"crypto/x509"
24
+
22 25
 	"code.google.com/p/go.net/websocket"
23 26
 	"github.com/docker/libcontainer/user"
24 27
 	"github.com/gorilla/mux"
... ...
@@ -39,6 +41,18 @@ var (
39 39
 	activationLock chan struct{}
40 40
 )
41 41
 
42
+type HttpServer struct {
43
+	srv *http.Server
44
+	l   net.Listener
45
+}
46
+
47
+func (s *HttpServer) Serve() error {
48
+	return s.srv.Serve(s.l)
49
+}
50
+func (s *HttpServer) Close() error {
51
+	return s.l.Close()
52
+}
53
+
42 54
 type HttpApiFunc func(eng *engine.Engine, version version.Version, w http.ResponseWriter, r *http.Request, vars map[string]string) error
43 55
 
44 56
 func hijackServer(w http.ResponseWriter) (io.ReadCloser, io.Writer, error) {
... ...
@@ -1334,9 +1348,14 @@ func ServeRequest(eng *engine.Engine, apiversion version.Version, w http.Respons
1334 1334
 	return nil
1335 1335
 }
1336 1336
 
1337
-// ServeFD creates an http.Server and sets it up to serve given a socket activated
1337
+// serveFd creates an http.Server and sets it up to serve given a socket activated
1338 1338
 // argument.
1339
-func ServeFd(addr string, handle http.Handler) error {
1339
+func serveFd(addr string, job *engine.Job) error {
1340
+	r, err := createRouter(job.Eng, job.GetenvBool("Logging"), job.GetenvBool("EnableCors"), job.Getenv("Version"))
1341
+	if err != nil {
1342
+		return err
1343
+	}
1344
+
1340 1345
 	ls, e := systemd.ListenFD(addr)
1341 1346
 	if e != nil {
1342 1347
 		return e
... ...
@@ -1354,7 +1373,7 @@ func ServeFd(addr string, handle http.Handler) error {
1354 1354
 	for i := range ls {
1355 1355
 		listener := ls[i]
1356 1356
 		go func() {
1357
-			httpSrv := http.Server{Handler: handle}
1357
+			httpSrv := http.Server{Handler: r}
1358 1358
 			chErrors <- httpSrv.Serve(listener)
1359 1359
 		}()
1360 1360
 	}
... ...
@@ -1382,6 +1401,41 @@ func lookupGidByName(nameOrGid string) (int, error) {
1382 1382
 	return -1, fmt.Errorf("Group %s not found", nameOrGid)
1383 1383
 }
1384 1384
 
1385
+func setupTls(cert, key, ca string, l net.Listener) (net.Listener, error) {
1386
+	tlsCert, err := tls.LoadX509KeyPair(cert, key)
1387
+	if err != nil {
1388
+		return nil, fmt.Errorf("Couldn't load X509 key pair (%s, %s): %s. Key encrypted?",
1389
+			cert, key, err)
1390
+	}
1391
+	tlsConfig := &tls.Config{
1392
+		NextProtos:   []string{"http/1.1"},
1393
+		Certificates: []tls.Certificate{tlsCert},
1394
+		// Avoid fallback on insecure SSL protocols
1395
+		MinVersion: tls.VersionTLS10,
1396
+	}
1397
+
1398
+	if ca != "" {
1399
+		certPool := x509.NewCertPool()
1400
+		file, err := ioutil.ReadFile(ca)
1401
+		if err != nil {
1402
+			return nil, fmt.Errorf("Couldn't read CA certificate: %s", err)
1403
+		}
1404
+		certPool.AppendCertsFromPEM(file)
1405
+		tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
1406
+		tlsConfig.ClientCAs = certPool
1407
+	}
1408
+
1409
+	return tls.NewListener(l, tlsConfig), nil
1410
+}
1411
+
1412
+func newListener(proto, addr string, bufferRequests bool) (net.Listener, error) {
1413
+	if bufferRequests {
1414
+		return listenbuffer.NewListenBuffer(proto, addr, activationLock)
1415
+	}
1416
+
1417
+	return net.Listen(proto, addr)
1418
+}
1419
+
1385 1420
 func changeGroup(addr string, nameOrGid string) error {
1386 1421
 	gid, err := lookupGidByName(nameOrGid)
1387 1422
 	if err != nil {
... ...
@@ -1392,99 +1446,95 @@ func changeGroup(addr string, nameOrGid string) error {
1392 1392
 	return os.Chown(addr, 0, gid)
1393 1393
 }
1394 1394
 
1395
-// ListenAndServe sets up the required http.Server and gets it listening for
1396
-// each addr passed in and does protocol specific checking.
1397
-func ListenAndServe(proto, addr string, job *engine.Job) error {
1398
-	var l net.Listener
1395
+func setSocketGroup(addr, group string) error {
1396
+	if group == "" {
1397
+		return nil
1398
+	}
1399
+
1400
+	if err := changeGroup(addr, group); err != nil {
1401
+		if group != "docker" {
1402
+			return err
1403
+		}
1404
+		log.Debugf("Warning: could not chgrp %s to docker: %v", addr, err)
1405
+	}
1406
+
1407
+	return nil
1408
+}
1409
+
1410
+func setupUnixHttp(addr string, job *engine.Job) (*HttpServer, error) {
1399 1411
 	r, err := createRouter(job.Eng, job.GetenvBool("Logging"), job.GetenvBool("EnableCors"), job.Getenv("Version"))
1400 1412
 	if err != nil {
1401
-		return err
1413
+		return nil, err
1402 1414
 	}
1403 1415
 
1404
-	if proto == "fd" {
1405
-		return ServeFd(addr, r)
1416
+	if err := syscall.Unlink(addr); err != nil && !os.IsNotExist(err) {
1417
+		return nil, err
1406 1418
 	}
1419
+	mask := syscall.Umask(0777)
1420
+	defer syscall.Umask(mask)
1407 1421
 
1408
-	if proto == "unix" {
1409
-		if err := syscall.Unlink(addr); err != nil && !os.IsNotExist(err) {
1410
-			return err
1411
-		}
1422
+	l, err := newListener("unix", addr, job.GetenvBool("BufferRequests"))
1423
+	if err != nil {
1424
+		return nil, err
1412 1425
 	}
1413 1426
 
1414
-	var oldmask int
1415
-	if proto == "unix" {
1416
-		oldmask = syscall.Umask(0777)
1427
+	if err := setSocketGroup(addr, job.Getenv("SocketGroup")); err != nil {
1428
+		return nil, err
1417 1429
 	}
1418 1430
 
1419
-	if job.GetenvBool("BufferRequests") {
1420
-		l, err = listenbuffer.NewListenBuffer(proto, addr, activationLock)
1421
-	} else {
1422
-		l, err = net.Listen(proto, addr)
1431
+	if err := os.Chmod(addr, 0660); err != nil {
1432
+		return nil, err
1423 1433
 	}
1424 1434
 
1425
-	if proto == "unix" {
1426
-		syscall.Umask(oldmask)
1435
+	return &HttpServer{&http.Server{Addr: addr, Handler: r}, l}, nil
1436
+}
1437
+
1438
+func setupTcpHttp(addr string, job *engine.Job) (*HttpServer, error) {
1439
+	if !strings.HasPrefix(addr, "127.0.0.1") && !job.GetenvBool("TlsVerify") {
1440
+		log.Infof("/!\\ DON'T BIND ON ANOTHER IP ADDRESS THAN 127.0.0.1 IF YOU DON'T KNOW WHAT YOU'RE DOING /!\\")
1427 1441
 	}
1442
+
1443
+	r, err := createRouter(job.Eng, job.GetenvBool("Logging"), job.GetenvBool("EnableCors"), job.Getenv("Version"))
1428 1444
 	if err != nil {
1429
-		return err
1445
+		return nil, err
1430 1446
 	}
1431 1447
 
1432
-	if proto != "unix" && (job.GetenvBool("Tls") || job.GetenvBool("TlsVerify")) {
1433
-		tlsCert := job.Getenv("TlsCert")
1434
-		tlsKey := job.Getenv("TlsKey")
1435
-		cert, err := tls.LoadX509KeyPair(tlsCert, tlsKey)
1436
-		if err != nil {
1437
-			return fmt.Errorf("Couldn't load X509 key pair (%s, %s): %s. Key encrypted?",
1438
-				tlsCert, tlsKey, err)
1439
-		}
1440
-		tlsConfig := &tls.Config{
1441
-			NextProtos:   []string{"http/1.1"},
1442
-			Certificates: []tls.Certificate{cert},
1443
-			// Avoid fallback on insecure SSL protocols
1444
-			MinVersion: tls.VersionTLS10,
1445
-		}
1446
-		if job.GetenvBool("TlsVerify") {
1447
-			certPool := x509.NewCertPool()
1448
-			file, err := ioutil.ReadFile(job.Getenv("TlsCa"))
1449
-			if err != nil {
1450
-				return fmt.Errorf("Couldn't read CA certificate: %s", err)
1451
-			}
1452
-			certPool.AppendCertsFromPEM(file)
1448
+	l, err := newListener("tcp", addr, job.GetenvBool("BufferRequests"))
1449
+	if err != nil {
1450
+		return nil, err
1451
+	}
1453 1452
 
1454
-			tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
1455
-			tlsConfig.ClientCAs = certPool
1453
+	if job.GetenvBool("Tls") || job.GetenvBool("TlsVerify") {
1454
+		var tlsCa string
1455
+		if job.GetenvBool("TlsVerify") {
1456
+			tlsCa = job.Getenv("TlsCa")
1457
+		}
1458
+		l, err = setupTls(job.Getenv("TlsCert"), job.Getenv("TlsKey"), tlsCa, l)
1459
+		if err != nil {
1460
+			return nil, err
1456 1461
 		}
1457
-		l = tls.NewListener(l, tlsConfig)
1458 1462
 	}
1463
+	return &HttpServer{&http.Server{Addr: addr, Handler: r}, l}, nil
1464
+}
1459 1465
 
1466
+// NewServer sets up the required Server and does protocol specific checking.
1467
+func NewServer(proto, addr string, job *engine.Job) (Server, error) {
1460 1468
 	// Basic error and sanity checking
1461 1469
 	switch proto {
1470
+	case "fd":
1471
+		return nil, serveFd(addr, job)
1462 1472
 	case "tcp":
1463
-		if !strings.HasPrefix(addr, "127.0.0.1") && !job.GetenvBool("TlsVerify") {
1464
-			log.Infof("/!\\ DON'T BIND ON ANOTHER IP ADDRESS THAN 127.0.0.1 IF YOU DON'T KNOW WHAT YOU'RE DOING /!\\")
1465
-		}
1473
+		return setupTcpHttp(addr, job)
1466 1474
 	case "unix":
1467
-		socketGroup := job.Getenv("SocketGroup")
1468
-		if socketGroup != "" {
1469
-			if err := changeGroup(addr, socketGroup); err != nil {
1470
-				if socketGroup == "docker" {
1471
-					// if the user hasn't explicitly specified the group ownership, don't fail on errors.
1472
-					log.Debugf("Warning: could not chgrp %s to docker: %s", addr, err.Error())
1473
-				} else {
1474
-					return err
1475
-				}
1476
-			}
1477
-
1478
-		}
1479
-		if err := os.Chmod(addr, 0660); err != nil {
1480
-			return err
1481
-		}
1475
+		return setupUnixHttp(addr, job)
1482 1476
 	default:
1483
-		return fmt.Errorf("Invalid protocol format.")
1477
+		return nil, fmt.Errorf("Invalid protocol format.")
1484 1478
 	}
1479
+}
1485 1480
 
1486
-	httpSrv := http.Server{Addr: addr, Handler: r}
1487
-	return httpSrv.Serve(l)
1481
+type Server interface {
1482
+	Serve() error
1483
+	Close() error
1488 1484
 }
1489 1485
 
1490 1486
 // ServeApi loops through all of the protocols sent in to docker and spawns
... ...
@@ -1506,7 +1556,12 @@ func ServeApi(job *engine.Job) engine.Status {
1506 1506
 		}
1507 1507
 		go func() {
1508 1508
 			log.Infof("Listening for HTTP on %s (%s)", protoAddrParts[0], protoAddrParts[1])
1509
-			chErrors <- ListenAndServe(protoAddrParts[0], protoAddrParts[1], job)
1509
+			srv, err := NewServer(protoAddrParts[0], protoAddrParts[1], job)
1510
+			if err != nil {
1511
+				chErrors <- err
1512
+				return
1513
+			}
1514
+			chErrors <- srv.Serve()
1510 1515
 		}()
1511 1516
 	}
1512 1517