Browse code

resume pulling the layer on disconnect

Docker-DCO-1.1-Signed-off-by: Cristian Staretu <cristian.staretu@gmail.com> (github: unclejack)

unclejack authored on 2014/03/26 09:33:17
Showing 4 changed files
... ...
@@ -256,12 +256,43 @@ func (r *Registry) GetRemoteImageJSON(imgID, registry string, token []string) ([
256 256
 	return jsonString, imageSize, nil
257 257
 }
258 258
 
259
-func (r *Registry) GetRemoteImageLayer(imgID, registry string, token []string) (io.ReadCloser, error) {
260
-	req, err := r.reqFactory.NewRequest("GET", registry+"images/"+imgID+"/layer", nil)
259
+func (r *Registry) GetRemoteImageLayer(imgID, registry string, token []string, imgSize int64) (io.ReadCloser, error) {
260
+	var (
261
+		retries   = 5
262
+		headRes   *http.Response
263
+		hasResume bool = false
264
+		imageURL       = fmt.Sprintf("%simages/%s/layer", registry, imgID)
265
+	)
266
+	headReq, err := r.reqFactory.NewRequest("HEAD", imageURL, nil)
267
+	if err != nil {
268
+		return nil, fmt.Errorf("Error while getting from the server: %s\n", err)
269
+	}
270
+	setTokenAuth(headReq, token)
271
+	for i := 1; i <= retries; i++ {
272
+		headRes, err = r.client.Do(headReq)
273
+		if err != nil && i == retries {
274
+			return nil, fmt.Errorf("Eror while making head request: %s\n", err)
275
+		} else if err != nil {
276
+			time.Sleep(time.Duration(i) * 5 * time.Second)
277
+			continue
278
+		}
279
+		break
280
+	}
281
+
282
+	if headRes.Header.Get("Accept-Ranges") == "bytes" && imgSize > 0 {
283
+		hasResume = true
284
+	}
285
+
286
+	req, err := r.reqFactory.NewRequest("GET", imageURL, nil)
261 287
 	if err != nil {
262 288
 		return nil, fmt.Errorf("Error while getting from the server: %s\n", err)
263 289
 	}
264 290
 	setTokenAuth(req, token)
291
+	if hasResume {
292
+		utils.Debugf("server supports resume")
293
+		return utils.ResumableRequestReader(r.client, req, 5, imgSize), nil
294
+	}
295
+	utils.Debugf("server doesn't support resume")
265 296
 	res, err := r.client.Do(req)
266 297
 	if err != nil {
267 298
 		return nil, err
... ...
@@ -725,6 +756,13 @@ type Registry struct {
725 725
 	indexEndpoint string
726 726
 }
727 727
 
728
+func AddRequiredHeadersToRedirectedRequests(req *http.Request, via []*http.Request) error {
729
+	if via != nil && via[0] != nil {
730
+		req.Header = via[0].Header
731
+	}
732
+	return nil
733
+}
734
+
728 735
 func NewRegistry(authConfig *AuthConfig, factory *utils.HTTPRequestFactory, indexEndpoint string) (r *Registry, err error) {
729 736
 	httpDial := func(proto string, addr string) (net.Conn, error) {
730 737
 		conn, err := net.Dial(proto, addr)
... ...
@@ -744,7 +782,8 @@ func NewRegistry(authConfig *AuthConfig, factory *utils.HTTPRequestFactory, inde
744 744
 	r = &Registry{
745 745
 		authConfig: authConfig,
746 746
 		client: &http.Client{
747
-			Transport: httpTransport,
747
+			Transport:     httpTransport,
748
+			CheckRedirect: AddRequiredHeadersToRedirectedRequests,
748 749
 		},
749 750
 		indexEndpoint: indexEndpoint,
750 751
 	}
... ...
@@ -70,7 +70,7 @@ func TestGetRemoteImageJSON(t *testing.T) {
70 70
 
71 71
 func TestGetRemoteImageLayer(t *testing.T) {
72 72
 	r := spawnTestRegistry(t)
73
-	data, err := r.GetRemoteImageLayer(IMAGE_ID, makeURL("/v1/"), TOKEN)
73
+	data, err := r.GetRemoteImageLayer(IMAGE_ID, makeURL("/v1/"), TOKEN, 0)
74 74
 	if err != nil {
75 75
 		t.Fatal(err)
76 76
 	}
... ...
@@ -78,7 +78,7 @@ func TestGetRemoteImageLayer(t *testing.T) {
78 78
 		t.Fatal("Expected non-nil data result")
79 79
 	}
80 80
 
81
-	_, err = r.GetRemoteImageLayer("abcdef", makeURL("/v1/"), TOKEN)
81
+	_, err = r.GetRemoteImageLayer("abcdef", makeURL("/v1/"), TOKEN, 0)
82 82
 	if err == nil {
83 83
 		t.Fatal("Expected image not found error")
84 84
 	}
... ...
@@ -1137,7 +1137,7 @@ func (srv *Server) pullImage(r *registry.Registry, out io.Writer, imgID, endpoin
1137 1137
 					status = fmt.Sprintf("Pulling fs layer [retries: %d]", j)
1138 1138
 				}
1139 1139
 				out.Write(sf.FormatProgress(utils.TruncateID(id), status, nil))
1140
-				layer, err := r.GetRemoteImageLayer(img.ID, endpoint, token)
1140
+				layer, err := r.GetRemoteImageLayer(img.ID, endpoint, token, int64(imgSize))
1141 1141
 				if uerr, ok := err.(*url.Error); ok {
1142 1142
 					err = uerr.Err
1143 1143
 				}
1144 1144
new file mode 100644
... ...
@@ -0,0 +1,87 @@
0
+package utils
1
+
2
+import (
3
+	"fmt"
4
+	"io"
5
+	"net/http"
6
+	"time"
7
+)
8
+
9
+type resumableRequestReader struct {
10
+	client          *http.Client
11
+	request         *http.Request
12
+	lastRange       int64
13
+	totalSize       int64
14
+	currentResponse *http.Response
15
+	failures        uint32
16
+	maxFailures     uint32
17
+}
18
+
19
+// ResumableRequestReader makes it possible to resume reading a request's body transparently
20
+// maxfail is the number of times we retry to make requests again (not resumes)
21
+// totalsize is the total length of the body; auto detect if not provided
22
+func ResumableRequestReader(c *http.Client, r *http.Request, maxfail uint32, totalsize int64) io.ReadCloser {
23
+	return &resumableRequestReader{client: c, request: r, maxFailures: maxfail, totalSize: totalsize}
24
+}
25
+
26
+func (r *resumableRequestReader) Read(p []byte) (n int, err error) {
27
+	if r.client == nil || r.request == nil {
28
+		return 0, fmt.Errorf("client and request can't be nil\n")
29
+	}
30
+	isFreshRequest := false
31
+	if r.lastRange != 0 && r.currentResponse == nil {
32
+		readRange := fmt.Sprintf("bytes=%d-%d", r.lastRange, r.totalSize)
33
+		r.request.Header.Set("Range", readRange)
34
+		time.Sleep(5 * time.Second)
35
+	}
36
+	if r.currentResponse == nil {
37
+		r.currentResponse, err = r.client.Do(r.request)
38
+		isFreshRequest = true
39
+	}
40
+	if err != nil && r.failures+1 != r.maxFailures {
41
+		r.cleanUpResponse()
42
+		r.failures += 1
43
+		time.Sleep(5 * time.Duration(r.failures) * time.Second)
44
+		return 0, nil
45
+	} else if err != nil {
46
+		r.cleanUpResponse()
47
+		return 0, err
48
+	}
49
+	if r.currentResponse.StatusCode == 416 && r.lastRange == r.totalSize && r.currentResponse.ContentLength == 0 {
50
+		r.cleanUpResponse()
51
+		return 0, io.EOF
52
+	} else if r.currentResponse.StatusCode != 206 && r.lastRange != 0 && isFreshRequest {
53
+		r.cleanUpResponse()
54
+		return 0, fmt.Errorf("the server doesn't support byte ranges")
55
+	}
56
+	if r.totalSize == 0 {
57
+		r.totalSize = r.currentResponse.ContentLength
58
+	} else if r.totalSize <= 0 {
59
+		r.cleanUpResponse()
60
+		return 0, fmt.Errorf("failed to auto detect content length")
61
+	}
62
+	n, err = r.currentResponse.Body.Read(p)
63
+	r.lastRange += int64(n)
64
+	if err != nil {
65
+		r.cleanUpResponse()
66
+	}
67
+	if err != nil && err != io.EOF {
68
+		Debugf("encountered error during pull and clearing it before resume: %s", err)
69
+		err = nil
70
+	}
71
+	return n, err
72
+}
73
+
74
+func (r *resumableRequestReader) Close() error {
75
+	r.cleanUpResponse()
76
+	r.client = nil
77
+	r.request = nil
78
+	return nil
79
+}
80
+
81
+func (r *resumableRequestReader) cleanUpResponse() {
82
+	if r.currentResponse != nil {
83
+		r.currentResponse.Body.Close()
84
+		r.currentResponse = nil
85
+	}
86
+}