Signed-off-by: Tibor Vass <tibor@docker.com>
| ... | ... |
@@ -17,6 +17,7 @@ import ( |
| 17 | 17 |
"github.com/docker/docker/pkg/progressreader" |
| 18 | 18 |
"github.com/docker/docker/pkg/streamformatter" |
| 19 | 19 |
"github.com/docker/docker/pkg/stringid" |
| 20 |
+ "github.com/docker/docker/pkg/transport" |
|
| 20 | 21 |
"github.com/docker/docker/registry" |
| 21 | 22 |
"github.com/docker/docker/utils" |
| 22 | 23 |
) |
| ... | ... |
@@ -55,16 +56,17 @@ func (s *TagStore) Pull(image string, tag string, imagePullConfig *ImagePullConf |
| 55 | 55 |
defer s.poolRemove("pull", utils.ImageReference(repoInfo.LocalName, tag))
|
| 56 | 56 |
|
| 57 | 57 |
logrus.Debugf("pulling image from host %q with remote name %q", repoInfo.Index.Name, repoInfo.RemoteName)
|
| 58 |
- endpoint, err := repoInfo.GetEndpoint() |
|
| 58 |
+ |
|
| 59 |
+ endpoint, err := repoInfo.GetEndpoint(imagePullConfig.MetaHeaders) |
|
| 59 | 60 |
if err != nil {
|
| 60 | 61 |
return err |
| 61 | 62 |
} |
| 62 |
- |
|
| 63 |
+ // TODO(tiborvass): reuse client from endpoint? |
|
| 63 | 64 |
// Adds Docker-specific headers as well as user-specified headers (metaHeaders) |
| 64 |
- tr := ®istry.DockerHeaders{
|
|
| 65 |
+ tr := transport.NewTransport( |
|
| 65 | 66 |
registry.NewTransport(registry.ReceiveTimeout, endpoint.IsSecure), |
| 66 |
- imagePullConfig.MetaHeaders, |
|
| 67 |
- } |
|
| 67 |
+ registry.DockerHeaders(imagePullConfig.MetaHeaders)..., |
|
| 68 |
+ ) |
|
| 68 | 69 |
client := registry.HTTPClient(tr) |
| 69 | 70 |
r, err := registry.NewSession(client, imagePullConfig.AuthConfig, endpoint) |
| 70 | 71 |
if err != nil {
|
| ... | ... |
@@ -18,6 +18,7 @@ import ( |
| 18 | 18 |
"github.com/docker/docker/pkg/progressreader" |
| 19 | 19 |
"github.com/docker/docker/pkg/streamformatter" |
| 20 | 20 |
"github.com/docker/docker/pkg/stringid" |
| 21 |
+ "github.com/docker/docker/pkg/transport" |
|
| 21 | 22 |
"github.com/docker/docker/registry" |
| 22 | 23 |
"github.com/docker/docker/runconfig" |
| 23 | 24 |
"github.com/docker/docker/utils" |
| ... | ... |
@@ -509,16 +510,17 @@ func (s *TagStore) Push(localName string, imagePushConfig *ImagePushConfig) erro |
| 509 | 509 |
} |
| 510 | 510 |
defer s.poolRemove("push", repoInfo.LocalName)
|
| 511 | 511 |
|
| 512 |
- endpoint, err := repoInfo.GetEndpoint() |
|
| 512 |
+ endpoint, err := repoInfo.GetEndpoint(imagePushConfig.MetaHeaders) |
|
| 513 | 513 |
if err != nil {
|
| 514 | 514 |
return err |
| 515 | 515 |
} |
| 516 |
- |
|
| 516 |
+ // TODO(tiborvass): reuse client from endpoint? |
|
| 517 | 517 |
// Adds Docker-specific headers as well as user-specified headers (metaHeaders) |
| 518 |
- tr := ®istry.DockerHeaders{
|
|
| 518 |
+ tr := transport.NewTransport( |
|
| 519 | 519 |
registry.NewTransport(registry.NoTimeout, endpoint.IsSecure), |
| 520 |
- imagePushConfig.MetaHeaders, |
|
| 521 |
- } |
|
| 520 |
+ registry.DockerHeaders(imagePushConfig.MetaHeaders)..., |
|
| 521 |
+ ) |
|
| 522 |
+ client := registry.HTTPClient(tr) |
|
| 522 | 523 |
r, err := registry.NewSession(client, imagePushConfig.AuthConfig, endpoint) |
| 523 | 524 |
if err != nil {
|
| 524 | 525 |
return err |
| 525 | 526 |
new file mode 100644 |
| ... | ... |
@@ -0,0 +1,27 @@ |
| 0 |
+Copyright (c) 2009 The oauth2 Authors. All rights reserved. |
|
| 1 |
+ |
|
| 2 |
+Redistribution and use in source and binary forms, with or without |
|
| 3 |
+modification, are permitted provided that the following conditions are |
|
| 4 |
+met: |
|
| 5 |
+ |
|
| 6 |
+ * Redistributions of source code must retain the above copyright |
|
| 7 |
+notice, this list of conditions and the following disclaimer. |
|
| 8 |
+ * Redistributions in binary form must reproduce the above |
|
| 9 |
+copyright notice, this list of conditions and the following disclaimer |
|
| 10 |
+in the documentation and/or other materials provided with the |
|
| 11 |
+distribution. |
|
| 12 |
+ * Neither the name of Google Inc. nor the names of its |
|
| 13 |
+contributors may be used to endorse or promote products derived from |
|
| 14 |
+this software without specific prior written permission. |
|
| 15 |
+ |
|
| 16 |
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
|
| 17 |
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
|
| 18 |
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
|
| 19 |
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
|
| 20 |
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
|
| 21 |
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT |
|
| 22 |
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
|
| 23 |
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
|
| 24 |
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
|
| 25 |
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
|
| 26 |
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 0 | 27 |
new file mode 100644 |
| ... | ... |
@@ -0,0 +1,148 @@ |
| 0 |
+package transport |
|
| 1 |
+ |
|
| 2 |
+import ( |
|
| 3 |
+ "io" |
|
| 4 |
+ "net/http" |
|
| 5 |
+ "sync" |
|
| 6 |
+) |
|
| 7 |
+ |
|
| 8 |
+type RequestModifier interface {
|
|
| 9 |
+ ModifyRequest(*http.Request) error |
|
| 10 |
+} |
|
| 11 |
+ |
|
| 12 |
+type headerModifier http.Header |
|
| 13 |
+ |
|
| 14 |
+// NewHeaderRequestModifier returns a RequestModifier that merges the HTTP headers |
|
| 15 |
+// passed as an argument, with the HTTP headers of a request. |
|
| 16 |
+// |
|
| 17 |
+// If the same key is present in both, the modifying header values for that key, |
|
| 18 |
+// are appended to the values for that same key in the request header. |
|
| 19 |
+func NewHeaderRequestModifier(header http.Header) RequestModifier {
|
|
| 20 |
+ return headerModifier(header) |
|
| 21 |
+} |
|
| 22 |
+ |
|
| 23 |
+func (h headerModifier) ModifyRequest(req *http.Request) error {
|
|
| 24 |
+ for k, s := range http.Header(h) {
|
|
| 25 |
+ req.Header[k] = append(req.Header[k], s...) |
|
| 26 |
+ } |
|
| 27 |
+ |
|
| 28 |
+ return nil |
|
| 29 |
+} |
|
| 30 |
+ |
|
| 31 |
+// NewTransport returns an http.RoundTripper that modifies requests according to |
|
| 32 |
+// the RequestModifiers passed in the arguments, before sending the requests to |
|
| 33 |
+// the base http.RoundTripper (which, if nil, defaults to http.DefaultTransport). |
|
| 34 |
+func NewTransport(base http.RoundTripper, modifiers ...RequestModifier) http.RoundTripper {
|
|
| 35 |
+ return &transport{
|
|
| 36 |
+ Modifiers: modifiers, |
|
| 37 |
+ Base: base, |
|
| 38 |
+ } |
|
| 39 |
+} |
|
| 40 |
+ |
|
| 41 |
+// transport is an http.RoundTripper that makes HTTP requests after |
|
| 42 |
+// copying and modifying the request |
|
| 43 |
+type transport struct {
|
|
| 44 |
+ Modifiers []RequestModifier |
|
| 45 |
+ Base http.RoundTripper |
|
| 46 |
+ |
|
| 47 |
+ mu sync.Mutex // guards modReq |
|
| 48 |
+ modReq map[*http.Request]*http.Request // original -> modified |
|
| 49 |
+} |
|
| 50 |
+ |
|
| 51 |
+func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
| 52 |
+ req2 := CloneRequest(req) |
|
| 53 |
+ for _, modifier := range t.Modifiers {
|
|
| 54 |
+ if err := modifier.ModifyRequest(req2); err != nil {
|
|
| 55 |
+ return nil, err |
|
| 56 |
+ } |
|
| 57 |
+ } |
|
| 58 |
+ |
|
| 59 |
+ t.setModReq(req, req2) |
|
| 60 |
+ res, err := t.base().RoundTrip(req2) |
|
| 61 |
+ if err != nil {
|
|
| 62 |
+ t.setModReq(req, nil) |
|
| 63 |
+ return nil, err |
|
| 64 |
+ } |
|
| 65 |
+ res.Body = &OnEOFReader{
|
|
| 66 |
+ Rc: res.Body, |
|
| 67 |
+ Fn: func() { t.setModReq(req, nil) },
|
|
| 68 |
+ } |
|
| 69 |
+ return res, nil |
|
| 70 |
+} |
|
| 71 |
+ |
|
| 72 |
+// CancelRequest cancels an in-flight request by closing its connection. |
|
| 73 |
+func (t *transport) CancelRequest(req *http.Request) {
|
|
| 74 |
+ type canceler interface {
|
|
| 75 |
+ CancelRequest(*http.Request) |
|
| 76 |
+ } |
|
| 77 |
+ if cr, ok := t.base().(canceler); ok {
|
|
| 78 |
+ t.mu.Lock() |
|
| 79 |
+ modReq := t.modReq[req] |
|
| 80 |
+ delete(t.modReq, req) |
|
| 81 |
+ t.mu.Unlock() |
|
| 82 |
+ cr.CancelRequest(modReq) |
|
| 83 |
+ } |
|
| 84 |
+} |
|
| 85 |
+ |
|
| 86 |
+func (t *transport) base() http.RoundTripper {
|
|
| 87 |
+ if t.Base != nil {
|
|
| 88 |
+ return t.Base |
|
| 89 |
+ } |
|
| 90 |
+ return http.DefaultTransport |
|
| 91 |
+} |
|
| 92 |
+ |
|
| 93 |
+func (t *transport) setModReq(orig, mod *http.Request) {
|
|
| 94 |
+ t.mu.Lock() |
|
| 95 |
+ defer t.mu.Unlock() |
|
| 96 |
+ if t.modReq == nil {
|
|
| 97 |
+ t.modReq = make(map[*http.Request]*http.Request) |
|
| 98 |
+ } |
|
| 99 |
+ if mod == nil {
|
|
| 100 |
+ delete(t.modReq, orig) |
|
| 101 |
+ } else {
|
|
| 102 |
+ t.modReq[orig] = mod |
|
| 103 |
+ } |
|
| 104 |
+} |
|
| 105 |
+ |
|
| 106 |
+// CloneRequest returns a clone of the provided *http.Request. |
|
| 107 |
+// The clone is a shallow copy of the struct and its Header map. |
|
| 108 |
+func CloneRequest(r *http.Request) *http.Request {
|
|
| 109 |
+ // shallow copy of the struct |
|
| 110 |
+ r2 := new(http.Request) |
|
| 111 |
+ *r2 = *r |
|
| 112 |
+ // deep copy of the Header |
|
| 113 |
+ r2.Header = make(http.Header, len(r.Header)) |
|
| 114 |
+ for k, s := range r.Header {
|
|
| 115 |
+ r2.Header[k] = append([]string(nil), s...) |
|
| 116 |
+ } |
|
| 117 |
+ |
|
| 118 |
+ return r2 |
|
| 119 |
+} |
|
| 120 |
+ |
|
| 121 |
+// OnEOFReader ensures a callback function is called |
|
| 122 |
+// on Close() and when the underlying Reader returns an io.EOF error |
|
| 123 |
+type OnEOFReader struct {
|
|
| 124 |
+ Rc io.ReadCloser |
|
| 125 |
+ Fn func() |
|
| 126 |
+} |
|
| 127 |
+ |
|
| 128 |
+func (r *OnEOFReader) Read(p []byte) (n int, err error) {
|
|
| 129 |
+ n, err = r.Rc.Read(p) |
|
| 130 |
+ if err == io.EOF {
|
|
| 131 |
+ r.runFunc() |
|
| 132 |
+ } |
|
| 133 |
+ return |
|
| 134 |
+} |
|
| 135 |
+ |
|
| 136 |
+func (r *OnEOFReader) Close() error {
|
|
| 137 |
+ err := r.Rc.Close() |
|
| 138 |
+ r.runFunc() |
|
| 139 |
+ return err |
|
| 140 |
+} |
|
| 141 |
+ |
|
| 142 |
+func (r *OnEOFReader) runFunc() {
|
|
| 143 |
+ if fn := r.Fn; fn != nil {
|
|
| 144 |
+ fn() |
|
| 145 |
+ r.Fn = nil |
|
| 146 |
+ } |
|
| 147 |
+} |
| ... | ... |
@@ -44,8 +44,6 @@ func (auth *RequestAuthorization) getToken() (string, error) {
|
| 44 | 44 |
return auth.tokenCache, nil |
| 45 | 45 |
} |
| 46 | 46 |
|
| 47 |
- client := auth.registryEndpoint.HTTPClient() |
|
| 48 |
- |
|
| 49 | 47 |
for _, challenge := range auth.registryEndpoint.AuthChallenges {
|
| 50 | 48 |
switch strings.ToLower(challenge.Scheme) {
|
| 51 | 49 |
case "basic": |
| ... | ... |
@@ -57,7 +55,7 @@ func (auth *RequestAuthorization) getToken() (string, error) {
|
| 57 | 57 |
params[k] = v |
| 58 | 58 |
} |
| 59 | 59 |
params["scope"] = fmt.Sprintf("%s:%s:%s", auth.resource, auth.scope, strings.Join(auth.actions, ","))
|
| 60 |
- token, err := getToken(auth.authConfig.Username, auth.authConfig.Password, params, auth.registryEndpoint, client) |
|
| 60 |
+ token, err := getToken(auth.authConfig.Username, auth.authConfig.Password, params, auth.registryEndpoint) |
|
| 61 | 61 |
if err != nil {
|
| 62 | 62 |
return "", err |
| 63 | 63 |
} |
| ... | ... |
@@ -104,7 +102,6 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri |
| 104 | 104 |
status string |
| 105 | 105 |
reqBody []byte |
| 106 | 106 |
err error |
| 107 |
- client = registryEndpoint.HTTPClient() |
|
| 108 | 107 |
reqStatusCode = 0 |
| 109 | 108 |
serverAddress = authConfig.ServerAddress |
| 110 | 109 |
) |
| ... | ... |
@@ -128,7 +125,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri |
| 128 | 128 |
|
| 129 | 129 |
// using `bytes.NewReader(jsonBody)` here causes the server to respond with a 411 status. |
| 130 | 130 |
b := strings.NewReader(string(jsonBody)) |
| 131 |
- req1, err := client.Post(serverAddress+"users/", "application/json; charset=utf-8", b) |
|
| 131 |
+ req1, err := registryEndpoint.client.Post(serverAddress+"users/", "application/json; charset=utf-8", b) |
|
| 132 | 132 |
if err != nil {
|
| 133 | 133 |
return "", fmt.Errorf("Server Error: %s", err)
|
| 134 | 134 |
} |
| ... | ... |
@@ -151,7 +148,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri |
| 151 | 151 |
if string(reqBody) == "\"Username or email already exists\"" {
|
| 152 | 152 |
req, err := http.NewRequest("GET", serverAddress+"users/", nil)
|
| 153 | 153 |
req.SetBasicAuth(authConfig.Username, authConfig.Password) |
| 154 |
- resp, err := client.Do(req) |
|
| 154 |
+ resp, err := registryEndpoint.client.Do(req) |
|
| 155 | 155 |
if err != nil {
|
| 156 | 156 |
return "", err |
| 157 | 157 |
} |
| ... | ... |
@@ -180,7 +177,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri |
| 180 | 180 |
// protected, so people can use `docker login` as an auth check. |
| 181 | 181 |
req, err := http.NewRequest("GET", serverAddress+"users/", nil)
|
| 182 | 182 |
req.SetBasicAuth(authConfig.Username, authConfig.Password) |
| 183 |
- resp, err := client.Do(req) |
|
| 183 |
+ resp, err := registryEndpoint.client.Do(req) |
|
| 184 | 184 |
if err != nil {
|
| 185 | 185 |
return "", err |
| 186 | 186 |
} |
| ... | ... |
@@ -217,7 +214,6 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri |
| 217 | 217 |
var ( |
| 218 | 218 |
err error |
| 219 | 219 |
allErrors []error |
| 220 |
- client = registryEndpoint.HTTPClient() |
|
| 221 | 220 |
) |
| 222 | 221 |
|
| 223 | 222 |
for _, challenge := range registryEndpoint.AuthChallenges {
|
| ... | ... |
@@ -225,9 +221,9 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri |
| 225 | 225 |
|
| 226 | 226 |
switch strings.ToLower(challenge.Scheme) {
|
| 227 | 227 |
case "basic": |
| 228 |
- err = tryV2BasicAuthLogin(authConfig, challenge.Parameters, registryEndpoint, client) |
|
| 228 |
+ err = tryV2BasicAuthLogin(authConfig, challenge.Parameters, registryEndpoint) |
|
| 229 | 229 |
case "bearer": |
| 230 |
- err = tryV2TokenAuthLogin(authConfig, challenge.Parameters, registryEndpoint, client) |
|
| 230 |
+ err = tryV2TokenAuthLogin(authConfig, challenge.Parameters, registryEndpoint) |
|
| 231 | 231 |
default: |
| 232 | 232 |
// Unsupported challenge types are explicitly skipped. |
| 233 | 233 |
err = fmt.Errorf("unsupported auth scheme: %q", challenge.Scheme)
|
| ... | ... |
@@ -245,7 +241,7 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri |
| 245 | 245 |
return "", fmt.Errorf("no successful auth challenge for %s - errors: %s", registryEndpoint, allErrors)
|
| 246 | 246 |
} |
| 247 | 247 |
|
| 248 |
-func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint, client *http.Client) error {
|
|
| 248 |
+func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint) error {
|
|
| 249 | 249 |
req, err := http.NewRequest("GET", registryEndpoint.Path(""), nil)
|
| 250 | 250 |
if err != nil {
|
| 251 | 251 |
return err |
| ... | ... |
@@ -253,7 +249,7 @@ func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str |
| 253 | 253 |
|
| 254 | 254 |
req.SetBasicAuth(authConfig.Username, authConfig.Password) |
| 255 | 255 |
|
| 256 |
- resp, err := client.Do(req) |
|
| 256 |
+ resp, err := registryEndpoint.client.Do(req) |
|
| 257 | 257 |
if err != nil {
|
| 258 | 258 |
return err |
| 259 | 259 |
} |
| ... | ... |
@@ -266,8 +262,8 @@ func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str |
| 266 | 266 |
return nil |
| 267 | 267 |
} |
| 268 | 268 |
|
| 269 |
-func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint, client *http.Client) error {
|
|
| 270 |
- token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint, client) |
|
| 269 |
+func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint) error {
|
|
| 270 |
+ token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint) |
|
| 271 | 271 |
if err != nil {
|
| 272 | 272 |
return err |
| 273 | 273 |
} |
| ... | ... |
@@ -279,7 +275,7 @@ func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str |
| 279 | 279 |
|
| 280 | 280 |
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
| 281 | 281 |
|
| 282 |
- resp, err := client.Do(req) |
|
| 282 |
+ resp, err := registryEndpoint.client.Do(req) |
|
| 283 | 283 |
if err != nil {
|
| 284 | 284 |
return err |
| 285 | 285 |
} |
| ... | ... |
@@ -11,6 +11,7 @@ import ( |
| 11 | 11 |
|
| 12 | 12 |
"github.com/Sirupsen/logrus" |
| 13 | 13 |
"github.com/docker/distribution/registry/api/v2" |
| 14 |
+ "github.com/docker/docker/pkg/transport" |
|
| 14 | 15 |
) |
| 15 | 16 |
|
| 16 | 17 |
// for mocking in unit tests |
| ... | ... |
@@ -41,9 +42,9 @@ func scanForAPIVersion(address string) (string, APIVersion) {
|
| 41 | 41 |
} |
| 42 | 42 |
|
| 43 | 43 |
// NewEndpoint parses the given address to return a registry endpoint. |
| 44 |
-func NewEndpoint(index *IndexInfo) (*Endpoint, error) {
|
|
| 44 |
+func NewEndpoint(index *IndexInfo, metaHeaders http.Header) (*Endpoint, error) {
|
|
| 45 | 45 |
// *TODO: Allow per-registry configuration of endpoints. |
| 46 |
- endpoint, err := newEndpoint(index.GetAuthConfigKey(), index.Secure) |
|
| 46 |
+ endpoint, err := newEndpoint(index.GetAuthConfigKey(), index.Secure, metaHeaders) |
|
| 47 | 47 |
if err != nil {
|
| 48 | 48 |
return nil, err |
| 49 | 49 |
} |
| ... | ... |
@@ -81,7 +82,7 @@ func validateEndpoint(endpoint *Endpoint) error {
|
| 81 | 81 |
return nil |
| 82 | 82 |
} |
| 83 | 83 |
|
| 84 |
-func newEndpoint(address string, secure bool) (*Endpoint, error) {
|
|
| 84 |
+func newEndpoint(address string, secure bool, metaHeaders http.Header) (*Endpoint, error) {
|
|
| 85 | 85 |
var ( |
| 86 | 86 |
endpoint = new(Endpoint) |
| 87 | 87 |
trimmedAddress string |
| ... | ... |
@@ -98,11 +99,13 @@ func newEndpoint(address string, secure bool) (*Endpoint, error) {
|
| 98 | 98 |
return nil, err |
| 99 | 99 |
} |
| 100 | 100 |
endpoint.IsSecure = secure |
| 101 |
+ tr := NewTransport(ConnectTimeout, endpoint.IsSecure) |
|
| 102 |
+ endpoint.client = HTTPClient(transport.NewTransport(tr, DockerHeaders(metaHeaders)...)) |
|
| 101 | 103 |
return endpoint, nil |
| 102 | 104 |
} |
| 103 | 105 |
|
| 104 |
-func (repoInfo *RepositoryInfo) GetEndpoint() (*Endpoint, error) {
|
|
| 105 |
- return NewEndpoint(repoInfo.Index) |
|
| 106 |
+func (repoInfo *RepositoryInfo) GetEndpoint(metaHeaders http.Header) (*Endpoint, error) {
|
|
| 107 |
+ return NewEndpoint(repoInfo.Index, metaHeaders) |
|
| 106 | 108 |
} |
| 107 | 109 |
|
| 108 | 110 |
// Endpoint stores basic information about a registry endpoint. |
| ... | ... |
@@ -174,7 +177,7 @@ func (e *Endpoint) pingV1() (RegistryInfo, error) {
|
| 174 | 174 |
return RegistryInfo{Standalone: false}, err
|
| 175 | 175 |
} |
| 176 | 176 |
|
| 177 |
- resp, err := e.HTTPClient().Do(req) |
|
| 177 |
+ resp, err := e.client.Do(req) |
|
| 178 | 178 |
if err != nil {
|
| 179 | 179 |
return RegistryInfo{Standalone: false}, err
|
| 180 | 180 |
} |
| ... | ... |
@@ -222,7 +225,7 @@ func (e *Endpoint) pingV2() (RegistryInfo, error) {
|
| 222 | 222 |
return RegistryInfo{}, err
|
| 223 | 223 |
} |
| 224 | 224 |
|
| 225 |
- resp, err := e.HTTPClient().Do(req) |
|
| 225 |
+ resp, err := e.client.Do(req) |
|
| 226 | 226 |
if err != nil {
|
| 227 | 227 |
return RegistryInfo{}, err
|
| 228 | 228 |
} |
| ... | ... |
@@ -261,11 +264,3 @@ HeaderLoop: |
| 261 | 261 |
|
| 262 | 262 |
return RegistryInfo{}, fmt.Errorf("v2 registry endpoint returned status %d: %q", resp.StatusCode, http.StatusText(resp.StatusCode))
|
| 263 | 263 |
} |
| 264 |
- |
|
| 265 |
-func (e *Endpoint) HTTPClient() *http.Client {
|
|
| 266 |
- if e.client == nil {
|
|
| 267 |
- tr := NewTransport(ConnectTimeout, e.IsSecure) |
|
| 268 |
- e.client = HTTPClient(tr) |
|
| 269 |
- } |
|
| 270 |
- return e.client |
|
| 271 |
-} |
| ... | ... |
@@ -19,7 +19,7 @@ func TestEndpointParse(t *testing.T) {
|
| 19 | 19 |
{"0.0.0.0:5000", "https://0.0.0.0:5000/v0/"},
|
| 20 | 20 |
} |
| 21 | 21 |
for _, td := range testData {
|
| 22 |
- e, err := newEndpoint(td.str, false) |
|
| 22 |
+ e, err := newEndpoint(td.str, false, nil) |
|
| 23 | 23 |
if err != nil {
|
| 24 | 24 |
t.Errorf("%q: %s", td.str, err)
|
| 25 | 25 |
} |
| ... | ... |
@@ -60,6 +60,7 @@ func TestValidateEndpointAmbiguousAPIVersion(t *testing.T) {
|
| 60 | 60 |
testEndpoint := Endpoint{
|
| 61 | 61 |
URL: testServerURL, |
| 62 | 62 |
Version: APIVersionUnknown, |
| 63 |
+ client: HTTPClient(NewTransport(ConnectTimeout, false)), |
|
| 63 | 64 |
} |
| 64 | 65 |
|
| 65 | 66 |
if err = validateEndpoint(&testEndpoint); err != nil {
|
| ... | ... |
@@ -19,6 +19,7 @@ import ( |
| 19 | 19 |
"github.com/docker/docker/autogen/dockerversion" |
| 20 | 20 |
"github.com/docker/docker/pkg/parsers/kernel" |
| 21 | 21 |
"github.com/docker/docker/pkg/timeoutconn" |
| 22 |
+ "github.com/docker/docker/pkg/transport" |
|
| 22 | 23 |
"github.com/docker/docker/pkg/useragent" |
| 23 | 24 |
) |
| 24 | 25 |
|
| ... | ... |
@@ -36,17 +37,32 @@ const ( |
| 36 | 36 |
ConnectTimeout |
| 37 | 37 |
) |
| 38 | 38 |
|
| 39 |
-type httpsTransport struct {
|
|
| 40 |
- *http.Transport |
|
| 39 |
+// dockerUserAgent is the User-Agent the Docker client uses to identify itself. |
|
| 40 |
+// It is populated on init(), comprising version information of different components. |
|
| 41 |
+var dockerUserAgent string |
|
| 42 |
+ |
|
| 43 |
+func init() {
|
|
| 44 |
+ httpVersion := make([]useragent.VersionInfo, 0, 6) |
|
| 45 |
+ httpVersion = append(httpVersion, useragent.VersionInfo{"docker", dockerversion.VERSION})
|
|
| 46 |
+ httpVersion = append(httpVersion, useragent.VersionInfo{"go", runtime.Version()})
|
|
| 47 |
+ httpVersion = append(httpVersion, useragent.VersionInfo{"git-commit", dockerversion.GITCOMMIT})
|
|
| 48 |
+ if kernelVersion, err := kernel.GetKernelVersion(); err == nil {
|
|
| 49 |
+ httpVersion = append(httpVersion, useragent.VersionInfo{"kernel", kernelVersion.String()})
|
|
| 50 |
+ } |
|
| 51 |
+ httpVersion = append(httpVersion, useragent.VersionInfo{"os", runtime.GOOS})
|
|
| 52 |
+ httpVersion = append(httpVersion, useragent.VersionInfo{"arch", runtime.GOARCH})
|
|
| 53 |
+ |
|
| 54 |
+ dockerUserAgent = useragent.AppendVersions("", httpVersion...)
|
|
| 41 | 55 |
} |
| 42 | 56 |
|
| 57 |
+type httpsRequestModifier struct{ tlsConfig *tls.Config }
|
|
| 58 |
+ |
|
| 43 | 59 |
// DRAGONS(tiborvass): If someone wonders why do we set tlsconfig in a roundtrip, |
| 44 | 60 |
// it's because it's so as to match the current behavior in master: we generate the |
| 45 | 61 |
// certpool on every-goddam-request. It's not great, but it allows people to just put |
| 46 | 62 |
// the certs in /etc/docker/certs.d/.../ and let docker "pick it up" immediately. Would |
| 47 | 63 |
// prefer an fsnotify implementation, but that was out of scope of my refactoring. |
| 48 |
-// TODO: improve things |
|
| 49 |
-func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
| 64 |
+func (m *httpsRequestModifier) ModifyRequest(req *http.Request) error {
|
|
| 50 | 65 |
var ( |
| 51 | 66 |
roots *x509.CertPool |
| 52 | 67 |
certs []tls.Certificate |
| ... | ... |
@@ -66,7 +82,7 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
| 66 | 66 |
logrus.Debugf("hostDir: %s", hostDir)
|
| 67 | 67 |
fs, err := ioutil.ReadDir(hostDir) |
| 68 | 68 |
if err != nil && !os.IsNotExist(err) {
|
| 69 |
- return nil, err |
|
| 69 |
+ return nil |
|
| 70 | 70 |
} |
| 71 | 71 |
|
| 72 | 72 |
for _, f := range fs {
|
| ... | ... |
@@ -77,7 +93,7 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
| 77 | 77 |
logrus.Debugf("crt: %s", hostDir+"/"+f.Name())
|
| 78 | 78 |
data, err := ioutil.ReadFile(path.Join(hostDir, f.Name())) |
| 79 | 79 |
if err != nil {
|
| 80 |
- return nil, err |
|
| 80 |
+ return err |
|
| 81 | 81 |
} |
| 82 | 82 |
roots.AppendCertsFromPEM(data) |
| 83 | 83 |
} |
| ... | ... |
@@ -86,11 +102,11 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
| 86 | 86 |
keyName := certName[:len(certName)-5] + ".key" |
| 87 | 87 |
logrus.Debugf("cert: %s", hostDir+"/"+f.Name())
|
| 88 | 88 |
if !hasFile(fs, keyName) {
|
| 89 |
- return nil, fmt.Errorf("Missing key %s for certificate %s", keyName, certName)
|
|
| 89 |
+ return fmt.Errorf("Missing key %s for certificate %s", keyName, certName)
|
|
| 90 | 90 |
} |
| 91 | 91 |
cert, err := tls.LoadX509KeyPair(path.Join(hostDir, certName), path.Join(hostDir, keyName)) |
| 92 | 92 |
if err != nil {
|
| 93 |
- return nil, err |
|
| 93 |
+ return err |
|
| 94 | 94 |
} |
| 95 | 95 |
certs = append(certs, cert) |
| 96 | 96 |
} |
| ... | ... |
@@ -99,38 +115,32 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
| 99 | 99 |
certName := keyName[:len(keyName)-4] + ".cert" |
| 100 | 100 |
logrus.Debugf("key: %s", hostDir+"/"+f.Name())
|
| 101 | 101 |
if !hasFile(fs, certName) {
|
| 102 |
- return nil, fmt.Errorf("Missing certificate %s for key %s", certName, keyName)
|
|
| 102 |
+ return fmt.Errorf("Missing certificate %s for key %s", certName, keyName)
|
|
| 103 | 103 |
} |
| 104 | 104 |
} |
| 105 | 105 |
} |
| 106 |
- if tr.Transport.TLSClientConfig == nil {
|
|
| 107 |
- tr.Transport.TLSClientConfig = &tls.Config{
|
|
| 108 |
- // Avoid fallback to SSL protocols < TLS1.0 |
|
| 109 |
- MinVersion: tls.VersionTLS10, |
|
| 110 |
- } |
|
| 111 |
- } |
|
| 112 |
- tr.Transport.TLSClientConfig.RootCAs = roots |
|
| 113 |
- tr.Transport.TLSClientConfig.Certificates = certs |
|
| 106 |
+ m.tlsConfig.RootCAs = roots |
|
| 107 |
+ m.tlsConfig.Certificates = certs |
|
| 114 | 108 |
} |
| 115 |
- return tr.Transport.RoundTrip(req) |
|
| 109 |
+ return nil |
|
| 116 | 110 |
} |
| 117 | 111 |
|
| 118 | 112 |
func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper {
|
| 119 |
- tlsConfig := tls.Config{
|
|
| 113 |
+ tlsConfig := &tls.Config{
|
|
| 120 | 114 |
// Avoid fallback to SSL protocols < TLS1.0 |
| 121 | 115 |
MinVersion: tls.VersionTLS10, |
| 122 | 116 |
InsecureSkipVerify: !secure, |
| 123 | 117 |
} |
| 124 | 118 |
|
| 125 |
- transport := &http.Transport{
|
|
| 119 |
+ tr := &http.Transport{
|
|
| 126 | 120 |
DisableKeepAlives: true, |
| 127 | 121 |
Proxy: http.ProxyFromEnvironment, |
| 128 |
- TLSClientConfig: &tlsConfig, |
|
| 122 |
+ TLSClientConfig: tlsConfig, |
|
| 129 | 123 |
} |
| 130 | 124 |
|
| 131 | 125 |
switch timeout {
|
| 132 | 126 |
case ConnectTimeout: |
| 133 |
- transport.Dial = func(proto string, addr string) (net.Conn, error) {
|
|
| 127 |
+ tr.Dial = func(proto string, addr string) (net.Conn, error) {
|
|
| 134 | 128 |
// Set the connect timeout to 30 seconds to allow for slower connection |
| 135 | 129 |
// times... |
| 136 | 130 |
d := net.Dialer{Timeout: 30 * time.Second, DualStack: true}
|
| ... | ... |
@@ -144,7 +154,7 @@ func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper {
|
| 144 | 144 |
return conn, nil |
| 145 | 145 |
} |
| 146 | 146 |
case ReceiveTimeout: |
| 147 |
- transport.Dial = func(proto string, addr string) (net.Conn, error) {
|
|
| 147 |
+ tr.Dial = func(proto string, addr string) (net.Conn, error) {
|
|
| 148 | 148 |
d := net.Dialer{DualStack: true}
|
| 149 | 149 |
|
| 150 | 150 |
conn, err := d.Dial(proto, addr) |
| ... | ... |
@@ -159,51 +169,23 @@ func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper {
|
| 159 | 159 |
if secure {
|
| 160 | 160 |
// note: httpsTransport also handles http transport |
| 161 | 161 |
// but for HTTPS, it sets up the certs |
| 162 |
- return &httpsTransport{transport}
|
|
| 162 |
+ return transport.NewTransport(tr, &httpsRequestModifier{tlsConfig})
|
|
| 163 | 163 |
} |
| 164 | 164 |
|
| 165 |
- return transport |
|
| 165 |
+ return tr |
|
| 166 | 166 |
} |
| 167 | 167 |
|
| 168 |
-type DockerHeaders struct {
|
|
| 169 |
- http.RoundTripper |
|
| 170 |
- Headers http.Header |
|
| 171 |
-} |
|
| 172 |
- |
|
| 173 |
-// cloneRequest returns a clone of the provided *http.Request. |
|
| 174 |
-// The clone is a shallow copy of the struct and its Header map |
|
| 175 |
-func cloneRequest(r *http.Request) *http.Request {
|
|
| 176 |
- // shallow copy of the struct |
|
| 177 |
- r2 := new(http.Request) |
|
| 178 |
- *r2 = *r |
|
| 179 |
- // deep copy of the Header |
|
| 180 |
- r2.Header = make(http.Header, len(r.Header)) |
|
| 181 |
- for k, s := range r.Header {
|
|
| 182 |
- r2.Header[k] = append([]string(nil), s...) |
|
| 168 |
+// DockerHeaders returns request modifiers that ensure requests have |
|
| 169 |
+// the User-Agent header set to dockerUserAgent and that metaHeaders |
|
| 170 |
+// are added. |
|
| 171 |
+func DockerHeaders(metaHeaders http.Header) []transport.RequestModifier {
|
|
| 172 |
+ modifiers := []transport.RequestModifier{
|
|
| 173 |
+ transport.NewHeaderRequestModifier(http.Header{"User-Agent": []string{dockerUserAgent}}),
|
|
| 183 | 174 |
} |
| 184 |
- return r2 |
|
| 185 |
-} |
|
| 186 |
- |
|
| 187 |
-func (tr *DockerHeaders) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
| 188 |
- req = cloneRequest(req) |
|
| 189 |
- httpVersion := make([]useragent.VersionInfo, 0, 4) |
|
| 190 |
- httpVersion = append(httpVersion, useragent.VersionInfo{"docker", dockerversion.VERSION})
|
|
| 191 |
- httpVersion = append(httpVersion, useragent.VersionInfo{"go", runtime.Version()})
|
|
| 192 |
- httpVersion = append(httpVersion, useragent.VersionInfo{"git-commit", dockerversion.GITCOMMIT})
|
|
| 193 |
- if kernelVersion, err := kernel.GetKernelVersion(); err == nil {
|
|
| 194 |
- httpVersion = append(httpVersion, useragent.VersionInfo{"kernel", kernelVersion.String()})
|
|
| 195 |
- } |
|
| 196 |
- httpVersion = append(httpVersion, useragent.VersionInfo{"os", runtime.GOOS})
|
|
| 197 |
- httpVersion = append(httpVersion, useragent.VersionInfo{"arch", runtime.GOARCH})
|
|
| 198 |
- |
|
| 199 |
- userAgent := useragent.AppendVersions(req.UserAgent(), httpVersion...) |
|
| 200 |
- |
|
| 201 |
- req.Header.Set("User-Agent", userAgent)
|
|
| 202 |
- |
|
| 203 |
- for k, v := range tr.Headers {
|
|
| 204 |
- req.Header[k] = v |
|
| 175 |
+ if metaHeaders != nil {
|
|
| 176 |
+ modifiers = append(modifiers, transport.NewHeaderRequestModifier(metaHeaders)) |
|
| 205 | 177 |
} |
| 206 |
- return tr.RoundTripper.RoundTrip(req) |
|
| 178 |
+ return modifiers |
|
| 207 | 179 |
} |
| 208 | 180 |
|
| 209 | 181 |
type debugTransport struct{ http.RoundTripper }
|
| ... | ... |
@@ -8,6 +8,7 @@ import ( |
| 8 | 8 |
"testing" |
| 9 | 9 |
|
| 10 | 10 |
"github.com/docker/docker/cliconfig" |
| 11 |
+ "github.com/docker/docker/pkg/transport" |
|
| 11 | 12 |
) |
| 12 | 13 |
|
| 13 | 14 |
var ( |
| ... | ... |
@@ -21,12 +22,12 @@ const ( |
| 21 | 21 |
|
| 22 | 22 |
func spawnTestRegistrySession(t *testing.T) *Session {
|
| 23 | 23 |
authConfig := &cliconfig.AuthConfig{}
|
| 24 |
- endpoint, err := NewEndpoint(makeIndex("/v1/"))
|
|
| 24 |
+ endpoint, err := NewEndpoint(makeIndex("/v1/"), nil)
|
|
| 25 | 25 |
if err != nil {
|
| 26 | 26 |
t.Fatal(err) |
| 27 | 27 |
} |
| 28 | 28 |
var tr http.RoundTripper = debugTransport{NewTransport(ReceiveTimeout, endpoint.IsSecure)}
|
| 29 |
- tr = &DockerHeaders{&authTransport{RoundTripper: tr, AuthConfig: authConfig}, nil}
|
|
| 29 |
+ tr = transport.NewTransport(AuthTransport(tr, authConfig, false), DockerHeaders(nil)...) |
|
| 30 | 30 |
client := HTTPClient(tr) |
| 31 | 31 |
r, err := NewSession(client, authConfig, endpoint) |
| 32 | 32 |
if err != nil {
|
| ... | ... |
@@ -48,7 +49,7 @@ func spawnTestRegistrySession(t *testing.T) *Session {
|
| 48 | 48 |
|
| 49 | 49 |
func TestPingRegistryEndpoint(t *testing.T) {
|
| 50 | 50 |
testPing := func(index *IndexInfo, expectedStandalone bool, assertMessage string) {
|
| 51 |
- ep, err := NewEndpoint(index) |
|
| 51 |
+ ep, err := NewEndpoint(index, nil) |
|
| 52 | 52 |
if err != nil {
|
| 53 | 53 |
t.Fatal(err) |
| 54 | 54 |
} |
| ... | ... |
@@ -68,7 +69,7 @@ func TestPingRegistryEndpoint(t *testing.T) {
|
| 68 | 68 |
func TestEndpoint(t *testing.T) {
|
| 69 | 69 |
// Simple wrapper to fail test if err != nil |
| 70 | 70 |
expandEndpoint := func(index *IndexInfo) *Endpoint {
|
| 71 |
- endpoint, err := NewEndpoint(index) |
|
| 71 |
+ endpoint, err := NewEndpoint(index, nil) |
|
| 72 | 72 |
if err != nil {
|
| 73 | 73 |
t.Fatal(err) |
| 74 | 74 |
} |
| ... | ... |
@@ -77,7 +78,7 @@ func TestEndpoint(t *testing.T) {
|
| 77 | 77 |
|
| 78 | 78 |
assertInsecureIndex := func(index *IndexInfo) {
|
| 79 | 79 |
index.Secure = true |
| 80 |
- _, err := NewEndpoint(index) |
|
| 80 |
+ _, err := NewEndpoint(index, nil) |
|
| 81 | 81 |
assertNotEqual(t, err, nil, index.Name+": Expected error for insecure index") |
| 82 | 82 |
assertEqual(t, strings.Contains(err.Error(), "insecure-registry"), true, index.Name+": Expected insecure-registry error for insecure index") |
| 83 | 83 |
index.Secure = false |
| ... | ... |
@@ -85,7 +86,7 @@ func TestEndpoint(t *testing.T) {
|
| 85 | 85 |
|
| 86 | 86 |
assertSecureIndex := func(index *IndexInfo) {
|
| 87 | 87 |
index.Secure = true |
| 88 |
- _, err := NewEndpoint(index) |
|
| 88 |
+ _, err := NewEndpoint(index, nil) |
|
| 89 | 89 |
assertNotEqual(t, err, nil, index.Name+": Expected cert error for secure index") |
| 90 | 90 |
assertEqual(t, strings.Contains(err.Error(), "certificate signed by unknown authority"), true, index.Name+": Expected cert error for secure index") |
| 91 | 91 |
index.Secure = false |
| ... | ... |
@@ -151,7 +152,7 @@ func TestEndpoint(t *testing.T) {
|
| 151 | 151 |
} |
| 152 | 152 |
for _, address := range badEndpoints {
|
| 153 | 153 |
index.Name = address |
| 154 |
- _, err := NewEndpoint(index) |
|
| 154 |
+ _, err := NewEndpoint(index, nil) |
|
| 155 | 155 |
checkNotEqual(t, err, nil, "Expected error while expanding bad endpoint") |
| 156 | 156 |
} |
| 157 | 157 |
} |
| ... | ... |
@@ -1,6 +1,10 @@ |
| 1 | 1 |
package registry |
| 2 | 2 |
|
| 3 |
-import "github.com/docker/docker/cliconfig" |
|
| 3 |
+import ( |
|
| 4 |
+ "net/http" |
|
| 5 |
+ |
|
| 6 |
+ "github.com/docker/docker/cliconfig" |
|
| 7 |
+) |
|
| 4 | 8 |
|
| 5 | 9 |
type Service struct {
|
| 6 | 10 |
Config *ServiceConfig |
| ... | ... |
@@ -27,7 +31,7 @@ func (s *Service) Auth(authConfig *cliconfig.AuthConfig) (string, error) {
|
| 27 | 27 |
if err != nil {
|
| 28 | 28 |
return "", err |
| 29 | 29 |
} |
| 30 |
- endpoint, err := NewEndpoint(index) |
|
| 30 |
+ endpoint, err := NewEndpoint(index, nil) |
|
| 31 | 31 |
if err != nil {
|
| 32 | 32 |
return "", err |
| 33 | 33 |
} |
| ... | ... |
@@ -44,11 +48,11 @@ func (s *Service) Search(term string, authConfig *cliconfig.AuthConfig, headers |
| 44 | 44 |
} |
| 45 | 45 |
|
| 46 | 46 |
// *TODO: Search multiple indexes. |
| 47 |
- endpoint, err := repoInfo.GetEndpoint() |
|
| 47 |
+ endpoint, err := repoInfo.GetEndpoint(http.Header(headers)) |
|
| 48 | 48 |
if err != nil {
|
| 49 | 49 |
return nil, err |
| 50 | 50 |
} |
| 51 |
- r, err := NewSession(endpoint.HTTPClient(), authConfig, endpoint) |
|
| 51 |
+ r, err := NewSession(endpoint.client, authConfig, endpoint) |
|
| 52 | 52 |
if err != nil {
|
| 53 | 53 |
return nil, err |
| 54 | 54 |
} |
| ... | ... |
@@ -4,6 +4,7 @@ import ( |
| 4 | 4 |
"bytes" |
| 5 | 5 |
"crypto/sha256" |
| 6 | 6 |
"errors" |
| 7 |
+ "sync" |
|
| 7 | 8 |
// this is required for some certificates |
| 8 | 9 |
_ "crypto/sha512" |
| 9 | 10 |
"encoding/hex" |
| ... | ... |
@@ -22,6 +23,7 @@ import ( |
| 22 | 22 |
"github.com/docker/docker/cliconfig" |
| 23 | 23 |
"github.com/docker/docker/pkg/httputils" |
| 24 | 24 |
"github.com/docker/docker/pkg/tarsum" |
| 25 |
+ "github.com/docker/docker/pkg/transport" |
|
| 25 | 26 |
) |
| 26 | 27 |
|
| 27 | 28 |
type Session struct {
|
| ... | ... |
@@ -31,7 +33,18 @@ type Session struct {
|
| 31 | 31 |
authConfig *cliconfig.AuthConfig |
| 32 | 32 |
} |
| 33 | 33 |
|
| 34 |
-// authTransport handles the auth layer when communicating with a v1 registry (private or official) |
|
| 34 |
+type authTransport struct {
|
|
| 35 |
+ http.RoundTripper |
|
| 36 |
+ *cliconfig.AuthConfig |
|
| 37 |
+ |
|
| 38 |
+ alwaysSetBasicAuth bool |
|
| 39 |
+ token []string |
|
| 40 |
+ |
|
| 41 |
+ mu sync.Mutex // guards modReq |
|
| 42 |
+ modReq map[*http.Request]*http.Request // original -> modified |
|
| 43 |
+} |
|
| 44 |
+ |
|
| 45 |
+// AuthTransport handles the auth layer when communicating with a v1 registry (private or official) |
|
| 35 | 46 |
// |
| 36 | 47 |
// For private v1 registries, set alwaysSetBasicAuth to true. |
| 37 | 48 |
// |
| ... | ... |
@@ -44,16 +57,23 @@ type Session struct {
|
| 44 | 44 |
// If the server sends a token without the client having requested it, it is ignored. |
| 45 | 45 |
// |
| 46 | 46 |
// This RoundTripper also has a CancelRequest method important for correct timeout handling. |
| 47 |
-type authTransport struct {
|
|
| 48 |
- http.RoundTripper |
|
| 49 |
- *cliconfig.AuthConfig |
|
| 50 |
- |
|
| 51 |
- alwaysSetBasicAuth bool |
|
| 52 |
- token []string |
|
| 47 |
+func AuthTransport(base http.RoundTripper, authConfig *cliconfig.AuthConfig, alwaysSetBasicAuth bool) http.RoundTripper {
|
|
| 48 |
+ if base == nil {
|
|
| 49 |
+ base = http.DefaultTransport |
|
| 50 |
+ } |
|
| 51 |
+ return &authTransport{
|
|
| 52 |
+ RoundTripper: base, |
|
| 53 |
+ AuthConfig: authConfig, |
|
| 54 |
+ alwaysSetBasicAuth: alwaysSetBasicAuth, |
|
| 55 |
+ modReq: make(map[*http.Request]*http.Request), |
|
| 56 |
+ } |
|
| 53 | 57 |
} |
| 54 | 58 |
|
| 55 |
-func (tr *authTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
| 56 |
- req = cloneRequest(req) |
|
| 59 |
+func (tr *authTransport) RoundTrip(orig *http.Request) (*http.Response, error) {
|
|
| 60 |
+ req := transport.CloneRequest(orig) |
|
| 61 |
+ tr.mu.Lock() |
|
| 62 |
+ tr.modReq[orig] = req |
|
| 63 |
+ tr.mu.Unlock() |
|
| 57 | 64 |
|
| 58 | 65 |
if tr.alwaysSetBasicAuth {
|
| 59 | 66 |
req.SetBasicAuth(tr.Username, tr.Password) |
| ... | ... |
@@ -73,14 +93,33 @@ func (tr *authTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
| 73 | 73 |
} |
| 74 | 74 |
resp, err := tr.RoundTripper.RoundTrip(req) |
| 75 | 75 |
if err != nil {
|
| 76 |
+ delete(tr.modReq, orig) |
|
| 76 | 77 |
return nil, err |
| 77 | 78 |
} |
| 78 | 79 |
if askedForToken && len(resp.Header["X-Docker-Token"]) > 0 {
|
| 79 | 80 |
tr.token = resp.Header["X-Docker-Token"] |
| 80 | 81 |
} |
| 82 |
+ resp.Body = &transport.OnEOFReader{
|
|
| 83 |
+ Rc: resp.Body, |
|
| 84 |
+ Fn: func() { delete(tr.modReq, orig) },
|
|
| 85 |
+ } |
|
| 81 | 86 |
return resp, nil |
| 82 | 87 |
} |
| 83 | 88 |
|
| 89 |
+// CancelRequest cancels an in-flight request by closing its connection. |
|
| 90 |
+func (tr *authTransport) CancelRequest(req *http.Request) {
|
|
| 91 |
+ type canceler interface {
|
|
| 92 |
+ CancelRequest(*http.Request) |
|
| 93 |
+ } |
|
| 94 |
+ if cr, ok := tr.RoundTripper.(canceler); ok {
|
|
| 95 |
+ tr.mu.Lock() |
|
| 96 |
+ modReq := tr.modReq[req] |
|
| 97 |
+ delete(tr.modReq, req) |
|
| 98 |
+ tr.mu.Unlock() |
|
| 99 |
+ cr.CancelRequest(modReq) |
|
| 100 |
+ } |
|
| 101 |
+} |
|
| 102 |
+ |
|
| 84 | 103 |
// TODO(tiborvass): remove authConfig param once registry client v2 is vendored |
| 85 | 104 |
func NewSession(client *http.Client, authConfig *cliconfig.AuthConfig, endpoint *Endpoint) (r *Session, err error) {
|
| 86 | 105 |
r = &Session{
|
| ... | ... |
@@ -105,7 +144,7 @@ func NewSession(client *http.Client, authConfig *cliconfig.AuthConfig, endpoint |
| 105 | 105 |
} |
| 106 | 106 |
} |
| 107 | 107 |
|
| 108 |
- client.Transport = &authTransport{RoundTripper: client.Transport, AuthConfig: authConfig, alwaysSetBasicAuth: alwaysSetBasicAuth}
|
|
| 108 |
+ client.Transport = AuthTransport(client.Transport, authConfig, alwaysSetBasicAuth) |
|
| 109 | 109 |
|
| 110 | 110 |
jar, err := cookiejar.New(nil) |
| 111 | 111 |
if err != nil {
|
| ... | ... |
@@ -27,7 +27,7 @@ func getV2Builder(e *Endpoint) *v2.URLBuilder {
|
| 27 | 27 |
func (r *Session) V2RegistryEndpoint(index *IndexInfo) (ep *Endpoint, err error) {
|
| 28 | 28 |
// TODO check if should use Mirror |
| 29 | 29 |
if index.Official {
|
| 30 |
- ep, err = newEndpoint(REGISTRYSERVER, true) |
|
| 30 |
+ ep, err = newEndpoint(REGISTRYSERVER, true, nil) |
|
| 31 | 31 |
if err != nil {
|
| 32 | 32 |
return |
| 33 | 33 |
} |
| ... | ... |
@@ -38,7 +38,7 @@ func (r *Session) V2RegistryEndpoint(index *IndexInfo) (ep *Endpoint, err error) |
| 38 | 38 |
} else if r.indexEndpoint.String() == index.GetAuthConfigKey() {
|
| 39 | 39 |
ep = r.indexEndpoint |
| 40 | 40 |
} else {
|
| 41 |
- ep, err = NewEndpoint(index) |
|
| 41 |
+ ep, err = NewEndpoint(index, nil) |
|
| 42 | 42 |
if err != nil {
|
| 43 | 43 |
return |
| 44 | 44 |
} |
| ... | ... |
@@ -13,7 +13,7 @@ type tokenResponse struct {
|
| 13 | 13 |
Token string `json:"token"` |
| 14 | 14 |
} |
| 15 | 15 |
|
| 16 |
-func getToken(username, password string, params map[string]string, registryEndpoint *Endpoint, client *http.Client) (token string, err error) {
|
|
| 16 |
+func getToken(username, password string, params map[string]string, registryEndpoint *Endpoint) (token string, err error) {
|
|
| 17 | 17 |
realm, ok := params["realm"] |
| 18 | 18 |
if !ok {
|
| 19 | 19 |
return "", errors.New("no realm specified for token auth challenge")
|
| ... | ... |
@@ -56,7 +56,7 @@ func getToken(username, password string, params map[string]string, registryEndpo |
| 56 | 56 |
|
| 57 | 57 |
req.URL.RawQuery = reqParams.Encode() |
| 58 | 58 |
|
| 59 |
- resp, err := client.Do(req) |
|
| 59 |
+ resp, err := registryEndpoint.client.Do(req) |
|
| 60 | 60 |
if err != nil {
|
| 61 | 61 |
return "", err |
| 62 | 62 |
} |