Browse code

Use stdlib TLS dialer

Since go1.8, the stdlib TLS net.Conn implementation implements the
`CloseWrite()` interface.

Signed-off-by: Brian Goff <cpuguy83@gmail.com>
Signed-off-by: Sebastiaan van Stijn <github@gone.nl>

Brian Goff authored on 2018/03/24 03:39:30
Showing 4 changed files
... ...
@@ -9,7 +9,6 @@ import (
9 9
 	"net/http"
10 10
 	"net/http/httputil"
11 11
 	"net/url"
12
-	"strings"
13 12
 	"time"
14 13
 
15 14
 	"github.com/docker/docker/api/types"
... ...
@@ -17,21 +16,6 @@ import (
17 17
 	"github.com/pkg/errors"
18 18
 )
19 19
 
20
-// tlsClientCon holds tls information and a dialed connection.
21
-type tlsClientCon struct {
22
-	*tls.Conn
23
-	rawConn net.Conn
24
-}
25
-
26
-func (c *tlsClientCon) CloseWrite() error {
27
-	// Go standard tls.Conn doesn't provide the CloseWrite() method so we do it
28
-	// on its underlying connection.
29
-	if conn, ok := c.rawConn.(types.CloseWriter); ok {
30
-		return conn.CloseWrite()
31
-	}
32
-	return nil
33
-}
34
-
35 20
 // postHijacked sends a POST request and hijacks the connection.
36 21
 func (cli *Client) postHijacked(ctx context.Context, path string, query url.Values, body interface{}, headers map[string][]string) (types.HijackedResponse, error) {
37 22
 	bodyEncoded, err := encodeData(body)
... ...
@@ -54,96 +38,9 @@ func (cli *Client) postHijacked(ctx context.Context, path string, query url.Valu
54 54
 	return types.HijackedResponse{Conn: conn, Reader: bufio.NewReader(conn)}, err
55 55
 }
56 56
 
57
-func tlsDial(network, addr string, config *tls.Config) (net.Conn, error) {
58
-	return tlsDialWithDialer(new(net.Dialer), network, addr, config)
59
-}
60
-
61
-// We need to copy Go's implementation of tls.Dial (pkg/cryptor/tls/tls.go) in
62
-// order to return our custom tlsClientCon struct which holds both the tls.Conn
63
-// object _and_ its underlying raw connection. The rationale for this is that
64
-// we need to be able to close the write end of the connection when attaching,
65
-// which tls.Conn does not provide.
66
-func tlsDialWithDialer(dialer *net.Dialer, network, addr string, config *tls.Config) (net.Conn, error) {
67
-	// We want the Timeout and Deadline values from dialer to cover the
68
-	// whole process: TCP connection and TLS handshake. This means that we
69
-	// also need to start our own timers now.
70
-	timeout := dialer.Timeout
71
-
72
-	if !dialer.Deadline.IsZero() {
73
-		deadlineTimeout := time.Until(dialer.Deadline)
74
-		if timeout == 0 || deadlineTimeout < timeout {
75
-			timeout = deadlineTimeout
76
-		}
77
-	}
78
-
79
-	var errChannel chan error
80
-
81
-	if timeout != 0 {
82
-		errChannel = make(chan error, 2)
83
-		time.AfterFunc(timeout, func() {
84
-			errChannel <- errors.New("")
85
-		})
86
-	}
87
-
88
-	proxyDialer, err := sockets.DialerFromEnvironment(dialer)
89
-	if err != nil {
90
-		return nil, err
91
-	}
92
-
93
-	rawConn, err := proxyDialer.Dial(network, addr)
94
-	if err != nil {
95
-		return nil, err
96
-	}
97
-	// When we set up a TCP connection for hijack, there could be long periods
98
-	// of inactivity (a long running command with no output) that in certain
99
-	// network setups may cause ECONNTIMEOUT, leaving the client in an unknown
100
-	// state. Setting TCP KeepAlive on the socket connection will prohibit
101
-	// ECONNTIMEOUT unless the socket connection truly is broken
102
-	if tcpConn, ok := rawConn.(*net.TCPConn); ok {
103
-		tcpConn.SetKeepAlive(true)
104
-		tcpConn.SetKeepAlivePeriod(30 * time.Second)
105
-	}
106
-
107
-	colonPos := strings.LastIndex(addr, ":")
108
-	if colonPos == -1 {
109
-		colonPos = len(addr)
110
-	}
111
-	hostname := addr[:colonPos]
112
-
113
-	// If no ServerName is set, infer the ServerName
114
-	// from the hostname we're connecting to.
115
-	if config.ServerName == "" {
116
-		// Make a copy to avoid polluting argument or default.
117
-		config = tlsConfigClone(config)
118
-		config.ServerName = hostname
119
-	}
120
-
121
-	conn := tls.Client(rawConn, config)
122
-
123
-	if timeout == 0 {
124
-		err = conn.Handshake()
125
-	} else {
126
-		go func() {
127
-			errChannel <- conn.Handshake()
128
-		}()
129
-
130
-		err = <-errChannel
131
-	}
132
-
133
-	if err != nil {
134
-		rawConn.Close()
135
-		return nil, err
136
-	}
137
-
138
-	// This is Docker difference with standard's crypto/tls package: returned a
139
-	// wrapper which holds both the TLS and raw connections.
140
-	return &tlsClientCon{conn, rawConn}, nil
141
-}
142
-
143 57
 func dial(proto, addr string, tlsConfig *tls.Config) (net.Conn, error) {
144 58
 	if tlsConfig != nil && proto != "unix" && proto != "npipe" {
145
-		// Notice this isn't Go standard's tls.Dial function
146
-		return tlsDial(proto, addr, tlsConfig)
59
+		return tls.Dial(proto, addr, tlsConfig)
147 60
 	}
148 61
 	if proto == "npipe" {
149 62
 		return sockets.DialPipe(addr, 32*time.Second)
150 63
new file mode 100644
... ...
@@ -0,0 +1,103 @@
0
+package client
1
+
2
+import (
3
+	"fmt"
4
+	"io/ioutil"
5
+	"net"
6
+	"net/http"
7
+	"net/http/httptest"
8
+	"net/url"
9
+	"testing"
10
+
11
+	"github.com/docker/docker/api/server/httputils"
12
+	"github.com/docker/docker/api/types"
13
+	"github.com/gotestyourself/gotestyourself/assert"
14
+	"github.com/pkg/errors"
15
+	"golang.org/x/net/context"
16
+)
17
+
18
+func TestTLSCloseWriter(t *testing.T) {
19
+	t.Parallel()
20
+
21
+	var chErr chan error
22
+	ts := &httptest.Server{Config: &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
23
+		chErr = make(chan error, 1)
24
+		defer close(chErr)
25
+		if err := httputils.ParseForm(req); err != nil {
26
+			chErr <- errors.Wrap(err, "error parsing form")
27
+			http.Error(w, err.Error(), 500)
28
+			return
29
+		}
30
+		r, rw, err := httputils.HijackConnection(w)
31
+		if err != nil {
32
+			chErr <- errors.Wrap(err, "error hijacking connection")
33
+			http.Error(w, err.Error(), 500)
34
+			return
35
+		}
36
+		defer r.Close()
37
+
38
+		fmt.Fprint(rw, "HTTP/1.1 101 UPGRADED\r\nContent-Type: application/vnd.docker.raw-stream\r\nConnection: Upgrade\r\nUpgrade: tcp\r\n\n")
39
+
40
+		buf := make([]byte, 5)
41
+		_, err = r.Read(buf)
42
+		if err != nil {
43
+			chErr <- errors.Wrap(err, "error reading from client")
44
+			return
45
+		}
46
+		_, err = rw.Write(buf)
47
+		if err != nil {
48
+			chErr <- errors.Wrap(err, "error writing to client")
49
+			return
50
+		}
51
+	})}}
52
+
53
+	var (
54
+		l   net.Listener
55
+		err error
56
+	)
57
+	for i := 1024; i < 10000; i++ {
58
+		l, err = net.Listen("tcp4", fmt.Sprintf("127.0.0.1:%d", i))
59
+		if err == nil {
60
+			break
61
+		}
62
+	}
63
+	assert.Assert(t, err)
64
+
65
+	ts.Listener = l
66
+	defer l.Close()
67
+
68
+	defer func() {
69
+		if chErr != nil {
70
+			assert.Assert(t, <-chErr)
71
+		}
72
+	}()
73
+
74
+	ts.StartTLS()
75
+	defer ts.Close()
76
+
77
+	serverURL, err := url.Parse(ts.URL)
78
+	assert.Assert(t, err)
79
+
80
+	client, err := NewClient("tcp://"+serverURL.Host, "", ts.Client(), nil)
81
+	assert.Assert(t, err)
82
+
83
+	resp, err := client.postHijacked(context.Background(), "/asdf", url.Values{}, nil, map[string][]string{"Content-Type": {"text/plain"}})
84
+	assert.Assert(t, err)
85
+	defer resp.Close()
86
+
87
+	if _, ok := resp.Conn.(types.CloseWriter); !ok {
88
+		t.Fatal("tls conn did not implement the CloseWrite interface")
89
+	}
90
+
91
+	_, err = resp.Conn.Write([]byte("hello"))
92
+	assert.Assert(t, err)
93
+
94
+	b, err := ioutil.ReadAll(resp.Reader)
95
+	assert.Assert(t, err)
96
+	assert.Assert(t, string(b) == "hello")
97
+	assert.Assert(t, resp.CloseWrite())
98
+
99
+	// This should error since writes are closed
100
+	_, err = resp.Conn.Write([]byte("no"))
101
+	assert.Assert(t, err != nil)
102
+}
0 103
deleted file mode 100644
... ...
@@ -1,11 +0,0 @@
1
-// +build go1.8
2
-
3
-package client // import "github.com/docker/docker/client"
4
-
5
-import "crypto/tls"
6
-
7
-// tlsConfigClone returns a clone of tls.Config. This function is provided for
8
-// compatibility for go1.7 that doesn't include this method in stdlib.
9
-func tlsConfigClone(c *tls.Config) *tls.Config {
10
-	return c.Clone()
11
-}
12 1
deleted file mode 100644
... ...
@@ -1,33 +0,0 @@
1
-// +build go1.7,!go1.8
2
-
3
-package client // import "github.com/docker/docker/client"
4
-
5
-import "crypto/tls"
6
-
7
-// tlsConfigClone returns a clone of tls.Config. This function is provided for
8
-// compatibility for go1.7 that doesn't include this method in stdlib.
9
-func tlsConfigClone(c *tls.Config) *tls.Config {
10
-	return &tls.Config{
11
-		Rand:                        c.Rand,
12
-		Time:                        c.Time,
13
-		Certificates:                c.Certificates,
14
-		NameToCertificate:           c.NameToCertificate,
15
-		GetCertificate:              c.GetCertificate,
16
-		RootCAs:                     c.RootCAs,
17
-		NextProtos:                  c.NextProtos,
18
-		ServerName:                  c.ServerName,
19
-		ClientAuth:                  c.ClientAuth,
20
-		ClientCAs:                   c.ClientCAs,
21
-		InsecureSkipVerify:          c.InsecureSkipVerify,
22
-		CipherSuites:                c.CipherSuites,
23
-		PreferServerCipherSuites:    c.PreferServerCipherSuites,
24
-		SessionTicketsDisabled:      c.SessionTicketsDisabled,
25
-		SessionTicketKey:            c.SessionTicketKey,
26
-		ClientSessionCache:          c.ClientSessionCache,
27
-		MinVersion:                  c.MinVersion,
28
-		MaxVersion:                  c.MaxVersion,
29
-		CurvePreferences:            c.CurvePreferences,
30
-		DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
31
-		Renegotiation:               c.Renegotiation,
32
-	}
33
-}