Browse code

client: add WithDialContext() and client.Dialer()

WithDialContext() allows specifying custom dialer for hijacking and supposed to
replace WithDialer().
WithDialer() is also updated to use WithDialContext().

client.Dialer() returns the dialer configured with WithDialContext().

Signed-off-by: Akihiro Suda <suda.akihiro@lab.ntt.co.jp>

Akihiro Suda authored on 2018/03/19 17:33:06
Showing 4 changed files
... ...
@@ -173,10 +173,17 @@ func WithTLSClientConfig(cacertPath, certPath, keyPath string) func(*Client) err
173 173
 
174 174
 // WithDialer applies the dialer.DialContext to the client transport. This can be
175 175
 // used to set the Timeout and KeepAlive settings of the client.
176
+// Deprecated: use WithDialContext
176 177
 func WithDialer(dialer *net.Dialer) func(*Client) error {
178
+	return WithDialContext(dialer.DialContext)
179
+}
180
+
181
+// WithDialContext applies the dialer to the client transport. This can be
182
+// used to set the Timeout and KeepAlive settings of the client.
183
+func WithDialContext(dialContext func(ctx context.Context, network, addr string) (net.Conn, error)) func(*Client) error {
177 184
 	return func(c *Client) error {
178 185
 		if transport, ok := c.client.Transport.(*http.Transport); ok {
179
-			transport.DialContext = dialer.DialContext
186
+			transport.DialContext = dialContext
180 187
 			return nil
181 188
 		}
182 189
 		return errors.Errorf("cannot apply dialer to transport: %T", c.client.Transport)
... ...
@@ -400,3 +407,16 @@ func (cli *Client) CustomHTTPHeaders() map[string]string {
400 400
 func (cli *Client) SetCustomHTTPHeaders(headers map[string]string) {
401 401
 	cli.customHTTPHeaders = headers
402 402
 }
403
+
404
+// Dialer returns a dialer for a raw stream connection, with HTTP/1.1 header, that can be used for proxying the daemon connection.
405
+// Used by `docker dial-stdio` (docker/cli#889).
406
+func (cli *Client) Dialer() func(context.Context) (net.Conn, error) {
407
+	return func(ctx context.Context) (net.Conn, error) {
408
+		if transport, ok := cli.client.Transport.(*http.Transport); ok {
409
+			if transport.DialContext != nil {
410
+				return transport.DialContext(ctx, cli.proto, cli.addr)
411
+			}
412
+		}
413
+		return fallbackDial(cli.proto, cli.addr, resolveTLSConfig(cli.client.Transport))
414
+	}
415
+}
... ...
@@ -30,7 +30,7 @@ func (cli *Client) postHijacked(ctx context.Context, path string, query url.Valu
30 30
 	}
31 31
 	req = cli.addHeaders(req, headers)
32 32
 
33
-	conn, err := cli.setupHijackConn(req, "tcp")
33
+	conn, err := cli.setupHijackConn(ctx, req, "tcp")
34 34
 	if err != nil {
35 35
 		return types.HijackedResponse{}, err
36 36
 	}
... ...
@@ -38,7 +38,9 @@ func (cli *Client) postHijacked(ctx context.Context, path string, query url.Valu
38 38
 	return types.HijackedResponse{Conn: conn, Reader: bufio.NewReader(conn)}, err
39 39
 }
40 40
 
41
-func dial(proto, addr string, tlsConfig *tls.Config) (net.Conn, error) {
41
+// fallbackDial is used when WithDialer() was not called.
42
+// See cli.Dialer().
43
+func fallbackDial(proto, addr string, tlsConfig *tls.Config) (net.Conn, error) {
42 44
 	if tlsConfig != nil && proto != "unix" && proto != "npipe" {
43 45
 		return tls.Dial(proto, addr, tlsConfig)
44 46
 	}
... ...
@@ -48,12 +50,13 @@ func dial(proto, addr string, tlsConfig *tls.Config) (net.Conn, error) {
48 48
 	return net.Dial(proto, addr)
49 49
 }
50 50
 
51
-func (cli *Client) setupHijackConn(req *http.Request, proto string) (net.Conn, error) {
51
+func (cli *Client) setupHijackConn(ctx context.Context, req *http.Request, proto string) (net.Conn, error) {
52 52
 	req.Host = cli.addr
53 53
 	req.Header.Set("Connection", "Upgrade")
54 54
 	req.Header.Set("Upgrade", proto)
55 55
 
56
-	conn, err := dial(cli.proto, cli.addr, resolveTLSConfig(cli.client.Transport))
56
+	dialer := cli.Dialer()
57
+	conn, err := dialer(ctx)
57 58
 	if err != nil {
58 59
 		return nil, errors.Wrap(err, "cannot connect to the Docker daemon. Is 'docker daemon' running on this host?")
59 60
 	}
... ...
@@ -39,6 +39,7 @@ type CommonAPIClient interface {
39 39
 	NegotiateAPIVersion(ctx context.Context)
40 40
 	NegotiateAPIVersionPing(types.Ping)
41 41
 	DialSession(ctx context.Context, proto string, meta map[string][]string) (net.Conn, error)
42
+	Dialer() func(context.Context) (net.Conn, error)
42 43
 	Close() error
43 44
 }
44 45
 
... ...
@@ -14,5 +14,5 @@ func (cli *Client) DialSession(ctx context.Context, proto string, meta map[strin
14 14
 	}
15 15
 	req = cli.addHeaders(req, meta)
16 16
 
17
-	return cli.setupHijackConn(req, proto)
17
+	return cli.setupHijackConn(ctx, req, proto)
18 18
 }