Browse code

Add transport package to support CancelRequest

Signed-off-by: Tibor Vass <tibor@docker.com>

Tibor Vass authored on 2015/05/16 10:35:04
Showing 13 changed files
... ...
@@ -17,6 +17,7 @@ import (
17 17
 	"github.com/docker/docker/pkg/progressreader"
18 18
 	"github.com/docker/docker/pkg/streamformatter"
19 19
 	"github.com/docker/docker/pkg/stringid"
20
+	"github.com/docker/docker/pkg/transport"
20 21
 	"github.com/docker/docker/registry"
21 22
 	"github.com/docker/docker/utils"
22 23
 )
... ...
@@ -55,16 +56,17 @@ func (s *TagStore) Pull(image string, tag string, imagePullConfig *ImagePullConf
55 55
 	defer s.poolRemove("pull", utils.ImageReference(repoInfo.LocalName, tag))
56 56
 
57 57
 	logrus.Debugf("pulling image from host %q with remote name %q", repoInfo.Index.Name, repoInfo.RemoteName)
58
-	endpoint, err := repoInfo.GetEndpoint()
58
+
59
+	endpoint, err := repoInfo.GetEndpoint(imagePullConfig.MetaHeaders)
59 60
 	if err != nil {
60 61
 		return err
61 62
 	}
62
-
63
+	// TODO(tiborvass): reuse client from endpoint?
63 64
 	// Adds Docker-specific headers as well as user-specified headers (metaHeaders)
64
-	tr := &registry.DockerHeaders{
65
+	tr := transport.NewTransport(
65 66
 		registry.NewTransport(registry.ReceiveTimeout, endpoint.IsSecure),
66
-		imagePullConfig.MetaHeaders,
67
-	}
67
+		registry.DockerHeaders(imagePullConfig.MetaHeaders)...,
68
+	)
68 69
 	client := registry.HTTPClient(tr)
69 70
 	r, err := registry.NewSession(client, imagePullConfig.AuthConfig, endpoint)
70 71
 	if err != nil {
... ...
@@ -18,6 +18,7 @@ import (
18 18
 	"github.com/docker/docker/pkg/progressreader"
19 19
 	"github.com/docker/docker/pkg/streamformatter"
20 20
 	"github.com/docker/docker/pkg/stringid"
21
+	"github.com/docker/docker/pkg/transport"
21 22
 	"github.com/docker/docker/registry"
22 23
 	"github.com/docker/docker/runconfig"
23 24
 	"github.com/docker/docker/utils"
... ...
@@ -509,16 +510,17 @@ func (s *TagStore) Push(localName string, imagePushConfig *ImagePushConfig) erro
509 509
 	}
510 510
 	defer s.poolRemove("push", repoInfo.LocalName)
511 511
 
512
-	endpoint, err := repoInfo.GetEndpoint()
512
+	endpoint, err := repoInfo.GetEndpoint(imagePushConfig.MetaHeaders)
513 513
 	if err != nil {
514 514
 		return err
515 515
 	}
516
-
516
+	// TODO(tiborvass): reuse client from endpoint?
517 517
 	// Adds Docker-specific headers as well as user-specified headers (metaHeaders)
518
-	tr := &registry.DockerHeaders{
518
+	tr := transport.NewTransport(
519 519
 		registry.NewTransport(registry.NoTimeout, endpoint.IsSecure),
520
-		imagePushConfig.MetaHeaders,
521
-	}
520
+		registry.DockerHeaders(imagePushConfig.MetaHeaders)...,
521
+	)
522
+	client := registry.HTTPClient(tr)
522 523
 	r, err := registry.NewSession(client, imagePushConfig.AuthConfig, endpoint)
523 524
 	if err != nil {
524 525
 		return err
525 526
new file mode 100644
... ...
@@ -0,0 +1,27 @@
0
+Copyright (c) 2009 The oauth2 Authors. All rights reserved.
1
+
2
+Redistribution and use in source and binary forms, with or without
3
+modification, are permitted provided that the following conditions are
4
+met:
5
+
6
+   * Redistributions of source code must retain the above copyright
7
+notice, this list of conditions and the following disclaimer.
8
+   * Redistributions in binary form must reproduce the above
9
+copyright notice, this list of conditions and the following disclaimer
10
+in the documentation and/or other materials provided with the
11
+distribution.
12
+   * Neither the name of Google Inc. nor the names of its
13
+contributors may be used to endorse or promote products derived from
14
+this software without specific prior written permission.
15
+
16
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
20
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
21
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
22
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
23
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
24
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
0 27
new file mode 100644
... ...
@@ -0,0 +1,148 @@
0
+package transport
1
+
2
+import (
3
+	"io"
4
+	"net/http"
5
+	"sync"
6
+)
7
+
8
+type RequestModifier interface {
9
+	ModifyRequest(*http.Request) error
10
+}
11
+
12
+type headerModifier http.Header
13
+
14
+// NewHeaderRequestModifier returns a RequestModifier that merges the HTTP headers
15
+// passed as an argument, with the HTTP headers of a request.
16
+//
17
+// If the same key is present in both, the modifying header values for that key,
18
+// are appended to the values for that same key in the request header.
19
+func NewHeaderRequestModifier(header http.Header) RequestModifier {
20
+	return headerModifier(header)
21
+}
22
+
23
+func (h headerModifier) ModifyRequest(req *http.Request) error {
24
+	for k, s := range http.Header(h) {
25
+		req.Header[k] = append(req.Header[k], s...)
26
+	}
27
+
28
+	return nil
29
+}
30
+
31
+// NewTransport returns an http.RoundTripper that modifies requests according to
32
+// the RequestModifiers passed in the arguments, before sending the requests to
33
+// the base http.RoundTripper (which, if nil, defaults to http.DefaultTransport).
34
+func NewTransport(base http.RoundTripper, modifiers ...RequestModifier) http.RoundTripper {
35
+	return &transport{
36
+		Modifiers: modifiers,
37
+		Base:      base,
38
+	}
39
+}
40
+
41
+// transport is an http.RoundTripper that makes HTTP requests after
42
+// copying and modifying the request
43
+type transport struct {
44
+	Modifiers []RequestModifier
45
+	Base      http.RoundTripper
46
+
47
+	mu     sync.Mutex                      // guards modReq
48
+	modReq map[*http.Request]*http.Request // original -> modified
49
+}
50
+
51
+func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
52
+	req2 := CloneRequest(req)
53
+	for _, modifier := range t.Modifiers {
54
+		if err := modifier.ModifyRequest(req2); err != nil {
55
+			return nil, err
56
+		}
57
+	}
58
+
59
+	t.setModReq(req, req2)
60
+	res, err := t.base().RoundTrip(req2)
61
+	if err != nil {
62
+		t.setModReq(req, nil)
63
+		return nil, err
64
+	}
65
+	res.Body = &OnEOFReader{
66
+		Rc: res.Body,
67
+		Fn: func() { t.setModReq(req, nil) },
68
+	}
69
+	return res, nil
70
+}
71
+
72
+// CancelRequest cancels an in-flight request by closing its connection.
73
+func (t *transport) CancelRequest(req *http.Request) {
74
+	type canceler interface {
75
+		CancelRequest(*http.Request)
76
+	}
77
+	if cr, ok := t.base().(canceler); ok {
78
+		t.mu.Lock()
79
+		modReq := t.modReq[req]
80
+		delete(t.modReq, req)
81
+		t.mu.Unlock()
82
+		cr.CancelRequest(modReq)
83
+	}
84
+}
85
+
86
+func (t *transport) base() http.RoundTripper {
87
+	if t.Base != nil {
88
+		return t.Base
89
+	}
90
+	return http.DefaultTransport
91
+}
92
+
93
+func (t *transport) setModReq(orig, mod *http.Request) {
94
+	t.mu.Lock()
95
+	defer t.mu.Unlock()
96
+	if t.modReq == nil {
97
+		t.modReq = make(map[*http.Request]*http.Request)
98
+	}
99
+	if mod == nil {
100
+		delete(t.modReq, orig)
101
+	} else {
102
+		t.modReq[orig] = mod
103
+	}
104
+}
105
+
106
+// CloneRequest returns a clone of the provided *http.Request.
107
+// The clone is a shallow copy of the struct and its Header map.
108
+func CloneRequest(r *http.Request) *http.Request {
109
+	// shallow copy of the struct
110
+	r2 := new(http.Request)
111
+	*r2 = *r
112
+	// deep copy of the Header
113
+	r2.Header = make(http.Header, len(r.Header))
114
+	for k, s := range r.Header {
115
+		r2.Header[k] = append([]string(nil), s...)
116
+	}
117
+
118
+	return r2
119
+}
120
+
121
+// OnEOFReader ensures a callback function is called
122
+// on Close() and when the underlying Reader returns an io.EOF error
123
+type OnEOFReader struct {
124
+	Rc io.ReadCloser
125
+	Fn func()
126
+}
127
+
128
+func (r *OnEOFReader) Read(p []byte) (n int, err error) {
129
+	n, err = r.Rc.Read(p)
130
+	if err == io.EOF {
131
+		r.runFunc()
132
+	}
133
+	return
134
+}
135
+
136
+func (r *OnEOFReader) Close() error {
137
+	err := r.Rc.Close()
138
+	r.runFunc()
139
+	return err
140
+}
141
+
142
+func (r *OnEOFReader) runFunc() {
143
+	if fn := r.Fn; fn != nil {
144
+		fn()
145
+		r.Fn = nil
146
+	}
147
+}
... ...
@@ -44,8 +44,6 @@ func (auth *RequestAuthorization) getToken() (string, error) {
44 44
 		return auth.tokenCache, nil
45 45
 	}
46 46
 
47
-	client := auth.registryEndpoint.HTTPClient()
48
-
49 47
 	for _, challenge := range auth.registryEndpoint.AuthChallenges {
50 48
 		switch strings.ToLower(challenge.Scheme) {
51 49
 		case "basic":
... ...
@@ -57,7 +55,7 @@ func (auth *RequestAuthorization) getToken() (string, error) {
57 57
 				params[k] = v
58 58
 			}
59 59
 			params["scope"] = fmt.Sprintf("%s:%s:%s", auth.resource, auth.scope, strings.Join(auth.actions, ","))
60
-			token, err := getToken(auth.authConfig.Username, auth.authConfig.Password, params, auth.registryEndpoint, client)
60
+			token, err := getToken(auth.authConfig.Username, auth.authConfig.Password, params, auth.registryEndpoint)
61 61
 			if err != nil {
62 62
 				return "", err
63 63
 			}
... ...
@@ -104,7 +102,6 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
104 104
 		status        string
105 105
 		reqBody       []byte
106 106
 		err           error
107
-		client        = registryEndpoint.HTTPClient()
108 107
 		reqStatusCode = 0
109 108
 		serverAddress = authConfig.ServerAddress
110 109
 	)
... ...
@@ -128,7 +125,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
128 128
 
129 129
 	// using `bytes.NewReader(jsonBody)` here causes the server to respond with a 411 status.
130 130
 	b := strings.NewReader(string(jsonBody))
131
-	req1, err := client.Post(serverAddress+"users/", "application/json; charset=utf-8", b)
131
+	req1, err := registryEndpoint.client.Post(serverAddress+"users/", "application/json; charset=utf-8", b)
132 132
 	if err != nil {
133 133
 		return "", fmt.Errorf("Server Error: %s", err)
134 134
 	}
... ...
@@ -151,7 +148,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
151 151
 		if string(reqBody) == "\"Username or email already exists\"" {
152 152
 			req, err := http.NewRequest("GET", serverAddress+"users/", nil)
153 153
 			req.SetBasicAuth(authConfig.Username, authConfig.Password)
154
-			resp, err := client.Do(req)
154
+			resp, err := registryEndpoint.client.Do(req)
155 155
 			if err != nil {
156 156
 				return "", err
157 157
 			}
... ...
@@ -180,7 +177,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
180 180
 		// protected, so people can use `docker login` as an auth check.
181 181
 		req, err := http.NewRequest("GET", serverAddress+"users/", nil)
182 182
 		req.SetBasicAuth(authConfig.Username, authConfig.Password)
183
-		resp, err := client.Do(req)
183
+		resp, err := registryEndpoint.client.Do(req)
184 184
 		if err != nil {
185 185
 			return "", err
186 186
 		}
... ...
@@ -217,7 +214,6 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
217 217
 	var (
218 218
 		err       error
219 219
 		allErrors []error
220
-		client    = registryEndpoint.HTTPClient()
221 220
 	)
222 221
 
223 222
 	for _, challenge := range registryEndpoint.AuthChallenges {
... ...
@@ -225,9 +221,9 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
225 225
 
226 226
 		switch strings.ToLower(challenge.Scheme) {
227 227
 		case "basic":
228
-			err = tryV2BasicAuthLogin(authConfig, challenge.Parameters, registryEndpoint, client)
228
+			err = tryV2BasicAuthLogin(authConfig, challenge.Parameters, registryEndpoint)
229 229
 		case "bearer":
230
-			err = tryV2TokenAuthLogin(authConfig, challenge.Parameters, registryEndpoint, client)
230
+			err = tryV2TokenAuthLogin(authConfig, challenge.Parameters, registryEndpoint)
231 231
 		default:
232 232
 			// Unsupported challenge types are explicitly skipped.
233 233
 			err = fmt.Errorf("unsupported auth scheme: %q", challenge.Scheme)
... ...
@@ -245,7 +241,7 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri
245 245
 	return "", fmt.Errorf("no successful auth challenge for %s - errors: %s", registryEndpoint, allErrors)
246 246
 }
247 247
 
248
-func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint, client *http.Client) error {
248
+func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint) error {
249 249
 	req, err := http.NewRequest("GET", registryEndpoint.Path(""), nil)
250 250
 	if err != nil {
251 251
 		return err
... ...
@@ -253,7 +249,7 @@ func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str
253 253
 
254 254
 	req.SetBasicAuth(authConfig.Username, authConfig.Password)
255 255
 
256
-	resp, err := client.Do(req)
256
+	resp, err := registryEndpoint.client.Do(req)
257 257
 	if err != nil {
258 258
 		return err
259 259
 	}
... ...
@@ -266,8 +262,8 @@ func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str
266 266
 	return nil
267 267
 }
268 268
 
269
-func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint, client *http.Client) error {
270
-	token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint, client)
269
+func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint) error {
270
+	token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint)
271 271
 	if err != nil {
272 272
 		return err
273 273
 	}
... ...
@@ -279,7 +275,7 @@ func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str
279 279
 
280 280
 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
281 281
 
282
-	resp, err := client.Do(req)
282
+	resp, err := registryEndpoint.client.Do(req)
283 283
 	if err != nil {
284 284
 		return err
285 285
 	}
... ...
@@ -11,6 +11,7 @@ import (
11 11
 
12 12
 	"github.com/Sirupsen/logrus"
13 13
 	"github.com/docker/distribution/registry/api/v2"
14
+	"github.com/docker/docker/pkg/transport"
14 15
 )
15 16
 
16 17
 // for mocking in unit tests
... ...
@@ -41,9 +42,9 @@ func scanForAPIVersion(address string) (string, APIVersion) {
41 41
 }
42 42
 
43 43
 // NewEndpoint parses the given address to return a registry endpoint.
44
-func NewEndpoint(index *IndexInfo) (*Endpoint, error) {
44
+func NewEndpoint(index *IndexInfo, metaHeaders http.Header) (*Endpoint, error) {
45 45
 	// *TODO: Allow per-registry configuration of endpoints.
46
-	endpoint, err := newEndpoint(index.GetAuthConfigKey(), index.Secure)
46
+	endpoint, err := newEndpoint(index.GetAuthConfigKey(), index.Secure, metaHeaders)
47 47
 	if err != nil {
48 48
 		return nil, err
49 49
 	}
... ...
@@ -81,7 +82,7 @@ func validateEndpoint(endpoint *Endpoint) error {
81 81
 	return nil
82 82
 }
83 83
 
84
-func newEndpoint(address string, secure bool) (*Endpoint, error) {
84
+func newEndpoint(address string, secure bool, metaHeaders http.Header) (*Endpoint, error) {
85 85
 	var (
86 86
 		endpoint       = new(Endpoint)
87 87
 		trimmedAddress string
... ...
@@ -98,11 +99,13 @@ func newEndpoint(address string, secure bool) (*Endpoint, error) {
98 98
 		return nil, err
99 99
 	}
100 100
 	endpoint.IsSecure = secure
101
+	tr := NewTransport(ConnectTimeout, endpoint.IsSecure)
102
+	endpoint.client = HTTPClient(transport.NewTransport(tr, DockerHeaders(metaHeaders)...))
101 103
 	return endpoint, nil
102 104
 }
103 105
 
104
-func (repoInfo *RepositoryInfo) GetEndpoint() (*Endpoint, error) {
105
-	return NewEndpoint(repoInfo.Index)
106
+func (repoInfo *RepositoryInfo) GetEndpoint(metaHeaders http.Header) (*Endpoint, error) {
107
+	return NewEndpoint(repoInfo.Index, metaHeaders)
106 108
 }
107 109
 
108 110
 // Endpoint stores basic information about a registry endpoint.
... ...
@@ -174,7 +177,7 @@ func (e *Endpoint) pingV1() (RegistryInfo, error) {
174 174
 		return RegistryInfo{Standalone: false}, err
175 175
 	}
176 176
 
177
-	resp, err := e.HTTPClient().Do(req)
177
+	resp, err := e.client.Do(req)
178 178
 	if err != nil {
179 179
 		return RegistryInfo{Standalone: false}, err
180 180
 	}
... ...
@@ -222,7 +225,7 @@ func (e *Endpoint) pingV2() (RegistryInfo, error) {
222 222
 		return RegistryInfo{}, err
223 223
 	}
224 224
 
225
-	resp, err := e.HTTPClient().Do(req)
225
+	resp, err := e.client.Do(req)
226 226
 	if err != nil {
227 227
 		return RegistryInfo{}, err
228 228
 	}
... ...
@@ -261,11 +264,3 @@ HeaderLoop:
261 261
 
262 262
 	return RegistryInfo{}, fmt.Errorf("v2 registry endpoint returned status %d: %q", resp.StatusCode, http.StatusText(resp.StatusCode))
263 263
 }
264
-
265
-func (e *Endpoint) HTTPClient() *http.Client {
266
-	if e.client == nil {
267
-		tr := NewTransport(ConnectTimeout, e.IsSecure)
268
-		e.client = HTTPClient(tr)
269
-	}
270
-	return e.client
271
-}
... ...
@@ -19,7 +19,7 @@ func TestEndpointParse(t *testing.T) {
19 19
 		{"0.0.0.0:5000", "https://0.0.0.0:5000/v0/"},
20 20
 	}
21 21
 	for _, td := range testData {
22
-		e, err := newEndpoint(td.str, false)
22
+		e, err := newEndpoint(td.str, false, nil)
23 23
 		if err != nil {
24 24
 			t.Errorf("%q: %s", td.str, err)
25 25
 		}
... ...
@@ -60,6 +60,7 @@ func TestValidateEndpointAmbiguousAPIVersion(t *testing.T) {
60 60
 	testEndpoint := Endpoint{
61 61
 		URL:     testServerURL,
62 62
 		Version: APIVersionUnknown,
63
+		client:  HTTPClient(NewTransport(ConnectTimeout, false)),
63 64
 	}
64 65
 
65 66
 	if err = validateEndpoint(&testEndpoint); err != nil {
... ...
@@ -19,6 +19,7 @@ import (
19 19
 	"github.com/docker/docker/autogen/dockerversion"
20 20
 	"github.com/docker/docker/pkg/parsers/kernel"
21 21
 	"github.com/docker/docker/pkg/timeoutconn"
22
+	"github.com/docker/docker/pkg/transport"
22 23
 	"github.com/docker/docker/pkg/useragent"
23 24
 )
24 25
 
... ...
@@ -36,17 +37,32 @@ const (
36 36
 	ConnectTimeout
37 37
 )
38 38
 
39
-type httpsTransport struct {
40
-	*http.Transport
39
+// dockerUserAgent is the User-Agent the Docker client uses to identify itself.
40
+// It is populated on init(), comprising version information of different components.
41
+var dockerUserAgent string
42
+
43
+func init() {
44
+	httpVersion := make([]useragent.VersionInfo, 0, 6)
45
+	httpVersion = append(httpVersion, useragent.VersionInfo{"docker", dockerversion.VERSION})
46
+	httpVersion = append(httpVersion, useragent.VersionInfo{"go", runtime.Version()})
47
+	httpVersion = append(httpVersion, useragent.VersionInfo{"git-commit", dockerversion.GITCOMMIT})
48
+	if kernelVersion, err := kernel.GetKernelVersion(); err == nil {
49
+		httpVersion = append(httpVersion, useragent.VersionInfo{"kernel", kernelVersion.String()})
50
+	}
51
+	httpVersion = append(httpVersion, useragent.VersionInfo{"os", runtime.GOOS})
52
+	httpVersion = append(httpVersion, useragent.VersionInfo{"arch", runtime.GOARCH})
53
+
54
+	dockerUserAgent = useragent.AppendVersions("", httpVersion...)
41 55
 }
42 56
 
57
+type httpsRequestModifier struct{ tlsConfig *tls.Config }
58
+
43 59
 // DRAGONS(tiborvass): If someone wonders why do we set tlsconfig in a roundtrip,
44 60
 // it's because it's so as to match the current behavior in master: we generate the
45 61
 // certpool on every-goddam-request. It's not great, but it allows people to just put
46 62
 // the certs in /etc/docker/certs.d/.../ and let docker "pick it up" immediately. Would
47 63
 // prefer an fsnotify implementation, but that was out of scope of my refactoring.
48
-// TODO: improve things
49
-func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
64
+func (m *httpsRequestModifier) ModifyRequest(req *http.Request) error {
50 65
 	var (
51 66
 		roots *x509.CertPool
52 67
 		certs []tls.Certificate
... ...
@@ -66,7 +82,7 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
66 66
 		logrus.Debugf("hostDir: %s", hostDir)
67 67
 		fs, err := ioutil.ReadDir(hostDir)
68 68
 		if err != nil && !os.IsNotExist(err) {
69
-			return nil, err
69
+			return nil
70 70
 		}
71 71
 
72 72
 		for _, f := range fs {
... ...
@@ -77,7 +93,7 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
77 77
 				logrus.Debugf("crt: %s", hostDir+"/"+f.Name())
78 78
 				data, err := ioutil.ReadFile(path.Join(hostDir, f.Name()))
79 79
 				if err != nil {
80
-					return nil, err
80
+					return err
81 81
 				}
82 82
 				roots.AppendCertsFromPEM(data)
83 83
 			}
... ...
@@ -86,11 +102,11 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
86 86
 				keyName := certName[:len(certName)-5] + ".key"
87 87
 				logrus.Debugf("cert: %s", hostDir+"/"+f.Name())
88 88
 				if !hasFile(fs, keyName) {
89
-					return nil, fmt.Errorf("Missing key %s for certificate %s", keyName, certName)
89
+					return fmt.Errorf("Missing key %s for certificate %s", keyName, certName)
90 90
 				}
91 91
 				cert, err := tls.LoadX509KeyPair(path.Join(hostDir, certName), path.Join(hostDir, keyName))
92 92
 				if err != nil {
93
-					return nil, err
93
+					return err
94 94
 				}
95 95
 				certs = append(certs, cert)
96 96
 			}
... ...
@@ -99,38 +115,32 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
99 99
 				certName := keyName[:len(keyName)-4] + ".cert"
100 100
 				logrus.Debugf("key: %s", hostDir+"/"+f.Name())
101 101
 				if !hasFile(fs, certName) {
102
-					return nil, fmt.Errorf("Missing certificate %s for key %s", certName, keyName)
102
+					return fmt.Errorf("Missing certificate %s for key %s", certName, keyName)
103 103
 				}
104 104
 			}
105 105
 		}
106
-		if tr.Transport.TLSClientConfig == nil {
107
-			tr.Transport.TLSClientConfig = &tls.Config{
108
-				// Avoid fallback to SSL protocols < TLS1.0
109
-				MinVersion: tls.VersionTLS10,
110
-			}
111
-		}
112
-		tr.Transport.TLSClientConfig.RootCAs = roots
113
-		tr.Transport.TLSClientConfig.Certificates = certs
106
+		m.tlsConfig.RootCAs = roots
107
+		m.tlsConfig.Certificates = certs
114 108
 	}
115
-	return tr.Transport.RoundTrip(req)
109
+	return nil
116 110
 }
117 111
 
118 112
 func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper {
119
-	tlsConfig := tls.Config{
113
+	tlsConfig := &tls.Config{
120 114
 		// Avoid fallback to SSL protocols < TLS1.0
121 115
 		MinVersion:         tls.VersionTLS10,
122 116
 		InsecureSkipVerify: !secure,
123 117
 	}
124 118
 
125
-	transport := &http.Transport{
119
+	tr := &http.Transport{
126 120
 		DisableKeepAlives: true,
127 121
 		Proxy:             http.ProxyFromEnvironment,
128
-		TLSClientConfig:   &tlsConfig,
122
+		TLSClientConfig:   tlsConfig,
129 123
 	}
130 124
 
131 125
 	switch timeout {
132 126
 	case ConnectTimeout:
133
-		transport.Dial = func(proto string, addr string) (net.Conn, error) {
127
+		tr.Dial = func(proto string, addr string) (net.Conn, error) {
134 128
 			// Set the connect timeout to 30 seconds to allow for slower connection
135 129
 			// times...
136 130
 			d := net.Dialer{Timeout: 30 * time.Second, DualStack: true}
... ...
@@ -144,7 +154,7 @@ func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper {
144 144
 			return conn, nil
145 145
 		}
146 146
 	case ReceiveTimeout:
147
-		transport.Dial = func(proto string, addr string) (net.Conn, error) {
147
+		tr.Dial = func(proto string, addr string) (net.Conn, error) {
148 148
 			d := net.Dialer{DualStack: true}
149 149
 
150 150
 			conn, err := d.Dial(proto, addr)
... ...
@@ -159,51 +169,23 @@ func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper {
159 159
 	if secure {
160 160
 		// note: httpsTransport also handles http transport
161 161
 		// but for HTTPS, it sets up the certs
162
-		return &httpsTransport{transport}
162
+		return transport.NewTransport(tr, &httpsRequestModifier{tlsConfig})
163 163
 	}
164 164
 
165
-	return transport
165
+	return tr
166 166
 }
167 167
 
168
-type DockerHeaders struct {
169
-	http.RoundTripper
170
-	Headers http.Header
171
-}
172
-
173
-// cloneRequest returns a clone of the provided *http.Request.
174
-// The clone is a shallow copy of the struct and its Header map
175
-func cloneRequest(r *http.Request) *http.Request {
176
-	// shallow copy of the struct
177
-	r2 := new(http.Request)
178
-	*r2 = *r
179
-	// deep copy of the Header
180
-	r2.Header = make(http.Header, len(r.Header))
181
-	for k, s := range r.Header {
182
-		r2.Header[k] = append([]string(nil), s...)
168
+// DockerHeaders returns request modifiers that ensure requests have
169
+// the User-Agent header set to dockerUserAgent and that metaHeaders
170
+// are added.
171
+func DockerHeaders(metaHeaders http.Header) []transport.RequestModifier {
172
+	modifiers := []transport.RequestModifier{
173
+		transport.NewHeaderRequestModifier(http.Header{"User-Agent": []string{dockerUserAgent}}),
183 174
 	}
184
-	return r2
185
-}
186
-
187
-func (tr *DockerHeaders) RoundTrip(req *http.Request) (*http.Response, error) {
188
-	req = cloneRequest(req)
189
-	httpVersion := make([]useragent.VersionInfo, 0, 4)
190
-	httpVersion = append(httpVersion, useragent.VersionInfo{"docker", dockerversion.VERSION})
191
-	httpVersion = append(httpVersion, useragent.VersionInfo{"go", runtime.Version()})
192
-	httpVersion = append(httpVersion, useragent.VersionInfo{"git-commit", dockerversion.GITCOMMIT})
193
-	if kernelVersion, err := kernel.GetKernelVersion(); err == nil {
194
-		httpVersion = append(httpVersion, useragent.VersionInfo{"kernel", kernelVersion.String()})
195
-	}
196
-	httpVersion = append(httpVersion, useragent.VersionInfo{"os", runtime.GOOS})
197
-	httpVersion = append(httpVersion, useragent.VersionInfo{"arch", runtime.GOARCH})
198
-
199
-	userAgent := useragent.AppendVersions(req.UserAgent(), httpVersion...)
200
-
201
-	req.Header.Set("User-Agent", userAgent)
202
-
203
-	for k, v := range tr.Headers {
204
-		req.Header[k] = v
175
+	if metaHeaders != nil {
176
+		modifiers = append(modifiers, transport.NewHeaderRequestModifier(metaHeaders))
205 177
 	}
206
-	return tr.RoundTripper.RoundTrip(req)
178
+	return modifiers
207 179
 }
208 180
 
209 181
 type debugTransport struct{ http.RoundTripper }
... ...
@@ -8,6 +8,7 @@ import (
8 8
 	"testing"
9 9
 
10 10
 	"github.com/docker/docker/cliconfig"
11
+	"github.com/docker/docker/pkg/transport"
11 12
 )
12 13
 
13 14
 var (
... ...
@@ -21,12 +22,12 @@ const (
21 21
 
22 22
 func spawnTestRegistrySession(t *testing.T) *Session {
23 23
 	authConfig := &cliconfig.AuthConfig{}
24
-	endpoint, err := NewEndpoint(makeIndex("/v1/"))
24
+	endpoint, err := NewEndpoint(makeIndex("/v1/"), nil)
25 25
 	if err != nil {
26 26
 		t.Fatal(err)
27 27
 	}
28 28
 	var tr http.RoundTripper = debugTransport{NewTransport(ReceiveTimeout, endpoint.IsSecure)}
29
-	tr = &DockerHeaders{&authTransport{RoundTripper: tr, AuthConfig: authConfig}, nil}
29
+	tr = transport.NewTransport(AuthTransport(tr, authConfig, false), DockerHeaders(nil)...)
30 30
 	client := HTTPClient(tr)
31 31
 	r, err := NewSession(client, authConfig, endpoint)
32 32
 	if err != nil {
... ...
@@ -48,7 +49,7 @@ func spawnTestRegistrySession(t *testing.T) *Session {
48 48
 
49 49
 func TestPingRegistryEndpoint(t *testing.T) {
50 50
 	testPing := func(index *IndexInfo, expectedStandalone bool, assertMessage string) {
51
-		ep, err := NewEndpoint(index)
51
+		ep, err := NewEndpoint(index, nil)
52 52
 		if err != nil {
53 53
 			t.Fatal(err)
54 54
 		}
... ...
@@ -68,7 +69,7 @@ func TestPingRegistryEndpoint(t *testing.T) {
68 68
 func TestEndpoint(t *testing.T) {
69 69
 	// Simple wrapper to fail test if err != nil
70 70
 	expandEndpoint := func(index *IndexInfo) *Endpoint {
71
-		endpoint, err := NewEndpoint(index)
71
+		endpoint, err := NewEndpoint(index, nil)
72 72
 		if err != nil {
73 73
 			t.Fatal(err)
74 74
 		}
... ...
@@ -77,7 +78,7 @@ func TestEndpoint(t *testing.T) {
77 77
 
78 78
 	assertInsecureIndex := func(index *IndexInfo) {
79 79
 		index.Secure = true
80
-		_, err := NewEndpoint(index)
80
+		_, err := NewEndpoint(index, nil)
81 81
 		assertNotEqual(t, err, nil, index.Name+": Expected error for insecure index")
82 82
 		assertEqual(t, strings.Contains(err.Error(), "insecure-registry"), true, index.Name+": Expected insecure-registry  error for insecure index")
83 83
 		index.Secure = false
... ...
@@ -85,7 +86,7 @@ func TestEndpoint(t *testing.T) {
85 85
 
86 86
 	assertSecureIndex := func(index *IndexInfo) {
87 87
 		index.Secure = true
88
-		_, err := NewEndpoint(index)
88
+		_, err := NewEndpoint(index, nil)
89 89
 		assertNotEqual(t, err, nil, index.Name+": Expected cert error for secure index")
90 90
 		assertEqual(t, strings.Contains(err.Error(), "certificate signed by unknown authority"), true, index.Name+": Expected cert error for secure index")
91 91
 		index.Secure = false
... ...
@@ -151,7 +152,7 @@ func TestEndpoint(t *testing.T) {
151 151
 	}
152 152
 	for _, address := range badEndpoints {
153 153
 		index.Name = address
154
-		_, err := NewEndpoint(index)
154
+		_, err := NewEndpoint(index, nil)
155 155
 		checkNotEqual(t, err, nil, "Expected error while expanding bad endpoint")
156 156
 	}
157 157
 }
... ...
@@ -1,6 +1,10 @@
1 1
 package registry
2 2
 
3
-import "github.com/docker/docker/cliconfig"
3
+import (
4
+	"net/http"
5
+
6
+	"github.com/docker/docker/cliconfig"
7
+)
4 8
 
5 9
 type Service struct {
6 10
 	Config *ServiceConfig
... ...
@@ -27,7 +31,7 @@ func (s *Service) Auth(authConfig *cliconfig.AuthConfig) (string, error) {
27 27
 	if err != nil {
28 28
 		return "", err
29 29
 	}
30
-	endpoint, err := NewEndpoint(index)
30
+	endpoint, err := NewEndpoint(index, nil)
31 31
 	if err != nil {
32 32
 		return "", err
33 33
 	}
... ...
@@ -44,11 +48,11 @@ func (s *Service) Search(term string, authConfig *cliconfig.AuthConfig, headers
44 44
 	}
45 45
 
46 46
 	// *TODO: Search multiple indexes.
47
-	endpoint, err := repoInfo.GetEndpoint()
47
+	endpoint, err := repoInfo.GetEndpoint(http.Header(headers))
48 48
 	if err != nil {
49 49
 		return nil, err
50 50
 	}
51
-	r, err := NewSession(endpoint.HTTPClient(), authConfig, endpoint)
51
+	r, err := NewSession(endpoint.client, authConfig, endpoint)
52 52
 	if err != nil {
53 53
 		return nil, err
54 54
 	}
... ...
@@ -4,6 +4,7 @@ import (
4 4
 	"bytes"
5 5
 	"crypto/sha256"
6 6
 	"errors"
7
+	"sync"
7 8
 	// this is required for some certificates
8 9
 	_ "crypto/sha512"
9 10
 	"encoding/hex"
... ...
@@ -22,6 +23,7 @@ import (
22 22
 	"github.com/docker/docker/cliconfig"
23 23
 	"github.com/docker/docker/pkg/httputils"
24 24
 	"github.com/docker/docker/pkg/tarsum"
25
+	"github.com/docker/docker/pkg/transport"
25 26
 )
26 27
 
27 28
 type Session struct {
... ...
@@ -31,7 +33,18 @@ type Session struct {
31 31
 	authConfig *cliconfig.AuthConfig
32 32
 }
33 33
 
34
-// authTransport handles the auth layer when communicating with a v1 registry (private or official)
34
+type authTransport struct {
35
+	http.RoundTripper
36
+	*cliconfig.AuthConfig
37
+
38
+	alwaysSetBasicAuth bool
39
+	token              []string
40
+
41
+	mu     sync.Mutex                      // guards modReq
42
+	modReq map[*http.Request]*http.Request // original -> modified
43
+}
44
+
45
+// AuthTransport handles the auth layer when communicating with a v1 registry (private or official)
35 46
 //
36 47
 // For private v1 registries, set alwaysSetBasicAuth to true.
37 48
 //
... ...
@@ -44,16 +57,23 @@ type Session struct {
44 44
 // If the server sends a token without the client having requested it, it is ignored.
45 45
 //
46 46
 // This RoundTripper also has a CancelRequest method important for correct timeout handling.
47
-type authTransport struct {
48
-	http.RoundTripper
49
-	*cliconfig.AuthConfig
50
-
51
-	alwaysSetBasicAuth bool
52
-	token              []string
47
+func AuthTransport(base http.RoundTripper, authConfig *cliconfig.AuthConfig, alwaysSetBasicAuth bool) http.RoundTripper {
48
+	if base == nil {
49
+		base = http.DefaultTransport
50
+	}
51
+	return &authTransport{
52
+		RoundTripper:       base,
53
+		AuthConfig:         authConfig,
54
+		alwaysSetBasicAuth: alwaysSetBasicAuth,
55
+		modReq:             make(map[*http.Request]*http.Request),
56
+	}
53 57
 }
54 58
 
55
-func (tr *authTransport) RoundTrip(req *http.Request) (*http.Response, error) {
56
-	req = cloneRequest(req)
59
+func (tr *authTransport) RoundTrip(orig *http.Request) (*http.Response, error) {
60
+	req := transport.CloneRequest(orig)
61
+	tr.mu.Lock()
62
+	tr.modReq[orig] = req
63
+	tr.mu.Unlock()
57 64
 
58 65
 	if tr.alwaysSetBasicAuth {
59 66
 		req.SetBasicAuth(tr.Username, tr.Password)
... ...
@@ -73,14 +93,33 @@ func (tr *authTransport) RoundTrip(req *http.Request) (*http.Response, error) {
73 73
 	}
74 74
 	resp, err := tr.RoundTripper.RoundTrip(req)
75 75
 	if err != nil {
76
+		delete(tr.modReq, orig)
76 77
 		return nil, err
77 78
 	}
78 79
 	if askedForToken && len(resp.Header["X-Docker-Token"]) > 0 {
79 80
 		tr.token = resp.Header["X-Docker-Token"]
80 81
 	}
82
+	resp.Body = &transport.OnEOFReader{
83
+		Rc: resp.Body,
84
+		Fn: func() { delete(tr.modReq, orig) },
85
+	}
81 86
 	return resp, nil
82 87
 }
83 88
 
89
+// CancelRequest cancels an in-flight request by closing its connection.
90
+func (tr *authTransport) CancelRequest(req *http.Request) {
91
+	type canceler interface {
92
+		CancelRequest(*http.Request)
93
+	}
94
+	if cr, ok := tr.RoundTripper.(canceler); ok {
95
+		tr.mu.Lock()
96
+		modReq := tr.modReq[req]
97
+		delete(tr.modReq, req)
98
+		tr.mu.Unlock()
99
+		cr.CancelRequest(modReq)
100
+	}
101
+}
102
+
84 103
 // TODO(tiborvass): remove authConfig param once registry client v2 is vendored
85 104
 func NewSession(client *http.Client, authConfig *cliconfig.AuthConfig, endpoint *Endpoint) (r *Session, err error) {
86 105
 	r = &Session{
... ...
@@ -105,7 +144,7 @@ func NewSession(client *http.Client, authConfig *cliconfig.AuthConfig, endpoint
105 105
 		}
106 106
 	}
107 107
 
108
-	client.Transport = &authTransport{RoundTripper: client.Transport, AuthConfig: authConfig, alwaysSetBasicAuth: alwaysSetBasicAuth}
108
+	client.Transport = AuthTransport(client.Transport, authConfig, alwaysSetBasicAuth)
109 109
 
110 110
 	jar, err := cookiejar.New(nil)
111 111
 	if err != nil {
... ...
@@ -27,7 +27,7 @@ func getV2Builder(e *Endpoint) *v2.URLBuilder {
27 27
 func (r *Session) V2RegistryEndpoint(index *IndexInfo) (ep *Endpoint, err error) {
28 28
 	// TODO check if should use Mirror
29 29
 	if index.Official {
30
-		ep, err = newEndpoint(REGISTRYSERVER, true)
30
+		ep, err = newEndpoint(REGISTRYSERVER, true, nil)
31 31
 		if err != nil {
32 32
 			return
33 33
 		}
... ...
@@ -38,7 +38,7 @@ func (r *Session) V2RegistryEndpoint(index *IndexInfo) (ep *Endpoint, err error)
38 38
 	} else if r.indexEndpoint.String() == index.GetAuthConfigKey() {
39 39
 		ep = r.indexEndpoint
40 40
 	} else {
41
-		ep, err = NewEndpoint(index)
41
+		ep, err = NewEndpoint(index, nil)
42 42
 		if err != nil {
43 43
 			return
44 44
 		}
... ...
@@ -13,7 +13,7 @@ type tokenResponse struct {
13 13
 	Token string `json:"token"`
14 14
 }
15 15
 
16
-func getToken(username, password string, params map[string]string, registryEndpoint *Endpoint, client *http.Client) (token string, err error) {
16
+func getToken(username, password string, params map[string]string, registryEndpoint *Endpoint) (token string, err error) {
17 17
 	realm, ok := params["realm"]
18 18
 	if !ok {
19 19
 		return "", errors.New("no realm specified for token auth challenge")
... ...
@@ -56,7 +56,7 @@ func getToken(username, password string, params map[string]string, registryEndpo
56 56
 
57 57
 	req.URL.RawQuery = reqParams.Encode()
58 58
 
59
-	resp, err := client.Do(req)
59
+	resp, err := registryEndpoint.client.Do(req)
60 60
 	if err != nil {
61 61
 		return "", err
62 62
 	}