Since go1.8, the stdlib TLS net.Conn implementation implements the
`CloseWrite()` interface.
Signed-off-by: Brian Goff <cpuguy83@gmail.com>
Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
| ... | ... |
@@ -9,7 +9,6 @@ import ( |
| 9 | 9 |
"net/http" |
| 10 | 10 |
"net/http/httputil" |
| 11 | 11 |
"net/url" |
| 12 |
- "strings" |
|
| 13 | 12 |
"time" |
| 14 | 13 |
|
| 15 | 14 |
"github.com/docker/docker/api/types" |
| ... | ... |
@@ -17,21 +16,6 @@ import ( |
| 17 | 17 |
"github.com/pkg/errors" |
| 18 | 18 |
) |
| 19 | 19 |
|
| 20 |
-// tlsClientCon holds tls information and a dialed connection. |
|
| 21 |
-type tlsClientCon struct {
|
|
| 22 |
- *tls.Conn |
|
| 23 |
- rawConn net.Conn |
|
| 24 |
-} |
|
| 25 |
- |
|
| 26 |
-func (c *tlsClientCon) CloseWrite() error {
|
|
| 27 |
- // Go standard tls.Conn doesn't provide the CloseWrite() method so we do it |
|
| 28 |
- // on its underlying connection. |
|
| 29 |
- if conn, ok := c.rawConn.(types.CloseWriter); ok {
|
|
| 30 |
- return conn.CloseWrite() |
|
| 31 |
- } |
|
| 32 |
- return nil |
|
| 33 |
-} |
|
| 34 |
- |
|
| 35 | 20 |
// postHijacked sends a POST request and hijacks the connection. |
| 36 | 21 |
func (cli *Client) postHijacked(ctx context.Context, path string, query url.Values, body interface{}, headers map[string][]string) (types.HijackedResponse, error) {
|
| 37 | 22 |
bodyEncoded, err := encodeData(body) |
| ... | ... |
@@ -54,96 +38,9 @@ func (cli *Client) postHijacked(ctx context.Context, path string, query url.Valu |
| 54 | 54 |
return types.HijackedResponse{Conn: conn, Reader: bufio.NewReader(conn)}, err
|
| 55 | 55 |
} |
| 56 | 56 |
|
| 57 |
-func tlsDial(network, addr string, config *tls.Config) (net.Conn, error) {
|
|
| 58 |
- return tlsDialWithDialer(new(net.Dialer), network, addr, config) |
|
| 59 |
-} |
|
| 60 |
- |
|
| 61 |
-// We need to copy Go's implementation of tls.Dial (pkg/cryptor/tls/tls.go) in |
|
| 62 |
-// order to return our custom tlsClientCon struct which holds both the tls.Conn |
|
| 63 |
-// object _and_ its underlying raw connection. The rationale for this is that |
|
| 64 |
-// we need to be able to close the write end of the connection when attaching, |
|
| 65 |
-// which tls.Conn does not provide. |
|
| 66 |
-func tlsDialWithDialer(dialer *net.Dialer, network, addr string, config *tls.Config) (net.Conn, error) {
|
|
| 67 |
- // We want the Timeout and Deadline values from dialer to cover the |
|
| 68 |
- // whole process: TCP connection and TLS handshake. This means that we |
|
| 69 |
- // also need to start our own timers now. |
|
| 70 |
- timeout := dialer.Timeout |
|
| 71 |
- |
|
| 72 |
- if !dialer.Deadline.IsZero() {
|
|
| 73 |
- deadlineTimeout := time.Until(dialer.Deadline) |
|
| 74 |
- if timeout == 0 || deadlineTimeout < timeout {
|
|
| 75 |
- timeout = deadlineTimeout |
|
| 76 |
- } |
|
| 77 |
- } |
|
| 78 |
- |
|
| 79 |
- var errChannel chan error |
|
| 80 |
- |
|
| 81 |
- if timeout != 0 {
|
|
| 82 |
- errChannel = make(chan error, 2) |
|
| 83 |
- time.AfterFunc(timeout, func() {
|
|
| 84 |
- errChannel <- errors.New("")
|
|
| 85 |
- }) |
|
| 86 |
- } |
|
| 87 |
- |
|
| 88 |
- proxyDialer, err := sockets.DialerFromEnvironment(dialer) |
|
| 89 |
- if err != nil {
|
|
| 90 |
- return nil, err |
|
| 91 |
- } |
|
| 92 |
- |
|
| 93 |
- rawConn, err := proxyDialer.Dial(network, addr) |
|
| 94 |
- if err != nil {
|
|
| 95 |
- return nil, err |
|
| 96 |
- } |
|
| 97 |
- // When we set up a TCP connection for hijack, there could be long periods |
|
| 98 |
- // of inactivity (a long running command with no output) that in certain |
|
| 99 |
- // network setups may cause ECONNTIMEOUT, leaving the client in an unknown |
|
| 100 |
- // state. Setting TCP KeepAlive on the socket connection will prohibit |
|
| 101 |
- // ECONNTIMEOUT unless the socket connection truly is broken |
|
| 102 |
- if tcpConn, ok := rawConn.(*net.TCPConn); ok {
|
|
| 103 |
- tcpConn.SetKeepAlive(true) |
|
| 104 |
- tcpConn.SetKeepAlivePeriod(30 * time.Second) |
|
| 105 |
- } |
|
| 106 |
- |
|
| 107 |
- colonPos := strings.LastIndex(addr, ":") |
|
| 108 |
- if colonPos == -1 {
|
|
| 109 |
- colonPos = len(addr) |
|
| 110 |
- } |
|
| 111 |
- hostname := addr[:colonPos] |
|
| 112 |
- |
|
| 113 |
- // If no ServerName is set, infer the ServerName |
|
| 114 |
- // from the hostname we're connecting to. |
|
| 115 |
- if config.ServerName == "" {
|
|
| 116 |
- // Make a copy to avoid polluting argument or default. |
|
| 117 |
- config = tlsConfigClone(config) |
|
| 118 |
- config.ServerName = hostname |
|
| 119 |
- } |
|
| 120 |
- |
|
| 121 |
- conn := tls.Client(rawConn, config) |
|
| 122 |
- |
|
| 123 |
- if timeout == 0 {
|
|
| 124 |
- err = conn.Handshake() |
|
| 125 |
- } else {
|
|
| 126 |
- go func() {
|
|
| 127 |
- errChannel <- conn.Handshake() |
|
| 128 |
- }() |
|
| 129 |
- |
|
| 130 |
- err = <-errChannel |
|
| 131 |
- } |
|
| 132 |
- |
|
| 133 |
- if err != nil {
|
|
| 134 |
- rawConn.Close() |
|
| 135 |
- return nil, err |
|
| 136 |
- } |
|
| 137 |
- |
|
| 138 |
- // This is Docker difference with standard's crypto/tls package: returned a |
|
| 139 |
- // wrapper which holds both the TLS and raw connections. |
|
| 140 |
- return &tlsClientCon{conn, rawConn}, nil
|
|
| 141 |
-} |
|
| 142 |
- |
|
| 143 | 57 |
func dial(proto, addr string, tlsConfig *tls.Config) (net.Conn, error) {
|
| 144 | 58 |
if tlsConfig != nil && proto != "unix" && proto != "npipe" {
|
| 145 |
- // Notice this isn't Go standard's tls.Dial function |
|
| 146 |
- return tlsDial(proto, addr, tlsConfig) |
|
| 59 |
+ return tls.Dial(proto, addr, tlsConfig) |
|
| 147 | 60 |
} |
| 148 | 61 |
if proto == "npipe" {
|
| 149 | 62 |
return sockets.DialPipe(addr, 32*time.Second) |
| 150 | 63 |
new file mode 100644 |
| ... | ... |
@@ -0,0 +1,103 @@ |
| 0 |
+package client |
|
| 1 |
+ |
|
| 2 |
+import ( |
|
| 3 |
+ "fmt" |
|
| 4 |
+ "io/ioutil" |
|
| 5 |
+ "net" |
|
| 6 |
+ "net/http" |
|
| 7 |
+ "net/http/httptest" |
|
| 8 |
+ "net/url" |
|
| 9 |
+ "testing" |
|
| 10 |
+ |
|
| 11 |
+ "github.com/docker/docker/api/server/httputils" |
|
| 12 |
+ "github.com/docker/docker/api/types" |
|
| 13 |
+ "github.com/gotestyourself/gotestyourself/assert" |
|
| 14 |
+ "github.com/pkg/errors" |
|
| 15 |
+ "golang.org/x/net/context" |
|
| 16 |
+) |
|
| 17 |
+ |
|
| 18 |
+func TestTLSCloseWriter(t *testing.T) {
|
|
| 19 |
+ t.Parallel() |
|
| 20 |
+ |
|
| 21 |
+ var chErr chan error |
|
| 22 |
+ ts := &httptest.Server{Config: &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
| 23 |
+ chErr = make(chan error, 1) |
|
| 24 |
+ defer close(chErr) |
|
| 25 |
+ if err := httputils.ParseForm(req); err != nil {
|
|
| 26 |
+ chErr <- errors.Wrap(err, "error parsing form") |
|
| 27 |
+ http.Error(w, err.Error(), 500) |
|
| 28 |
+ return |
|
| 29 |
+ } |
|
| 30 |
+ r, rw, err := httputils.HijackConnection(w) |
|
| 31 |
+ if err != nil {
|
|
| 32 |
+ chErr <- errors.Wrap(err, "error hijacking connection") |
|
| 33 |
+ http.Error(w, err.Error(), 500) |
|
| 34 |
+ return |
|
| 35 |
+ } |
|
| 36 |
+ defer r.Close() |
|
| 37 |
+ |
|
| 38 |
+ fmt.Fprint(rw, "HTTP/1.1 101 UPGRADED\r\nContent-Type: application/vnd.docker.raw-stream\r\nConnection: Upgrade\r\nUpgrade: tcp\r\n\n") |
|
| 39 |
+ |
|
| 40 |
+ buf := make([]byte, 5) |
|
| 41 |
+ _, err = r.Read(buf) |
|
| 42 |
+ if err != nil {
|
|
| 43 |
+ chErr <- errors.Wrap(err, "error reading from client") |
|
| 44 |
+ return |
|
| 45 |
+ } |
|
| 46 |
+ _, err = rw.Write(buf) |
|
| 47 |
+ if err != nil {
|
|
| 48 |
+ chErr <- errors.Wrap(err, "error writing to client") |
|
| 49 |
+ return |
|
| 50 |
+ } |
|
| 51 |
+ })}} |
|
| 52 |
+ |
|
| 53 |
+ var ( |
|
| 54 |
+ l net.Listener |
|
| 55 |
+ err error |
|
| 56 |
+ ) |
|
| 57 |
+ for i := 1024; i < 10000; i++ {
|
|
| 58 |
+ l, err = net.Listen("tcp4", fmt.Sprintf("127.0.0.1:%d", i))
|
|
| 59 |
+ if err == nil {
|
|
| 60 |
+ break |
|
| 61 |
+ } |
|
| 62 |
+ } |
|
| 63 |
+ assert.Assert(t, err) |
|
| 64 |
+ |
|
| 65 |
+ ts.Listener = l |
|
| 66 |
+ defer l.Close() |
|
| 67 |
+ |
|
| 68 |
+ defer func() {
|
|
| 69 |
+ if chErr != nil {
|
|
| 70 |
+ assert.Assert(t, <-chErr) |
|
| 71 |
+ } |
|
| 72 |
+ }() |
|
| 73 |
+ |
|
| 74 |
+ ts.StartTLS() |
|
| 75 |
+ defer ts.Close() |
|
| 76 |
+ |
|
| 77 |
+ serverURL, err := url.Parse(ts.URL) |
|
| 78 |
+ assert.Assert(t, err) |
|
| 79 |
+ |
|
| 80 |
+ client, err := NewClient("tcp://"+serverURL.Host, "", ts.Client(), nil)
|
|
| 81 |
+ assert.Assert(t, err) |
|
| 82 |
+ |
|
| 83 |
+ resp, err := client.postHijacked(context.Background(), "/asdf", url.Values{}, nil, map[string][]string{"Content-Type": {"text/plain"}})
|
|
| 84 |
+ assert.Assert(t, err) |
|
| 85 |
+ defer resp.Close() |
|
| 86 |
+ |
|
| 87 |
+ if _, ok := resp.Conn.(types.CloseWriter); !ok {
|
|
| 88 |
+ t.Fatal("tls conn did not implement the CloseWrite interface")
|
|
| 89 |
+ } |
|
| 90 |
+ |
|
| 91 |
+ _, err = resp.Conn.Write([]byte("hello"))
|
|
| 92 |
+ assert.Assert(t, err) |
|
| 93 |
+ |
|
| 94 |
+ b, err := ioutil.ReadAll(resp.Reader) |
|
| 95 |
+ assert.Assert(t, err) |
|
| 96 |
+ assert.Assert(t, string(b) == "hello") |
|
| 97 |
+ assert.Assert(t, resp.CloseWrite()) |
|
| 98 |
+ |
|
| 99 |
+ // This should error since writes are closed |
|
| 100 |
+ _, err = resp.Conn.Write([]byte("no"))
|
|
| 101 |
+ assert.Assert(t, err != nil) |
|
| 102 |
+} |
| 0 | 103 |
deleted file mode 100644 |
| ... | ... |
@@ -1,11 +0,0 @@ |
| 1 |
-// +build go1.8 |
|
| 2 |
- |
|
| 3 |
-package client // import "github.com/docker/docker/client" |
|
| 4 |
- |
|
| 5 |
-import "crypto/tls" |
|
| 6 |
- |
|
| 7 |
-// tlsConfigClone returns a clone of tls.Config. This function is provided for |
|
| 8 |
-// compatibility for go1.7 that doesn't include this method in stdlib. |
|
| 9 |
-func tlsConfigClone(c *tls.Config) *tls.Config {
|
|
| 10 |
- return c.Clone() |
|
| 11 |
-} |
| 12 | 1 |
deleted file mode 100644 |
| ... | ... |
@@ -1,33 +0,0 @@ |
| 1 |
-// +build go1.7,!go1.8 |
|
| 2 |
- |
|
| 3 |
-package client // import "github.com/docker/docker/client" |
|
| 4 |
- |
|
| 5 |
-import "crypto/tls" |
|
| 6 |
- |
|
| 7 |
-// tlsConfigClone returns a clone of tls.Config. This function is provided for |
|
| 8 |
-// compatibility for go1.7 that doesn't include this method in stdlib. |
|
| 9 |
-func tlsConfigClone(c *tls.Config) *tls.Config {
|
|
| 10 |
- return &tls.Config{
|
|
| 11 |
- Rand: c.Rand, |
|
| 12 |
- Time: c.Time, |
|
| 13 |
- Certificates: c.Certificates, |
|
| 14 |
- NameToCertificate: c.NameToCertificate, |
|
| 15 |
- GetCertificate: c.GetCertificate, |
|
| 16 |
- RootCAs: c.RootCAs, |
|
| 17 |
- NextProtos: c.NextProtos, |
|
| 18 |
- ServerName: c.ServerName, |
|
| 19 |
- ClientAuth: c.ClientAuth, |
|
| 20 |
- ClientCAs: c.ClientCAs, |
|
| 21 |
- InsecureSkipVerify: c.InsecureSkipVerify, |
|
| 22 |
- CipherSuites: c.CipherSuites, |
|
| 23 |
- PreferServerCipherSuites: c.PreferServerCipherSuites, |
|
| 24 |
- SessionTicketsDisabled: c.SessionTicketsDisabled, |
|
| 25 |
- SessionTicketKey: c.SessionTicketKey, |
|
| 26 |
- ClientSessionCache: c.ClientSessionCache, |
|
| 27 |
- MinVersion: c.MinVersion, |
|
| 28 |
- MaxVersion: c.MaxVersion, |
|
| 29 |
- CurvePreferences: c.CurvePreferences, |
|
| 30 |
- DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, |
|
| 31 |
- Renegotiation: c.Renegotiation, |
|
| 32 |
- } |
|
| 33 |
-} |