Browse code

Fix race condition when waiting for a concurrent layer pull

Before, this only waited for the download to complete. There was no
guarantee that the layer had been registered in the graph and was ready
use. This is especially problematic with v2 pulls, which wait for all
downloads before extracting layers.

Change Broadcaster to allow an error value to be propagated from Close
to the waiters.

Make the wait stop when the extraction is finished, rather than just the
download.

This also fixes v2 layer downloads to prefix the pool key with "layer:"
instead of "img:". "img:" is the wrong prefix, because this is what v1
uses for entire images. A v1 pull waiting for one of these operations to
finish would only wait for that particular layer, not all its
dependencies.

Signed-off-by: Aaron Lehmann <aaron.lehmann@docker.com>

Aaron Lehmann authored on 2015/08/26 06:23:52
Showing 6 changed files
... ...
@@ -106,13 +106,14 @@ func (s *TagStore) recursiveLoad(address, tmpImageDir string) error {
106 106
 		}
107 107
 
108 108
 		// ensure no two downloads of the same layer happen at the same time
109
-		if ps, found := s.poolAdd("pull", "layer:"+img.ID); found {
109
+		poolKey := "layer:" + img.ID
110
+		broadcaster, found := s.poolAdd("pull", poolKey)
111
+		if found {
110 112
 			logrus.Debugf("Image (id: %s) load is already running, waiting", img.ID)
111
-			ps.Wait()
112
-			return nil
113
+			return broadcaster.Wait()
113 114
 		}
114 115
 
115
-		defer s.poolRemove("pull", "layer:"+img.ID)
116
+		defer s.poolRemove("pull", poolKey)
116 117
 
117 118
 		if img.Parent != "" {
118 119
 			if !s.graph.Exists(img.Parent) {
... ...
@@ -138,16 +138,14 @@ func (p *v1Puller) pullRepository(askedTag string) error {
138 138
 			}
139 139
 
140 140
 			// ensure no two downloads of the same image happen at the same time
141
-			broadcaster, found := p.poolAdd("pull", "img:"+img.ID)
141
+			poolKey := "img:" + img.ID
142
+			broadcaster, found := p.poolAdd("pull", poolKey)
143
+			broadcaster.Add(out)
142 144
 			if found {
143
-				broadcaster.Add(out)
144
-				broadcaster.Wait()
145
-				out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil))
146
-				errors <- nil
145
+				errors <- broadcaster.Wait()
147 146
 				return
148 147
 			}
149
-			broadcaster.Add(out)
150
-			defer p.poolRemove("pull", "img:"+img.ID)
148
+			defer p.poolRemove("pull", poolKey)
151 149
 
152 150
 			// we need to retain it until tagging
153 151
 			p.graph.Retain(sessionID, img.ID)
... ...
@@ -188,6 +186,7 @@ func (p *v1Puller) pullRepository(askedTag string) error {
188 188
 				err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName, lastErr)
189 189
 				broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil))
190 190
 				errors <- err
191
+				broadcaster.CloseWithError(err)
191 192
 				return
192 193
 			}
193 194
 			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil))
... ...
@@ -225,8 +224,9 @@ func (p *v1Puller) pullRepository(askedTag string) error {
225 225
 	return nil
226 226
 }
227 227
 
228
-func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []string) (bool, error) {
229
-	history, err := p.session.GetRemoteHistory(imgID, endpoint)
228
+func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []string) (layersDownloaded bool, err error) {
229
+	var history []string
230
+	history, err = p.session.GetRemoteHistory(imgID, endpoint)
230 231
 	if err != nil {
231 232
 		return false, err
232 233
 	}
... ...
@@ -239,20 +239,28 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
239 239
 	p.graph.Retain(sessionID, history[1:]...)
240 240
 	defer p.graph.Release(sessionID, history[1:]...)
241 241
 
242
-	layersDownloaded := false
242
+	layersDownloaded = false
243 243
 	for i := len(history) - 1; i >= 0; i-- {
244 244
 		id := history[i]
245 245
 
246 246
 		// ensure no two downloads of the same layer happen at the same time
247
-		broadcaster, found := p.poolAdd("pull", "layer:"+id)
247
+		poolKey := "layer:" + id
248
+		broadcaster, found := p.poolAdd("pull", poolKey)
249
+		broadcaster.Add(out)
248 250
 		if found {
249 251
 			logrus.Debugf("Image (id: %s) pull is already running, skipping", id)
250
-			broadcaster.Add(out)
251
-			broadcaster.Wait()
252
-		} else {
253
-			broadcaster.Add(out)
252
+			err = broadcaster.Wait()
253
+			if err != nil {
254
+				return layersDownloaded, err
255
+			}
256
+			continue
254 257
 		}
255
-		defer p.poolRemove("pull", "layer:"+id)
258
+
259
+		// This must use a closure so it captures the value of err when
260
+		// the function returns, not when the 'defer' is evaluated.
261
+		defer func() {
262
+			p.poolRemoveWithError("pull", poolKey, err)
263
+		}()
256 264
 
257 265
 		if !p.graph.Exists(id) {
258 266
 			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Pulling metadata", nil))
... ...
@@ -328,6 +336,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
328 328
 			}
329 329
 		}
330 330
 		broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Download complete", nil))
331
+		broadcaster.Close()
331 332
 	}
332 333
 	return layersDownloaded, nil
333 334
 }
... ...
@@ -74,14 +74,17 @@ func (p *v2Puller) pullV2Repository(tag string) (err error) {
74 74
 	}
75 75
 
76 76
 	broadcaster, found := p.poolAdd("pull", taggedName)
77
+	broadcaster.Add(p.config.OutStream)
77 78
 	if found {
78 79
 		// Another pull of the same repository is already taking place; just wait for it to finish
79
-		broadcaster.Add(p.config.OutStream)
80
-		broadcaster.Wait()
81
-		return nil
80
+		return broadcaster.Wait()
82 81
 	}
83
-	defer p.poolRemove("pull", taggedName)
84
-	broadcaster.Add(p.config.OutStream)
82
+
83
+	// This must use a closure so it captures the value of err when the
84
+	// function returns, not when the 'defer' is evaluated.
85
+	defer func() {
86
+		p.poolRemoveWithError("pull", taggedName, err)
87
+	}()
85 88
 
86 89
 	var layersDownloaded bool
87 90
 	for _, tag := range tags {
... ...
@@ -101,13 +104,15 @@ func (p *v2Puller) pullV2Repository(tag string) (err error) {
101 101
 
102 102
 // downloadInfo is used to pass information from download to extractor
103 103
 type downloadInfo struct {
104
-	img     *image.Image
105
-	tmpFile *os.File
106
-	digest  digest.Digest
107
-	layer   distribution.ReadSeekCloser
108
-	size    int64
109
-	err     chan error
110
-	out     io.Writer // Download progress is written here.
104
+	img         *image.Image
105
+	tmpFile     *os.File
106
+	digest      digest.Digest
107
+	layer       distribution.ReadSeekCloser
108
+	size        int64
109
+	err         chan error
110
+	out         io.Writer // Download progress is written here.
111
+	poolKey     string
112
+	broadcaster *progressreader.Broadcaster
111 113
 }
112 114
 
113 115
 type errVerification struct{}
... ...
@@ -117,19 +122,15 @@ func (errVerification) Error() string { return "verification failed" }
117 117
 func (p *v2Puller) download(di *downloadInfo) {
118 118
 	logrus.Debugf("pulling blob %q to %s", di.digest, di.img.ID)
119 119
 
120
-	out := di.out
121
-
122
-	broadcaster, found := p.poolAdd("pull", "img:"+di.img.ID)
120
+	di.poolKey = "layer:" + di.img.ID
121
+	broadcaster, found := p.poolAdd("pull", di.poolKey)
122
+	broadcaster.Add(di.out)
123
+	di.broadcaster = broadcaster
123 124
 	if found {
124
-		broadcaster.Add(out)
125
-		broadcaster.Wait()
126
-		out.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil))
127 125
 		di.err <- nil
128 126
 		return
129 127
 	}
130 128
 
131
-	broadcaster.Add(out)
132
-	defer p.poolRemove("pull", "img:"+di.img.ID)
133 129
 	tmpFile, err := ioutil.TempFile("", "GetImageBlob")
134 130
 	if err != nil {
135 131
 		di.err <- err
... ...
@@ -279,6 +280,7 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo
279 279
 	// run clean for all downloads to prevent leftovers
280 280
 	for _, d := range downloads {
281 281
 		defer func(d *downloadInfo) {
282
+			p.poolRemoveWithError("pull", d.poolKey, err)
282 283
 			if d.tmpFile != nil {
283 284
 				d.tmpFile.Close()
284 285
 				if err := os.RemoveAll(d.tmpFile.Name()); err != nil {
... ...
@@ -293,14 +295,21 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo
293 293
 		if err := <-d.err; err != nil {
294 294
 			return false, err
295 295
 		}
296
+
296 297
 		if d.layer == nil {
298
+			// Wait for a different pull to download and extract
299
+			// this layer.
300
+			err = d.broadcaster.Wait()
301
+			if err != nil {
302
+				return false, err
303
+			}
297 304
 			continue
298 305
 		}
299
-		// if tmpFile is empty assume download and extracted elsewhere
306
+
300 307
 		d.tmpFile.Seek(0, 0)
301 308
 		reader := progressreader.New(progressreader.Config{
302 309
 			In:        d.tmpFile,
303
-			Out:       out,
310
+			Out:       d.broadcaster,
304 311
 			Formatter: p.sf,
305 312
 			Size:      d.size,
306 313
 			NewLines:  false,
... ...
@@ -317,8 +326,8 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo
317 317
 			return false, err
318 318
 		}
319 319
 
320
-		// FIXME: Pool release here for parallel tag pull (ensures any downloads block until fully extracted)
321
-		out.Write(p.sf.FormatProgress(stringid.TruncateID(d.img.ID), "Pull complete", nil))
320
+		d.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(d.img.ID), "Pull complete", nil))
321
+		d.broadcaster.Close()
322 322
 		tagUpdated = true
323 323
 	}
324 324
 
... ...
@@ -462,18 +462,18 @@ func (store *TagStore) poolAdd(kind, key string) (*progressreader.Broadcaster, b
462 462
 	return broadcaster, false
463 463
 }
464 464
 
465
-func (store *TagStore) poolRemove(kind, key string) error {
465
+func (store *TagStore) poolRemoveWithError(kind, key string, broadcasterResult error) error {
466 466
 	store.Lock()
467 467
 	defer store.Unlock()
468 468
 	switch kind {
469 469
 	case "pull":
470
-		if ps, exists := store.pullingPool[key]; exists {
471
-			ps.Close()
470
+		if broadcaster, exists := store.pullingPool[key]; exists {
471
+			broadcaster.CloseWithError(broadcasterResult)
472 472
 			delete(store.pullingPool, key)
473 473
 		}
474 474
 	case "push":
475
-		if ps, exists := store.pushingPool[key]; exists {
476
-			ps.Close()
475
+		if broadcaster, exists := store.pushingPool[key]; exists {
476
+			broadcaster.CloseWithError(broadcasterResult)
477 477
 			delete(store.pushingPool, key)
478 478
 		}
479 479
 	default:
... ...
@@ -481,3 +481,7 @@ func (store *TagStore) poolRemove(kind, key string) error {
481 481
 	}
482 482
 	return nil
483 483
 }
484
+
485
+func (store *TagStore) poolRemove(kind, key string) error {
486
+	return store.poolRemoveWithError(kind, key, nil)
487
+}
... ...
@@ -2,6 +2,8 @@ package main
2 2
 
3 3
 import (
4 4
 	"fmt"
5
+	"os/exec"
6
+	"strings"
5 7
 
6 8
 	"github.com/go-check/check"
7 9
 )
... ...
@@ -37,3 +39,134 @@ func (s *DockerRegistrySuite) TestPullImageWithAliases(c *check.C) {
37 37
 		}
38 38
 	}
39 39
 }
40
+
41
+// TestConcurrentPullWholeRepo pulls the same repo concurrently.
42
+func (s *DockerRegistrySuite) TestConcurrentPullWholeRepo(c *check.C) {
43
+	repoName := fmt.Sprintf("%v/dockercli/busybox", privateRegistryURL)
44
+
45
+	repos := []string{}
46
+	for _, tag := range []string{"recent", "fresh", "todays"} {
47
+		repo := fmt.Sprintf("%v:%v", repoName, tag)
48
+		_, err := buildImage(repo, fmt.Sprintf(`
49
+		    FROM busybox
50
+		    ENTRYPOINT ["/bin/echo"]
51
+		    ENV FOO foo
52
+		    ENV BAR bar
53
+		    CMD echo %s
54
+		`, repo), true)
55
+		if err != nil {
56
+			c.Fatal(err)
57
+		}
58
+		dockerCmd(c, "push", repo)
59
+		repos = append(repos, repo)
60
+	}
61
+
62
+	// Clear local images store.
63
+	args := append([]string{"rmi"}, repos...)
64
+	dockerCmd(c, args...)
65
+
66
+	// Run multiple re-pulls concurrently
67
+	results := make(chan error)
68
+	numPulls := 3
69
+
70
+	for i := 0; i != numPulls; i++ {
71
+		go func() {
72
+			_, _, err := runCommandWithOutput(exec.Command(dockerBinary, "pull", "-a", repoName))
73
+			results <- err
74
+		}()
75
+	}
76
+
77
+	// These checks are separate from the loop above because the check
78
+	// package is not goroutine-safe.
79
+	for i := 0; i != numPulls; i++ {
80
+		err := <-results
81
+		c.Assert(err, check.IsNil, check.Commentf("concurrent pull failed with error: %v", err))
82
+	}
83
+
84
+	// Ensure all tags were pulled successfully
85
+	for _, repo := range repos {
86
+		dockerCmd(c, "inspect", repo)
87
+		out, _ := dockerCmd(c, "run", "--rm", repo)
88
+		if strings.TrimSpace(out) != "/bin/sh -c echo "+repo {
89
+			c.Fatalf("CMD did not contain /bin/sh -c echo %s: %s", repo, out)
90
+		}
91
+	}
92
+}
93
+
94
+// TestConcurrentFailingPull tries a concurrent pull that doesn't succeed.
95
+func (s *DockerRegistrySuite) TestConcurrentFailingPull(c *check.C) {
96
+	repoName := fmt.Sprintf("%v/dockercli/busybox", privateRegistryURL)
97
+
98
+	// Run multiple pulls concurrently
99
+	results := make(chan error)
100
+	numPulls := 3
101
+
102
+	for i := 0; i != numPulls; i++ {
103
+		go func() {
104
+			_, _, err := runCommandWithOutput(exec.Command(dockerBinary, "pull", repoName+":asdfasdf"))
105
+			results <- err
106
+		}()
107
+	}
108
+
109
+	// These checks are separate from the loop above because the check
110
+	// package is not goroutine-safe.
111
+	for i := 0; i != numPulls; i++ {
112
+		err := <-results
113
+		if err == nil {
114
+			c.Fatal("expected pull to fail")
115
+		}
116
+	}
117
+}
118
+
119
+// TestConcurrentPullMultipleTags pulls multiple tags from the same repo
120
+// concurrently.
121
+func (s *DockerRegistrySuite) TestConcurrentPullMultipleTags(c *check.C) {
122
+	repoName := fmt.Sprintf("%v/dockercli/busybox", privateRegistryURL)
123
+
124
+	repos := []string{}
125
+	for _, tag := range []string{"recent", "fresh", "todays"} {
126
+		repo := fmt.Sprintf("%v:%v", repoName, tag)
127
+		_, err := buildImage(repo, fmt.Sprintf(`
128
+		    FROM busybox
129
+		    ENTRYPOINT ["/bin/echo"]
130
+		    ENV FOO foo
131
+		    ENV BAR bar
132
+		    CMD echo %s
133
+		`, repo), true)
134
+		if err != nil {
135
+			c.Fatal(err)
136
+		}
137
+		dockerCmd(c, "push", repo)
138
+		repos = append(repos, repo)
139
+	}
140
+
141
+	// Clear local images store.
142
+	args := append([]string{"rmi"}, repos...)
143
+	dockerCmd(c, args...)
144
+
145
+	// Re-pull individual tags, in parallel
146
+	results := make(chan error)
147
+
148
+	for _, repo := range repos {
149
+		go func(repo string) {
150
+			_, _, err := runCommandWithOutput(exec.Command(dockerBinary, "pull", repo))
151
+			results <- err
152
+		}(repo)
153
+	}
154
+
155
+	// These checks are separate from the loop above because the check
156
+	// package is not goroutine-safe.
157
+	for range repos {
158
+		err := <-results
159
+		c.Assert(err, check.IsNil, check.Commentf("concurrent pull failed with error: %v", err))
160
+	}
161
+
162
+	// Ensure all tags were pulled successfully
163
+	for _, repo := range repos {
164
+		dockerCmd(c, "inspect", repo)
165
+		out, _ := dockerCmd(c, "run", "--rm", repo)
166
+		if strings.TrimSpace(out) != "/bin/sh -c echo "+repo {
167
+			c.Fatalf("CMD did not contain /bin/sh -c echo %s: %s", repo, out)
168
+		}
169
+	}
170
+}
... ...
@@ -27,6 +27,9 @@ type Broadcaster struct {
27 27
 	// isClosed is set to true when Close is called to avoid closing c
28 28
 	// multiple times.
29 29
 	isClosed bool
30
+	// result is the argument passed to the first call of Close, and
31
+	// returned to callers of Wait
32
+	result error
30 33
 }
31 34
 
32 35
 // NewBroadcaster returns a Broadcaster structure
... ...
@@ -134,23 +137,33 @@ func (broadcaster *Broadcaster) Add(w io.Writer) error {
134 134
 	return nil
135 135
 }
136 136
 
137
-// Close signals to all observers that the operation has finished.
138
-func (broadcaster *Broadcaster) Close() {
137
+// CloseWithError signals to all observers that the operation has finished. Its
138
+// argument is a result that should be returned to waiters blocking on Wait.
139
+func (broadcaster *Broadcaster) CloseWithError(result error) {
139 140
 	broadcaster.Lock()
140 141
 	if broadcaster.isClosed {
141 142
 		broadcaster.Unlock()
142 143
 		return
143 144
 	}
144 145
 	broadcaster.isClosed = true
146
+	broadcaster.result = result
145 147
 	close(broadcaster.c)
146 148
 	broadcaster.cond.Broadcast()
147 149
 	broadcaster.Unlock()
148 150
 
149
-	// Don't return from Close until all writers have caught up.
151
+	// Don't return until all writers have caught up.
150 152
 	broadcaster.wg.Wait()
151 153
 }
152 154
 
155
+// Close signals to all observers that the operation has finished. It causes
156
+// all calls to Wait to return nil.
157
+func (broadcaster *Broadcaster) Close() {
158
+	broadcaster.CloseWithError(nil)
159
+}
160
+
153 161
 // Wait blocks until the operation is marked as completed by the Done method.
154
-func (broadcaster *Broadcaster) Wait() {
162
+// It returns the argument that was passed to Close.
163
+func (broadcaster *Broadcaster) Wait() error {
155 164
 	<-broadcaster.c
165
+	return broadcaster.result
156 166
 }