Browse Source

resume pulling the layer on disconnect

Docker-DCO-1.1-Signed-off-by: Cristian Staretu <cristian.staretu@gmail.com> (github: unclejack)
unclejack 11 years ago
parent
commit
2a1b7f222a
4 changed files with 132 additions and 6 deletions
  1. 42 3
      registry/registry.go
  2. 2 2
      registry/registry_test.go
  3. 1 1
      server/server.go
  4. 87 0
      utils/resumablerequestreader.go

+ 42 - 3
registry/registry.go

@@ -256,12 +256,43 @@ func (r *Registry) GetRemoteImageJSON(imgID, registry string, token []string) ([
 	return jsonString, imageSize, nil
 }
 
-func (r *Registry) GetRemoteImageLayer(imgID, registry string, token []string) (io.ReadCloser, error) {
-	req, err := r.reqFactory.NewRequest("GET", registry+"images/"+imgID+"/layer", nil)
+func (r *Registry) GetRemoteImageLayer(imgID, registry string, token []string, imgSize int64) (io.ReadCloser, error) {
+	var (
+		retries   = 5
+		headRes   *http.Response
+		hasResume bool = false
+		imageURL       = fmt.Sprintf("%simages/%s/layer", registry, imgID)
+	)
+	headReq, err := r.reqFactory.NewRequest("HEAD", imageURL, nil)
+	if err != nil {
+		return nil, fmt.Errorf("Error while getting from the server: %s\n", err)
+	}
+	setTokenAuth(headReq, token)
+	for i := 1; i <= retries; i++ {
+		headRes, err = r.client.Do(headReq)
+		if err != nil && i == retries {
+			return nil, fmt.Errorf("Eror while making head request: %s\n", err)
+		} else if err != nil {
+			time.Sleep(time.Duration(i) * 5 * time.Second)
+			continue
+		}
+		break
+	}
+
+	if headRes.Header.Get("Accept-Ranges") == "bytes" && imgSize > 0 {
+		hasResume = true
+	}
+
+	req, err := r.reqFactory.NewRequest("GET", imageURL, nil)
 	if err != nil {
 		return nil, fmt.Errorf("Error while getting from the server: %s\n", err)
 	}
 	setTokenAuth(req, token)
+	if hasResume {
+		utils.Debugf("server supports resume")
+		return utils.ResumableRequestReader(r.client, req, 5, imgSize), nil
+	}
+	utils.Debugf("server doesn't support resume")
 	res, err := r.client.Do(req)
 	if err != nil {
 		return nil, err
@@ -725,6 +756,13 @@ type Registry struct {
 	indexEndpoint string
 }
 
+func AddRequiredHeadersToRedirectedRequests(req *http.Request, via []*http.Request) error {
+	if via != nil && via[0] != nil {
+		req.Header = via[0].Header
+	}
+	return nil
+}
+
 func NewRegistry(authConfig *AuthConfig, factory *utils.HTTPRequestFactory, indexEndpoint string) (r *Registry, err error) {
 	httpDial := func(proto string, addr string) (net.Conn, error) {
 		conn, err := net.Dial(proto, addr)
@@ -744,7 +782,8 @@ func NewRegistry(authConfig *AuthConfig, factory *utils.HTTPRequestFactory, inde
 	r = &Registry{
 		authConfig: authConfig,
 		client: &http.Client{
-			Transport: httpTransport,
+			Transport:     httpTransport,
+			CheckRedirect: AddRequiredHeadersToRedirectedRequests,
 		},
 		indexEndpoint: indexEndpoint,
 	}

+ 2 - 2
registry/registry_test.go

@@ -70,7 +70,7 @@ func TestGetRemoteImageJSON(t *testing.T) {
 
 func TestGetRemoteImageLayer(t *testing.T) {
 	r := spawnTestRegistry(t)
-	data, err := r.GetRemoteImageLayer(IMAGE_ID, makeURL("/v1/"), TOKEN)
+	data, err := r.GetRemoteImageLayer(IMAGE_ID, makeURL("/v1/"), TOKEN, 0)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -78,7 +78,7 @@ func TestGetRemoteImageLayer(t *testing.T) {
 		t.Fatal("Expected non-nil data result")
 	}
 
-	_, err = r.GetRemoteImageLayer("abcdef", makeURL("/v1/"), TOKEN)
+	_, err = r.GetRemoteImageLayer("abcdef", makeURL("/v1/"), TOKEN, 0)
 	if err == nil {
 		t.Fatal("Expected image not found error")
 	}

+ 1 - 1
server/server.go

@@ -1137,7 +1137,7 @@ func (srv *Server) pullImage(r *registry.Registry, out io.Writer, imgID, endpoin
 					status = fmt.Sprintf("Pulling fs layer [retries: %d]", j)
 				}
 				out.Write(sf.FormatProgress(utils.TruncateID(id), status, nil))
-				layer, err := r.GetRemoteImageLayer(img.ID, endpoint, token)
+				layer, err := r.GetRemoteImageLayer(img.ID, endpoint, token, int64(imgSize))
 				if uerr, ok := err.(*url.Error); ok {
 					err = uerr.Err
 				}

+ 87 - 0
utils/resumablerequestreader.go

@@ -0,0 +1,87 @@
+package utils
+
+import (
+	"fmt"
+	"io"
+	"net/http"
+	"time"
+)
+
+type resumableRequestReader struct {
+	client          *http.Client
+	request         *http.Request
+	lastRange       int64
+	totalSize       int64
+	currentResponse *http.Response
+	failures        uint32
+	maxFailures     uint32
+}
+
+// ResumableRequestReader makes it possible to resume reading a request's body transparently
+// maxfail is the number of times we retry to make requests again (not resumes)
+// totalsize is the total length of the body; auto detect if not provided
+func ResumableRequestReader(c *http.Client, r *http.Request, maxfail uint32, totalsize int64) io.ReadCloser {
+	return &resumableRequestReader{client: c, request: r, maxFailures: maxfail, totalSize: totalsize}
+}
+
+func (r *resumableRequestReader) Read(p []byte) (n int, err error) {
+	if r.client == nil || r.request == nil {
+		return 0, fmt.Errorf("client and request can't be nil\n")
+	}
+	isFreshRequest := false
+	if r.lastRange != 0 && r.currentResponse == nil {
+		readRange := fmt.Sprintf("bytes=%d-%d", r.lastRange, r.totalSize)
+		r.request.Header.Set("Range", readRange)
+		time.Sleep(5 * time.Second)
+	}
+	if r.currentResponse == nil {
+		r.currentResponse, err = r.client.Do(r.request)
+		isFreshRequest = true
+	}
+	if err != nil && r.failures+1 != r.maxFailures {
+		r.cleanUpResponse()
+		r.failures += 1
+		time.Sleep(5 * time.Duration(r.failures) * time.Second)
+		return 0, nil
+	} else if err != nil {
+		r.cleanUpResponse()
+		return 0, err
+	}
+	if r.currentResponse.StatusCode == 416 && r.lastRange == r.totalSize && r.currentResponse.ContentLength == 0 {
+		r.cleanUpResponse()
+		return 0, io.EOF
+	} else if r.currentResponse.StatusCode != 206 && r.lastRange != 0 && isFreshRequest {
+		r.cleanUpResponse()
+		return 0, fmt.Errorf("the server doesn't support byte ranges")
+	}
+	if r.totalSize == 0 {
+		r.totalSize = r.currentResponse.ContentLength
+	} else if r.totalSize <= 0 {
+		r.cleanUpResponse()
+		return 0, fmt.Errorf("failed to auto detect content length")
+	}
+	n, err = r.currentResponse.Body.Read(p)
+	r.lastRange += int64(n)
+	if err != nil {
+		r.cleanUpResponse()
+	}
+	if err != nil && err != io.EOF {
+		Debugf("encountered error during pull and clearing it before resume: %s", err)
+		err = nil
+	}
+	return n, err
+}
+
+func (r *resumableRequestReader) Close() error {
+	r.cleanUpResponse()
+	r.client = nil
+	r.request = nil
+	return nil
+}
+
+func (r *resumableRequestReader) cleanUpResponse() {
+	if r.currentResponse != nil {
+		r.currentResponse.Body.Close()
+		r.currentResponse = nil
+	}
+}