Browse code

Improved push and pull with upload manager and download manager

This commit adds a transfer manager which deduplicates and schedules
transfers, and also an upload manager and download manager that build on
top of the transfer manager to provide high-level interfaces for uploads
and downloads. The push and pull code is modified to use these building
blocks.

Some benefits of the changes:

- Simplification of push/pull code
- Pushes can upload layers concurrently
- Failed downloads and uploads are retried after backoff delays
- Cancellation is supported, but individual transfers will only be
cancelled if all pushes or pulls using them are cancelled.
- The distribution code is decoupled from Docker Engine packages and API
conventions (i.e. streamformatter), which will make it easier to split
out.

This commit also includes unit tests for the new distribution/xfer
package. The tests cover 87.8% of the statements in the package.

Signed-off-by: Aaron Lehmann <aaron.lehmann@docker.com>

Aaron Lehmann authored on 2015/11/14 09:59:01
Showing 36 changed files
... ...
@@ -23,7 +23,7 @@ import (
23 23
 	"github.com/docker/docker/pkg/httputils"
24 24
 	"github.com/docker/docker/pkg/jsonmessage"
25 25
 	flag "github.com/docker/docker/pkg/mflag"
26
-	"github.com/docker/docker/pkg/progressreader"
26
+	"github.com/docker/docker/pkg/progress"
27 27
 	"github.com/docker/docker/pkg/streamformatter"
28 28
 	"github.com/docker/docker/pkg/ulimit"
29 29
 	"github.com/docker/docker/pkg/units"
... ...
@@ -169,16 +169,9 @@ func (cli *DockerCli) CmdBuild(args ...string) error {
169 169
 	context = replaceDockerfileTarWrapper(context, newDockerfile, relDockerfile)
170 170
 
171 171
 	// Setup an upload progress bar
172
-	// FIXME: ProgressReader shouldn't be this annoying to use
173
-	sf := streamformatter.NewStreamFormatter()
174
-	var body io.Reader = progressreader.New(progressreader.Config{
175
-		In:        context,
176
-		Out:       cli.out,
177
-		Formatter: sf,
178
-		NewLines:  true,
179
-		ID:        "",
180
-		Action:    "Sending build context to Docker daemon",
181
-	})
172
+	progressOutput := streamformatter.NewStreamFormatter().NewProgressOutput(cli.out, true)
173
+
174
+	var body io.Reader = progress.NewProgressReader(context, progressOutput, 0, "", "Sending build context to Docker daemon")
182 175
 
183 176
 	var memory int64
184 177
 	if *flMemoryString != "" {
... ...
@@ -447,17 +440,10 @@ func getContextFromURL(out io.Writer, remoteURL, dockerfileName string) (absCont
447 447
 		return "", "", fmt.Errorf("unable to download remote context %s: %v", remoteURL, err)
448 448
 	}
449 449
 	defer response.Body.Close()
450
+	progressOutput := streamformatter.NewStreamFormatter().NewProgressOutput(out, true)
450 451
 
451 452
 	// Pass the response body through a progress reader.
452
-	progReader := &progressreader.Config{
453
-		In:        response.Body,
454
-		Out:       out,
455
-		Formatter: streamformatter.NewStreamFormatter(),
456
-		Size:      response.ContentLength,
457
-		NewLines:  true,
458
-		ID:        "",
459
-		Action:    fmt.Sprintf("Downloading build context from remote url: %s", remoteURL),
460
-	}
453
+	progReader := progress.NewProgressReader(response.Body, progressOutput, response.ContentLength, "", fmt.Sprintf("Downloading build context from remote url: %s", remoteURL))
461 454
 
462 455
 	return getContextFromReader(progReader, dockerfileName)
463 456
 }
... ...
@@ -23,7 +23,7 @@ import (
23 23
 	"github.com/docker/docker/pkg/archive"
24 24
 	"github.com/docker/docker/pkg/chrootarchive"
25 25
 	"github.com/docker/docker/pkg/ioutils"
26
-	"github.com/docker/docker/pkg/progressreader"
26
+	"github.com/docker/docker/pkg/progress"
27 27
 	"github.com/docker/docker/pkg/streamformatter"
28 28
 	"github.com/docker/docker/pkg/ulimit"
29 29
 	"github.com/docker/docker/runconfig"
... ...
@@ -325,7 +325,7 @@ func (s *router) postBuild(ctx context.Context, w http.ResponseWriter, r *http.R
325 325
 	sf := streamformatter.NewJSONStreamFormatter()
326 326
 	errf := func(err error) error {
327 327
 		// Do not write the error in the http output if it's still empty.
328
-		// This prevents from writing a 200(OK) when there is an interal error.
328
+		// This prevents from writing a 200(OK) when there is an internal error.
329 329
 		if !output.Flushed() {
330 330
 			return err
331 331
 		}
... ...
@@ -401,23 +401,17 @@ func (s *router) postBuild(ctx context.Context, w http.ResponseWriter, r *http.R
401 401
 	remoteURL := r.FormValue("remote")
402 402
 
403 403
 	// Currently, only used if context is from a remote url.
404
-	// The field `In` is set by DetectContextFromRemoteURL.
405 404
 	// Look at code in DetectContextFromRemoteURL for more information.
406
-	pReader := &progressreader.Config{
407
-		// TODO: make progressreader streamformatter-agnostic
408
-		Out:       output,
409
-		Formatter: sf,
410
-		Size:      r.ContentLength,
411
-		NewLines:  true,
412
-		ID:        "Downloading context",
413
-		Action:    remoteURL,
405
+	createProgressReader := func(in io.ReadCloser) io.ReadCloser {
406
+		progressOutput := sf.NewProgressOutput(output, true)
407
+		return progress.NewProgressReader(in, progressOutput, r.ContentLength, "Downloading context", remoteURL)
414 408
 	}
415 409
 
416 410
 	var (
417 411
 		context        builder.ModifiableContext
418 412
 		dockerfileName string
419 413
 	)
420
-	context, dockerfileName, err = daemonbuilder.DetectContextFromRemoteURL(r.Body, remoteURL, pReader)
414
+	context, dockerfileName, err = daemonbuilder.DetectContextFromRemoteURL(r.Body, remoteURL, createProgressReader)
421 415
 	if err != nil {
422 416
 		return errf(err)
423 417
 	}
... ...
@@ -29,7 +29,7 @@ import (
29 29
 	"github.com/docker/docker/pkg/httputils"
30 30
 	"github.com/docker/docker/pkg/ioutils"
31 31
 	"github.com/docker/docker/pkg/jsonmessage"
32
-	"github.com/docker/docker/pkg/progressreader"
32
+	"github.com/docker/docker/pkg/progress"
33 33
 	"github.com/docker/docker/pkg/streamformatter"
34 34
 	"github.com/docker/docker/pkg/stringid"
35 35
 	"github.com/docker/docker/pkg/stringutils"
... ...
@@ -264,17 +264,11 @@ func (b *Builder) download(srcURL string) (fi builder.FileInfo, err error) {
264 264
 		return
265 265
 	}
266 266
 
267
+	stdoutFormatter := b.Stdout.(*streamformatter.StdoutFormatter)
268
+	progressOutput := stdoutFormatter.StreamFormatter.NewProgressOutput(stdoutFormatter.Writer, true)
269
+	progressReader := progress.NewProgressReader(resp.Body, progressOutput, resp.ContentLength, "", "Downloading")
267 270
 	// Download and dump result to tmp file
268
-	if _, err = io.Copy(tmpFile, progressreader.New(progressreader.Config{
269
-		In: resp.Body,
270
-		// TODO: make progressreader streamformatter agnostic
271
-		Out:       b.Stdout.(*streamformatter.StdoutFormatter).Writer,
272
-		Formatter: b.Stdout.(*streamformatter.StdoutFormatter).StreamFormatter,
273
-		Size:      resp.ContentLength,
274
-		NewLines:  true,
275
-		ID:        "",
276
-		Action:    "Downloading",
277
-	})); err != nil {
271
+	if _, err = io.Copy(tmpFile, progressReader); err != nil {
278 272
 		tmpFile.Close()
279 273
 		return
280 274
 	}
... ...
@@ -34,6 +34,7 @@ import (
34 34
 	"github.com/docker/docker/daemon/network"
35 35
 	"github.com/docker/docker/distribution"
36 36
 	dmetadata "github.com/docker/docker/distribution/metadata"
37
+	"github.com/docker/docker/distribution/xfer"
37 38
 	derr "github.com/docker/docker/errors"
38 39
 	"github.com/docker/docker/image"
39 40
 	"github.com/docker/docker/image/tarexport"
... ...
@@ -49,7 +50,9 @@ import (
49 49
 	"github.com/docker/docker/pkg/namesgenerator"
50 50
 	"github.com/docker/docker/pkg/nat"
51 51
 	"github.com/docker/docker/pkg/parsers/filters"
52
+	"github.com/docker/docker/pkg/progress"
52 53
 	"github.com/docker/docker/pkg/signal"
54
+	"github.com/docker/docker/pkg/streamformatter"
53 55
 	"github.com/docker/docker/pkg/stringid"
54 56
 	"github.com/docker/docker/pkg/stringutils"
55 57
 	"github.com/docker/docker/pkg/sysinfo"
... ...
@@ -66,6 +69,16 @@ import (
66 66
 	lntypes "github.com/docker/libnetwork/types"
67 67
 	"github.com/docker/libtrust"
68 68
 	"github.com/opencontainers/runc/libcontainer"
69
+	"golang.org/x/net/context"
70
+)
71
+
72
+const (
73
+	// maxDownloadConcurrency is the maximum number of downloads that
74
+	// may take place at a time for each pull.
75
+	maxDownloadConcurrency = 3
76
+	// maxUploadConcurrency is the maximum number of uploads that
77
+	// may take place at a time for each push.
78
+	maxUploadConcurrency = 5
69 79
 )
70 80
 
71 81
 var (
... ...
@@ -126,7 +139,8 @@ type Daemon struct {
126 126
 	containers                *contStore
127 127
 	execCommands              *exec.Store
128 128
 	tagStore                  tag.Store
129
-	distributionPool          *distribution.Pool
129
+	downloadManager           *xfer.LayerDownloadManager
130
+	uploadManager             *xfer.LayerUploadManager
130 131
 	distributionMetadataStore dmetadata.Store
131 132
 	trustKey                  libtrust.PrivateKey
132 133
 	idIndex                   *truncindex.TruncIndex
... ...
@@ -738,7 +752,8 @@ func NewDaemon(config *Config, registryService *registry.Service) (daemon *Daemo
738 738
 		return nil, err
739 739
 	}
740 740
 
741
-	distributionPool := distribution.NewPool()
741
+	d.downloadManager = xfer.NewLayerDownloadManager(d.layerStore, maxDownloadConcurrency)
742
+	d.uploadManager = xfer.NewLayerUploadManager(maxUploadConcurrency)
742 743
 
743 744
 	ifs, err := image.NewFSStoreBackend(filepath.Join(imageRoot, "imagedb"))
744 745
 	if err != nil {
... ...
@@ -834,7 +849,6 @@ func NewDaemon(config *Config, registryService *registry.Service) (daemon *Daemo
834 834
 	d.containers = &contStore{s: make(map[string]*container.Container)}
835 835
 	d.execCommands = exec.NewStore()
836 836
 	d.tagStore = tagStore
837
-	d.distributionPool = distributionPool
838 837
 	d.distributionMetadataStore = distributionMetadataStore
839 838
 	d.trustKey = trustKey
840 839
 	d.idIndex = truncindex.NewTruncIndex([]string{})
... ...
@@ -1038,23 +1052,53 @@ func (daemon *Daemon) TagImage(newTag reference.Named, imageName string) error {
1038 1038
 	return nil
1039 1039
 }
1040 1040
 
1041
+func writeDistributionProgress(cancelFunc func(), outStream io.Writer, progressChan <-chan progress.Progress) {
1042
+	progressOutput := streamformatter.NewJSONStreamFormatter().NewProgressOutput(outStream, false)
1043
+	operationCancelled := false
1044
+
1045
+	for prog := range progressChan {
1046
+		if err := progressOutput.WriteProgress(prog); err != nil && !operationCancelled {
1047
+			logrus.Errorf("error writing progress to client: %v", err)
1048
+			cancelFunc()
1049
+			operationCancelled = true
1050
+			// Don't return, because we need to continue draining
1051
+			// progressChan until it's closed to avoid a deadlock.
1052
+		}
1053
+	}
1054
+}
1055
+
1041 1056
 // PullImage initiates a pull operation. image is the repository name to pull, and
1042 1057
 // tag may be either empty, or indicate a specific tag to pull.
1043 1058
 func (daemon *Daemon) PullImage(ref reference.Named, metaHeaders map[string][]string, authConfig *cliconfig.AuthConfig, outStream io.Writer) error {
1059
+	// Include a buffer so that slow client connections don't affect
1060
+	// transfer performance.
1061
+	progressChan := make(chan progress.Progress, 100)
1062
+
1063
+	writesDone := make(chan struct{})
1064
+
1065
+	ctx, cancelFunc := context.WithCancel(context.Background())
1066
+
1067
+	go func() {
1068
+		writeDistributionProgress(cancelFunc, outStream, progressChan)
1069
+		close(writesDone)
1070
+	}()
1071
+
1044 1072
 	imagePullConfig := &distribution.ImagePullConfig{
1045 1073
 		MetaHeaders:     metaHeaders,
1046 1074
 		AuthConfig:      authConfig,
1047
-		OutStream:       outStream,
1075
+		ProgressOutput:  progress.ChanOutput(progressChan),
1048 1076
 		RegistryService: daemon.RegistryService,
1049 1077
 		EventsService:   daemon.EventsService,
1050 1078
 		MetadataStore:   daemon.distributionMetadataStore,
1051
-		LayerStore:      daemon.layerStore,
1052 1079
 		ImageStore:      daemon.imageStore,
1053 1080
 		TagStore:        daemon.tagStore,
1054
-		Pool:            daemon.distributionPool,
1081
+		DownloadManager: daemon.downloadManager,
1055 1082
 	}
1056 1083
 
1057
-	return distribution.Pull(ref, imagePullConfig)
1084
+	err := distribution.Pull(ctx, ref, imagePullConfig)
1085
+	close(progressChan)
1086
+	<-writesDone
1087
+	return err
1058 1088
 }
1059 1089
 
1060 1090
 // ExportImage exports a list of images to the given output stream. The
... ...
@@ -1069,10 +1113,23 @@ func (daemon *Daemon) ExportImage(names []string, outStream io.Writer) error {
1069 1069
 
1070 1070
 // PushImage initiates a push operation on the repository named localName.
1071 1071
 func (daemon *Daemon) PushImage(ref reference.Named, metaHeaders map[string][]string, authConfig *cliconfig.AuthConfig, outStream io.Writer) error {
1072
+	// Include a buffer so that slow client connections don't affect
1073
+	// transfer performance.
1074
+	progressChan := make(chan progress.Progress, 100)
1075
+
1076
+	writesDone := make(chan struct{})
1077
+
1078
+	ctx, cancelFunc := context.WithCancel(context.Background())
1079
+
1080
+	go func() {
1081
+		writeDistributionProgress(cancelFunc, outStream, progressChan)
1082
+		close(writesDone)
1083
+	}()
1084
+
1072 1085
 	imagePushConfig := &distribution.ImagePushConfig{
1073 1086
 		MetaHeaders:     metaHeaders,
1074 1087
 		AuthConfig:      authConfig,
1075
-		OutStream:       outStream,
1088
+		ProgressOutput:  progress.ChanOutput(progressChan),
1076 1089
 		RegistryService: daemon.RegistryService,
1077 1090
 		EventsService:   daemon.EventsService,
1078 1091
 		MetadataStore:   daemon.distributionMetadataStore,
... ...
@@ -1080,9 +1137,13 @@ func (daemon *Daemon) PushImage(ref reference.Named, metaHeaders map[string][]st
1080 1080
 		ImageStore:      daemon.imageStore,
1081 1081
 		TagStore:        daemon.tagStore,
1082 1082
 		TrustKey:        daemon.trustKey,
1083
+		UploadManager:   daemon.uploadManager,
1083 1084
 	}
1084 1085
 
1085
-	return distribution.Push(ref, imagePushConfig)
1086
+	err := distribution.Push(ctx, ref, imagePushConfig)
1087
+	close(progressChan)
1088
+	<-writesDone
1089
+	return err
1086 1090
 }
1087 1091
 
1088 1092
 // LookupImage looks up an image by name and returns it as an ImageInspect
... ...
@@ -21,7 +21,6 @@ import (
21 21
 	"github.com/docker/docker/pkg/httputils"
22 22
 	"github.com/docker/docker/pkg/idtools"
23 23
 	"github.com/docker/docker/pkg/ioutils"
24
-	"github.com/docker/docker/pkg/progressreader"
25 24
 	"github.com/docker/docker/pkg/urlutil"
26 25
 	"github.com/docker/docker/registry"
27 26
 	"github.com/docker/docker/runconfig"
... ...
@@ -239,7 +238,7 @@ func (d Docker) Start(c *container.Container) error {
239 239
 // DetectContextFromRemoteURL returns a context and in certain cases the name of the dockerfile to be used
240 240
 // irrespective of user input.
241 241
 // progressReader is only used if remoteURL is actually a URL (not empty, and not a Git endpoint).
242
-func DetectContextFromRemoteURL(r io.ReadCloser, remoteURL string, progressReader *progressreader.Config) (context builder.ModifiableContext, dockerfileName string, err error) {
242
+func DetectContextFromRemoteURL(r io.ReadCloser, remoteURL string, createProgressReader func(in io.ReadCloser) io.ReadCloser) (context builder.ModifiableContext, dockerfileName string, err error) {
243 243
 	switch {
244 244
 	case remoteURL == "":
245 245
 		context, err = builder.MakeTarSumContext(r)
... ...
@@ -262,8 +261,7 @@ func DetectContextFromRemoteURL(r io.ReadCloser, remoteURL string, progressReade
262 262
 			},
263 263
 			// fallback handler (tar context)
264 264
 			"": func(rc io.ReadCloser) (io.ReadCloser, error) {
265
-				progressReader.In = rc
266
-				return progressReader, nil
265
+				return createProgressReader(rc), nil
267 266
 			},
268 267
 		})
269 268
 	default:
... ...
@@ -13,7 +13,7 @@ import (
13 13
 	"github.com/docker/docker/image"
14 14
 	"github.com/docker/docker/layer"
15 15
 	"github.com/docker/docker/pkg/httputils"
16
-	"github.com/docker/docker/pkg/progressreader"
16
+	"github.com/docker/docker/pkg/progress"
17 17
 	"github.com/docker/docker/pkg/streamformatter"
18 18
 	"github.com/docker/docker/runconfig"
19 19
 )
... ...
@@ -47,16 +47,8 @@ func (daemon *Daemon) ImportImage(src string, newRef reference.Named, msg string
47 47
 		if err != nil {
48 48
 			return err
49 49
 		}
50
-		progressReader := progressreader.New(progressreader.Config{
51
-			In:        resp.Body,
52
-			Out:       outStream,
53
-			Formatter: sf,
54
-			Size:      resp.ContentLength,
55
-			NewLines:  true,
56
-			ID:        "",
57
-			Action:    "Importing",
58
-		})
59
-		archive = progressReader
50
+		progressOutput := sf.NewProgressOutput(outStream, true)
51
+		archive = progress.NewProgressReader(resp.Body, progressOutput, resp.ContentLength, "", "Importing")
60 52
 	}
61 53
 
62 54
 	defer archive.Close()
... ...
@@ -23,20 +23,20 @@ func (idserv *V1IDService) namespace() string {
23 23
 }
24 24
 
25 25
 // Get finds a layer by its V1 ID.
26
-func (idserv *V1IDService) Get(v1ID, registry string) (layer.ChainID, error) {
26
+func (idserv *V1IDService) Get(v1ID, registry string) (layer.DiffID, error) {
27 27
 	if err := v1.ValidateID(v1ID); err != nil {
28
-		return layer.ChainID(""), err
28
+		return layer.DiffID(""), err
29 29
 	}
30 30
 
31 31
 	idBytes, err := idserv.store.Get(idserv.namespace(), registry+","+v1ID)
32 32
 	if err != nil {
33
-		return layer.ChainID(""), err
33
+		return layer.DiffID(""), err
34 34
 	}
35
-	return layer.ChainID(idBytes), nil
35
+	return layer.DiffID(idBytes), nil
36 36
 }
37 37
 
38 38
 // Set associates an image with a V1 ID.
39
-func (idserv *V1IDService) Set(v1ID, registry string, id layer.ChainID) error {
39
+func (idserv *V1IDService) Set(v1ID, registry string, id layer.DiffID) error {
40 40
 	if err := v1.ValidateID(v1ID); err != nil {
41 41
 		return err
42 42
 	}
... ...
@@ -24,22 +24,22 @@ func TestV1IDService(t *testing.T) {
24 24
 	testVectors := []struct {
25 25
 		registry string
26 26
 		v1ID     string
27
-		layerID  layer.ChainID
27
+		layerID  layer.DiffID
28 28
 	}{
29 29
 		{
30 30
 			registry: "registry1",
31 31
 			v1ID:     "f0cd5ca10b07f35512fc2f1cbf9a6cefbdb5cba70ac6b0c9e5988f4497f71937",
32
-			layerID:  layer.ChainID("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"),
32
+			layerID:  layer.DiffID("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"),
33 33
 		},
34 34
 		{
35 35
 			registry: "registry2",
36 36
 			v1ID:     "9e3447ca24cb96d86ebd5960cb34d1299b07e0a0e03801d90b9969a2c187dd6e",
37
-			layerID:  layer.ChainID("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa"),
37
+			layerID:  layer.DiffID("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa"),
38 38
 		},
39 39
 		{
40 40
 			registry: "registry1",
41 41
 			v1ID:     "9e3447ca24cb96d86ebd5960cb34d1299b07e0a0e03801d90b9969a2c187dd6e",
42
-			layerID:  layer.ChainID("sha256:03f4658f8b782e12230c1783426bd3bacce651ce582a4ffb6fbbfa2079428ecb"),
42
+			layerID:  layer.DiffID("sha256:03f4658f8b782e12230c1783426bd3bacce651ce582a4ffb6fbbfa2079428ecb"),
43 43
 		},
44 44
 	}
45 45
 
46 46
deleted file mode 100644
... ...
@@ -1,51 +0,0 @@
1
-package distribution
2
-
3
-import (
4
-	"sync"
5
-
6
-	"github.com/docker/docker/pkg/broadcaster"
7
-)
8
-
9
-// A Pool manages concurrent pulls. It deduplicates in-progress downloads.
10
-type Pool struct {
11
-	sync.Mutex
12
-	pullingPool map[string]*broadcaster.Buffered
13
-}
14
-
15
-// NewPool creates a new Pool.
16
-func NewPool() *Pool {
17
-	return &Pool{
18
-		pullingPool: make(map[string]*broadcaster.Buffered),
19
-	}
20
-}
21
-
22
-// add checks if a pull is already running, and returns (broadcaster, true)
23
-// if a running operation is found. Otherwise, it creates a new one and returns
24
-// (broadcaster, false).
25
-func (pool *Pool) add(key string) (*broadcaster.Buffered, bool) {
26
-	pool.Lock()
27
-	defer pool.Unlock()
28
-
29
-	if p, exists := pool.pullingPool[key]; exists {
30
-		return p, true
31
-	}
32
-
33
-	broadcaster := broadcaster.NewBuffered()
34
-	pool.pullingPool[key] = broadcaster
35
-
36
-	return broadcaster, false
37
-}
38
-
39
-func (pool *Pool) removeWithError(key string, broadcasterResult error) error {
40
-	pool.Lock()
41
-	defer pool.Unlock()
42
-	if broadcaster, exists := pool.pullingPool[key]; exists {
43
-		broadcaster.CloseWithError(broadcasterResult)
44
-		delete(pool.pullingPool, key)
45
-	}
46
-	return nil
47
-}
48
-
49
-func (pool *Pool) remove(key string) error {
50
-	return pool.removeWithError(key, nil)
51
-}
52 1
deleted file mode 100644
... ...
@@ -1,28 +0,0 @@
1
-package distribution
2
-
3
-import (
4
-	"testing"
5
-)
6
-
7
-func TestPools(t *testing.T) {
8
-	p := NewPool()
9
-
10
-	if _, found := p.add("test1"); found {
11
-		t.Fatal("Expected pull test1 not to be in progress")
12
-	}
13
-	if _, found := p.add("test2"); found {
14
-		t.Fatal("Expected pull test2 not to be in progress")
15
-	}
16
-	if _, found := p.add("test1"); !found {
17
-		t.Fatalf("Expected pull test1 to be in progress`")
18
-	}
19
-	if err := p.remove("test2"); err != nil {
20
-		t.Fatal(err)
21
-	}
22
-	if err := p.remove("test2"); err != nil {
23
-		t.Fatal(err)
24
-	}
25
-	if err := p.remove("test1"); err != nil {
26
-		t.Fatal(err)
27
-	}
28
-}
... ...
@@ -2,7 +2,7 @@ package distribution
2 2
 
3 3
 import (
4 4
 	"fmt"
5
-	"io"
5
+	"os"
6 6
 	"strings"
7 7
 
8 8
 	"github.com/Sirupsen/logrus"
... ...
@@ -10,11 +10,12 @@ import (
10 10
 	"github.com/docker/docker/cliconfig"
11 11
 	"github.com/docker/docker/daemon/events"
12 12
 	"github.com/docker/docker/distribution/metadata"
13
+	"github.com/docker/docker/distribution/xfer"
13 14
 	"github.com/docker/docker/image"
14
-	"github.com/docker/docker/layer"
15
-	"github.com/docker/docker/pkg/streamformatter"
15
+	"github.com/docker/docker/pkg/progress"
16 16
 	"github.com/docker/docker/registry"
17 17
 	"github.com/docker/docker/tag"
18
+	"golang.org/x/net/context"
18 19
 )
19 20
 
20 21
 // ImagePullConfig stores pull configuration.
... ...
@@ -25,9 +26,9 @@ type ImagePullConfig struct {
25 25
 	// AuthConfig holds authentication credentials for authenticating with
26 26
 	// the registry.
27 27
 	AuthConfig *cliconfig.AuthConfig
28
-	// OutStream is the output writer for showing the status of the pull
28
+	// ProgressOutput is the interface for showing the status of the pull
29 29
 	// operation.
30
-	OutStream io.Writer
30
+	ProgressOutput progress.Output
31 31
 	// RegistryService is the registry service to use for TLS configuration
32 32
 	// and endpoint lookup.
33 33
 	RegistryService *registry.Service
... ...
@@ -36,14 +37,12 @@ type ImagePullConfig struct {
36 36
 	// MetadataStore is the storage backend for distribution-specific
37 37
 	// metadata.
38 38
 	MetadataStore metadata.Store
39
-	// LayerStore manages layers.
40
-	LayerStore layer.Store
41 39
 	// ImageStore manages images.
42 40
 	ImageStore image.Store
43 41
 	// TagStore manages tags.
44 42
 	TagStore tag.Store
45
-	// Pool manages concurrent pulls.
46
-	Pool *Pool
43
+	// DownloadManager manages concurrent pulls.
44
+	DownloadManager *xfer.LayerDownloadManager
47 45
 }
48 46
 
49 47
 // Puller is an interface that abstracts pulling for different API versions.
... ...
@@ -51,7 +50,7 @@ type Puller interface {
51 51
 	// Pull tries to pull the image referenced by `tag`
52 52
 	// Pull returns an error if any, as well as a boolean that determines whether to retry Pull on the next configured endpoint.
53 53
 	//
54
-	Pull(ref reference.Named) (fallback bool, err error)
54
+	Pull(ctx context.Context, ref reference.Named) (fallback bool, err error)
55 55
 }
56 56
 
57 57
 // newPuller returns a Puller interface that will pull from either a v1 or v2
... ...
@@ -59,14 +58,13 @@ type Puller interface {
59 59
 // whether a v1 or v2 puller will be created. The other parameters are passed
60 60
 // through to the underlying puller implementation for use during the actual
61 61
 // pull operation.
62
-func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePullConfig *ImagePullConfig, sf *streamformatter.StreamFormatter) (Puller, error) {
62
+func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePullConfig *ImagePullConfig) (Puller, error) {
63 63
 	switch endpoint.Version {
64 64
 	case registry.APIVersion2:
65 65
 		return &v2Puller{
66 66
 			blobSumService: metadata.NewBlobSumService(imagePullConfig.MetadataStore),
67 67
 			endpoint:       endpoint,
68 68
 			config:         imagePullConfig,
69
-			sf:             sf,
70 69
 			repoInfo:       repoInfo,
71 70
 		}, nil
72 71
 	case registry.APIVersion1:
... ...
@@ -74,7 +72,6 @@ func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo,
74 74
 			v1IDService: metadata.NewV1IDService(imagePullConfig.MetadataStore),
75 75
 			endpoint:    endpoint,
76 76
 			config:      imagePullConfig,
77
-			sf:          sf,
78 77
 			repoInfo:    repoInfo,
79 78
 		}, nil
80 79
 	}
... ...
@@ -83,9 +80,7 @@ func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo,
83 83
 
84 84
 // Pull initiates a pull operation. image is the repository name to pull, and
85 85
 // tag may be either empty, or indicate a specific tag to pull.
86
-func Pull(ref reference.Named, imagePullConfig *ImagePullConfig) error {
87
-	var sf = streamformatter.NewJSONStreamFormatter()
88
-
86
+func Pull(ctx context.Context, ref reference.Named, imagePullConfig *ImagePullConfig) error {
89 87
 	// Resolve the Repository name from fqn to RepositoryInfo
90 88
 	repoInfo, err := imagePullConfig.RegistryService.ResolveRepository(ref)
91 89
 	if err != nil {
... ...
@@ -120,12 +115,19 @@ func Pull(ref reference.Named, imagePullConfig *ImagePullConfig) error {
120 120
 	for _, endpoint := range endpoints {
121 121
 		logrus.Debugf("Trying to pull %s from %s %s", repoInfo.LocalName, endpoint.URL, endpoint.Version)
122 122
 
123
-		puller, err := newPuller(endpoint, repoInfo, imagePullConfig, sf)
123
+		puller, err := newPuller(endpoint, repoInfo, imagePullConfig)
124 124
 		if err != nil {
125 125
 			errors = append(errors, err.Error())
126 126
 			continue
127 127
 		}
128
-		if fallback, err := puller.Pull(ref); err != nil {
128
+		if fallback, err := puller.Pull(ctx, ref); err != nil {
129
+			// Was this pull cancelled? If so, don't try to fall
130
+			// back.
131
+			select {
132
+			case <-ctx.Done():
133
+				fallback = false
134
+			default:
135
+			}
129 136
 			if fallback {
130 137
 				if _, ok := err.(registry.ErrNoSupport); !ok {
131 138
 					// Because we found an error that's not ErrNoSupport, discard all subsequent ErrNoSupport errors.
... ...
@@ -165,11 +167,11 @@ func Pull(ref reference.Named, imagePullConfig *ImagePullConfig) error {
165 165
 // status message indicates that a newer image was downloaded. Otherwise, it
166 166
 // indicates that the image is up to date. requestedTag is the tag the message
167 167
 // will refer to.
168
-func writeStatus(requestedTag string, out io.Writer, sf *streamformatter.StreamFormatter, layersDownloaded bool) {
168
+func writeStatus(requestedTag string, out progress.Output, layersDownloaded bool) {
169 169
 	if layersDownloaded {
170
-		out.Write(sf.FormatStatus("", "Status: Downloaded newer image for %s", requestedTag))
170
+		progress.Message(out, "", "Status: Downloaded newer image for "+requestedTag)
171 171
 	} else {
172
-		out.Write(sf.FormatStatus("", "Status: Image is up to date for %s", requestedTag))
172
+		progress.Message(out, "", "Status: Image is up to date for "+requestedTag)
173 173
 	}
174 174
 }
175 175
 
... ...
@@ -183,3 +185,16 @@ func validateRepoName(name string) error {
183 183
 	}
184 184
 	return nil
185 185
 }
186
+
187
+// tmpFileClose creates a closer function for a temporary file that closes the file
188
+// and also deletes it.
189
+func tmpFileCloser(tmpFile *os.File) func() error {
190
+	return func() error {
191
+		tmpFile.Close()
192
+		if err := os.RemoveAll(tmpFile.Name()); err != nil {
193
+			logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name())
194
+		}
195
+
196
+		return nil
197
+	}
198
+}
... ...
@@ -1,43 +1,42 @@
1 1
 package distribution
2 2
 
3 3
 import (
4
-	"encoding/json"
5 4
 	"errors"
6 5
 	"fmt"
7 6
 	"io"
7
+	"io/ioutil"
8 8
 	"net"
9 9
 	"net/url"
10 10
 	"strings"
11
-	"sync"
12 11
 	"time"
13 12
 
14 13
 	"github.com/Sirupsen/logrus"
15 14
 	"github.com/docker/distribution/reference"
16 15
 	"github.com/docker/distribution/registry/client/transport"
17 16
 	"github.com/docker/docker/distribution/metadata"
17
+	"github.com/docker/docker/distribution/xfer"
18 18
 	"github.com/docker/docker/image"
19 19
 	"github.com/docker/docker/image/v1"
20 20
 	"github.com/docker/docker/layer"
21
-	"github.com/docker/docker/pkg/archive"
22
-	"github.com/docker/docker/pkg/progressreader"
23
-	"github.com/docker/docker/pkg/streamformatter"
21
+	"github.com/docker/docker/pkg/ioutils"
22
+	"github.com/docker/docker/pkg/progress"
24 23
 	"github.com/docker/docker/pkg/stringid"
25 24
 	"github.com/docker/docker/registry"
25
+	"golang.org/x/net/context"
26 26
 )
27 27
 
28 28
 type v1Puller struct {
29 29
 	v1IDService *metadata.V1IDService
30 30
 	endpoint    registry.APIEndpoint
31 31
 	config      *ImagePullConfig
32
-	sf          *streamformatter.StreamFormatter
33 32
 	repoInfo    *registry.RepositoryInfo
34 33
 	session     *registry.Session
35 34
 }
36 35
 
37
-func (p *v1Puller) Pull(ref reference.Named) (fallback bool, err error) {
36
+func (p *v1Puller) Pull(ctx context.Context, ref reference.Named) (fallback bool, err error) {
38 37
 	if _, isDigested := ref.(reference.Digested); isDigested {
39 38
 		// Allowing fallback, because HTTPS v1 is before HTTP v2
40
-		return true, registry.ErrNoSupport{errors.New("Cannot pull by digest with v1 registry")}
39
+		return true, registry.ErrNoSupport{Err: errors.New("Cannot pull by digest with v1 registry")}
41 40
 	}
42 41
 
43 42
 	tlsConfig, err := p.config.RegistryService.TLSConfig(p.repoInfo.Index.Name)
... ...
@@ -62,19 +61,17 @@ func (p *v1Puller) Pull(ref reference.Named) (fallback bool, err error) {
62 62
 		logrus.Debugf("Fallback from error: %s", err)
63 63
 		return true, err
64 64
 	}
65
-	if err := p.pullRepository(ref); err != nil {
65
+	if err := p.pullRepository(ctx, ref); err != nil {
66 66
 		// TODO(dmcgowan): Check if should fallback
67 67
 		return false, err
68 68
 	}
69
-	out := p.config.OutStream
70
-	out.Write(p.sf.FormatStatus("", "%s: this image was pulled from a legacy registry.  Important: This registry version will not be supported in future versions of docker.", p.repoInfo.CanonicalName.Name()))
69
+	progress.Message(p.config.ProgressOutput, "", p.repoInfo.CanonicalName.Name()+": this image was pulled from a legacy registry.  Important: This registry version will not be supported in future versions of docker.")
71 70
 
72 71
 	return false, nil
73 72
 }
74 73
 
75
-func (p *v1Puller) pullRepository(ref reference.Named) error {
76
-	out := p.config.OutStream
77
-	out.Write(p.sf.FormatStatus("", "Pulling repository %s", p.repoInfo.CanonicalName.Name()))
74
+func (p *v1Puller) pullRepository(ctx context.Context, ref reference.Named) error {
75
+	progress.Message(p.config.ProgressOutput, "", "Pulling repository "+p.repoInfo.CanonicalName.Name())
78 76
 
79 77
 	repoData, err := p.session.GetRepositoryData(p.repoInfo.RemoteName)
80 78
 	if err != nil {
... ...
@@ -112,46 +109,18 @@ func (p *v1Puller) pullRepository(ref reference.Named) error {
112 112
 		}
113 113
 	}
114 114
 
115
-	errors := make(chan error)
116
-	layerDownloaded := make(chan struct{})
117
-
118 115
 	layersDownloaded := false
119
-	var wg sync.WaitGroup
120 116
 	for _, imgData := range repoData.ImgList {
121 117
 		if isTagged && imgData.Tag != tagged.Tag() {
122 118
 			continue
123 119
 		}
124 120
 
125
-		wg.Add(1)
126
-		go func(img *registry.ImgData) {
127
-			p.downloadImage(out, repoData, img, layerDownloaded, errors)
128
-			wg.Done()
129
-		}(imgData)
130
-	}
131
-
132
-	go func() {
133
-		wg.Wait()
134
-		close(errors)
135
-	}()
136
-
137
-	var lastError error
138
-selectLoop:
139
-	for {
140
-		select {
141
-		case err, ok := <-errors:
142
-			if !ok {
143
-				break selectLoop
144
-			}
145
-			lastError = err
146
-		case <-layerDownloaded:
147
-			layersDownloaded = true
121
+		err := p.downloadImage(ctx, repoData, imgData, &layersDownloaded)
122
+		if err != nil {
123
+			return err
148 124
 		}
149 125
 	}
150 126
 
151
-	if lastError != nil {
152
-		return lastError
153
-	}
154
-
155 127
 	localNameRef := p.repoInfo.LocalName
156 128
 	if isTagged {
157 129
 		localNameRef, err = reference.WithTag(localNameRef, tagged.Tag())
... ...
@@ -159,194 +128,143 @@ selectLoop:
159 159
 			localNameRef = p.repoInfo.LocalName
160 160
 		}
161 161
 	}
162
-	writeStatus(localNameRef.String(), out, p.sf, layersDownloaded)
162
+	writeStatus(localNameRef.String(), p.config.ProgressOutput, layersDownloaded)
163 163
 	return nil
164 164
 }
165 165
 
166
-func (p *v1Puller) downloadImage(out io.Writer, repoData *registry.RepositoryData, img *registry.ImgData, layerDownloaded chan struct{}, errors chan error) {
166
+func (p *v1Puller) downloadImage(ctx context.Context, repoData *registry.RepositoryData, img *registry.ImgData, layersDownloaded *bool) error {
167 167
 	if img.Tag == "" {
168 168
 		logrus.Debugf("Image (id: %s) present in this repository but untagged, skipping", img.ID)
169
-		return
169
+		return nil
170 170
 	}
171 171
 
172 172
 	localNameRef, err := reference.WithTag(p.repoInfo.LocalName, img.Tag)
173 173
 	if err != nil {
174 174
 		retErr := fmt.Errorf("Image (id: %s) has invalid tag: %s", img.ID, img.Tag)
175 175
 		logrus.Debug(retErr.Error())
176
-		errors <- retErr
176
+		return retErr
177 177
 	}
178 178
 
179 179
 	if err := v1.ValidateID(img.ID); err != nil {
180
-		errors <- err
181
-		return
180
+		return err
182 181
 	}
183 182
 
184
-	out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName.Name()), nil))
183
+	progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName.Name())
185 184
 	success := false
186 185
 	var lastErr error
187
-	var isDownloaded bool
188 186
 	for _, ep := range p.repoInfo.Index.Mirrors {
189 187
 		ep += "v1/"
190
-		out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep), nil))
191
-		if isDownloaded, err = p.pullImage(out, img.ID, ep, localNameRef); err != nil {
188
+		progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep))
189
+		if err = p.pullImage(ctx, img.ID, ep, localNameRef, layersDownloaded); err != nil {
192 190
 			// Don't report errors when pulling from mirrors.
193 191
 			logrus.Debugf("Error pulling image (%s) from %s, mirror: %s, %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep, err)
194 192
 			continue
195 193
 		}
196
-		if isDownloaded {
197
-			layerDownloaded <- struct{}{}
198
-		}
199 194
 		success = true
200 195
 		break
201 196
 	}
202 197
 	if !success {
203 198
 		for _, ep := range repoData.Endpoints {
204
-			out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep), nil))
205
-			if isDownloaded, err = p.pullImage(out, img.ID, ep, localNameRef); err != nil {
199
+			progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep)
200
+			if err = p.pullImage(ctx, img.ID, ep, localNameRef, layersDownloaded); err != nil {
206 201
 				// It's not ideal that only the last error is returned, it would be better to concatenate the errors.
207 202
 				// As the error is also given to the output stream the user will see the error.
208 203
 				lastErr = err
209
-				out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep, err), nil))
204
+				progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep, err)
210 205
 				continue
211 206
 			}
212
-			if isDownloaded {
213
-				layerDownloaded <- struct{}{}
214
-			}
215 207
 			success = true
216 208
 			break
217 209
 		}
218 210
 	}
219 211
 	if !success {
220 212
 		err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName.Name(), lastErr)
221
-		out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil))
222
-		errors <- err
223
-		return
213
+		progress.Update(p.config.ProgressOutput, stringid.TruncateID(img.ID), err.Error())
214
+		return err
224 215
 	}
225
-	out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil))
216
+	progress.Update(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Download complete")
217
+	return nil
226 218
 }
227 219
 
228
-func (p *v1Puller) pullImage(out io.Writer, v1ID, endpoint string, localNameRef reference.Named) (layersDownloaded bool, err error) {
220
+func (p *v1Puller) pullImage(ctx context.Context, v1ID, endpoint string, localNameRef reference.Named, layersDownloaded *bool) (err error) {
229 221
 	var history []string
230 222
 	history, err = p.session.GetRemoteHistory(v1ID, endpoint)
231 223
 	if err != nil {
232
-		return false, err
224
+		return err
233 225
 	}
234 226
 	if len(history) < 1 {
235
-		return false, fmt.Errorf("empty history for image %s", v1ID)
227
+		return fmt.Errorf("empty history for image %s", v1ID)
236 228
 	}
237
-	out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Pulling dependent layers", nil))
238
-	// FIXME: Try to stream the images?
239
-	// FIXME: Launch the getRemoteImage() in goroutines
229
+	progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Pulling dependent layers")
240 230
 
241 231
 	var (
242
-		referencedLayers []layer.Layer
243
-		parentID         layer.ChainID
244
-		newHistory       []image.History
245
-		img              *image.V1Image
246
-		imgJSON          []byte
247
-		imgSize          int64
232
+		descriptors []xfer.DownloadDescriptor
233
+		newHistory  []image.History
234
+		imgJSON     []byte
235
+		imgSize     int64
248 236
 	)
249 237
 
250
-	defer func() {
251
-		for _, l := range referencedLayers {
252
-			layer.ReleaseAndLog(p.config.LayerStore, l)
253
-		}
254
-	}()
255
-
256
-	layersDownloaded = false
257
-
258
-	// Iterate over layers from top-most to bottom-most, checking if any
259
-	// already exist on disk.
260
-	var i int
261
-	for i = 0; i != len(history); i++ {
262
-		v1LayerID := history[i]
263
-		// Do we have a mapping for this particular v1 ID on this
264
-		// registry?
265
-		if layerID, err := p.v1IDService.Get(v1LayerID, p.repoInfo.Index.Name); err == nil {
266
-			// Does the layer actually exist
267
-			if l, err := p.config.LayerStore.Get(layerID); err == nil {
268
-				for j := i; j >= 0; j-- {
269
-					logrus.Debugf("Layer already exists: %s", history[j])
270
-					out.Write(p.sf.FormatProgress(stringid.TruncateID(history[j]), "Already exists", nil))
271
-				}
272
-				referencedLayers = append(referencedLayers, l)
273
-				parentID = layerID
274
-				break
275
-			}
276
-		}
277
-	}
278
-
279
-	needsDownload := i
280
-
281 238
 	// Iterate over layers, in order from bottom-most to top-most. Download
282
-	// config for all layers, and download actual layer data if needed.
283
-	for i = len(history) - 1; i >= 0; i-- {
239
+	// config for all layers and create descriptors.
240
+	for i := len(history) - 1; i >= 0; i-- {
284 241
 		v1LayerID := history[i]
285
-		imgJSON, imgSize, err = p.downloadLayerConfig(out, v1LayerID, endpoint)
242
+		imgJSON, imgSize, err = p.downloadLayerConfig(v1LayerID, endpoint)
286 243
 		if err != nil {
287
-			return layersDownloaded, err
288
-		}
289
-
290
-		img = &image.V1Image{}
291
-		if err := json.Unmarshal(imgJSON, img); err != nil {
292
-			return layersDownloaded, err
293
-		}
294
-
295
-		if i < needsDownload {
296
-			l, err := p.downloadLayer(out, v1LayerID, endpoint, parentID, imgSize, &layersDownloaded)
297
-
298
-			// Note: This needs to be done even in the error case to avoid
299
-			// stale references to the layer.
300
-			if l != nil {
301
-				referencedLayers = append(referencedLayers, l)
302
-			}
303
-			if err != nil {
304
-				return layersDownloaded, err
305
-			}
306
-
307
-			parentID = l.ChainID()
244
+			return err
308 245
 		}
309 246
 
310 247
 		// Create a new-style config from the legacy configs
311 248
 		h, err := v1.HistoryFromConfig(imgJSON, false)
312 249
 		if err != nil {
313
-			return layersDownloaded, err
250
+			return err
314 251
 		}
315 252
 		newHistory = append(newHistory, h)
253
+
254
+		layerDescriptor := &v1LayerDescriptor{
255
+			v1LayerID:        v1LayerID,
256
+			indexName:        p.repoInfo.Index.Name,
257
+			endpoint:         endpoint,
258
+			v1IDService:      p.v1IDService,
259
+			layersDownloaded: layersDownloaded,
260
+			layerSize:        imgSize,
261
+			session:          p.session,
262
+		}
263
+
264
+		descriptors = append(descriptors, layerDescriptor)
316 265
 	}
317 266
 
318 267
 	rootFS := image.NewRootFS()
319
-	l := referencedLayers[len(referencedLayers)-1]
320
-	for l != nil {
321
-		rootFS.DiffIDs = append([]layer.DiffID{l.DiffID()}, rootFS.DiffIDs...)
322
-		l = l.Parent()
268
+	resultRootFS, release, err := p.config.DownloadManager.Download(ctx, *rootFS, descriptors, p.config.ProgressOutput)
269
+	if err != nil {
270
+		return err
323 271
 	}
272
+	defer release()
324 273
 
325
-	config, err := v1.MakeConfigFromV1Config(imgJSON, rootFS, newHistory)
274
+	config, err := v1.MakeConfigFromV1Config(imgJSON, &resultRootFS, newHistory)
326 275
 	if err != nil {
327
-		return layersDownloaded, err
276
+		return err
328 277
 	}
329 278
 
330 279
 	imageID, err := p.config.ImageStore.Create(config)
331 280
 	if err != nil {
332
-		return layersDownloaded, err
281
+		return err
333 282
 	}
334 283
 
335 284
 	if err := p.config.TagStore.AddTag(localNameRef, imageID, true); err != nil {
336
-		return layersDownloaded, err
285
+		return err
337 286
 	}
338 287
 
339
-	return layersDownloaded, nil
288
+	return nil
340 289
 }
341 290
 
342
-func (p *v1Puller) downloadLayerConfig(out io.Writer, v1LayerID, endpoint string) (imgJSON []byte, imgSize int64, err error) {
343
-	out.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Pulling metadata", nil))
291
+func (p *v1Puller) downloadLayerConfig(v1LayerID, endpoint string) (imgJSON []byte, imgSize int64, err error) {
292
+	progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1LayerID), "Pulling metadata")
344 293
 
345 294
 	retries := 5
346 295
 	for j := 1; j <= retries; j++ {
347 296
 		imgJSON, imgSize, err := p.session.GetRemoteImageJSON(v1LayerID, endpoint)
348 297
 		if err != nil && j == retries {
349
-			out.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Error pulling layer metadata", nil))
298
+			progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1LayerID), "Error pulling layer metadata")
350 299
 			return nil, 0, err
351 300
 		} else if err != nil {
352 301
 			time.Sleep(time.Duration(j) * 500 * time.Millisecond)
... ...
@@ -360,95 +278,66 @@ func (p *v1Puller) downloadLayerConfig(out io.Writer, v1LayerID, endpoint string
360 360
 	return nil, 0, nil
361 361
 }
362 362
 
363
-func (p *v1Puller) downloadLayer(out io.Writer, v1LayerID, endpoint string, parentID layer.ChainID, layerSize int64, layersDownloaded *bool) (l layer.Layer, err error) {
364
-	// ensure no two downloads of the same layer happen at the same time
365
-	poolKey := "layer:" + v1LayerID
366
-	broadcaster, found := p.config.Pool.add(poolKey)
367
-	broadcaster.Add(out)
368
-	if found {
369
-		logrus.Debugf("Image (id: %s) pull is already running, skipping", v1LayerID)
370
-		if err = broadcaster.Wait(); err != nil {
371
-			return nil, err
372
-		}
373
-		layerID, err := p.v1IDService.Get(v1LayerID, p.repoInfo.Index.Name)
374
-		if err != nil {
375
-			return nil, err
376
-		}
377
-		// Does the layer actually exist
378
-		l, err := p.config.LayerStore.Get(layerID)
379
-		if err != nil {
380
-			return nil, err
381
-		}
382
-		return l, nil
383
-	}
363
+type v1LayerDescriptor struct {
364
+	v1LayerID        string
365
+	indexName        string
366
+	endpoint         string
367
+	v1IDService      *metadata.V1IDService
368
+	layersDownloaded *bool
369
+	layerSize        int64
370
+	session          *registry.Session
371
+}
384 372
 
385
-	// This must use a closure so it captures the value of err when
386
-	// the function returns, not when the 'defer' is evaluated.
387
-	defer func() {
388
-		p.config.Pool.removeWithError(poolKey, err)
389
-	}()
373
+func (ld *v1LayerDescriptor) Key() string {
374
+	return "v1:" + ld.v1LayerID
375
+}
390 376
 
391
-	retries := 5
392
-	for j := 1; j <= retries; j++ {
393
-		// Get the layer
394
-		status := "Pulling fs layer"
395
-		if j > 1 {
396
-			status = fmt.Sprintf("Pulling fs layer [retries: %d]", j)
397
-		}
398
-		broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), status, nil))
399
-		layerReader, err := p.session.GetRemoteImageLayer(v1LayerID, endpoint, layerSize)
377
+func (ld *v1LayerDescriptor) ID() string {
378
+	return stringid.TruncateID(ld.v1LayerID)
379
+}
380
+
381
+func (ld *v1LayerDescriptor) DiffID() (layer.DiffID, error) {
382
+	return ld.v1IDService.Get(ld.v1LayerID, ld.indexName)
383
+}
384
+
385
+func (ld *v1LayerDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) {
386
+	progress.Update(progressOutput, ld.ID(), "Pulling fs layer")
387
+	layerReader, err := ld.session.GetRemoteImageLayer(ld.v1LayerID, ld.endpoint, ld.layerSize)
388
+	if err != nil {
389
+		progress.Update(progressOutput, ld.ID(), "Error pulling dependent layers")
400 390
 		if uerr, ok := err.(*url.Error); ok {
401 391
 			err = uerr.Err
402 392
 		}
403
-		if terr, ok := err.(net.Error); ok && terr.Timeout() && j < retries {
404
-			time.Sleep(time.Duration(j) * 500 * time.Millisecond)
405
-			continue
406
-		} else if err != nil {
407
-			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Error pulling dependent layers", nil))
408
-			return nil, err
409
-		}
410
-		*layersDownloaded = true
411
-		defer layerReader.Close()
412
-
413
-		reader := progressreader.New(progressreader.Config{
414
-			In:        layerReader,
415
-			Out:       broadcaster,
416
-			Formatter: p.sf,
417
-			Size:      layerSize,
418
-			NewLines:  false,
419
-			ID:        stringid.TruncateID(v1LayerID),
420
-			Action:    "Downloading",
421
-		})
422
-
423
-		inflatedLayerData, err := archive.DecompressStream(reader)
424
-		if err != nil {
425
-			return nil, fmt.Errorf("could not get decompression stream: %v", err)
426
-		}
427
-
428
-		l, err := p.config.LayerStore.Register(inflatedLayerData, parentID)
429
-		if err != nil {
430
-			return nil, fmt.Errorf("failed to register layer: %v", err)
393
+		if terr, ok := err.(net.Error); ok && terr.Timeout() {
394
+			return nil, 0, err
431 395
 		}
432
-		logrus.Debugf("layer %s registered successfully", l.DiffID())
396
+		return nil, 0, xfer.DoNotRetry{Err: err}
397
+	}
398
+	*ld.layersDownloaded = true
433 399
 
434
-		if terr, ok := err.(net.Error); ok && terr.Timeout() && j < retries {
435
-			time.Sleep(time.Duration(j) * 500 * time.Millisecond)
436
-			continue
437
-		} else if err != nil {
438
-			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Error downloading dependent layers", nil))
439
-			return nil, err
440
-		}
400
+	tmpFile, err := ioutil.TempFile("", "GetImageBlob")
401
+	if err != nil {
402
+		layerReader.Close()
403
+		return nil, 0, err
404
+	}
441 405
 
442
-		// Cache mapping from this v1 ID to content-addressable layer ID
443
-		if err := p.v1IDService.Set(v1LayerID, p.repoInfo.Index.Name, l.ChainID()); err != nil {
444
-			return nil, err
445
-		}
406
+	reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, layerReader), progressOutput, ld.layerSize, ld.ID(), "Downloading")
407
+	defer reader.Close()
446 408
 
447
-		broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Download complete", nil))
448
-		broadcaster.Close()
449
-		return l, nil
409
+	_, err = io.Copy(tmpFile, reader)
410
+	if err != nil {
411
+		return nil, 0, err
450 412
 	}
451 413
 
452
-	// not reached
453
-	return nil, nil
414
+	progress.Update(progressOutput, ld.ID(), "Download complete")
415
+
416
+	logrus.Debugf("Downloaded %s to tempfile %s", ld.ID(), tmpFile.Name())
417
+
418
+	tmpFile.Seek(0, 0)
419
+	return ioutils.NewReadCloserWrapper(tmpFile, tmpFileCloser(tmpFile)), ld.layerSize, nil
420
+}
421
+
422
+func (ld *v1LayerDescriptor) Registered(diffID layer.DiffID) {
423
+	// Cache mapping from this layer's DiffID to the blobsum
424
+	ld.v1IDService.Set(ld.v1LayerID, ld.indexName, diffID)
454 425
 }
... ...
@@ -15,13 +15,12 @@ import (
15 15
 	"github.com/docker/distribution/manifest/schema1"
16 16
 	"github.com/docker/distribution/reference"
17 17
 	"github.com/docker/docker/distribution/metadata"
18
+	"github.com/docker/docker/distribution/xfer"
18 19
 	"github.com/docker/docker/image"
19 20
 	"github.com/docker/docker/image/v1"
20 21
 	"github.com/docker/docker/layer"
21
-	"github.com/docker/docker/pkg/archive"
22
-	"github.com/docker/docker/pkg/broadcaster"
23
-	"github.com/docker/docker/pkg/progressreader"
24
-	"github.com/docker/docker/pkg/streamformatter"
22
+	"github.com/docker/docker/pkg/ioutils"
23
+	"github.com/docker/docker/pkg/progress"
25 24
 	"github.com/docker/docker/pkg/stringid"
26 25
 	"github.com/docker/docker/registry"
27 26
 	"golang.org/x/net/context"
... ...
@@ -31,23 +30,19 @@ type v2Puller struct {
31 31
 	blobSumService *metadata.BlobSumService
32 32
 	endpoint       registry.APIEndpoint
33 33
 	config         *ImagePullConfig
34
-	sf             *streamformatter.StreamFormatter
35 34
 	repoInfo       *registry.RepositoryInfo
36 35
 	repo           distribution.Repository
37
-	sessionID      string
38 36
 }
39 37
 
40
-func (p *v2Puller) Pull(ref reference.Named) (fallback bool, err error) {
38
+func (p *v2Puller) Pull(ctx context.Context, ref reference.Named) (fallback bool, err error) {
41 39
 	// TODO(tiborvass): was ReceiveTimeout
42 40
 	p.repo, err = NewV2Repository(p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "pull")
43 41
 	if err != nil {
44
-		logrus.Debugf("Error getting v2 registry: %v", err)
42
+		logrus.Warnf("Error getting v2 registry: %v", err)
45 43
 		return true, err
46 44
 	}
47 45
 
48
-	p.sessionID = stringid.GenerateRandomID()
49
-
50
-	if err := p.pullV2Repository(ref); err != nil {
46
+	if err := p.pullV2Repository(ctx, ref); err != nil {
51 47
 		if registry.ContinueOnError(err) {
52 48
 			logrus.Debugf("Error trying v2 registry: %v", err)
53 49
 			return true, err
... ...
@@ -57,7 +52,7 @@ func (p *v2Puller) Pull(ref reference.Named) (fallback bool, err error) {
57 57
 	return false, nil
58 58
 }
59 59
 
60
-func (p *v2Puller) pullV2Repository(ref reference.Named) (err error) {
60
+func (p *v2Puller) pullV2Repository(ctx context.Context, ref reference.Named) (err error) {
61 61
 	var refs []reference.Named
62 62
 	taggedName := p.repoInfo.LocalName
63 63
 	if tagged, isTagged := ref.(reference.Tagged); isTagged {
... ...
@@ -73,7 +68,7 @@ func (p *v2Puller) pullV2Repository(ref reference.Named) (err error) {
73 73
 		}
74 74
 		refs = []reference.Named{taggedName}
75 75
 	} else {
76
-		manSvc, err := p.repo.Manifests(context.Background())
76
+		manSvc, err := p.repo.Manifests(ctx)
77 77
 		if err != nil {
78 78
 			return err
79 79
 		}
... ...
@@ -98,98 +93,109 @@ func (p *v2Puller) pullV2Repository(ref reference.Named) (err error) {
98 98
 	for _, pullRef := range refs {
99 99
 		// pulledNew is true if either new layers were downloaded OR if existing images were newly tagged
100 100
 		// TODO(tiborvass): should we change the name of `layersDownload`? What about message in WriteStatus?
101
-		pulledNew, err := p.pullV2Tag(p.config.OutStream, pullRef)
101
+		pulledNew, err := p.pullV2Tag(ctx, pullRef)
102 102
 		if err != nil {
103 103
 			return err
104 104
 		}
105 105
 		layersDownloaded = layersDownloaded || pulledNew
106 106
 	}
107 107
 
108
-	writeStatus(taggedName.String(), p.config.OutStream, p.sf, layersDownloaded)
108
+	writeStatus(taggedName.String(), p.config.ProgressOutput, layersDownloaded)
109 109
 
110 110
 	return nil
111 111
 }
112 112
 
113
-// downloadInfo is used to pass information from download to extractor
114
-type downloadInfo struct {
115
-	tmpFile     *os.File
116
-	digest      digest.Digest
117
-	layer       distribution.ReadSeekCloser
118
-	size        int64
119
-	err         chan error
120
-	poolKey     string
121
-	broadcaster *broadcaster.Buffered
113
+type v2LayerDescriptor struct {
114
+	digest         digest.Digest
115
+	repo           distribution.Repository
116
+	blobSumService *metadata.BlobSumService
122 117
 }
123 118
 
124
-type errVerification struct{}
119
+func (ld *v2LayerDescriptor) Key() string {
120
+	return "v2:" + ld.digest.String()
121
+}
125 122
 
126
-func (errVerification) Error() string { return "verification failed" }
123
+func (ld *v2LayerDescriptor) ID() string {
124
+	return stringid.TruncateID(ld.digest.String())
125
+}
127 126
 
128
-func (p *v2Puller) download(di *downloadInfo) {
129
-	logrus.Debugf("pulling blob %q", di.digest)
127
+func (ld *v2LayerDescriptor) DiffID() (layer.DiffID, error) {
128
+	return ld.blobSumService.GetDiffID(ld.digest)
129
+}
130
+
131
+func (ld *v2LayerDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) {
132
+	logrus.Debugf("pulling blob %q", ld.digest)
130 133
 
131
-	blobs := p.repo.Blobs(context.Background())
134
+	blobs := ld.repo.Blobs(ctx)
132 135
 
133
-	layerDownload, err := blobs.Open(context.Background(), di.digest)
136
+	layerDownload, err := blobs.Open(ctx, ld.digest)
134 137
 	if err != nil {
135
-		logrus.Debugf("Error fetching layer: %v", err)
136
-		di.err <- err
137
-		return
138
+		logrus.Debugf("Error statting layer: %v", err)
139
+		if err == distribution.ErrBlobUnknown {
140
+			return nil, 0, xfer.DoNotRetry{Err: err}
141
+		}
142
+		return nil, 0, retryOnError(err)
138 143
 	}
139
-	defer layerDownload.Close()
140 144
 
141
-	di.size, err = layerDownload.Seek(0, os.SEEK_END)
145
+	size, err := layerDownload.Seek(0, os.SEEK_END)
142 146
 	if err != nil {
143 147
 		// Seek failed, perhaps because there was no Content-Length
144 148
 		// header. This shouldn't fail the download, because we can
145 149
 		// still continue without a progress bar.
146
-		di.size = 0
150
+		size = 0
147 151
 	} else {
148 152
 		// Restore the seek offset at the beginning of the stream.
149 153
 		_, err = layerDownload.Seek(0, os.SEEK_SET)
150 154
 		if err != nil {
151
-			di.err <- err
152
-			return
155
+			return nil, 0, err
153 156
 		}
154 157
 	}
155 158
 
156
-	verifier, err := digest.NewDigestVerifier(di.digest)
159
+	reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, layerDownload), progressOutput, size, ld.ID(), "Downloading")
160
+	defer reader.Close()
161
+
162
+	verifier, err := digest.NewDigestVerifier(ld.digest)
157 163
 	if err != nil {
158
-		di.err <- err
159
-		return
164
+		return nil, 0, xfer.DoNotRetry{Err: err}
160 165
 	}
161 166
 
162
-	digestStr := di.digest.String()
167
+	tmpFile, err := ioutil.TempFile("", "GetImageBlob")
168
+	if err != nil {
169
+		return nil, 0, xfer.DoNotRetry{Err: err}
170
+	}
163 171
 
164
-	reader := progressreader.New(progressreader.Config{
165
-		In:        ioutil.NopCloser(io.TeeReader(layerDownload, verifier)),
166
-		Out:       di.broadcaster,
167
-		Formatter: p.sf,
168
-		Size:      di.size,
169
-		NewLines:  false,
170
-		ID:        stringid.TruncateID(digestStr),
171
-		Action:    "Downloading",
172
-	})
173
-	io.Copy(di.tmpFile, reader)
172
+	_, err = io.Copy(tmpFile, io.TeeReader(reader, verifier))
173
+	if err != nil {
174
+		return nil, 0, retryOnError(err)
175
+	}
174 176
 
175
-	di.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(digestStr), "Verifying Checksum", nil))
177
+	progress.Update(progressOutput, ld.ID(), "Verifying Checksum")
176 178
 
177 179
 	if !verifier.Verified() {
178
-		err = fmt.Errorf("filesystem layer verification failed for digest %s", di.digest)
180
+		err = fmt.Errorf("filesystem layer verification failed for digest %s", ld.digest)
179 181
 		logrus.Error(err)
180
-		di.err <- err
181
-		return
182
+		tmpFile.Close()
183
+		if err := os.RemoveAll(tmpFile.Name()); err != nil {
184
+			logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name())
185
+		}
186
+
187
+		return nil, 0, xfer.DoNotRetry{Err: err}
182 188
 	}
183 189
 
184
-	di.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(digestStr), "Download complete", nil))
190
+	progress.Update(progressOutput, ld.ID(), "Download complete")
185 191
 
186
-	logrus.Debugf("Downloaded %s to tempfile %s", digestStr, di.tmpFile.Name())
187
-	di.layer = layerDownload
192
+	logrus.Debugf("Downloaded %s to tempfile %s", ld.ID(), tmpFile.Name())
193
+
194
+	tmpFile.Seek(0, 0)
195
+	return ioutils.NewReadCloserWrapper(tmpFile, tmpFileCloser(tmpFile)), size, nil
196
+}
188 197
 
189
-	di.err <- nil
198
+func (ld *v2LayerDescriptor) Registered(diffID layer.DiffID) {
199
+	// Cache mapping from this layer's DiffID to the blobsum
200
+	ld.blobSumService.Add(diffID, ld.digest)
190 201
 }
191 202
 
192
-func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated bool, err error) {
203
+func (p *v2Puller) pullV2Tag(ctx context.Context, ref reference.Named) (tagUpdated bool, err error) {
193 204
 	tagOrDigest := ""
194 205
 	if tagged, isTagged := ref.(reference.Tagged); isTagged {
195 206
 		tagOrDigest = tagged.Tag()
... ...
@@ -201,7 +207,7 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo
201 201
 
202 202
 	logrus.Debugf("Pulling ref from V2 registry: %q", tagOrDigest)
203 203
 
204
-	manSvc, err := p.repo.Manifests(context.Background())
204
+	manSvc, err := p.repo.Manifests(ctx)
205 205
 	if err != nil {
206 206
 		return false, err
207 207
 	}
... ...
@@ -231,33 +237,17 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo
231 231
 		return false, err
232 232
 	}
233 233
 
234
-	out.Write(p.sf.FormatStatus(tagOrDigest, "Pulling from %s", p.repo.Name()))
234
+	progress.Message(p.config.ProgressOutput, tagOrDigest, "Pulling from "+p.repo.Name())
235 235
 
236
-	var downloads []*downloadInfo
237
-
238
-	defer func() {
239
-		for _, d := range downloads {
240
-			p.config.Pool.removeWithError(d.poolKey, err)
241
-			if d.tmpFile != nil {
242
-				d.tmpFile.Close()
243
-				if err := os.RemoveAll(d.tmpFile.Name()); err != nil {
244
-					logrus.Errorf("Failed to remove temp file: %s", d.tmpFile.Name())
245
-				}
246
-			}
247
-		}
248
-	}()
236
+	var descriptors []xfer.DownloadDescriptor
249 237
 
250 238
 	// Image history converted to the new format
251 239
 	var history []image.History
252 240
 
253
-	poolKey := "v2layer:"
254
-	notFoundLocally := false
255
-
256 241
 	// Note that the order of this loop is in the direction of bottom-most
257 242
 	// to top-most, so that the downloads slice gets ordered correctly.
258 243
 	for i := len(verifiedManifest.FSLayers) - 1; i >= 0; i-- {
259 244
 		blobSum := verifiedManifest.FSLayers[i].BlobSum
260
-		poolKey += blobSum.String()
261 245
 
262 246
 		var throwAway struct {
263 247
 			ThrowAway bool `json:"throwaway,omitempty"`
... ...
@@ -276,119 +266,22 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo
276 276
 			continue
277 277
 		}
278 278
 
279
-		// Do we have a layer on disk corresponding to the set of
280
-		// blobsums up to this point?
281
-		if !notFoundLocally {
282
-			notFoundLocally = true
283
-			diffID, err := p.blobSumService.GetDiffID(blobSum)
284
-			if err == nil {
285
-				rootFS.Append(diffID)
286
-				if l, err := p.config.LayerStore.Get(rootFS.ChainID()); err == nil {
287
-					notFoundLocally = false
288
-					logrus.Debugf("Layer already exists: %s", blobSum.String())
289
-					out.Write(p.sf.FormatProgress(stringid.TruncateID(blobSum.String()), "Already exists", nil))
290
-					defer layer.ReleaseAndLog(p.config.LayerStore, l)
291
-					continue
292
-				} else {
293
-					rootFS.DiffIDs = rootFS.DiffIDs[:len(rootFS.DiffIDs)-1]
294
-				}
295
-			}
279
+		layerDescriptor := &v2LayerDescriptor{
280
+			digest:         blobSum,
281
+			repo:           p.repo,
282
+			blobSumService: p.blobSumService,
296 283
 		}
297 284
 
298
-		out.Write(p.sf.FormatProgress(stringid.TruncateID(blobSum.String()), "Pulling fs layer", nil))
299
-
300
-		tmpFile, err := ioutil.TempFile("", "GetImageBlob")
301
-		if err != nil {
302
-			return false, err
303
-		}
304
-
305
-		d := &downloadInfo{
306
-			poolKey: poolKey,
307
-			digest:  blobSum,
308
-			tmpFile: tmpFile,
309
-			// TODO: seems like this chan buffer solved hanging problem in go1.5,
310
-			// this can indicate some deeper problem that somehow we never take
311
-			// error from channel in loop below
312
-			err: make(chan error, 1),
313
-		}
314
-
315
-		downloads = append(downloads, d)
316
-
317
-		broadcaster, found := p.config.Pool.add(d.poolKey)
318
-		broadcaster.Add(out)
319
-		d.broadcaster = broadcaster
320
-		if found {
321
-			d.err <- nil
322
-		} else {
323
-			go p.download(d)
324
-		}
285
+		descriptors = append(descriptors, layerDescriptor)
325 286
 	}
326 287
 
327
-	for _, d := range downloads {
328
-		if err := <-d.err; err != nil {
329
-			return false, err
330
-		}
331
-
332
-		if d.layer == nil {
333
-			// Wait for a different pull to download and extract
334
-			// this layer.
335
-			err = d.broadcaster.Wait()
336
-			if err != nil {
337
-				return false, err
338
-			}
339
-
340
-			diffID, err := p.blobSumService.GetDiffID(d.digest)
341
-			if err != nil {
342
-				return false, err
343
-			}
344
-			rootFS.Append(diffID)
345
-
346
-			l, err := p.config.LayerStore.Get(rootFS.ChainID())
347
-			if err != nil {
348
-				return false, err
349
-			}
350
-
351
-			defer layer.ReleaseAndLog(p.config.LayerStore, l)
352
-
353
-			continue
354
-		}
355
-
356
-		d.tmpFile.Seek(0, 0)
357
-		reader := progressreader.New(progressreader.Config{
358
-			In:        d.tmpFile,
359
-			Out:       d.broadcaster,
360
-			Formatter: p.sf,
361
-			Size:      d.size,
362
-			NewLines:  false,
363
-			ID:        stringid.TruncateID(d.digest.String()),
364
-			Action:    "Extracting",
365
-		})
366
-
367
-		inflatedLayerData, err := archive.DecompressStream(reader)
368
-		if err != nil {
369
-			return false, fmt.Errorf("could not get decompression stream: %v", err)
370
-		}
371
-
372
-		l, err := p.config.LayerStore.Register(inflatedLayerData, rootFS.ChainID())
373
-		if err != nil {
374
-			return false, fmt.Errorf("failed to register layer: %v", err)
375
-		}
376
-		logrus.Debugf("layer %s registered successfully", l.DiffID())
377
-		rootFS.Append(l.DiffID())
378
-
379
-		// Cache mapping from this layer's DiffID to the blobsum
380
-		if err := p.blobSumService.Add(l.DiffID(), d.digest); err != nil {
381
-			return false, err
382
-		}
383
-
384
-		defer layer.ReleaseAndLog(p.config.LayerStore, l)
385
-
386
-		d.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(d.digest.String()), "Pull complete", nil))
387
-		d.broadcaster.Close()
388
-		tagUpdated = true
288
+	resultRootFS, release, err := p.config.DownloadManager.Download(ctx, *rootFS, descriptors, p.config.ProgressOutput)
289
+	if err != nil {
290
+		return false, err
389 291
 	}
292
+	defer release()
390 293
 
391
-	config, err := v1.MakeConfigFromV1Config([]byte(verifiedManifest.History[0].V1Compatibility), rootFS, history)
294
+	config, err := v1.MakeConfigFromV1Config([]byte(verifiedManifest.History[0].V1Compatibility), &resultRootFS, history)
392 295
 	if err != nil {
393 296
 		return false, err
394 297
 	}
... ...
@@ -403,30 +296,24 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo
403 403
 		return false, err
404 404
 	}
405 405
 
406
-	// Check for new tag if no layers downloaded
407
-	var oldTagImageID image.ID
408
-	if !tagUpdated {
409
-		oldTagImageID, err = p.config.TagStore.Get(ref)
410
-		if err != nil || oldTagImageID != imageID {
411
-			tagUpdated = true
412
-		}
406
+	if manifestDigest != "" {
407
+		progress.Message(p.config.ProgressOutput, "", "Digest: "+manifestDigest.String())
413 408
 	}
414 409
 
415
-	if tagUpdated {
416
-		if canonical, ok := ref.(reference.Canonical); ok {
417
-			if err = p.config.TagStore.AddDigest(canonical, imageID, true); err != nil {
418
-				return false, err
419
-			}
420
-		} else if err = p.config.TagStore.AddTag(ref, imageID, true); err != nil {
421
-			return false, err
422
-		}
410
+	oldTagImageID, err := p.config.TagStore.Get(ref)
411
+	if err == nil && oldTagImageID == imageID {
412
+		return false, nil
423 413
 	}
424 414
 
425
-	if manifestDigest != "" {
426
-		out.Write(p.sf.FormatStatus("", "Digest: %s", manifestDigest))
415
+	if canonical, ok := ref.(reference.Canonical); ok {
416
+		if err = p.config.TagStore.AddDigest(canonical, imageID, true); err != nil {
417
+			return false, err
418
+		}
419
+	} else if err = p.config.TagStore.AddTag(ref, imageID, true); err != nil {
420
+		return false, err
427 421
 	}
428 422
 
429
-	return tagUpdated, nil
423
+	return true, nil
430 424
 }
431 425
 
432 426
 func verifyManifest(signedManifest *schema1.SignedManifest, ref reference.Reference) (m *schema1.Manifest, err error) {
... ...
@@ -12,12 +12,14 @@ import (
12 12
 	"github.com/docker/docker/cliconfig"
13 13
 	"github.com/docker/docker/daemon/events"
14 14
 	"github.com/docker/docker/distribution/metadata"
15
+	"github.com/docker/docker/distribution/xfer"
15 16
 	"github.com/docker/docker/image"
16 17
 	"github.com/docker/docker/layer"
17
-	"github.com/docker/docker/pkg/streamformatter"
18
+	"github.com/docker/docker/pkg/progress"
18 19
 	"github.com/docker/docker/registry"
19 20
 	"github.com/docker/docker/tag"
20 21
 	"github.com/docker/libtrust"
22
+	"golang.org/x/net/context"
21 23
 )
22 24
 
23 25
 // ImagePushConfig stores push configuration.
... ...
@@ -28,9 +30,9 @@ type ImagePushConfig struct {
28 28
 	// AuthConfig holds authentication credentials for authenticating with
29 29
 	// the registry.
30 30
 	AuthConfig *cliconfig.AuthConfig
31
-	// OutStream is the output writer for showing the status of the push
31
+	// ProgressOutput is the interface for showing the status of the push
32 32
 	// operation.
33
-	OutStream io.Writer
33
+	ProgressOutput progress.Output
34 34
 	// RegistryService is the registry service to use for TLS configuration
35 35
 	// and endpoint lookup.
36 36
 	RegistryService *registry.Service
... ...
@@ -48,6 +50,8 @@ type ImagePushConfig struct {
48 48
 	// TrustKey is the private key for legacy signatures. This is typically
49 49
 	// an ephemeral key, since these signatures are no longer verified.
50 50
 	TrustKey libtrust.PrivateKey
51
+	// UploadManager dispatches uploads.
52
+	UploadManager *xfer.LayerUploadManager
51 53
 }
52 54
 
53 55
 // Pusher is an interface that abstracts pushing for different API versions.
... ...
@@ -56,7 +60,7 @@ type Pusher interface {
56 56
 	// Push returns an error if any, as well as a boolean that determines whether to retry Push on the next configured endpoint.
57 57
 	//
58 58
 	// TODO(tiborvass): have Push() take a reference to repository + tag, so that the pusher itself is repository-agnostic.
59
-	Push() (fallback bool, err error)
59
+	Push(ctx context.Context) (fallback bool, err error)
60 60
 }
61 61
 
62 62
 const compressionBufSize = 32768
... ...
@@ -66,7 +70,7 @@ const compressionBufSize = 32768
66 66
 // whether a v1 or v2 pusher will be created. The other parameters are passed
67 67
 // through to the underlying pusher implementation for use during the actual
68 68
 // push operation.
69
-func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePushConfig *ImagePushConfig, sf *streamformatter.StreamFormatter) (Pusher, error) {
69
+func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePushConfig *ImagePushConfig) (Pusher, error) {
70 70
 	switch endpoint.Version {
71 71
 	case registry.APIVersion2:
72 72
 		return &v2Pusher{
... ...
@@ -75,8 +79,7 @@ func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *reg
75 75
 			endpoint:       endpoint,
76 76
 			repoInfo:       repoInfo,
77 77
 			config:         imagePushConfig,
78
-			sf:             sf,
79
-			layersPushed:   make(map[digest.Digest]bool),
78
+			layersPushed:   pushMap{layersPushed: make(map[digest.Digest]bool)},
80 79
 		}, nil
81 80
 	case registry.APIVersion1:
82 81
 		return &v1Pusher{
... ...
@@ -85,7 +88,6 @@ func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *reg
85 85
 			endpoint:    endpoint,
86 86
 			repoInfo:    repoInfo,
87 87
 			config:      imagePushConfig,
88
-			sf:          sf,
89 88
 		}, nil
90 89
 	}
91 90
 	return nil, fmt.Errorf("unknown version %d for registry %s", endpoint.Version, endpoint.URL)
... ...
@@ -94,11 +96,9 @@ func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *reg
94 94
 // Push initiates a push operation on the repository named localName.
95 95
 // ref is the specific variant of the image to be pushed.
96 96
 // If no tag is provided, all tags will be pushed.
97
-func Push(ref reference.Named, imagePushConfig *ImagePushConfig) error {
97
+func Push(ctx context.Context, ref reference.Named, imagePushConfig *ImagePushConfig) error {
98 98
 	// FIXME: Allow to interrupt current push when new push of same image is done.
99 99
 
100
-	var sf = streamformatter.NewJSONStreamFormatter()
101
-
102 100
 	// Resolve the Repository name from fqn to RepositoryInfo
103 101
 	repoInfo, err := imagePushConfig.RegistryService.ResolveRepository(ref)
104 102
 	if err != nil {
... ...
@@ -110,7 +110,7 @@ func Push(ref reference.Named, imagePushConfig *ImagePushConfig) error {
110 110
 		return err
111 111
 	}
112 112
 
113
-	imagePushConfig.OutStream.Write(sf.FormatStatus("", "The push refers to a repository [%s]", repoInfo.CanonicalName))
113
+	progress.Messagef(imagePushConfig.ProgressOutput, "", "The push refers to a repository [%s]", repoInfo.CanonicalName.String())
114 114
 
115 115
 	associations := imagePushConfig.TagStore.ReferencesByName(repoInfo.LocalName)
116 116
 	if len(associations) == 0 {
... ...
@@ -121,12 +121,20 @@ func Push(ref reference.Named, imagePushConfig *ImagePushConfig) error {
121 121
 	for _, endpoint := range endpoints {
122 122
 		logrus.Debugf("Trying to push %s to %s %s", repoInfo.CanonicalName, endpoint.URL, endpoint.Version)
123 123
 
124
-		pusher, err := NewPusher(ref, endpoint, repoInfo, imagePushConfig, sf)
124
+		pusher, err := NewPusher(ref, endpoint, repoInfo, imagePushConfig)
125 125
 		if err != nil {
126 126
 			lastErr = err
127 127
 			continue
128 128
 		}
129
-		if fallback, err := pusher.Push(); err != nil {
129
+		if fallback, err := pusher.Push(ctx); err != nil {
130
+			// Was this push cancelled? If so, don't try to fall
131
+			// back.
132
+			select {
133
+			case <-ctx.Done():
134
+				fallback = false
135
+			default:
136
+			}
137
+
130 138
 			if fallback {
131 139
 				lastErr = err
132 140
 				continue
... ...
@@ -2,8 +2,6 @@ package distribution
2 2
 
3 3
 import (
4 4
 	"fmt"
5
-	"io"
6
-	"io/ioutil"
7 5
 	"sync"
8 6
 
9 7
 	"github.com/Sirupsen/logrus"
... ...
@@ -15,25 +13,23 @@ import (
15 15
 	"github.com/docker/docker/image/v1"
16 16
 	"github.com/docker/docker/layer"
17 17
 	"github.com/docker/docker/pkg/ioutils"
18
-	"github.com/docker/docker/pkg/progressreader"
19
-	"github.com/docker/docker/pkg/streamformatter"
18
+	"github.com/docker/docker/pkg/progress"
20 19
 	"github.com/docker/docker/pkg/stringid"
21 20
 	"github.com/docker/docker/registry"
21
+	"golang.org/x/net/context"
22 22
 )
23 23
 
24 24
 type v1Pusher struct {
25
+	ctx         context.Context
25 26
 	v1IDService *metadata.V1IDService
26 27
 	endpoint    registry.APIEndpoint
27 28
 	ref         reference.Named
28 29
 	repoInfo    *registry.RepositoryInfo
29 30
 	config      *ImagePushConfig
30
-	sf          *streamformatter.StreamFormatter
31 31
 	session     *registry.Session
32
-
33
-	out io.Writer
34 32
 }
35 33
 
36
-func (p *v1Pusher) Push() (fallback bool, err error) {
34
+func (p *v1Pusher) Push(ctx context.Context) (fallback bool, err error) {
37 35
 	tlsConfig, err := p.config.RegistryService.TLSConfig(p.repoInfo.Index.Name)
38 36
 	if err != nil {
39 37
 		return false, err
... ...
@@ -55,7 +51,7 @@ func (p *v1Pusher) Push() (fallback bool, err error) {
55 55
 		// TODO(dmcgowan): Check if should fallback
56 56
 		return true, err
57 57
 	}
58
-	if err := p.pushRepository(); err != nil {
58
+	if err := p.pushRepository(ctx); err != nil {
59 59
 		// TODO(dmcgowan): Check if should fallback
60 60
 		return false, err
61 61
 	}
... ...
@@ -306,12 +302,12 @@ func (p *v1Pusher) lookupImageOnEndpoint(wg *sync.WaitGroup, endpoint string, im
306 306
 			logrus.Errorf("Error in LookupRemoteImage: %s", err)
307 307
 			imagesToPush <- v1ID
308 308
 		} else {
309
-			p.out.Write(p.sf.FormatStatus("", "Image %s already pushed, skipping", stringid.TruncateID(v1ID)))
309
+			progress.Messagef(p.config.ProgressOutput, "", "Image %s already pushed, skipping", stringid.TruncateID(v1ID))
310 310
 		}
311 311
 	}
312 312
 }
313 313
 
314
-func (p *v1Pusher) pushImageToEndpoint(endpoint string, imageList []v1Image, tags map[image.ID][]string, repo *registry.RepositoryData) error {
314
+func (p *v1Pusher) pushImageToEndpoint(ctx context.Context, endpoint string, imageList []v1Image, tags map[image.ID][]string, repo *registry.RepositoryData) error {
315 315
 	workerCount := len(imageList)
316 316
 	// start a maximum of 5 workers to check if images exist on the specified endpoint.
317 317
 	if workerCount > 5 {
... ...
@@ -349,14 +345,14 @@ func (p *v1Pusher) pushImageToEndpoint(endpoint string, imageList []v1Image, tag
349 349
 	for _, img := range imageList {
350 350
 		v1ID := img.V1ID()
351 351
 		if _, push := shouldPush[v1ID]; push {
352
-			if _, err := p.pushImage(img, endpoint); err != nil {
352
+			if _, err := p.pushImage(ctx, img, endpoint); err != nil {
353 353
 				// FIXME: Continue on error?
354 354
 				return err
355 355
 			}
356 356
 		}
357 357
 		if topImage, isTopImage := img.(*v1TopImage); isTopImage {
358 358
 			for _, tag := range tags[topImage.imageID] {
359
-				p.out.Write(p.sf.FormatStatus("", "Pushing tag for rev [%s] on {%s}", stringid.TruncateID(v1ID), endpoint+"repositories/"+p.repoInfo.RemoteName.Name()+"/tags/"+tag))
359
+				progress.Messagef(p.config.ProgressOutput, "", "Pushing tag for rev [%s] on {%s}", stringid.TruncateID(v1ID), endpoint+"repositories/"+p.repoInfo.RemoteName.Name()+"/tags/"+tag)
360 360
 				if err := p.session.PushRegistryTag(p.repoInfo.RemoteName, v1ID, tag, endpoint); err != nil {
361 361
 					return err
362 362
 				}
... ...
@@ -367,8 +363,7 @@ func (p *v1Pusher) pushImageToEndpoint(endpoint string, imageList []v1Image, tag
367 367
 }
368 368
 
369 369
 // pushRepository pushes layers that do not already exist on the registry.
370
-func (p *v1Pusher) pushRepository() error {
371
-	p.out = ioutils.NewWriteFlusher(p.config.OutStream)
370
+func (p *v1Pusher) pushRepository(ctx context.Context) error {
372 371
 	imgList, tags, referencedLayers, err := p.getImageList()
373 372
 	defer func() {
374 373
 		for _, l := range referencedLayers {
... ...
@@ -378,7 +373,7 @@ func (p *v1Pusher) pushRepository() error {
378 378
 	if err != nil {
379 379
 		return err
380 380
 	}
381
-	p.out.Write(p.sf.FormatStatus("", "Sending image list"))
381
+	progress.Message(p.config.ProgressOutput, "", "Sending image list")
382 382
 
383 383
 	imageIndex := createImageIndex(imgList, tags)
384 384
 	for _, data := range imageIndex {
... ...
@@ -391,10 +386,10 @@ func (p *v1Pusher) pushRepository() error {
391 391
 	if err != nil {
392 392
 		return err
393 393
 	}
394
-	p.out.Write(p.sf.FormatStatus("", "Pushing repository %s", p.repoInfo.CanonicalName))
394
+	progress.Message(p.config.ProgressOutput, "", "Pushing repository "+p.repoInfo.CanonicalName.String())
395 395
 	// push the repository to each of the endpoints only if it does not exist.
396 396
 	for _, endpoint := range repoData.Endpoints {
397
-		if err := p.pushImageToEndpoint(endpoint, imgList, tags, repoData); err != nil {
397
+		if err := p.pushImageToEndpoint(ctx, endpoint, imgList, tags, repoData); err != nil {
398 398
 			return err
399 399
 		}
400 400
 	}
... ...
@@ -402,11 +397,11 @@ func (p *v1Pusher) pushRepository() error {
402 402
 	return err
403 403
 }
404 404
 
405
-func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err error) {
405
+func (p *v1Pusher) pushImage(ctx context.Context, v1Image v1Image, ep string) (checksum string, err error) {
406 406
 	v1ID := v1Image.V1ID()
407 407
 
408 408
 	jsonRaw := v1Image.Config()
409
-	p.out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Pushing", nil))
409
+	progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Pushing")
410 410
 
411 411
 	// General rule is to use ID for graph accesses and compatibilityID for
412 412
 	// calls to session.registry()
... ...
@@ -417,7 +412,7 @@ func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err e
417 417
 	// Send the json
418 418
 	if err := p.session.PushImageJSONRegistry(imgData, jsonRaw, ep); err != nil {
419 419
 		if err == registry.ErrAlreadyExists {
420
-			p.out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Image already pushed, skipping", nil))
420
+			progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Image already pushed, skipping")
421 421
 			return "", nil
422 422
 		}
423 423
 		return "", err
... ...
@@ -437,15 +432,8 @@ func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err e
437 437
 	// Send the layer
438 438
 	logrus.Debugf("rendered layer for %s of [%d] size", v1ID, size)
439 439
 
440
-	reader := progressreader.New(progressreader.Config{
441
-		In:        ioutil.NopCloser(arch),
442
-		Out:       p.out,
443
-		Formatter: p.sf,
444
-		Size:      size,
445
-		NewLines:  false,
446
-		ID:        stringid.TruncateID(v1ID),
447
-		Action:    "Pushing",
448
-	})
440
+	reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, arch), p.config.ProgressOutput, size, stringid.TruncateID(v1ID), "Pushing")
441
+	defer reader.Close()
449 442
 
450 443
 	checksum, checksumPayload, err := p.session.PushImageLayerRegistry(v1ID, reader, ep, jsonRaw)
451 444
 	if err != nil {
... ...
@@ -458,10 +446,10 @@ func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err e
458 458
 		return "", err
459 459
 	}
460 460
 
461
-	if err := p.v1IDService.Set(v1ID, p.repoInfo.Index.Name, l.ChainID()); err != nil {
461
+	if err := p.v1IDService.Set(v1ID, p.repoInfo.Index.Name, l.DiffID()); err != nil {
462 462
 		logrus.Warnf("Could not set v1 ID mapping: %v", err)
463 463
 	}
464 464
 
465
-	p.out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Image successfully pushed", nil))
465
+	progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Image successfully pushed")
466 466
 	return imgData.Checksum, nil
467 467
 }
... ...
@@ -5,7 +5,7 @@ import (
5 5
 	"errors"
6 6
 	"fmt"
7 7
 	"io"
8
-	"io/ioutil"
8
+	"sync"
9 9
 	"time"
10 10
 
11 11
 	"github.com/Sirupsen/logrus"
... ...
@@ -15,11 +15,12 @@ import (
15 15
 	"github.com/docker/distribution/manifest/schema1"
16 16
 	"github.com/docker/distribution/reference"
17 17
 	"github.com/docker/docker/distribution/metadata"
18
+	"github.com/docker/docker/distribution/xfer"
18 19
 	"github.com/docker/docker/image"
19 20
 	"github.com/docker/docker/image/v1"
20 21
 	"github.com/docker/docker/layer"
21
-	"github.com/docker/docker/pkg/progressreader"
22
-	"github.com/docker/docker/pkg/streamformatter"
22
+	"github.com/docker/docker/pkg/ioutils"
23
+	"github.com/docker/docker/pkg/progress"
23 24
 	"github.com/docker/docker/pkg/stringid"
24 25
 	"github.com/docker/docker/registry"
25 26
 	"github.com/docker/docker/tag"
... ...
@@ -32,16 +33,20 @@ type v2Pusher struct {
32 32
 	endpoint       registry.APIEndpoint
33 33
 	repoInfo       *registry.RepositoryInfo
34 34
 	config         *ImagePushConfig
35
-	sf             *streamformatter.StreamFormatter
36 35
 	repo           distribution.Repository
37 36
 
38 37
 	// layersPushed is the set of layers known to exist on the remote side.
39 38
 	// This avoids redundant queries when pushing multiple tags that
40 39
 	// involve the same layers.
40
+	layersPushed pushMap
41
+}
42
+
43
+type pushMap struct {
44
+	sync.Mutex
41 45
 	layersPushed map[digest.Digest]bool
42 46
 }
43 47
 
44
-func (p *v2Pusher) Push() (fallback bool, err error) {
48
+func (p *v2Pusher) Push(ctx context.Context) (fallback bool, err error) {
45 49
 	p.repo, err = NewV2Repository(p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "push", "pull")
46 50
 	if err != nil {
47 51
 		logrus.Debugf("Error getting v2 registry: %v", err)
... ...
@@ -75,7 +80,7 @@ func (p *v2Pusher) Push() (fallback bool, err error) {
75 75
 	}
76 76
 
77 77
 	for _, association := range associations {
78
-		if err := p.pushV2Tag(association); err != nil {
78
+		if err := p.pushV2Tag(ctx, association); err != nil {
79 79
 			return false, err
80 80
 		}
81 81
 	}
... ...
@@ -83,7 +88,7 @@ func (p *v2Pusher) Push() (fallback bool, err error) {
83 83
 	return false, nil
84 84
 }
85 85
 
86
-func (p *v2Pusher) pushV2Tag(association tag.Association) error {
86
+func (p *v2Pusher) pushV2Tag(ctx context.Context, association tag.Association) error {
87 87
 	ref := association.Ref
88 88
 	logrus.Debugf("Pushing repository: %s", ref.String())
89 89
 
... ...
@@ -92,8 +97,6 @@ func (p *v2Pusher) pushV2Tag(association tag.Association) error {
92 92
 		return fmt.Errorf("could not find image from tag %s: %v", ref.String(), err)
93 93
 	}
94 94
 
95
-	out := p.config.OutStream
96
-
97 95
 	var l layer.Layer
98 96
 
99 97
 	topLayerID := img.RootFS.ChainID()
... ...
@@ -107,33 +110,41 @@ func (p *v2Pusher) pushV2Tag(association tag.Association) error {
107 107
 		defer layer.ReleaseAndLog(p.config.LayerStore, l)
108 108
 	}
109 109
 
110
-	fsLayers := make(map[layer.DiffID]schema1.FSLayer)
110
+	var descriptors []xfer.UploadDescriptor
111 111
 
112 112
 	// Push empty layer if necessary
113 113
 	for _, h := range img.History {
114 114
 		if h.EmptyLayer {
115
-			dgst, err := p.pushLayerIfNecessary(out, layer.EmptyLayer)
116
-			if err != nil {
117
-				return err
115
+			descriptors = []xfer.UploadDescriptor{
116
+				&v2PushDescriptor{
117
+					layer:          layer.EmptyLayer,
118
+					blobSumService: p.blobSumService,
119
+					repo:           p.repo,
120
+					layersPushed:   &p.layersPushed,
121
+				},
118 122
 			}
119
-			p.layersPushed[dgst] = true
120
-			fsLayers[layer.EmptyLayer.DiffID()] = schema1.FSLayer{BlobSum: dgst}
121 123
 			break
122 124
 		}
123 125
 	}
124 126
 
127
+	// Loop bounds condition is to avoid pushing the base layer on Windows.
125 128
 	for i := 0; i < len(img.RootFS.DiffIDs); i++ {
126
-		dgst, err := p.pushLayerIfNecessary(out, l)
127
-		if err != nil {
128
-			return err
129
+		descriptor := &v2PushDescriptor{
130
+			layer:          l,
131
+			blobSumService: p.blobSumService,
132
+			repo:           p.repo,
133
+			layersPushed:   &p.layersPushed,
129 134
 		}
130
-
131
-		p.layersPushed[dgst] = true
132
-		fsLayers[l.DiffID()] = schema1.FSLayer{BlobSum: dgst}
135
+		descriptors = append(descriptors, descriptor)
133 136
 
134 137
 		l = l.Parent()
135 138
 	}
136 139
 
140
+	fsLayers, err := p.config.UploadManager.Upload(ctx, descriptors, p.config.ProgressOutput)
141
+	if err != nil {
142
+		return err
143
+	}
144
+
137 145
 	var tag string
138 146
 	if tagged, isTagged := ref.(reference.Tagged); isTagged {
139 147
 		tag = tagged.Tag()
... ...
@@ -157,59 +168,124 @@ func (p *v2Pusher) pushV2Tag(association tag.Association) error {
157 157
 		if tagged, isTagged := ref.(reference.Tagged); isTagged {
158 158
 			// NOTE: do not change this format without first changing the trust client
159 159
 			// code. This information is used to determine what was pushed and should be signed.
160
-			out.Write(p.sf.FormatStatus("", "%s: digest: %s size: %d", tagged.Tag(), manifestDigest, manifestSize))
160
+			progress.Messagef(p.config.ProgressOutput, "", "%s: digest: %s size: %d", tagged.Tag(), manifestDigest, manifestSize)
161 161
 		}
162 162
 	}
163 163
 
164
-	manSvc, err := p.repo.Manifests(context.Background())
164
+	manSvc, err := p.repo.Manifests(ctx)
165 165
 	if err != nil {
166 166
 		return err
167 167
 	}
168 168
 	return manSvc.Put(signed)
169 169
 }
170 170
 
171
-func (p *v2Pusher) pushLayerIfNecessary(out io.Writer, l layer.Layer) (digest.Digest, error) {
172
-	logrus.Debugf("Pushing layer: %s", l.DiffID())
171
+type v2PushDescriptor struct {
172
+	layer          layer.Layer
173
+	blobSumService *metadata.BlobSumService
174
+	repo           distribution.Repository
175
+	layersPushed   *pushMap
176
+}
177
+
178
+func (pd *v2PushDescriptor) Key() string {
179
+	return "v2push:" + pd.repo.Name() + " " + pd.layer.DiffID().String()
180
+}
181
+
182
+func (pd *v2PushDescriptor) ID() string {
183
+	return stringid.TruncateID(pd.layer.DiffID().String())
184
+}
185
+
186
+func (pd *v2PushDescriptor) DiffID() layer.DiffID {
187
+	return pd.layer.DiffID()
188
+}
189
+
190
+func (pd *v2PushDescriptor) Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error) {
191
+	diffID := pd.DiffID()
192
+
193
+	logrus.Debugf("Pushing layer: %s", diffID)
173 194
 
174 195
 	// Do we have any blobsums associated with this layer's DiffID?
175
-	possibleBlobsums, err := p.blobSumService.GetBlobSums(l.DiffID())
196
+	possibleBlobsums, err := pd.blobSumService.GetBlobSums(diffID)
176 197
 	if err == nil {
177
-		dgst, exists, err := p.blobSumAlreadyExists(possibleBlobsums)
198
+		dgst, exists, err := blobSumAlreadyExists(ctx, possibleBlobsums, pd.repo, pd.layersPushed)
178 199
 		if err != nil {
179
-			out.Write(p.sf.FormatProgress(stringid.TruncateID(string(l.DiffID())), "Image push failed", nil))
180
-			return "", err
200
+			progress.Update(progressOutput, pd.ID(), "Image push failed")
201
+			return "", retryOnError(err)
181 202
 		}
182 203
 		if exists {
183
-			out.Write(p.sf.FormatProgress(stringid.TruncateID(string(l.DiffID())), "Layer already exists", nil))
204
+			progress.Update(progressOutput, pd.ID(), "Layer already exists")
184 205
 			return dgst, nil
185 206
 		}
186 207
 	}
187 208
 
188 209
 	// if digest was empty or not saved, or if blob does not exist on the remote repository,
189 210
 	// then push the blob.
190
-	pushDigest, err := p.pushV2Layer(p.repo.Blobs(context.Background()), l)
211
+	bs := pd.repo.Blobs(ctx)
212
+
213
+	// Send the layer
214
+	layerUpload, err := bs.Create(ctx)
215
+	if err != nil {
216
+		return "", retryOnError(err)
217
+	}
218
+	defer layerUpload.Close()
219
+
220
+	arch, err := pd.layer.TarStream()
191 221
 	if err != nil {
192
-		return "", err
222
+		return "", xfer.DoNotRetry{Err: err}
193 223
 	}
224
+
225
+	// don't care if this fails; best effort
226
+	size, _ := pd.layer.DiffSize()
227
+
228
+	reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, arch), progressOutput, size, pd.ID(), "Pushing")
229
+	defer reader.Close()
230
+	compressedReader := compress(reader)
231
+
232
+	digester := digest.Canonical.New()
233
+	tee := io.TeeReader(compressedReader, digester.Hash())
234
+
235
+	nn, err := layerUpload.ReadFrom(tee)
236
+	compressedReader.Close()
237
+	if err != nil {
238
+		return "", retryOnError(err)
239
+	}
240
+
241
+	pushDigest := digester.Digest()
242
+	if _, err := layerUpload.Commit(ctx, distribution.Descriptor{Digest: pushDigest}); err != nil {
243
+		return "", retryOnError(err)
244
+	}
245
+
246
+	logrus.Debugf("uploaded layer %s (%s), %d bytes", diffID, pushDigest, nn)
247
+	progress.Update(progressOutput, pd.ID(), "Pushed")
248
+
194 249
 	// Cache mapping from this layer's DiffID to the blobsum
195
-	if err := p.blobSumService.Add(l.DiffID(), pushDigest); err != nil {
196
-		return "", err
250
+	if err := pd.blobSumService.Add(diffID, pushDigest); err != nil {
251
+		return "", xfer.DoNotRetry{Err: err}
197 252
 	}
198 253
 
254
+	pd.layersPushed.Lock()
255
+	pd.layersPushed.layersPushed[pushDigest] = true
256
+	pd.layersPushed.Unlock()
257
+
199 258
 	return pushDigest, nil
200 259
 }
201 260
 
202 261
 // blobSumAlreadyExists checks if the registry already know about any of the
203 262
 // blobsums passed in the "blobsums" slice. If it finds one that the registry
204 263
 // knows about, it returns the known digest and "true".
205
-func (p *v2Pusher) blobSumAlreadyExists(blobsums []digest.Digest) (digest.Digest, bool, error) {
264
+func blobSumAlreadyExists(ctx context.Context, blobsums []digest.Digest, repo distribution.Repository, layersPushed *pushMap) (digest.Digest, bool, error) {
265
+	layersPushed.Lock()
206 266
 	for _, dgst := range blobsums {
207
-		if p.layersPushed[dgst] {
267
+		if layersPushed.layersPushed[dgst] {
208 268
 			// it is already known that the push is not needed and
209 269
 			// therefore doing a stat is unnecessary
270
+			layersPushed.Unlock()
210 271
 			return dgst, true, nil
211 272
 		}
212
-		_, err := p.repo.Blobs(context.Background()).Stat(context.Background(), dgst)
273
+	}
274
+	layersPushed.Unlock()
275
+
276
+	for _, dgst := range blobsums {
277
+		_, err := repo.Blobs(ctx).Stat(ctx, dgst)
213 278
 		switch err {
214 279
 		case nil:
215 280
 			return dgst, true, nil
... ...
@@ -226,7 +302,7 @@ func (p *v2Pusher) blobSumAlreadyExists(blobsums []digest.Digest) (digest.Digest
226 226
 // FSLayer digests.
227 227
 // FIXME: This should be moved to the distribution repo, since it will also
228 228
 // be useful for converting new manifests to the old format.
229
-func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.DiffID]schema1.FSLayer) (*schema1.Manifest, error) {
229
+func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.DiffID]digest.Digest) (*schema1.Manifest, error) {
230 230
 	if len(img.History) == 0 {
231 231
 		return nil, errors.New("empty history when trying to create V2 manifest")
232 232
 	}
... ...
@@ -271,7 +347,7 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif
271 271
 		if !present {
272 272
 			return nil, fmt.Errorf("missing layer in CreateV2Manifest: %s", diffID.String())
273 273
 		}
274
-		dgst, err := digest.FromBytes([]byte(fsLayer.BlobSum.Hex() + " " + parent))
274
+		dgst, err := digest.FromBytes([]byte(fsLayer.Hex() + " " + parent))
275 275
 		if err != nil {
276 276
 			return nil, err
277 277
 		}
... ...
@@ -294,7 +370,7 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif
294 294
 
295 295
 		reversedIndex := len(img.History) - i - 1
296 296
 		history[reversedIndex].V1Compatibility = string(jsonBytes)
297
-		fsLayerList[reversedIndex] = fsLayer
297
+		fsLayerList[reversedIndex] = schema1.FSLayer{BlobSum: fsLayer}
298 298
 
299 299
 		parent = v1ID
300 300
 	}
... ...
@@ -315,11 +391,11 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif
315 315
 		return nil, fmt.Errorf("missing layer in CreateV2Manifest: %s", diffID.String())
316 316
 	}
317 317
 
318
-	dgst, err := digest.FromBytes([]byte(fsLayer.BlobSum.Hex() + " " + parent + " " + string(img.RawJSON())))
318
+	dgst, err := digest.FromBytes([]byte(fsLayer.Hex() + " " + parent + " " + string(img.RawJSON())))
319 319
 	if err != nil {
320 320
 		return nil, err
321 321
 	}
322
-	fsLayerList[0] = fsLayer
322
+	fsLayerList[0] = schema1.FSLayer{BlobSum: fsLayer}
323 323
 
324 324
 	// Top-level v1compatibility string should be a modified version of the
325 325
 	// image config.
... ...
@@ -346,66 +422,3 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif
346 346
 		History:      history,
347 347
 	}, nil
348 348
 }
349
-
350
-func rawJSON(value interface{}) *json.RawMessage {
351
-	jsonval, err := json.Marshal(value)
352
-	if err != nil {
353
-		return nil
354
-	}
355
-	return (*json.RawMessage)(&jsonval)
356
-}
357
-
358
-func (p *v2Pusher) pushV2Layer(bs distribution.BlobService, l layer.Layer) (digest.Digest, error) {
359
-	out := p.config.OutStream
360
-	displayID := stringid.TruncateID(string(l.DiffID()))
361
-
362
-	out.Write(p.sf.FormatProgress(displayID, "Preparing", nil))
363
-
364
-	arch, err := l.TarStream()
365
-	if err != nil {
366
-		return "", err
367
-	}
368
-	defer arch.Close()
369
-
370
-	// Send the layer
371
-	layerUpload, err := bs.Create(context.Background())
372
-	if err != nil {
373
-		return "", err
374
-	}
375
-	defer layerUpload.Close()
376
-
377
-	// don't care if this fails; best effort
378
-	size, _ := l.DiffSize()
379
-
380
-	reader := progressreader.New(progressreader.Config{
381
-		In:        ioutil.NopCloser(arch), // we'll take care of close here.
382
-		Out:       out,
383
-		Formatter: p.sf,
384
-		Size:      size,
385
-		NewLines:  false,
386
-		ID:        displayID,
387
-		Action:    "Pushing",
388
-	})
389
-
390
-	compressedReader := compress(reader)
391
-
392
-	digester := digest.Canonical.New()
393
-	tee := io.TeeReader(compressedReader, digester.Hash())
394
-
395
-	out.Write(p.sf.FormatProgress(displayID, "Pushing", nil))
396
-	nn, err := layerUpload.ReadFrom(tee)
397
-	compressedReader.Close()
398
-	if err != nil {
399
-		return "", err
400
-	}
401
-
402
-	dgst := digester.Digest()
403
-	if _, err := layerUpload.Commit(context.Background(), distribution.Descriptor{Digest: dgst}); err != nil {
404
-		return "", err
405
-	}
406
-
407
-	logrus.Debugf("uploaded layer %s (%s), %d bytes", l.DiffID(), dgst, nn)
408
-	out.Write(p.sf.FormatProgress(displayID, "Pushed", nil))
409
-
410
-	return dgst, nil
411
-}
... ...
@@ -116,10 +116,10 @@ func TestCreateV2Manifest(t *testing.T) {
116 116
 		t.Fatalf("json decoding failed: %v", err)
117 117
 	}
118 118
 
119
-	fsLayers := map[layer.DiffID]schema1.FSLayer{
120
-		layer.DiffID("sha256:c6f988f4874bb0add23a778f753c65efe992244e148a1d2ec2a8b664fb66bbd1"): {BlobSum: digest.Digest("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4")},
121
-		layer.DiffID("sha256:5f70bf18a086007016e948b04aed3b82103a36bea41755b6cddfaf10ace3c6ef"): {BlobSum: digest.Digest("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa")},
122
-		layer.DiffID("sha256:13f53e08df5a220ab6d13c58b2bf83a59cbdc2e04d0a3f041ddf4b0ba4112d49"): {BlobSum: digest.Digest("sha256:b4ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4")},
119
+	fsLayers := map[layer.DiffID]digest.Digest{
120
+		layer.DiffID("sha256:c6f988f4874bb0add23a778f753c65efe992244e148a1d2ec2a8b664fb66bbd1"): digest.Digest("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"),
121
+		layer.DiffID("sha256:5f70bf18a086007016e948b04aed3b82103a36bea41755b6cddfaf10ace3c6ef"): digest.Digest("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa"),
122
+		layer.DiffID("sha256:13f53e08df5a220ab6d13c58b2bf83a59cbdc2e04d0a3f041ddf4b0ba4112d49"): digest.Digest("sha256:b4ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"),
123 123
 	}
124 124
 
125 125
 	manifest, err := CreateV2Manifest("testrepo", "testtag", img, fsLayers)
... ...
@@ -13,10 +13,12 @@ import (
13 13
 	"github.com/docker/distribution"
14 14
 	"github.com/docker/distribution/digest"
15 15
 	"github.com/docker/distribution/manifest/schema1"
16
+	"github.com/docker/distribution/registry/api/errcode"
16 17
 	"github.com/docker/distribution/registry/client"
17 18
 	"github.com/docker/distribution/registry/client/auth"
18 19
 	"github.com/docker/distribution/registry/client/transport"
19 20
 	"github.com/docker/docker/cliconfig"
21
+	"github.com/docker/docker/distribution/xfer"
20 22
 	"github.com/docker/docker/registry"
21 23
 	"golang.org/x/net/context"
22 24
 )
... ...
@@ -59,7 +61,7 @@ func NewV2Repository(repoInfo *registry.RepositoryInfo, endpoint registry.APIEnd
59 59
 	authTransport := transport.NewTransport(base, modifiers...)
60 60
 	pingClient := &http.Client{
61 61
 		Transport: authTransport,
62
-		Timeout:   5 * time.Second,
62
+		Timeout:   15 * time.Second,
63 63
 	}
64 64
 	endpointStr := strings.TrimRight(endpoint.URL, "/") + "/v2/"
65 65
 	req, err := http.NewRequest("GET", endpointStr, nil)
... ...
@@ -132,3 +134,23 @@ func (th *existingTokenHandler) AuthorizeRequest(req *http.Request, params map[s
132 132
 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", th.token))
133 133
 	return nil
134 134
 }
135
+
136
+// retryOnError wraps the error in xfer.DoNotRetry if we should not retry the
137
+// operation after this error.
138
+func retryOnError(err error) error {
139
+	switch v := err.(type) {
140
+	case errcode.Errors:
141
+		return retryOnError(v[0])
142
+	case errcode.Error:
143
+		switch v.Code {
144
+		case errcode.ErrorCodeUnauthorized, errcode.ErrorCodeUnsupported, errcode.ErrorCodeDenied:
145
+			return xfer.DoNotRetry{Err: err}
146
+		}
147
+
148
+	}
149
+	// let's be nice and fallback if the error is a completely
150
+	// unexpected one.
151
+	// If new errors have to be handled in some way, please
152
+	// add them to the switch above.
153
+	return err
154
+}
... ...
@@ -11,9 +11,9 @@ import (
11 11
 	"github.com/docker/distribution/reference"
12 12
 	"github.com/docker/distribution/registry/client/auth"
13 13
 	"github.com/docker/docker/cliconfig"
14
-	"github.com/docker/docker/pkg/streamformatter"
15 14
 	"github.com/docker/docker/registry"
16 15
 	"github.com/docker/docker/utils"
16
+	"golang.org/x/net/context"
17 17
 )
18 18
 
19 19
 func TestTokenPassThru(t *testing.T) {
... ...
@@ -72,8 +72,7 @@ func TestTokenPassThru(t *testing.T) {
72 72
 		MetaHeaders: http.Header{},
73 73
 		AuthConfig:  authConfig,
74 74
 	}
75
-	sf := streamformatter.NewJSONStreamFormatter()
76
-	puller, err := newPuller(endpoint, repoInfo, imagePullConfig, sf)
75
+	puller, err := newPuller(endpoint, repoInfo, imagePullConfig)
77 76
 	if err != nil {
78 77
 		t.Fatal(err)
79 78
 	}
... ...
@@ -86,7 +85,7 @@ func TestTokenPassThru(t *testing.T) {
86 86
 	logrus.Debug("About to pull")
87 87
 	// We expect it to fail, since we haven't mock'd the full registry exchange in our handler above
88 88
 	tag, _ := reference.WithTag(n, "tag_goes_here")
89
-	_ = p.pullV2Repository(tag)
89
+	_ = p.pullV2Repository(context.Background(), tag)
90 90
 
91 91
 	if !gotToken {
92 92
 		t.Fatal("Failed to receive registry token")
93 93
new file mode 100644
... ...
@@ -0,0 +1,420 @@
0
+package xfer
1
+
2
+import (
3
+	"errors"
4
+	"fmt"
5
+	"io"
6
+	"time"
7
+
8
+	"github.com/Sirupsen/logrus"
9
+	"github.com/docker/docker/image"
10
+	"github.com/docker/docker/layer"
11
+	"github.com/docker/docker/pkg/archive"
12
+	"github.com/docker/docker/pkg/ioutils"
13
+	"github.com/docker/docker/pkg/progress"
14
+	"golang.org/x/net/context"
15
+)
16
+
17
+const maxDownloadAttempts = 5
18
+
19
+// LayerDownloadManager figures out which layers need to be downloaded, then
20
+// registers and downloads those, taking into account dependencies between
21
+// layers.
22
+type LayerDownloadManager struct {
23
+	layerStore layer.Store
24
+	tm         TransferManager
25
+}
26
+
27
+// NewLayerDownloadManager returns a new LayerDownloadManager.
28
+func NewLayerDownloadManager(layerStore layer.Store, concurrencyLimit int) *LayerDownloadManager {
29
+	return &LayerDownloadManager{
30
+		layerStore: layerStore,
31
+		tm:         NewTransferManager(concurrencyLimit),
32
+	}
33
+}
34
+
35
+type downloadTransfer struct {
36
+	Transfer
37
+
38
+	layerStore layer.Store
39
+	layer      layer.Layer
40
+	err        error
41
+}
42
+
43
+// result returns the layer resulting from the download, if the download
44
+// and registration were successful.
45
+func (d *downloadTransfer) result() (layer.Layer, error) {
46
+	return d.layer, d.err
47
+}
48
+
49
+// A DownloadDescriptor references a layer that may need to be downloaded.
50
+type DownloadDescriptor interface {
51
+	// Key returns the key used to deduplicate downloads.
52
+	Key() string
53
+	// ID returns the ID for display purposes.
54
+	ID() string
55
+	// DiffID should return the DiffID for this layer, or an error
56
+	// if it is unknown (for example, if it has not been downloaded
57
+	// before).
58
+	DiffID() (layer.DiffID, error)
59
+	// Download is called to perform the download.
60
+	Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error)
61
+}
62
+
63
+// DownloadDescriptorWithRegistered is a DownloadDescriptor that has an
64
+// additional Registered method which gets called after a downloaded layer is
65
+// registered. This allows the user of the download manager to know the DiffID
66
+// of each registered layer. This method is called if a cast to
67
+// DownloadDescriptorWithRegistered is successful.
68
+type DownloadDescriptorWithRegistered interface {
69
+	DownloadDescriptor
70
+	Registered(diffID layer.DiffID)
71
+}
72
+
73
+// Download is a blocking function which ensures the requested layers are
74
+// present in the layer store. It uses the string returned by the Key method to
75
+// deduplicate downloads. If a given layer is not already known to present in
76
+// the layer store, and the key is not used by an in-progress download, the
77
+// Download method is called to get the layer tar data. Layers are then
78
+// registered in the appropriate order.  The caller must call the returned
79
+// release function once it is is done with the returned RootFS object.
80
+func (ldm *LayerDownloadManager) Download(ctx context.Context, initialRootFS image.RootFS, layers []DownloadDescriptor, progressOutput progress.Output) (image.RootFS, func(), error) {
81
+	var (
82
+		topLayer       layer.Layer
83
+		topDownload    *downloadTransfer
84
+		watcher        *Watcher
85
+		missingLayer   bool
86
+		transferKey    = ""
87
+		downloadsByKey = make(map[string]*downloadTransfer)
88
+	)
89
+
90
+	rootFS := initialRootFS
91
+	for _, descriptor := range layers {
92
+		key := descriptor.Key()
93
+		transferKey += key
94
+
95
+		if !missingLayer {
96
+			missingLayer = true
97
+			diffID, err := descriptor.DiffID()
98
+			if err == nil {
99
+				getRootFS := rootFS
100
+				getRootFS.Append(diffID)
101
+				l, err := ldm.layerStore.Get(getRootFS.ChainID())
102
+				if err == nil {
103
+					// Layer already exists.
104
+					logrus.Debugf("Layer already exists: %s", descriptor.ID())
105
+					progress.Update(progressOutput, descriptor.ID(), "Already exists")
106
+					if topLayer != nil {
107
+						layer.ReleaseAndLog(ldm.layerStore, topLayer)
108
+					}
109
+					topLayer = l
110
+					missingLayer = false
111
+					rootFS.Append(diffID)
112
+					continue
113
+				}
114
+			}
115
+		}
116
+
117
+		// Does this layer have the same data as a previous layer in
118
+		// the stack? If so, avoid downloading it more than once.
119
+		var topDownloadUncasted Transfer
120
+		if existingDownload, ok := downloadsByKey[key]; ok {
121
+			xferFunc := ldm.makeDownloadFuncFromDownload(descriptor, existingDownload, topDownload)
122
+			defer topDownload.Transfer.Release(watcher)
123
+			topDownloadUncasted, watcher = ldm.tm.Transfer(transferKey, xferFunc, progressOutput)
124
+			topDownload = topDownloadUncasted.(*downloadTransfer)
125
+			continue
126
+		}
127
+
128
+		// Layer is not known to exist - download and register it.
129
+		progress.Update(progressOutput, descriptor.ID(), "Pulling fs layer")
130
+
131
+		var xferFunc DoFunc
132
+		if topDownload != nil {
133
+			xferFunc = ldm.makeDownloadFunc(descriptor, "", topDownload)
134
+			defer topDownload.Transfer.Release(watcher)
135
+		} else {
136
+			xferFunc = ldm.makeDownloadFunc(descriptor, rootFS.ChainID(), nil)
137
+		}
138
+		topDownloadUncasted, watcher = ldm.tm.Transfer(transferKey, xferFunc, progressOutput)
139
+		topDownload = topDownloadUncasted.(*downloadTransfer)
140
+		downloadsByKey[key] = topDownload
141
+	}
142
+
143
+	if topDownload == nil {
144
+		return rootFS, func() { layer.ReleaseAndLog(ldm.layerStore, topLayer) }, nil
145
+	}
146
+
147
+	// Won't be using the list built up so far - will generate it
148
+	// from downloaded layers instead.
149
+	rootFS.DiffIDs = []layer.DiffID{}
150
+
151
+	defer func() {
152
+		if topLayer != nil {
153
+			layer.ReleaseAndLog(ldm.layerStore, topLayer)
154
+		}
155
+	}()
156
+
157
+	select {
158
+	case <-ctx.Done():
159
+		topDownload.Transfer.Release(watcher)
160
+		return rootFS, func() {}, ctx.Err()
161
+	case <-topDownload.Done():
162
+		break
163
+	}
164
+
165
+	l, err := topDownload.result()
166
+	if err != nil {
167
+		topDownload.Transfer.Release(watcher)
168
+		return rootFS, func() {}, err
169
+	}
170
+
171
+	// Must do this exactly len(layers) times, so we don't include the
172
+	// base layer on Windows.
173
+	for range layers {
174
+		if l == nil {
175
+			topDownload.Transfer.Release(watcher)
176
+			return rootFS, func() {}, errors.New("internal error: too few parent layers")
177
+		}
178
+		rootFS.DiffIDs = append([]layer.DiffID{l.DiffID()}, rootFS.DiffIDs...)
179
+		l = l.Parent()
180
+	}
181
+	return rootFS, func() { topDownload.Transfer.Release(watcher) }, err
182
+}
183
+
184
+// makeDownloadFunc returns a function that performs the layer download and
185
+// registration. If parentDownload is non-nil, it waits for that download to
186
+// complete before the registration step, and registers the downloaded data
187
+// on top of parentDownload's resulting layer. Otherwise, it registers the
188
+// layer on top of the ChainID given by parentLayer.
189
+func (ldm *LayerDownloadManager) makeDownloadFunc(descriptor DownloadDescriptor, parentLayer layer.ChainID, parentDownload *downloadTransfer) DoFunc {
190
+	return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
191
+		d := &downloadTransfer{
192
+			Transfer:   NewTransfer(),
193
+			layerStore: ldm.layerStore,
194
+		}
195
+
196
+		go func() {
197
+			defer func() {
198
+				close(progressChan)
199
+			}()
200
+
201
+			progressOutput := progress.ChanOutput(progressChan)
202
+
203
+			select {
204
+			case <-start:
205
+			default:
206
+				progress.Update(progressOutput, descriptor.ID(), "Waiting")
207
+				<-start
208
+			}
209
+
210
+			if parentDownload != nil {
211
+				// Did the parent download already fail or get
212
+				// cancelled?
213
+				select {
214
+				case <-parentDownload.Done():
215
+					_, err := parentDownload.result()
216
+					if err != nil {
217
+						d.err = err
218
+						return
219
+					}
220
+				default:
221
+				}
222
+			}
223
+
224
+			var (
225
+				downloadReader io.ReadCloser
226
+				size           int64
227
+				err            error
228
+				retries        int
229
+			)
230
+
231
+			for {
232
+				downloadReader, size, err = descriptor.Download(d.Transfer.Context(), progressOutput)
233
+				if err == nil {
234
+					break
235
+				}
236
+
237
+				// If an error was returned because the context
238
+				// was cancelled, we shouldn't retry.
239
+				select {
240
+				case <-d.Transfer.Context().Done():
241
+					d.err = err
242
+					return
243
+				default:
244
+				}
245
+
246
+				retries++
247
+				if _, isDNR := err.(DoNotRetry); isDNR || retries == maxDownloadAttempts {
248
+					logrus.Errorf("Download failed: %v", err)
249
+					d.err = err
250
+					return
251
+				}
252
+
253
+				logrus.Errorf("Download failed, retrying: %v", err)
254
+				delay := retries * 5
255
+				ticker := time.NewTicker(time.Second)
256
+
257
+			selectLoop:
258
+				for {
259
+					progress.Updatef(progressOutput, descriptor.ID(), "Retrying in %d seconds", delay)
260
+					select {
261
+					case <-ticker.C:
262
+						delay--
263
+						if delay == 0 {
264
+							ticker.Stop()
265
+							break selectLoop
266
+						}
267
+					case <-d.Transfer.Context().Done():
268
+						ticker.Stop()
269
+						d.err = errors.New("download cancelled during retry delay")
270
+						return
271
+					}
272
+
273
+				}
274
+			}
275
+
276
+			close(inactive)
277
+
278
+			if parentDownload != nil {
279
+				select {
280
+				case <-d.Transfer.Context().Done():
281
+					d.err = errors.New("layer registration cancelled")
282
+					downloadReader.Close()
283
+					return
284
+				case <-parentDownload.Done():
285
+				}
286
+
287
+				l, err := parentDownload.result()
288
+				if err != nil {
289
+					d.err = err
290
+					downloadReader.Close()
291
+					return
292
+				}
293
+				parentLayer = l.ChainID()
294
+			}
295
+
296
+			reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(d.Transfer.Context(), downloadReader), progressOutput, size, descriptor.ID(), "Extracting")
297
+			defer reader.Close()
298
+
299
+			inflatedLayerData, err := archive.DecompressStream(reader)
300
+			if err != nil {
301
+				d.err = fmt.Errorf("could not get decompression stream: %v", err)
302
+				return
303
+			}
304
+
305
+			d.layer, err = d.layerStore.Register(inflatedLayerData, parentLayer)
306
+			if err != nil {
307
+				select {
308
+				case <-d.Transfer.Context().Done():
309
+					d.err = errors.New("layer registration cancelled")
310
+				default:
311
+					d.err = fmt.Errorf("failed to register layer: %v", err)
312
+				}
313
+				return
314
+			}
315
+
316
+			progress.Update(progressOutput, descriptor.ID(), "Pull complete")
317
+			withRegistered, hasRegistered := descriptor.(DownloadDescriptorWithRegistered)
318
+			if hasRegistered {
319
+				withRegistered.Registered(d.layer.DiffID())
320
+			}
321
+
322
+			// Doesn't actually need to be its own goroutine, but
323
+			// done like this so we can defer close(c).
324
+			go func() {
325
+				<-d.Transfer.Released()
326
+				if d.layer != nil {
327
+					layer.ReleaseAndLog(d.layerStore, d.layer)
328
+				}
329
+			}()
330
+		}()
331
+
332
+		return d
333
+	}
334
+}
335
+
336
+// makeDownloadFuncFromDownload returns a function that performs the layer
337
+// registration when the layer data is coming from an existing download. It
338
+// waits for sourceDownload and parentDownload to complete, and then
339
+// reregisters the data from sourceDownload's top layer on top of
340
+// parentDownload. This function does not log progress output because it would
341
+// interfere with the progress reporting for sourceDownload, which has the same
342
+// Key.
343
+func (ldm *LayerDownloadManager) makeDownloadFuncFromDownload(descriptor DownloadDescriptor, sourceDownload *downloadTransfer, parentDownload *downloadTransfer) DoFunc {
344
+	return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
345
+		d := &downloadTransfer{
346
+			Transfer:   NewTransfer(),
347
+			layerStore: ldm.layerStore,
348
+		}
349
+
350
+		go func() {
351
+			defer func() {
352
+				close(progressChan)
353
+			}()
354
+
355
+			<-start
356
+
357
+			close(inactive)
358
+
359
+			select {
360
+			case <-d.Transfer.Context().Done():
361
+				d.err = errors.New("layer registration cancelled")
362
+				return
363
+			case <-parentDownload.Done():
364
+			}
365
+
366
+			l, err := parentDownload.result()
367
+			if err != nil {
368
+				d.err = err
369
+				return
370
+			}
371
+			parentLayer := l.ChainID()
372
+
373
+			// sourceDownload should have already finished if
374
+			// parentDownload finished, but wait for it explicitly
375
+			// to be sure.
376
+			select {
377
+			case <-d.Transfer.Context().Done():
378
+				d.err = errors.New("layer registration cancelled")
379
+				return
380
+			case <-sourceDownload.Done():
381
+			}
382
+
383
+			l, err = sourceDownload.result()
384
+			if err != nil {
385
+				d.err = err
386
+				return
387
+			}
388
+
389
+			layerReader, err := l.TarStream()
390
+			if err != nil {
391
+				d.err = err
392
+				return
393
+			}
394
+			defer layerReader.Close()
395
+
396
+			d.layer, err = d.layerStore.Register(layerReader, parentLayer)
397
+			if err != nil {
398
+				d.err = fmt.Errorf("failed to register layer: %v", err)
399
+				return
400
+			}
401
+
402
+			withRegistered, hasRegistered := descriptor.(DownloadDescriptorWithRegistered)
403
+			if hasRegistered {
404
+				withRegistered.Registered(d.layer.DiffID())
405
+			}
406
+
407
+			// Doesn't actually need to be its own goroutine, but
408
+			// done like this so we can defer close(c).
409
+			go func() {
410
+				<-d.Transfer.Released()
411
+				if d.layer != nil {
412
+					layer.ReleaseAndLog(d.layerStore, d.layer)
413
+				}
414
+			}()
415
+		}()
416
+
417
+		return d
418
+	}
419
+}
0 420
new file mode 100644
... ...
@@ -0,0 +1,332 @@
0
+package xfer
1
+
2
+import (
3
+	"bytes"
4
+	"errors"
5
+	"io"
6
+	"io/ioutil"
7
+	"sync/atomic"
8
+	"testing"
9
+	"time"
10
+
11
+	"github.com/docker/distribution/digest"
12
+	"github.com/docker/docker/image"
13
+	"github.com/docker/docker/layer"
14
+	"github.com/docker/docker/pkg/archive"
15
+	"github.com/docker/docker/pkg/progress"
16
+	"golang.org/x/net/context"
17
+)
18
+
19
+const maxDownloadConcurrency = 3
20
+
21
+type mockLayer struct {
22
+	layerData bytes.Buffer
23
+	diffID    layer.DiffID
24
+	chainID   layer.ChainID
25
+	parent    layer.Layer
26
+}
27
+
28
+func (ml *mockLayer) TarStream() (io.ReadCloser, error) {
29
+	return ioutil.NopCloser(bytes.NewBuffer(ml.layerData.Bytes())), nil
30
+}
31
+
32
+func (ml *mockLayer) ChainID() layer.ChainID {
33
+	return ml.chainID
34
+}
35
+
36
+func (ml *mockLayer) DiffID() layer.DiffID {
37
+	return ml.diffID
38
+}
39
+
40
+func (ml *mockLayer) Parent() layer.Layer {
41
+	return ml.parent
42
+}
43
+
44
+func (ml *mockLayer) Size() (size int64, err error) {
45
+	return 0, nil
46
+}
47
+
48
+func (ml *mockLayer) DiffSize() (size int64, err error) {
49
+	return 0, nil
50
+}
51
+
52
+func (ml *mockLayer) Metadata() (map[string]string, error) {
53
+	return make(map[string]string), nil
54
+}
55
+
56
+type mockLayerStore struct {
57
+	layers map[layer.ChainID]*mockLayer
58
+}
59
+
60
+func createChainIDFromParent(parent layer.ChainID, dgsts ...layer.DiffID) layer.ChainID {
61
+	if len(dgsts) == 0 {
62
+		return parent
63
+	}
64
+	if parent == "" {
65
+		return createChainIDFromParent(layer.ChainID(dgsts[0]), dgsts[1:]...)
66
+	}
67
+	// H = "H(n-1) SHA256(n)"
68
+	dgst, err := digest.FromBytes([]byte(string(parent) + " " + string(dgsts[0])))
69
+	if err != nil {
70
+		// Digest calculation is not expected to throw an error,
71
+		// any error at this point is a program error
72
+		panic(err)
73
+	}
74
+	return createChainIDFromParent(layer.ChainID(dgst), dgsts[1:]...)
75
+}
76
+
77
+func (ls *mockLayerStore) Register(reader io.Reader, parentID layer.ChainID) (layer.Layer, error) {
78
+	var (
79
+		parent layer.Layer
80
+		err    error
81
+	)
82
+
83
+	if parentID != "" {
84
+		parent, err = ls.Get(parentID)
85
+		if err != nil {
86
+			return nil, err
87
+		}
88
+	}
89
+
90
+	l := &mockLayer{parent: parent}
91
+	_, err = l.layerData.ReadFrom(reader)
92
+	if err != nil {
93
+		return nil, err
94
+	}
95
+	diffID, err := digest.FromBytes(l.layerData.Bytes())
96
+	if err != nil {
97
+		return nil, err
98
+	}
99
+	l.diffID = layer.DiffID(diffID)
100
+	l.chainID = createChainIDFromParent(parentID, l.diffID)
101
+
102
+	ls.layers[l.chainID] = l
103
+	return l, nil
104
+}
105
+
106
+func (ls *mockLayerStore) Get(chainID layer.ChainID) (layer.Layer, error) {
107
+	l, ok := ls.layers[chainID]
108
+	if !ok {
109
+		return nil, layer.ErrLayerDoesNotExist
110
+	}
111
+	return l, nil
112
+}
113
+
114
+func (ls *mockLayerStore) Release(l layer.Layer) ([]layer.Metadata, error) {
115
+	return []layer.Metadata{}, nil
116
+}
117
+
118
+func (ls *mockLayerStore) Mount(id string, parent layer.ChainID, label string, init layer.MountInit) (layer.RWLayer, error) {
119
+	return nil, errors.New("not implemented")
120
+}
121
+
122
+func (ls *mockLayerStore) Unmount(id string) error {
123
+	return errors.New("not implemented")
124
+}
125
+
126
+func (ls *mockLayerStore) DeleteMount(id string) ([]layer.Metadata, error) {
127
+	return nil, errors.New("not implemented")
128
+}
129
+
130
+func (ls *mockLayerStore) Changes(id string) ([]archive.Change, error) {
131
+	return nil, errors.New("not implemented")
132
+}
133
+
134
+type mockDownloadDescriptor struct {
135
+	currentDownloads *int32
136
+	id               string
137
+	diffID           layer.DiffID
138
+	registeredDiffID layer.DiffID
139
+	expectedDiffID   layer.DiffID
140
+	simulateRetries  int
141
+}
142
+
143
+// Key returns the key used to deduplicate downloads.
144
+func (d *mockDownloadDescriptor) Key() string {
145
+	return d.id
146
+}
147
+
148
+// ID returns the ID for display purposes.
149
+func (d *mockDownloadDescriptor) ID() string {
150
+	return d.id
151
+}
152
+
153
+// DiffID should return the DiffID for this layer, or an error
154
+// if it is unknown (for example, if it has not been downloaded
155
+// before).
156
+func (d *mockDownloadDescriptor) DiffID() (layer.DiffID, error) {
157
+	if d.diffID != "" {
158
+		return d.diffID, nil
159
+	}
160
+	return "", errors.New("no diffID available")
161
+}
162
+
163
+func (d *mockDownloadDescriptor) Registered(diffID layer.DiffID) {
164
+	d.registeredDiffID = diffID
165
+}
166
+
167
+func (d *mockDownloadDescriptor) mockTarStream() io.ReadCloser {
168
+	// The mock implementation returns the ID repeated 5 times as a tar
169
+	// stream instead of actual tar data. The data is ignored except for
170
+	// computing IDs.
171
+	return ioutil.NopCloser(bytes.NewBuffer([]byte(d.id + d.id + d.id + d.id + d.id)))
172
+}
173
+
174
+// Download is called to perform the download.
175
+func (d *mockDownloadDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) {
176
+	if d.currentDownloads != nil {
177
+		defer atomic.AddInt32(d.currentDownloads, -1)
178
+
179
+		if atomic.AddInt32(d.currentDownloads, 1) > maxDownloadConcurrency {
180
+			return nil, 0, errors.New("concurrency limit exceeded")
181
+		}
182
+	}
183
+
184
+	// Sleep a bit to simulate a time-consuming download.
185
+	for i := int64(0); i <= 10; i++ {
186
+		select {
187
+		case <-ctx.Done():
188
+			return nil, 0, ctx.Err()
189
+		case <-time.After(10 * time.Millisecond):
190
+			progressOutput.WriteProgress(progress.Progress{ID: d.ID(), Action: "Downloading", Current: i, Total: 10})
191
+		}
192
+	}
193
+
194
+	if d.simulateRetries != 0 {
195
+		d.simulateRetries--
196
+		return nil, 0, errors.New("simulating retry")
197
+	}
198
+
199
+	return d.mockTarStream(), 0, nil
200
+}
201
+
202
+func downloadDescriptors(currentDownloads *int32) []DownloadDescriptor {
203
+	return []DownloadDescriptor{
204
+		&mockDownloadDescriptor{
205
+			currentDownloads: currentDownloads,
206
+			id:               "id1",
207
+			expectedDiffID:   layer.DiffID("sha256:68e2c75dc5c78ea9240689c60d7599766c213ae210434c53af18470ae8c53ec1"),
208
+		},
209
+		&mockDownloadDescriptor{
210
+			currentDownloads: currentDownloads,
211
+			id:               "id2",
212
+			expectedDiffID:   layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"),
213
+		},
214
+		&mockDownloadDescriptor{
215
+			currentDownloads: currentDownloads,
216
+			id:               "id3",
217
+			expectedDiffID:   layer.DiffID("sha256:58745a8bbd669c25213e9de578c4da5c8ee1c836b3581432c2b50e38a6753300"),
218
+		},
219
+		&mockDownloadDescriptor{
220
+			currentDownloads: currentDownloads,
221
+			id:               "id2",
222
+			expectedDiffID:   layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"),
223
+		},
224
+		&mockDownloadDescriptor{
225
+			currentDownloads: currentDownloads,
226
+			id:               "id4",
227
+			expectedDiffID:   layer.DiffID("sha256:0dfb5b9577716cc173e95af7c10289322c29a6453a1718addc00c0c5b1330936"),
228
+			simulateRetries:  1,
229
+		},
230
+		&mockDownloadDescriptor{
231
+			currentDownloads: currentDownloads,
232
+			id:               "id5",
233
+			expectedDiffID:   layer.DiffID("sha256:0a5f25fa1acbc647f6112a6276735d0fa01e4ee2aa7ec33015e337350e1ea23d"),
234
+		},
235
+	}
236
+}
237
+
238
+func TestSuccessfulDownload(t *testing.T) {
239
+	layerStore := &mockLayerStore{make(map[layer.ChainID]*mockLayer)}
240
+	ldm := NewLayerDownloadManager(layerStore, maxDownloadConcurrency)
241
+
242
+	progressChan := make(chan progress.Progress)
243
+	progressDone := make(chan struct{})
244
+	receivedProgress := make(map[string]int64)
245
+
246
+	go func() {
247
+		for p := range progressChan {
248
+			if p.Action == "Downloading" {
249
+				receivedProgress[p.ID] = p.Current
250
+			} else if p.Action == "Already exists" {
251
+				receivedProgress[p.ID] = -1
252
+			}
253
+		}
254
+		close(progressDone)
255
+	}()
256
+
257
+	var currentDownloads int32
258
+	descriptors := downloadDescriptors(&currentDownloads)
259
+
260
+	firstDescriptor := descriptors[0].(*mockDownloadDescriptor)
261
+
262
+	// Pre-register the first layer to simulate an already-existing layer
263
+	l, err := layerStore.Register(firstDescriptor.mockTarStream(), "")
264
+	if err != nil {
265
+		t.Fatal(err)
266
+	}
267
+	firstDescriptor.diffID = l.DiffID()
268
+
269
+	rootFS, releaseFunc, err := ldm.Download(context.Background(), *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan))
270
+	if err != nil {
271
+		t.Fatalf("download error: %v", err)
272
+	}
273
+
274
+	releaseFunc()
275
+
276
+	close(progressChan)
277
+	<-progressDone
278
+
279
+	if len(rootFS.DiffIDs) != len(descriptors) {
280
+		t.Fatal("got wrong number of diffIDs in rootfs")
281
+	}
282
+
283
+	for i, d := range descriptors {
284
+		descriptor := d.(*mockDownloadDescriptor)
285
+
286
+		if descriptor.diffID != "" {
287
+			if receivedProgress[d.ID()] != -1 {
288
+				t.Fatalf("did not get 'already exists' message for %v", d.ID())
289
+			}
290
+		} else if receivedProgress[d.ID()] != 10 {
291
+			t.Fatalf("missing or wrong progress output for %v (got: %d)", d.ID(), receivedProgress[d.ID()])
292
+		}
293
+
294
+		if rootFS.DiffIDs[i] != descriptor.expectedDiffID {
295
+			t.Fatalf("rootFS item %d has the wrong diffID (expected: %v got: %v)", i, descriptor.expectedDiffID, rootFS.DiffIDs[i])
296
+		}
297
+
298
+		if descriptor.diffID == "" && descriptor.registeredDiffID != rootFS.DiffIDs[i] {
299
+			t.Fatal("diffID mismatch between rootFS and Registered callback")
300
+		}
301
+	}
302
+}
303
+
304
+func TestCancelledDownload(t *testing.T) {
305
+	ldm := NewLayerDownloadManager(&mockLayerStore{make(map[layer.ChainID]*mockLayer)}, maxDownloadConcurrency)
306
+
307
+	progressChan := make(chan progress.Progress)
308
+	progressDone := make(chan struct{})
309
+
310
+	go func() {
311
+		for range progressChan {
312
+		}
313
+		close(progressDone)
314
+	}()
315
+
316
+	ctx, cancel := context.WithCancel(context.Background())
317
+
318
+	go func() {
319
+		<-time.After(time.Millisecond)
320
+		cancel()
321
+	}()
322
+
323
+	descriptors := downloadDescriptors(nil)
324
+	_, _, err := ldm.Download(ctx, *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan))
325
+	if err != context.Canceled {
326
+		t.Fatal("expected download to be cancelled")
327
+	}
328
+
329
+	close(progressChan)
330
+	<-progressDone
331
+}
0 332
new file mode 100644
... ...
@@ -0,0 +1,343 @@
0
+package xfer
1
+
2
+import (
3
+	"sync"
4
+
5
+	"github.com/docker/docker/pkg/progress"
6
+	"golang.org/x/net/context"
7
+)
8
+
9
+// DoNotRetry is an error wrapper indicating that the error cannot be resolved
10
+// with a retry.
11
+type DoNotRetry struct {
12
+	Err error
13
+}
14
+
15
+// Error returns the stringified representation of the encapsulated error.
16
+func (e DoNotRetry) Error() string {
17
+	return e.Err.Error()
18
+}
19
+
20
+// Watcher is returned by Watch and can be passed to Release to stop watching.
21
+type Watcher struct {
22
+	// signalChan is used to signal to the watcher goroutine that
23
+	// new progress information is available, or that the transfer
24
+	// has finished.
25
+	signalChan chan struct{}
26
+	// releaseChan signals to the watcher goroutine that the watcher
27
+	// should be detached.
28
+	releaseChan chan struct{}
29
+	// running remains open as long as the watcher is watching the
30
+	// transfer. It gets closed if the transfer finishes or the
31
+	// watcher is detached.
32
+	running chan struct{}
33
+}
34
+
35
+// Transfer represents an in-progress transfer.
36
+type Transfer interface {
37
+	Watch(progressOutput progress.Output) *Watcher
38
+	Release(*Watcher)
39
+	Context() context.Context
40
+	Cancel()
41
+	Done() <-chan struct{}
42
+	Released() <-chan struct{}
43
+	Broadcast(masterProgressChan <-chan progress.Progress)
44
+}
45
+
46
+type transfer struct {
47
+	mu sync.Mutex
48
+
49
+	ctx    context.Context
50
+	cancel context.CancelFunc
51
+
52
+	// watchers keeps track of the goroutines monitoring progress output,
53
+	// indexed by the channels that release them.
54
+	watchers map[chan struct{}]*Watcher
55
+
56
+	// lastProgress is the most recently received progress event.
57
+	lastProgress progress.Progress
58
+	// hasLastProgress is true when lastProgress has been set.
59
+	hasLastProgress bool
60
+
61
+	// running remains open as long as the transfer is in progress.
62
+	running chan struct{}
63
+	// hasWatchers stays open until all watchers release the trasnfer.
64
+	hasWatchers chan struct{}
65
+
66
+	// broadcastDone is true if the master progress channel has closed.
67
+	broadcastDone bool
68
+	// broadcastSyncChan allows watchers to "ping" the broadcasting
69
+	// goroutine to wait for it for deplete its input channel. This ensures
70
+	// a detaching watcher won't miss an event that was sent before it
71
+	// started detaching.
72
+	broadcastSyncChan chan struct{}
73
+}
74
+
75
+// NewTransfer creates a new transfer.
76
+func NewTransfer() Transfer {
77
+	t := &transfer{
78
+		watchers:          make(map[chan struct{}]*Watcher),
79
+		running:           make(chan struct{}),
80
+		hasWatchers:       make(chan struct{}),
81
+		broadcastSyncChan: make(chan struct{}),
82
+	}
83
+
84
+	// This uses context.Background instead of a caller-supplied context
85
+	// so that a transfer won't be cancelled automatically if the client
86
+	// which requested it is ^C'd (there could be other viewers).
87
+	t.ctx, t.cancel = context.WithCancel(context.Background())
88
+
89
+	return t
90
+}
91
+
92
+// Broadcast copies the progress and error output to all viewers.
93
+func (t *transfer) Broadcast(masterProgressChan <-chan progress.Progress) {
94
+	for {
95
+		var (
96
+			p  progress.Progress
97
+			ok bool
98
+		)
99
+		select {
100
+		case p, ok = <-masterProgressChan:
101
+		default:
102
+			// We've depleted the channel, so now we can handle
103
+			// reads on broadcastSyncChan to let detaching watchers
104
+			// know we're caught up.
105
+			select {
106
+			case <-t.broadcastSyncChan:
107
+				continue
108
+			case p, ok = <-masterProgressChan:
109
+			}
110
+		}
111
+
112
+		t.mu.Lock()
113
+		if ok {
114
+			t.lastProgress = p
115
+			t.hasLastProgress = true
116
+			for _, w := range t.watchers {
117
+				select {
118
+				case w.signalChan <- struct{}{}:
119
+				default:
120
+				}
121
+			}
122
+
123
+		} else {
124
+			t.broadcastDone = true
125
+		}
126
+		t.mu.Unlock()
127
+		if !ok {
128
+			close(t.running)
129
+			return
130
+		}
131
+	}
132
+}
133
+
134
+// Watch adds a watcher to the transfer. The supplied channel gets progress
135
+// updates and is closed when the transfer finishes.
136
+func (t *transfer) Watch(progressOutput progress.Output) *Watcher {
137
+	t.mu.Lock()
138
+	defer t.mu.Unlock()
139
+
140
+	w := &Watcher{
141
+		releaseChan: make(chan struct{}),
142
+		signalChan:  make(chan struct{}),
143
+		running:     make(chan struct{}),
144
+	}
145
+
146
+	if t.broadcastDone {
147
+		close(w.running)
148
+		return w
149
+	}
150
+
151
+	t.watchers[w.releaseChan] = w
152
+
153
+	go func() {
154
+		defer func() {
155
+			close(w.running)
156
+		}()
157
+		done := false
158
+		for {
159
+			t.mu.Lock()
160
+			hasLastProgress := t.hasLastProgress
161
+			lastProgress := t.lastProgress
162
+			t.mu.Unlock()
163
+
164
+			// This might write the last progress item a
165
+			// second time (since channel closure also gets
166
+			// us here), but that's fine.
167
+			if hasLastProgress {
168
+				progressOutput.WriteProgress(lastProgress)
169
+			}
170
+
171
+			if done {
172
+				return
173
+			}
174
+
175
+			select {
176
+			case <-w.signalChan:
177
+			case <-w.releaseChan:
178
+				done = true
179
+				// Since the watcher is going to detach, make
180
+				// sure the broadcaster is caught up so we
181
+				// don't miss anything.
182
+				select {
183
+				case t.broadcastSyncChan <- struct{}{}:
184
+				case <-t.running:
185
+				}
186
+			case <-t.running:
187
+				done = true
188
+			}
189
+		}
190
+	}()
191
+
192
+	return w
193
+}
194
+
195
+// Release is the inverse of Watch; indicating that the watcher no longer wants
196
+// to be notified about the progress of the transfer. All calls to Watch must
197
+// be paired with later calls to Release so that the lifecycle of the transfer
198
+// is properly managed.
199
+func (t *transfer) Release(watcher *Watcher) {
200
+	t.mu.Lock()
201
+	delete(t.watchers, watcher.releaseChan)
202
+
203
+	if len(t.watchers) == 0 {
204
+		close(t.hasWatchers)
205
+		t.cancel()
206
+	}
207
+	t.mu.Unlock()
208
+
209
+	close(watcher.releaseChan)
210
+	// Block until the watcher goroutine completes
211
+	<-watcher.running
212
+}
213
+
214
+// Done returns a channel which is closed if the transfer completes or is
215
+// cancelled. Note that having 0 watchers causes a transfer to be cancelled.
216
+func (t *transfer) Done() <-chan struct{} {
217
+	// Note that this doesn't return t.ctx.Done() because that channel will
218
+	// be closed the moment Cancel is called, and we need to return a
219
+	// channel that blocks until a cancellation is actually acknowledged by
220
+	// the transfer function.
221
+	return t.running
222
+}
223
+
224
+// Released returns a channel which is closed once all watchers release the
225
+// transfer.
226
+func (t *transfer) Released() <-chan struct{} {
227
+	return t.hasWatchers
228
+}
229
+
230
+// Context returns the context associated with the transfer.
231
+func (t *transfer) Context() context.Context {
232
+	return t.ctx
233
+}
234
+
235
+// Cancel cancels the context associated with the transfer.
236
+func (t *transfer) Cancel() {
237
+	t.cancel()
238
+}
239
+
240
+// DoFunc is a function called by the transfer manager to actually perform
241
+// a transfer. It should be non-blocking. It should wait until the start channel
242
+// is closed before transfering any data. If the function closes inactive, that
243
+// signals to the transfer manager that the job is no longer actively moving
244
+// data - for example, it may be waiting for a dependent tranfer to finish.
245
+// This prevents it from taking up a slot.
246
+type DoFunc func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer
247
+
248
+// TransferManager is used by LayerDownloadManager and LayerUploadManager to
249
+// schedule and deduplicate transfers. It is up to the TransferManager
250
+// implementation to make the scheduling and concurrency decisions.
251
+type TransferManager interface {
252
+	// Transfer checks if a transfer with the given key is in progress. If
253
+	// so, it returns progress and error output from that transfer.
254
+	// Otherwise, it will call xferFunc to initiate the transfer.
255
+	Transfer(key string, xferFunc DoFunc, progressOutput progress.Output) (Transfer, *Watcher)
256
+}
257
+
258
+type transferManager struct {
259
+	mu sync.Mutex
260
+
261
+	concurrencyLimit int
262
+	activeTransfers  int
263
+	transfers        map[string]Transfer
264
+	waitingTransfers []chan struct{}
265
+}
266
+
267
+// NewTransferManager returns a new TransferManager.
268
+func NewTransferManager(concurrencyLimit int) TransferManager {
269
+	return &transferManager{
270
+		concurrencyLimit: concurrencyLimit,
271
+		transfers:        make(map[string]Transfer),
272
+	}
273
+}
274
+
275
+// Transfer checks if a transfer matching the given key is in progress. If not,
276
+// it starts one by calling xferFunc. The caller supplies a channel which
277
+// receives progress output from the transfer.
278
+func (tm *transferManager) Transfer(key string, xferFunc DoFunc, progressOutput progress.Output) (Transfer, *Watcher) {
279
+	tm.mu.Lock()
280
+	defer tm.mu.Unlock()
281
+
282
+	if xfer, present := tm.transfers[key]; present {
283
+		// Transfer is already in progress.
284
+		watcher := xfer.Watch(progressOutput)
285
+		return xfer, watcher
286
+	}
287
+
288
+	start := make(chan struct{})
289
+	inactive := make(chan struct{})
290
+
291
+	if tm.activeTransfers < tm.concurrencyLimit {
292
+		close(start)
293
+		tm.activeTransfers++
294
+	} else {
295
+		tm.waitingTransfers = append(tm.waitingTransfers, start)
296
+	}
297
+
298
+	masterProgressChan := make(chan progress.Progress)
299
+	xfer := xferFunc(masterProgressChan, start, inactive)
300
+	watcher := xfer.Watch(progressOutput)
301
+	go xfer.Broadcast(masterProgressChan)
302
+	tm.transfers[key] = xfer
303
+
304
+	// When the transfer is finished, remove from the map.
305
+	go func() {
306
+		for {
307
+			select {
308
+			case <-inactive:
309
+				tm.mu.Lock()
310
+				tm.inactivate(start)
311
+				tm.mu.Unlock()
312
+				inactive = nil
313
+			case <-xfer.Done():
314
+				tm.mu.Lock()
315
+				if inactive != nil {
316
+					tm.inactivate(start)
317
+				}
318
+				delete(tm.transfers, key)
319
+				tm.mu.Unlock()
320
+				return
321
+			}
322
+		}
323
+	}()
324
+
325
+	return xfer, watcher
326
+}
327
+
328
+func (tm *transferManager) inactivate(start chan struct{}) {
329
+	// If the transfer was started, remove it from the activeTransfers
330
+	// count.
331
+	select {
332
+	case <-start:
333
+		// Start next transfer if any are waiting
334
+		if len(tm.waitingTransfers) != 0 {
335
+			close(tm.waitingTransfers[0])
336
+			tm.waitingTransfers = tm.waitingTransfers[1:]
337
+		} else {
338
+			tm.activeTransfers--
339
+		}
340
+	default:
341
+	}
342
+}
0 343
new file mode 100644
... ...
@@ -0,0 +1,385 @@
0
+package xfer
1
+
2
+import (
3
+	"sync/atomic"
4
+	"testing"
5
+	"time"
6
+
7
+	"github.com/docker/docker/pkg/progress"
8
+)
9
+
10
+func TestTransfer(t *testing.T) {
11
+	makeXferFunc := func(id string) DoFunc {
12
+		return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
13
+			select {
14
+			case <-start:
15
+			default:
16
+				t.Fatalf("transfer function not started even though concurrency limit not reached")
17
+			}
18
+
19
+			xfer := NewTransfer()
20
+			go func() {
21
+				for i := 0; i <= 10; i++ {
22
+					progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10}
23
+					time.Sleep(10 * time.Millisecond)
24
+				}
25
+				close(progressChan)
26
+			}()
27
+			return xfer
28
+		}
29
+	}
30
+
31
+	tm := NewTransferManager(5)
32
+	progressChan := make(chan progress.Progress)
33
+	progressDone := make(chan struct{})
34
+	receivedProgress := make(map[string]int64)
35
+
36
+	go func() {
37
+		for p := range progressChan {
38
+			val, present := receivedProgress[p.ID]
39
+			if !present {
40
+				if p.Current != 0 {
41
+					t.Fatalf("got unexpected progress value: %d (expected 0)", p.Current)
42
+				}
43
+			} else if p.Current == 10 {
44
+				// Special case: last progress output may be
45
+				// repeated because the transfer finishing
46
+				// causes the latest progress output to be
47
+				// written to the channel (in case the watcher
48
+				// missed it).
49
+				if p.Current != 9 && p.Current != 10 {
50
+					t.Fatalf("got unexpected progress value: %d (expected %d)", p.Current, val+1)
51
+				}
52
+			} else if p.Current != val+1 {
53
+				t.Fatalf("got unexpected progress value: %d (expected %d)", p.Current, val+1)
54
+			}
55
+			receivedProgress[p.ID] = p.Current
56
+		}
57
+		close(progressDone)
58
+	}()
59
+
60
+	// Start a few transfers
61
+	ids := []string{"id1", "id2", "id3"}
62
+	xfers := make([]Transfer, len(ids))
63
+	watchers := make([]*Watcher, len(ids))
64
+	for i, id := range ids {
65
+		xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan))
66
+	}
67
+
68
+	for i, xfer := range xfers {
69
+		<-xfer.Done()
70
+		xfer.Release(watchers[i])
71
+	}
72
+	close(progressChan)
73
+	<-progressDone
74
+
75
+	for _, id := range ids {
76
+		if receivedProgress[id] != 10 {
77
+			t.Fatalf("final progress value %d instead of 10", receivedProgress[id])
78
+		}
79
+	}
80
+}
81
+
82
+func TestConcurrencyLimit(t *testing.T) {
83
+	concurrencyLimit := 3
84
+	var runningJobs int32
85
+
86
+	makeXferFunc := func(id string) DoFunc {
87
+		return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
88
+			xfer := NewTransfer()
89
+			go func() {
90
+				<-start
91
+				totalJobs := atomic.AddInt32(&runningJobs, 1)
92
+				if int(totalJobs) > concurrencyLimit {
93
+					t.Fatalf("too many jobs running")
94
+				}
95
+				for i := 0; i <= 10; i++ {
96
+					progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10}
97
+					time.Sleep(10 * time.Millisecond)
98
+				}
99
+				atomic.AddInt32(&runningJobs, -1)
100
+				close(progressChan)
101
+			}()
102
+			return xfer
103
+		}
104
+	}
105
+
106
+	tm := NewTransferManager(concurrencyLimit)
107
+	progressChan := make(chan progress.Progress)
108
+	progressDone := make(chan struct{})
109
+	receivedProgress := make(map[string]int64)
110
+
111
+	go func() {
112
+		for p := range progressChan {
113
+			receivedProgress[p.ID] = p.Current
114
+		}
115
+		close(progressDone)
116
+	}()
117
+
118
+	// Start more transfers than the concurrency limit
119
+	ids := []string{"id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"}
120
+	xfers := make([]Transfer, len(ids))
121
+	watchers := make([]*Watcher, len(ids))
122
+	for i, id := range ids {
123
+		xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan))
124
+	}
125
+
126
+	for i, xfer := range xfers {
127
+		<-xfer.Done()
128
+		xfer.Release(watchers[i])
129
+	}
130
+	close(progressChan)
131
+	<-progressDone
132
+
133
+	for _, id := range ids {
134
+		if receivedProgress[id] != 10 {
135
+			t.Fatalf("final progress value %d instead of 10", receivedProgress[id])
136
+		}
137
+	}
138
+}
139
+
140
+func TestInactiveJobs(t *testing.T) {
141
+	concurrencyLimit := 3
142
+	var runningJobs int32
143
+	testDone := make(chan struct{})
144
+
145
+	makeXferFunc := func(id string) DoFunc {
146
+		return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
147
+			xfer := NewTransfer()
148
+			go func() {
149
+				<-start
150
+				totalJobs := atomic.AddInt32(&runningJobs, 1)
151
+				if int(totalJobs) > concurrencyLimit {
152
+					t.Fatalf("too many jobs running")
153
+				}
154
+				for i := 0; i <= 10; i++ {
155
+					progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10}
156
+					time.Sleep(10 * time.Millisecond)
157
+				}
158
+				atomic.AddInt32(&runningJobs, -1)
159
+				close(inactive)
160
+				<-testDone
161
+				close(progressChan)
162
+			}()
163
+			return xfer
164
+		}
165
+	}
166
+
167
+	tm := NewTransferManager(concurrencyLimit)
168
+	progressChan := make(chan progress.Progress)
169
+	progressDone := make(chan struct{})
170
+	receivedProgress := make(map[string]int64)
171
+
172
+	go func() {
173
+		for p := range progressChan {
174
+			receivedProgress[p.ID] = p.Current
175
+		}
176
+		close(progressDone)
177
+	}()
178
+
179
+	// Start more transfers than the concurrency limit
180
+	ids := []string{"id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"}
181
+	xfers := make([]Transfer, len(ids))
182
+	watchers := make([]*Watcher, len(ids))
183
+	for i, id := range ids {
184
+		xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan))
185
+	}
186
+
187
+	close(testDone)
188
+	for i, xfer := range xfers {
189
+		<-xfer.Done()
190
+		xfer.Release(watchers[i])
191
+	}
192
+	close(progressChan)
193
+	<-progressDone
194
+
195
+	for _, id := range ids {
196
+		if receivedProgress[id] != 10 {
197
+			t.Fatalf("final progress value %d instead of 10", receivedProgress[id])
198
+		}
199
+	}
200
+}
201
+
202
+func TestWatchRelease(t *testing.T) {
203
+	ready := make(chan struct{})
204
+
205
+	makeXferFunc := func(id string) DoFunc {
206
+		return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
207
+			xfer := NewTransfer()
208
+			go func() {
209
+				defer func() {
210
+					close(progressChan)
211
+				}()
212
+				<-ready
213
+				for i := int64(0); ; i++ {
214
+					select {
215
+					case <-time.After(10 * time.Millisecond):
216
+					case <-xfer.Context().Done():
217
+						return
218
+					}
219
+					progressChan <- progress.Progress{ID: id, Action: "testing", Current: i, Total: 10}
220
+				}
221
+			}()
222
+			return xfer
223
+		}
224
+	}
225
+
226
+	tm := NewTransferManager(5)
227
+
228
+	type watcherInfo struct {
229
+		watcher               *Watcher
230
+		progressChan          chan progress.Progress
231
+		progressDone          chan struct{}
232
+		receivedFirstProgress chan struct{}
233
+	}
234
+
235
+	progressConsumer := func(w watcherInfo) {
236
+		first := true
237
+		for range w.progressChan {
238
+			if first {
239
+				close(w.receivedFirstProgress)
240
+			}
241
+			first = false
242
+		}
243
+		close(w.progressDone)
244
+	}
245
+
246
+	// Start a transfer
247
+	watchers := make([]watcherInfo, 5)
248
+	var xfer Transfer
249
+	watchers[0].progressChan = make(chan progress.Progress)
250
+	watchers[0].progressDone = make(chan struct{})
251
+	watchers[0].receivedFirstProgress = make(chan struct{})
252
+	xfer, watchers[0].watcher = tm.Transfer("id1", makeXferFunc("id1"), progress.ChanOutput(watchers[0].progressChan))
253
+	go progressConsumer(watchers[0])
254
+
255
+	// Give it multiple watchers
256
+	for i := 1; i != len(watchers); i++ {
257
+		watchers[i].progressChan = make(chan progress.Progress)
258
+		watchers[i].progressDone = make(chan struct{})
259
+		watchers[i].receivedFirstProgress = make(chan struct{})
260
+		watchers[i].watcher = xfer.Watch(progress.ChanOutput(watchers[i].progressChan))
261
+		go progressConsumer(watchers[i])
262
+	}
263
+
264
+	// Now that the watchers are set up, allow the transfer goroutine to
265
+	// proceed.
266
+	close(ready)
267
+
268
+	// Confirm that each watcher gets progress output.
269
+	for _, w := range watchers {
270
+		<-w.receivedFirstProgress
271
+	}
272
+
273
+	// Release one watcher every 5ms
274
+	for _, w := range watchers {
275
+		xfer.Release(w.watcher)
276
+		<-time.After(5 * time.Millisecond)
277
+	}
278
+
279
+	// Now that all watchers have been released, Released() should
280
+	// return a closed channel.
281
+	<-xfer.Released()
282
+
283
+	// Done() should return a closed channel because the xfer func returned
284
+	// due to cancellation.
285
+	<-xfer.Done()
286
+
287
+	for _, w := range watchers {
288
+		close(w.progressChan)
289
+		<-w.progressDone
290
+	}
291
+}
292
+
293
+func TestDuplicateTransfer(t *testing.T) {
294
+	ready := make(chan struct{})
295
+
296
+	var xferFuncCalls int32
297
+
298
+	makeXferFunc := func(id string) DoFunc {
299
+		return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
300
+			atomic.AddInt32(&xferFuncCalls, 1)
301
+			xfer := NewTransfer()
302
+			go func() {
303
+				defer func() {
304
+					close(progressChan)
305
+				}()
306
+				<-ready
307
+				for i := int64(0); ; i++ {
308
+					select {
309
+					case <-time.After(10 * time.Millisecond):
310
+					case <-xfer.Context().Done():
311
+						return
312
+					}
313
+					progressChan <- progress.Progress{ID: id, Action: "testing", Current: i, Total: 10}
314
+				}
315
+			}()
316
+			return xfer
317
+		}
318
+	}
319
+
320
+	tm := NewTransferManager(5)
321
+
322
+	type transferInfo struct {
323
+		xfer                  Transfer
324
+		watcher               *Watcher
325
+		progressChan          chan progress.Progress
326
+		progressDone          chan struct{}
327
+		receivedFirstProgress chan struct{}
328
+	}
329
+
330
+	progressConsumer := func(t transferInfo) {
331
+		first := true
332
+		for range t.progressChan {
333
+			if first {
334
+				close(t.receivedFirstProgress)
335
+			}
336
+			first = false
337
+		}
338
+		close(t.progressDone)
339
+	}
340
+
341
+	// Try to start multiple transfers with the same ID
342
+	transfers := make([]transferInfo, 5)
343
+	for i := range transfers {
344
+		t := &transfers[i]
345
+		t.progressChan = make(chan progress.Progress)
346
+		t.progressDone = make(chan struct{})
347
+		t.receivedFirstProgress = make(chan struct{})
348
+		t.xfer, t.watcher = tm.Transfer("id1", makeXferFunc("id1"), progress.ChanOutput(t.progressChan))
349
+		go progressConsumer(*t)
350
+	}
351
+
352
+	// Allow the transfer goroutine to proceed.
353
+	close(ready)
354
+
355
+	// Confirm that each watcher gets progress output.
356
+	for _, t := range transfers {
357
+		<-t.receivedFirstProgress
358
+	}
359
+
360
+	// Confirm that the transfer function was called exactly once.
361
+	if xferFuncCalls != 1 {
362
+		t.Fatal("transfer function wasn't called exactly once")
363
+	}
364
+
365
+	// Release one watcher every 5ms
366
+	for _, t := range transfers {
367
+		t.xfer.Release(t.watcher)
368
+		<-time.After(5 * time.Millisecond)
369
+	}
370
+
371
+	for _, t := range transfers {
372
+		// Now that all watchers have been released, Released() should
373
+		// return a closed channel.
374
+		<-t.xfer.Released()
375
+		// Done() should return a closed channel because the xfer func returned
376
+		// due to cancellation.
377
+		<-t.xfer.Done()
378
+	}
379
+
380
+	for _, t := range transfers {
381
+		close(t.progressChan)
382
+		<-t.progressDone
383
+	}
384
+}
0 385
new file mode 100644
... ...
@@ -0,0 +1,159 @@
0
+package xfer
1
+
2
+import (
3
+	"errors"
4
+	"time"
5
+
6
+	"github.com/Sirupsen/logrus"
7
+	"github.com/docker/distribution/digest"
8
+	"github.com/docker/docker/layer"
9
+	"github.com/docker/docker/pkg/progress"
10
+	"golang.org/x/net/context"
11
+)
12
+
13
+const maxUploadAttempts = 5
14
+
15
+// LayerUploadManager provides task management and progress reporting for
16
+// uploads.
17
+type LayerUploadManager struct {
18
+	tm TransferManager
19
+}
20
+
21
+// NewLayerUploadManager returns a new LayerUploadManager.
22
+func NewLayerUploadManager(concurrencyLimit int) *LayerUploadManager {
23
+	return &LayerUploadManager{
24
+		tm: NewTransferManager(concurrencyLimit),
25
+	}
26
+}
27
+
28
+type uploadTransfer struct {
29
+	Transfer
30
+
31
+	diffID layer.DiffID
32
+	digest digest.Digest
33
+	err    error
34
+}
35
+
36
+// An UploadDescriptor references a layer that may need to be uploaded.
37
+type UploadDescriptor interface {
38
+	// Key returns the key used to deduplicate uploads.
39
+	Key() string
40
+	// ID returns the ID for display purposes.
41
+	ID() string
42
+	// DiffID should return the DiffID for this layer.
43
+	DiffID() layer.DiffID
44
+	// Upload is called to perform the Upload.
45
+	Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error)
46
+}
47
+
48
+// Upload is a blocking function which ensures the listed layers are present on
49
+// the remote registry. It uses the string returned by the Key method to
50
+// deduplicate uploads.
51
+func (lum *LayerUploadManager) Upload(ctx context.Context, layers []UploadDescriptor, progressOutput progress.Output) (map[layer.DiffID]digest.Digest, error) {
52
+	var (
53
+		uploads          []*uploadTransfer
54
+		digests          = make(map[layer.DiffID]digest.Digest)
55
+		dedupDescriptors = make(map[string]struct{})
56
+	)
57
+
58
+	for _, descriptor := range layers {
59
+		progress.Update(progressOutput, descriptor.ID(), "Preparing")
60
+
61
+		key := descriptor.Key()
62
+		if _, present := dedupDescriptors[key]; present {
63
+			continue
64
+		}
65
+		dedupDescriptors[key] = struct{}{}
66
+
67
+		xferFunc := lum.makeUploadFunc(descriptor)
68
+		upload, watcher := lum.tm.Transfer(descriptor.Key(), xferFunc, progressOutput)
69
+		defer upload.Release(watcher)
70
+		uploads = append(uploads, upload.(*uploadTransfer))
71
+	}
72
+
73
+	for _, upload := range uploads {
74
+		select {
75
+		case <-ctx.Done():
76
+			return nil, ctx.Err()
77
+		case <-upload.Transfer.Done():
78
+			if upload.err != nil {
79
+				return nil, upload.err
80
+			}
81
+			digests[upload.diffID] = upload.digest
82
+		}
83
+	}
84
+
85
+	return digests, nil
86
+}
87
+
88
+func (lum *LayerUploadManager) makeUploadFunc(descriptor UploadDescriptor) DoFunc {
89
+	return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
90
+		u := &uploadTransfer{
91
+			Transfer: NewTransfer(),
92
+			diffID:   descriptor.DiffID(),
93
+		}
94
+
95
+		go func() {
96
+			defer func() {
97
+				close(progressChan)
98
+			}()
99
+
100
+			progressOutput := progress.ChanOutput(progressChan)
101
+
102
+			select {
103
+			case <-start:
104
+			default:
105
+				progress.Update(progressOutput, descriptor.ID(), "Waiting")
106
+				<-start
107
+			}
108
+
109
+			retries := 0
110
+			for {
111
+				digest, err := descriptor.Upload(u.Transfer.Context(), progressOutput)
112
+				if err == nil {
113
+					u.digest = digest
114
+					break
115
+				}
116
+
117
+				// If an error was returned because the context
118
+				// was cancelled, we shouldn't retry.
119
+				select {
120
+				case <-u.Transfer.Context().Done():
121
+					u.err = err
122
+					return
123
+				default:
124
+				}
125
+
126
+				retries++
127
+				if _, isDNR := err.(DoNotRetry); isDNR || retries == maxUploadAttempts {
128
+					logrus.Errorf("Upload failed: %v", err)
129
+					u.err = err
130
+					return
131
+				}
132
+
133
+				logrus.Errorf("Upload failed, retrying: %v", err)
134
+				delay := retries * 5
135
+				ticker := time.NewTicker(time.Second)
136
+
137
+			selectLoop:
138
+				for {
139
+					progress.Updatef(progressOutput, descriptor.ID(), "Retrying in %d seconds", delay)
140
+					select {
141
+					case <-ticker.C:
142
+						delay--
143
+						if delay == 0 {
144
+							ticker.Stop()
145
+							break selectLoop
146
+						}
147
+					case <-u.Transfer.Context().Done():
148
+						ticker.Stop()
149
+						u.err = errors.New("upload cancelled during retry delay")
150
+						return
151
+					}
152
+				}
153
+			}
154
+		}()
155
+
156
+		return u
157
+	}
158
+}
0 159
new file mode 100644
... ...
@@ -0,0 +1,153 @@
0
+package xfer
1
+
2
+import (
3
+	"errors"
4
+	"sync/atomic"
5
+	"testing"
6
+	"time"
7
+
8
+	"github.com/docker/distribution/digest"
9
+	"github.com/docker/docker/layer"
10
+	"github.com/docker/docker/pkg/progress"
11
+	"golang.org/x/net/context"
12
+)
13
+
14
+const maxUploadConcurrency = 3
15
+
16
+type mockUploadDescriptor struct {
17
+	currentUploads  *int32
18
+	diffID          layer.DiffID
19
+	simulateRetries int
20
+}
21
+
22
+// Key returns the key used to deduplicate downloads.
23
+func (u *mockUploadDescriptor) Key() string {
24
+	return u.diffID.String()
25
+}
26
+
27
+// ID returns the ID for display purposes.
28
+func (u *mockUploadDescriptor) ID() string {
29
+	return u.diffID.String()
30
+}
31
+
32
+// DiffID should return the DiffID for this layer.
33
+func (u *mockUploadDescriptor) DiffID() layer.DiffID {
34
+	return u.diffID
35
+}
36
+
37
+// Upload is called to perform the upload.
38
+func (u *mockUploadDescriptor) Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error) {
39
+	if u.currentUploads != nil {
40
+		defer atomic.AddInt32(u.currentUploads, -1)
41
+
42
+		if atomic.AddInt32(u.currentUploads, 1) > maxUploadConcurrency {
43
+			return "", errors.New("concurrency limit exceeded")
44
+		}
45
+	}
46
+
47
+	// Sleep a bit to simulate a time-consuming upload.
48
+	for i := int64(0); i <= 10; i++ {
49
+		select {
50
+		case <-ctx.Done():
51
+			return "", ctx.Err()
52
+		case <-time.After(10 * time.Millisecond):
53
+			progressOutput.WriteProgress(progress.Progress{ID: u.ID(), Current: i, Total: 10})
54
+		}
55
+	}
56
+
57
+	if u.simulateRetries != 0 {
58
+		u.simulateRetries--
59
+		return "", errors.New("simulating retry")
60
+	}
61
+
62
+	// For the mock implementation, use SHA256(DiffID) as the returned
63
+	// digest.
64
+	return digest.FromBytes([]byte(u.diffID.String()))
65
+}
66
+
67
+func uploadDescriptors(currentUploads *int32) []UploadDescriptor {
68
+	return []UploadDescriptor{
69
+		&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"), 0},
70
+		&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:1515325234325236634634608943609283523908626098235490238423902343"), 0},
71
+		&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:6929356290463485374960346430698374523437683470934634534953453453"), 0},
72
+		&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"), 0},
73
+		&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:8159352387436803946235346346368745389534789534897538734598734987"), 1},
74
+		&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:4637863963478346897346987346987346789346789364879364897364987346"), 0},
75
+	}
76
+}
77
+
78
+var expectedDigests = map[layer.DiffID]digest.Digest{
79
+	layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"): digest.Digest("sha256:c5095d6cf7ee42b7b064371dcc1dc3fb4af197f04d01a60009d484bd432724fc"),
80
+	layer.DiffID("sha256:1515325234325236634634608943609283523908626098235490238423902343"): digest.Digest("sha256:968cbfe2ff5269ea1729b3804767a1f57ffbc442d3bc86f47edbf7e688a4f36e"),
81
+	layer.DiffID("sha256:6929356290463485374960346430698374523437683470934634534953453453"): digest.Digest("sha256:8a5e56ab4b477a400470a7d5d4c1ca0c91235fd723ab19cc862636a06f3a735d"),
82
+	layer.DiffID("sha256:8159352387436803946235346346368745389534789534897538734598734987"): digest.Digest("sha256:5e733e5cd3688512fc240bd5c178e72671c9915947d17bb8451750d827944cb2"),
83
+	layer.DiffID("sha256:4637863963478346897346987346987346789346789364879364897364987346"): digest.Digest("sha256:ec4bb98d15e554a9f66c3ef9296cf46772c0ded3b1592bd8324d96e2f60f460c"),
84
+}
85
+
86
+func TestSuccessfulUpload(t *testing.T) {
87
+	lum := NewLayerUploadManager(maxUploadConcurrency)
88
+
89
+	progressChan := make(chan progress.Progress)
90
+	progressDone := make(chan struct{})
91
+	receivedProgress := make(map[string]int64)
92
+
93
+	go func() {
94
+		for p := range progressChan {
95
+			receivedProgress[p.ID] = p.Current
96
+		}
97
+		close(progressDone)
98
+	}()
99
+
100
+	var currentUploads int32
101
+	descriptors := uploadDescriptors(&currentUploads)
102
+
103
+	digests, err := lum.Upload(context.Background(), descriptors, progress.ChanOutput(progressChan))
104
+	if err != nil {
105
+		t.Fatalf("upload error: %v", err)
106
+	}
107
+
108
+	close(progressChan)
109
+	<-progressDone
110
+
111
+	if len(digests) != len(expectedDigests) {
112
+		t.Fatal("wrong number of keys in digests map")
113
+	}
114
+
115
+	for key, val := range expectedDigests {
116
+		if digests[key] != val {
117
+			t.Fatalf("mismatch in digest array for key %v (expected %v, got %v)", key, val, digests[key])
118
+		}
119
+		if receivedProgress[key.String()] != 10 {
120
+			t.Fatalf("missing or wrong progress output for %v", key)
121
+		}
122
+	}
123
+}
124
+
125
+func TestCancelledUpload(t *testing.T) {
126
+	lum := NewLayerUploadManager(maxUploadConcurrency)
127
+
128
+	progressChan := make(chan progress.Progress)
129
+	progressDone := make(chan struct{})
130
+
131
+	go func() {
132
+		for range progressChan {
133
+		}
134
+		close(progressDone)
135
+	}()
136
+
137
+	ctx, cancel := context.WithCancel(context.Background())
138
+
139
+	go func() {
140
+		<-time.After(time.Millisecond)
141
+		cancel()
142
+	}()
143
+
144
+	descriptors := uploadDescriptors(nil)
145
+	_, err := lum.Upload(ctx, descriptors, progress.ChanOutput(progressChan))
146
+	if err != context.Canceled {
147
+		t.Fatal("expected upload to be cancelled")
148
+	}
149
+
150
+	close(progressChan)
151
+	<-progressDone
152
+}
... ...
@@ -140,7 +140,7 @@ func (s *DockerHubPullSuite) TestPullAllTagsFromCentralRegistry(c *check.C) {
140 140
 }
141 141
 
142 142
 // TestPullClientDisconnect kills the client during a pull operation and verifies that the operation
143
-// still succesfully completes on the daemon side.
143
+// gets cancelled.
144 144
 //
145 145
 // Ref: docker/docker#15589
146 146
 func (s *DockerHubPullSuite) TestPullClientDisconnect(c *check.C) {
... ...
@@ -161,14 +161,8 @@ func (s *DockerHubPullSuite) TestPullClientDisconnect(c *check.C) {
161 161
 	err = pullCmd.Process.Kill()
162 162
 	c.Assert(err, checker.IsNil)
163 163
 
164
-	maxAttempts := 20
165
-	for i := 0; ; i++ {
166
-		if _, err := s.CmdWithError("inspect", repoName); err == nil {
167
-			break
168
-		}
169
-		if i >= maxAttempts {
170
-			c.Fatal("timeout reached: image was not pulled after client disconnected")
171
-		}
172
-		time.Sleep(500 * time.Millisecond)
164
+	time.Sleep(2 * time.Second)
165
+	if _, err := s.CmdWithError("inspect", repoName); err == nil {
166
+		c.Fatal("image was pulled after client disconnected")
173 167
 	}
174 168
 }
175 169
deleted file mode 100644
... ...
@@ -1,167 +0,0 @@
1
-package broadcaster
2
-
3
-import (
4
-	"errors"
5
-	"io"
6
-	"sync"
7
-)
8
-
9
-// Buffered keeps track of one or more observers watching the progress
10
-// of an operation. For example, if multiple clients are trying to pull an
11
-// image, they share a Buffered struct for the download operation.
12
-type Buffered struct {
13
-	sync.Mutex
14
-	// c is a channel that observers block on, waiting for the operation
15
-	// to finish.
16
-	c chan struct{}
17
-	// cond is a condition variable used to wake up observers when there's
18
-	// new data available.
19
-	cond *sync.Cond
20
-	// history is a buffer of the progress output so far, so a new observer
21
-	// can catch up. The history is stored as a slice of separate byte
22
-	// slices, so that if the writer is a WriteFlusher, the flushes will
23
-	// happen in the right places.
24
-	history [][]byte
25
-	// wg is a WaitGroup used to wait for all writes to finish on Close
26
-	wg sync.WaitGroup
27
-	// result is the argument passed to the first call of Close, and
28
-	// returned to callers of Wait
29
-	result error
30
-}
31
-
32
-// NewBuffered returns an initialized Buffered structure.
33
-func NewBuffered() *Buffered {
34
-	b := &Buffered{
35
-		c: make(chan struct{}),
36
-	}
37
-	b.cond = sync.NewCond(b)
38
-	return b
39
-}
40
-
41
-// closed returns true if and only if the broadcaster has been closed
42
-func (broadcaster *Buffered) closed() bool {
43
-	select {
44
-	case <-broadcaster.c:
45
-		return true
46
-	default:
47
-		return false
48
-	}
49
-}
50
-
51
-// receiveWrites runs as a goroutine so that writes don't block the Write
52
-// function. It writes the new data in broadcaster.history each time there's
53
-// activity on the broadcaster.cond condition variable.
54
-func (broadcaster *Buffered) receiveWrites(observer io.Writer) {
55
-	n := 0
56
-
57
-	broadcaster.Lock()
58
-
59
-	// The condition variable wait is at the end of this loop, so that the
60
-	// first iteration will write the history so far.
61
-	for {
62
-		newData := broadcaster.history[n:]
63
-		// Make a copy of newData so we can release the lock
64
-		sendData := make([][]byte, len(newData), len(newData))
65
-		copy(sendData, newData)
66
-		broadcaster.Unlock()
67
-
68
-		for len(sendData) > 0 {
69
-			_, err := observer.Write(sendData[0])
70
-			if err != nil {
71
-				broadcaster.wg.Done()
72
-				return
73
-			}
74
-			n++
75
-			sendData = sendData[1:]
76
-		}
77
-
78
-		broadcaster.Lock()
79
-
80
-		// If we are behind, we need to catch up instead of waiting
81
-		// or handling a closure.
82
-		if len(broadcaster.history) != n {
83
-			continue
84
-		}
85
-
86
-		// detect closure of the broadcast writer
87
-		if broadcaster.closed() {
88
-			broadcaster.Unlock()
89
-			broadcaster.wg.Done()
90
-			return
91
-		}
92
-
93
-		broadcaster.cond.Wait()
94
-
95
-		// Mutex is still locked as the loop continues
96
-	}
97
-}
98
-
99
-// Write adds data to the history buffer, and also writes it to all current
100
-// observers.
101
-func (broadcaster *Buffered) Write(p []byte) (n int, err error) {
102
-	broadcaster.Lock()
103
-	defer broadcaster.Unlock()
104
-
105
-	// Is the broadcaster closed? If so, the write should fail.
106
-	if broadcaster.closed() {
107
-		return 0, errors.New("attempted write to a closed broadcaster.Buffered")
108
-	}
109
-
110
-	// Add message in p to the history slice
111
-	newEntry := make([]byte, len(p), len(p))
112
-	copy(newEntry, p)
113
-	broadcaster.history = append(broadcaster.history, newEntry)
114
-
115
-	broadcaster.cond.Broadcast()
116
-
117
-	return len(p), nil
118
-}
119
-
120
-// Add adds an observer to the broadcaster. The new observer receives the
121
-// data from the history buffer, and also all subsequent data.
122
-func (broadcaster *Buffered) Add(w io.Writer) error {
123
-	// The lock is acquired here so that Add can't race with Close
124
-	broadcaster.Lock()
125
-	defer broadcaster.Unlock()
126
-
127
-	if broadcaster.closed() {
128
-		return errors.New("attempted to add observer to a closed broadcaster.Buffered")
129
-	}
130
-
131
-	broadcaster.wg.Add(1)
132
-	go broadcaster.receiveWrites(w)
133
-
134
-	return nil
135
-}
136
-
137
-// CloseWithError signals to all observers that the operation has finished. Its
138
-// argument is a result that should be returned to waiters blocking on Wait.
139
-func (broadcaster *Buffered) CloseWithError(result error) {
140
-	broadcaster.Lock()
141
-	if broadcaster.closed() {
142
-		broadcaster.Unlock()
143
-		return
144
-	}
145
-	broadcaster.result = result
146
-	close(broadcaster.c)
147
-	broadcaster.cond.Broadcast()
148
-	broadcaster.Unlock()
149
-
150
-	// Don't return until all writers have caught up.
151
-	broadcaster.wg.Wait()
152
-}
153
-
154
-// Close signals to all observers that the operation has finished. It causes
155
-// all calls to Wait to return nil.
156
-func (broadcaster *Buffered) Close() {
157
-	broadcaster.CloseWithError(nil)
158
-}
159
-
160
-// Wait blocks until the operation is marked as completed by the Close method,
161
-// and all writer goroutines have completed. It returns the argument that was
162
-// passed to Close.
163
-func (broadcaster *Buffered) Wait() error {
164
-	<-broadcaster.c
165
-	broadcaster.wg.Wait()
166
-	return broadcaster.result
167
-}
... ...
@@ -4,6 +4,8 @@ import (
4 4
 	"crypto/sha256"
5 5
 	"encoding/hex"
6 6
 	"io"
7
+
8
+	"golang.org/x/net/context"
7 9
 )
8 10
 
9 11
 type readCloserWrapper struct {
... ...
@@ -81,3 +83,72 @@ func (r *OnEOFReader) runFunc() {
81 81
 		r.Fn = nil
82 82
 	}
83 83
 }
84
+
85
+// cancelReadCloser wraps an io.ReadCloser with a context for cancelling read
86
+// operations.
87
+type cancelReadCloser struct {
88
+	cancel func()
89
+	pR     *io.PipeReader // Stream to read from
90
+	pW     *io.PipeWriter
91
+}
92
+
93
+// NewCancelReadCloser creates a wrapper that closes the ReadCloser when the
94
+// context is cancelled. The returned io.ReadCloser must be closed when it is
95
+// no longer needed.
96
+func NewCancelReadCloser(ctx context.Context, in io.ReadCloser) io.ReadCloser {
97
+	pR, pW := io.Pipe()
98
+
99
+	// Create a context used to signal when the pipe is closed
100
+	doneCtx, cancel := context.WithCancel(context.Background())
101
+
102
+	p := &cancelReadCloser{
103
+		cancel: cancel,
104
+		pR:     pR,
105
+		pW:     pW,
106
+	}
107
+
108
+	go func() {
109
+		_, err := io.Copy(pW, in)
110
+		select {
111
+		case <-ctx.Done():
112
+			// If the context was closed, p.closeWithError
113
+			// was already called. Calling it again would
114
+			// change the error that Read returns.
115
+		default:
116
+			p.closeWithError(err)
117
+		}
118
+		in.Close()
119
+	}()
120
+	go func() {
121
+		for {
122
+			select {
123
+			case <-ctx.Done():
124
+				p.closeWithError(ctx.Err())
125
+			case <-doneCtx.Done():
126
+				return
127
+			}
128
+		}
129
+	}()
130
+
131
+	return p
132
+}
133
+
134
+// Read wraps the Read method of the pipe that provides data from the wrapped
135
+// ReadCloser.
136
+func (p *cancelReadCloser) Read(buf []byte) (n int, err error) {
137
+	return p.pR.Read(buf)
138
+}
139
+
140
+// closeWithError closes the wrapper and its underlying reader. It will
141
+// cause future calls to Read to return err.
142
+func (p *cancelReadCloser) closeWithError(err error) {
143
+	p.pW.CloseWithError(err)
144
+	p.cancel()
145
+}
146
+
147
+// Close closes the wrapper its underlying reader. It will cause
148
+// future calls to Read to return io.EOF.
149
+func (p *cancelReadCloser) Close() error {
150
+	p.closeWithError(io.EOF)
151
+	return nil
152
+}
... ...
@@ -2,8 +2,12 @@ package ioutils
2 2
 
3 3
 import (
4 4
 	"fmt"
5
+	"io/ioutil"
5 6
 	"strings"
6 7
 	"testing"
8
+	"time"
9
+
10
+	"golang.org/x/net/context"
7 11
 )
8 12
 
9 13
 // Implement io.Reader
... ...
@@ -65,3 +69,26 @@ func TestHashData(t *testing.T) {
65 65
 		t.Fatalf("Expecting %s, got %s", expected, actual)
66 66
 	}
67 67
 }
68
+
69
+type perpetualReader struct{}
70
+
71
+func (p *perpetualReader) Read(buf []byte) (n int, err error) {
72
+	for i := 0; i != len(buf); i++ {
73
+		buf[i] = 'a'
74
+	}
75
+	return len(buf), nil
76
+}
77
+
78
+func TestCancelReadCloser(t *testing.T) {
79
+	ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
80
+	cancelReadCloser := NewCancelReadCloser(ctx, ioutil.NopCloser(&perpetualReader{}))
81
+	for {
82
+		var buf [128]byte
83
+		_, err := cancelReadCloser.Read(buf[:])
84
+		if err == context.DeadlineExceeded {
85
+			break
86
+		} else if err != nil {
87
+			t.Fatalf("got unexpected error: %v", err)
88
+		}
89
+	}
90
+}
68 91
new file mode 100644
... ...
@@ -0,0 +1,63 @@
0
+package progress
1
+
2
+import (
3
+	"fmt"
4
+)
5
+
6
+// Progress represents the progress of a transfer.
7
+type Progress struct {
8
+	ID string
9
+
10
+	// Progress contains a Message or...
11
+	Message string
12
+
13
+	// ...progress of an action
14
+	Action  string
15
+	Current int64
16
+	Total   int64
17
+
18
+	LastUpdate bool
19
+}
20
+
21
+// Output is an interface for writing progress information. It's
22
+// like a writer for progress, but we don't call it Writer because
23
+// that would be confusing next to ProgressReader (also, because it
24
+// doesn't implement the io.Writer interface).
25
+type Output interface {
26
+	WriteProgress(Progress) error
27
+}
28
+
29
+type chanOutput chan<- Progress
30
+
31
+func (out chanOutput) WriteProgress(p Progress) error {
32
+	out <- p
33
+	return nil
34
+}
35
+
36
+// ChanOutput returns a Output that writes progress updates to the
37
+// supplied channel.
38
+func ChanOutput(progressChan chan<- Progress) Output {
39
+	return chanOutput(progressChan)
40
+}
41
+
42
+// Update is a convenience function to write a progress update to the channel.
43
+func Update(out Output, id, action string) {
44
+	out.WriteProgress(Progress{ID: id, Action: action})
45
+}
46
+
47
+// Updatef is a convenience function to write a printf-formatted progress update
48
+// to the channel.
49
+func Updatef(out Output, id, format string, a ...interface{}) {
50
+	Update(out, id, fmt.Sprintf(format, a...))
51
+}
52
+
53
+// Message is a convenience function to write a progress message to the channel.
54
+func Message(out Output, id, message string) {
55
+	out.WriteProgress(Progress{ID: id, Message: message})
56
+}
57
+
58
+// Messagef is a convenience function to write a printf-formatted progress
59
+// message to the channel.
60
+func Messagef(out Output, id, format string, a ...interface{}) {
61
+	Message(out, id, fmt.Sprintf(format, a...))
62
+}
0 63
new file mode 100644
... ...
@@ -0,0 +1,59 @@
0
+package progress
1
+
2
+import (
3
+	"io"
4
+)
5
+
6
+// Reader is a Reader with progress bar.
7
+type Reader struct {
8
+	in         io.ReadCloser // Stream to read from
9
+	out        Output        // Where to send progress bar to
10
+	size       int64
11
+	current    int64
12
+	lastUpdate int64
13
+	id         string
14
+	action     string
15
+}
16
+
17
+// NewProgressReader creates a new ProgressReader.
18
+func NewProgressReader(in io.ReadCloser, out Output, size int64, id, action string) *Reader {
19
+	return &Reader{
20
+		in:     in,
21
+		out:    out,
22
+		size:   size,
23
+		id:     id,
24
+		action: action,
25
+	}
26
+}
27
+
28
+func (p *Reader) Read(buf []byte) (n int, err error) {
29
+	read, err := p.in.Read(buf)
30
+	p.current += int64(read)
31
+	updateEvery := int64(1024 * 512) //512kB
32
+	if p.size > 0 {
33
+		// Update progress for every 1% read if 1% < 512kB
34
+		if increment := int64(0.01 * float64(p.size)); increment < updateEvery {
35
+			updateEvery = increment
36
+		}
37
+	}
38
+	if p.current-p.lastUpdate > updateEvery || err != nil {
39
+		p.updateProgress(err != nil && read == 0)
40
+		p.lastUpdate = p.current
41
+	}
42
+
43
+	return read, err
44
+}
45
+
46
+// Close closes the progress reader and its underlying reader.
47
+func (p *Reader) Close() error {
48
+	if p.current < p.size {
49
+		// print a full progress bar when closing prematurely
50
+		p.current = p.size
51
+		p.updateProgress(false)
52
+	}
53
+	return p.in.Close()
54
+}
55
+
56
+func (p *Reader) updateProgress(last bool) {
57
+	p.out.WriteProgress(Progress{ID: p.id, Action: p.action, Current: p.current, Total: p.size, LastUpdate: last})
58
+}
0 59
new file mode 100644
... ...
@@ -0,0 +1,75 @@
0
+package progress
1
+
2
+import (
3
+	"bytes"
4
+	"io"
5
+	"io/ioutil"
6
+	"testing"
7
+)
8
+
9
+func TestOutputOnPrematureClose(t *testing.T) {
10
+	content := []byte("TESTING")
11
+	reader := ioutil.NopCloser(bytes.NewReader(content))
12
+	progressChan := make(chan Progress, 10)
13
+
14
+	pr := NewProgressReader(reader, ChanOutput(progressChan), int64(len(content)), "Test", "Read")
15
+
16
+	part := make([]byte, 4, 4)
17
+	_, err := io.ReadFull(pr, part)
18
+	if err != nil {
19
+		pr.Close()
20
+		t.Fatal(err)
21
+	}
22
+
23
+drainLoop:
24
+	for {
25
+		select {
26
+		case <-progressChan:
27
+		default:
28
+			break drainLoop
29
+		}
30
+	}
31
+
32
+	pr.Close()
33
+
34
+	select {
35
+	case <-progressChan:
36
+	default:
37
+		t.Fatalf("Expected some output when closing prematurely")
38
+	}
39
+}
40
+
41
+func TestCompleteSilently(t *testing.T) {
42
+	content := []byte("TESTING")
43
+	reader := ioutil.NopCloser(bytes.NewReader(content))
44
+	progressChan := make(chan Progress, 10)
45
+
46
+	pr := NewProgressReader(reader, ChanOutput(progressChan), int64(len(content)), "Test", "Read")
47
+
48
+	out, err := ioutil.ReadAll(pr)
49
+	if err != nil {
50
+		pr.Close()
51
+		t.Fatal(err)
52
+	}
53
+	if string(out) != "TESTING" {
54
+		pr.Close()
55
+		t.Fatalf("Unexpected output %q from reader", string(out))
56
+	}
57
+
58
+drainLoop:
59
+	for {
60
+		select {
61
+		case <-progressChan:
62
+		default:
63
+			break drainLoop
64
+		}
65
+	}
66
+
67
+	pr.Close()
68
+
69
+	select {
70
+	case <-progressChan:
71
+		t.Fatalf("Should have closed silently when read is complete")
72
+	default:
73
+	}
74
+}
0 75
deleted file mode 100644
... ...
@@ -1,68 +0,0 @@
1
-// Package progressreader provides a Reader with a progress bar that can be
2
-// printed out using the streamformatter package.
3
-package progressreader
4
-
5
-import (
6
-	"io"
7
-
8
-	"github.com/docker/docker/pkg/jsonmessage"
9
-	"github.com/docker/docker/pkg/streamformatter"
10
-)
11
-
12
-// Config contains the configuration for a Reader with progress bar.
13
-type Config struct {
14
-	In         io.ReadCloser // Stream to read from
15
-	Out        io.Writer     // Where to send progress bar to
16
-	Formatter  *streamformatter.StreamFormatter
17
-	Size       int64
18
-	Current    int64
19
-	LastUpdate int64
20
-	NewLines   bool
21
-	ID         string
22
-	Action     string
23
-}
24
-
25
-// New creates a new Config.
26
-func New(newReader Config) *Config {
27
-	return &newReader
28
-}
29
-
30
-func (config *Config) Read(p []byte) (n int, err error) {
31
-	read, err := config.In.Read(p)
32
-	config.Current += int64(read)
33
-	updateEvery := int64(1024 * 512) //512kB
34
-	if config.Size > 0 {
35
-		// Update progress for every 1% read if 1% < 512kB
36
-		if increment := int64(0.01 * float64(config.Size)); increment < updateEvery {
37
-			updateEvery = increment
38
-		}
39
-	}
40
-	if config.Current-config.LastUpdate > updateEvery || err != nil {
41
-		updateProgress(config)
42
-		config.LastUpdate = config.Current
43
-	}
44
-
45
-	if err != nil && read == 0 {
46
-		updateProgress(config)
47
-		if config.NewLines {
48
-			config.Out.Write(config.Formatter.FormatStatus("", ""))
49
-		}
50
-	}
51
-	return read, err
52
-}
53
-
54
-// Close closes the reader (Config).
55
-func (config *Config) Close() error {
56
-	if config.Current < config.Size {
57
-		//print a full progress bar when closing prematurely
58
-		config.Current = config.Size
59
-		updateProgress(config)
60
-	}
61
-	return config.In.Close()
62
-}
63
-
64
-func updateProgress(config *Config) {
65
-	progress := jsonmessage.JSONProgress{Current: config.Current, Total: config.Size}
66
-	fmtMessage := config.Formatter.FormatProgress(config.ID, config.Action, &progress)
67
-	config.Out.Write(fmtMessage)
68
-}
69 1
deleted file mode 100644
... ...
@@ -1,94 +0,0 @@
1
-package progressreader
2
-
3
-import (
4
-	"bufio"
5
-	"bytes"
6
-	"io"
7
-	"io/ioutil"
8
-	"testing"
9
-
10
-	"github.com/docker/docker/pkg/streamformatter"
11
-)
12
-
13
-func TestOutputOnPrematureClose(t *testing.T) {
14
-	var outBuf bytes.Buffer
15
-	content := []byte("TESTING")
16
-	reader := ioutil.NopCloser(bytes.NewReader(content))
17
-	writer := bufio.NewWriter(&outBuf)
18
-
19
-	prCfg := Config{
20
-		In:        reader,
21
-		Out:       writer,
22
-		Formatter: streamformatter.NewStreamFormatter(),
23
-		Size:      int64(len(content)),
24
-		NewLines:  true,
25
-		ID:        "Test",
26
-		Action:    "Read",
27
-	}
28
-	pr := New(prCfg)
29
-
30
-	part := make([]byte, 4, 4)
31
-	_, err := io.ReadFull(pr, part)
32
-	if err != nil {
33
-		pr.Close()
34
-		t.Fatal(err)
35
-	}
36
-
37
-	if err := writer.Flush(); err != nil {
38
-		pr.Close()
39
-		t.Fatal(err)
40
-	}
41
-
42
-	tlen := outBuf.Len()
43
-	pr.Close()
44
-	if err := writer.Flush(); err != nil {
45
-		t.Fatal(err)
46
-	}
47
-
48
-	if outBuf.Len() == tlen {
49
-		t.Fatalf("Expected some output when closing prematurely")
50
-	}
51
-}
52
-
53
-func TestCompleteSilently(t *testing.T) {
54
-	var outBuf bytes.Buffer
55
-	content := []byte("TESTING")
56
-	reader := ioutil.NopCloser(bytes.NewReader(content))
57
-	writer := bufio.NewWriter(&outBuf)
58
-
59
-	prCfg := Config{
60
-		In:        reader,
61
-		Out:       writer,
62
-		Formatter: streamformatter.NewStreamFormatter(),
63
-		Size:      int64(len(content)),
64
-		NewLines:  true,
65
-		ID:        "Test",
66
-		Action:    "Read",
67
-	}
68
-	pr := New(prCfg)
69
-
70
-	out, err := ioutil.ReadAll(pr)
71
-	if err != nil {
72
-		pr.Close()
73
-		t.Fatal(err)
74
-	}
75
-	if string(out) != "TESTING" {
76
-		pr.Close()
77
-		t.Fatalf("Unexpected output %q from reader", string(out))
78
-	}
79
-
80
-	if err := writer.Flush(); err != nil {
81
-		pr.Close()
82
-		t.Fatal(err)
83
-	}
84
-
85
-	tlen := outBuf.Len()
86
-	pr.Close()
87
-	if err := writer.Flush(); err != nil {
88
-		t.Fatal(err)
89
-	}
90
-
91
-	if outBuf.Len() > tlen {
92
-		t.Fatalf("Should have closed silently when read is complete")
93
-	}
94
-}
... ...
@@ -7,6 +7,7 @@ import (
7 7
 	"io"
8 8
 
9 9
 	"github.com/docker/docker/pkg/jsonmessage"
10
+	"github.com/docker/docker/pkg/progress"
10 11
 )
11 12
 
12 13
 // StreamFormatter formats a stream, optionally using JSON.
... ...
@@ -92,6 +93,44 @@ func (sf *StreamFormatter) FormatProgress(id, action string, progress *jsonmessa
92 92
 	return []byte(action + " " + progress.String() + endl)
93 93
 }
94 94
 
95
+// NewProgressOutput returns a progress.Output object that can be passed to
96
+// progress.NewProgressReader.
97
+func (sf *StreamFormatter) NewProgressOutput(out io.Writer, newLines bool) progress.Output {
98
+	return &progressOutput{
99
+		sf:       sf,
100
+		out:      out,
101
+		newLines: newLines,
102
+	}
103
+}
104
+
105
+type progressOutput struct {
106
+	sf       *StreamFormatter
107
+	out      io.Writer
108
+	newLines bool
109
+}
110
+
111
+// WriteProgress formats progress information from a ProgressReader.
112
+func (out *progressOutput) WriteProgress(prog progress.Progress) error {
113
+	var formatted []byte
114
+	if prog.Message != "" {
115
+		formatted = out.sf.FormatStatus(prog.ID, prog.Message)
116
+	} else {
117
+		jsonProgress := jsonmessage.JSONProgress{Current: prog.Current, Total: prog.Total}
118
+		formatted = out.sf.FormatProgress(prog.ID, prog.Action, &jsonProgress)
119
+	}
120
+	_, err := out.out.Write(formatted)
121
+	if err != nil {
122
+		return err
123
+	}
124
+
125
+	if out.newLines && prog.LastUpdate {
126
+		_, err = out.out.Write(out.sf.FormatStatus("", ""))
127
+		return err
128
+	}
129
+
130
+	return nil
131
+}
132
+
95 133
 // StdoutFormatter is a streamFormatter that writes to the standard output.
96 134
 type StdoutFormatter struct {
97 135
 	io.Writer
... ...
@@ -17,7 +17,6 @@ import (
17 17
 	"net/url"
18 18
 	"strconv"
19 19
 	"strings"
20
-	"time"
21 20
 
22 21
 	"github.com/Sirupsen/logrus"
23 22
 	"github.com/docker/distribution/reference"
... ...
@@ -270,7 +269,6 @@ func (r *Session) GetRemoteImageJSON(imgID, registry string) ([]byte, int64, err
270 270
 // GetRemoteImageLayer retrieves an image layer from the registry
271 271
 func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io.ReadCloser, error) {
272 272
 	var (
273
-		retries    = 5
274 273
 		statusCode = 0
275 274
 		res        *http.Response
276 275
 		err        error
... ...
@@ -281,14 +279,9 @@ func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io
281 281
 	if err != nil {
282 282
 		return nil, fmt.Errorf("Error while getting from the server: %v", err)
283 283
 	}
284
-	// TODO(tiborvass): why are we doing retries at this level?
285
-	// These retries should be generic to both v1 and v2
286
-	for i := 1; i <= retries; i++ {
287
-		statusCode = 0
288
-		res, err = r.client.Do(req)
289
-		if err == nil {
290
-			break
291
-		}
284
+	statusCode = 0
285
+	res, err = r.client.Do(req)
286
+	if err != nil {
292 287
 		logrus.Debugf("Error contacting registry %s: %v", registry, err)
293 288
 		if res != nil {
294 289
 			if res.Body != nil {
... ...
@@ -296,11 +289,8 @@ func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io
296 296
 			}
297 297
 			statusCode = res.StatusCode
298 298
 		}
299
-		if i == retries {
300
-			return nil, fmt.Errorf("Server error: Status %d while fetching image layer (%s)",
301
-				statusCode, imgID)
302
-		}
303
-		time.Sleep(time.Duration(i) * 5 * time.Second)
299
+		return nil, fmt.Errorf("Server error: Status %d while fetching image layer (%s)",
300
+			statusCode, imgID)
304 301
 	}
305 302
 
306 303
 	if res.StatusCode != 200 {