Selaa lähdekoodia

Move httputils/reasumablerequestreader to the single consumer.

Signed-off-by: Daniel Nephin <dnephin@docker.com>
Daniel Nephin 8 vuotta sitten
vanhempi
commit
65515af075

+ 12 - 12
pkg/httputils/resumablerequestreader.go → registry/resumable/resumablerequestreader.go

@@ -1,4 +1,4 @@
-package httputils
+package resumable
 
 import (
 	"fmt"
@@ -9,7 +9,7 @@ import (
 	"github.com/Sirupsen/logrus"
 )
 
-type resumableRequestReader struct {
+type requestReader struct {
 	client          *http.Client
 	request         *http.Request
 	lastRange       int64
@@ -20,22 +20,22 @@ type resumableRequestReader struct {
 	waitDuration    time.Duration
 }
 
-// ResumableRequestReader makes it possible to resume reading a request's body transparently
+// NewRequestReader makes it possible to resume reading a request's body transparently
 // maxfail is the number of times we retry to make requests again (not resumes)
 // totalsize is the total length of the body; auto detect if not provided
-func ResumableRequestReader(c *http.Client, r *http.Request, maxfail uint32, totalsize int64) io.ReadCloser {
-	return &resumableRequestReader{client: c, request: r, maxFailures: maxfail, totalSize: totalsize, waitDuration: 5 * time.Second}
+func NewRequestReader(c *http.Client, r *http.Request, maxfail uint32, totalsize int64) io.ReadCloser {
+	return &requestReader{client: c, request: r, maxFailures: maxfail, totalSize: totalsize, waitDuration: 5 * time.Second}
 }
 
-// ResumableRequestReaderWithInitialResponse makes it possible to resume
+// NewRequestReaderWithInitialResponse makes it possible to resume
 // reading the body of an already initiated request.
-func ResumableRequestReaderWithInitialResponse(c *http.Client, r *http.Request, maxfail uint32, totalsize int64, initialResponse *http.Response) io.ReadCloser {
-	return &resumableRequestReader{client: c, request: r, maxFailures: maxfail, totalSize: totalsize, currentResponse: initialResponse, waitDuration: 5 * time.Second}
+func NewRequestReaderWithInitialResponse(c *http.Client, r *http.Request, maxfail uint32, totalsize int64, initialResponse *http.Response) io.ReadCloser {
+	return &requestReader{client: c, request: r, maxFailures: maxfail, totalSize: totalsize, currentResponse: initialResponse, waitDuration: 5 * time.Second}
 }
 
-func (r *resumableRequestReader) Read(p []byte) (n int, err error) {
+func (r *requestReader) Read(p []byte) (n int, err error) {
 	if r.client == nil || r.request == nil {
-		return 0, fmt.Errorf("client and request can't be nil\n")
+		return 0, fmt.Errorf("client and request can't be nil")
 	}
 	isFreshRequest := false
 	if r.lastRange != 0 && r.currentResponse == nil {
@@ -81,14 +81,14 @@ func (r *resumableRequestReader) Read(p []byte) (n int, err error) {
 	return n, err
 }
 
-func (r *resumableRequestReader) Close() error {
+func (r *requestReader) Close() error {
 	r.cleanUpResponse()
 	r.client = nil
 	r.request = nil
 	return nil
 }
 
-func (r *resumableRequestReader) cleanUpResponse() {
+func (r *requestReader) cleanUpResponse() {
 	if r.currentResponse != nil {
 		r.currentResponse.Body.Close()
 		r.currentResponse = nil

+ 39 - 93
pkg/httputils/resumablerequestreader_test.go → registry/resumable/resumablerequestreader_test.go

@@ -1,7 +1,9 @@
-package httputils
+package resumable
 
 import (
 	"fmt"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 	"io"
 	"io/ioutil"
 	"net/http"
@@ -21,28 +23,19 @@ func TestResumableRequestHeaderSimpleErrors(t *testing.T) {
 
 	var req *http.Request
 	req, err := http.NewRequest("GET", ts.URL, nil)
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 
-	expectedError := "client and request can't be nil\n"
-	resreq := &resumableRequestReader{}
+	resreq := &requestReader{}
 	_, err = resreq.Read([]byte{})
-	if err == nil || err.Error() != expectedError {
-		t.Fatalf("Expected an error with '%s', got %v.", expectedError, err)
-	}
+	assert.EqualError(t, err, "client and request can't be nil")
 
-	resreq = &resumableRequestReader{
+	resreq = &requestReader{
 		client:    client,
 		request:   req,
 		totalSize: -1,
 	}
-	expectedError = "failed to auto detect content length"
 	_, err = resreq.Read([]byte{})
-	if err == nil || err.Error() != expectedError {
-		t.Fatalf("Expected an error with '%s', got %v.", expectedError, err)
-	}
-
+	assert.EqualError(t, err, "failed to auto detect content length")
 }
 
 // Not too much failures, bails out after some wait
@@ -51,11 +44,9 @@ func TestResumableRequestHeaderNotTooMuchFailures(t *testing.T) {
 
 	var badReq *http.Request
 	badReq, err := http.NewRequest("GET", "I'm not an url", nil)
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 
-	resreq := &resumableRequestReader{
+	resreq := &requestReader{
 		client:       client,
 		request:      badReq,
 		failures:     0,
@@ -63,9 +54,8 @@ func TestResumableRequestHeaderNotTooMuchFailures(t *testing.T) {
 		waitDuration: 10 * time.Millisecond,
 	}
 	read, err := resreq.Read([]byte{})
-	if err != nil || read != 0 {
-		t.Fatalf("Expected no error and no byte read, got err:%v, read:%v.", err, read)
-	}
+	require.NoError(t, err)
+	assert.Equal(t, 0, read)
 }
 
 // Too much failures, returns the error
@@ -74,11 +64,9 @@ func TestResumableRequestHeaderTooMuchFailures(t *testing.T) {
 
 	var badReq *http.Request
 	badReq, err := http.NewRequest("GET", "I'm not an url", nil)
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 
-	resreq := &resumableRequestReader{
+	resreq := &requestReader{
 		client:      client,
 		request:     badReq,
 		failures:    0,
@@ -88,9 +76,8 @@ func TestResumableRequestHeaderTooMuchFailures(t *testing.T) {
 
 	expectedError := `Get I%27m%20not%20an%20url: unsupported protocol scheme ""`
 	read, err := resreq.Read([]byte{})
-	if err == nil || err.Error() != expectedError || read != 0 {
-		t.Fatalf("Expected the error '%s', got err:%v, read:%v.", expectedError, err, read)
-	}
+	assert.EqualError(t, err, expectedError)
+	assert.Equal(t, 0, read)
 }
 
 type errorReaderCloser struct{}
@@ -105,9 +92,7 @@ func (errorReaderCloser) Read(p []byte) (n int, err error) {
 func TestResumableRequestReaderWithReadError(t *testing.T) {
 	var req *http.Request
 	req, err := http.NewRequest("GET", "", nil)
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 
 	client := &http.Client{}
 
@@ -119,7 +104,7 @@ func TestResumableRequestReaderWithReadError(t *testing.T) {
 		Body:          errorReaderCloser{},
 	}
 
-	resreq := &resumableRequestReader{
+	resreq := &requestReader{
 		client:          client,
 		request:         req,
 		currentResponse: response,
@@ -130,21 +115,15 @@ func TestResumableRequestReaderWithReadError(t *testing.T) {
 
 	buf := make([]byte, 1)
 	read, err := resreq.Read(buf)
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 
-	if read != 0 {
-		t.Fatalf("Expected to have read nothing, but read %v", read)
-	}
+	assert.Equal(t, 0, read)
 }
 
 func TestResumableRequestReaderWithEOFWith416Response(t *testing.T) {
 	var req *http.Request
 	req, err := http.NewRequest("GET", "", nil)
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 
 	client := &http.Client{}
 
@@ -156,7 +135,7 @@ func TestResumableRequestReaderWithEOFWith416Response(t *testing.T) {
 		Body:          ioutil.NopCloser(strings.NewReader("")),
 	}
 
-	resreq := &resumableRequestReader{
+	resreq := &requestReader{
 		client:          client,
 		request:         req,
 		currentResponse: response,
@@ -167,9 +146,7 @@ func TestResumableRequestReaderWithEOFWith416Response(t *testing.T) {
 
 	buf := make([]byte, 1)
 	_, err = resreq.Read(buf)
-	if err == nil || err != io.EOF {
-		t.Fatalf("Expected an io.EOF error, got %v", err)
-	}
+	assert.EqualError(t, err, io.EOF.Error())
 }
 
 func TestResumableRequestReaderWithServerDoesntSupportByteRanges(t *testing.T) {
@@ -182,29 +159,23 @@ func TestResumableRequestReaderWithServerDoesntSupportByteRanges(t *testing.T) {
 
 	var req *http.Request
 	req, err := http.NewRequest("GET", ts.URL, nil)
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 
 	client := &http.Client{}
 
-	resreq := &resumableRequestReader{
+	resreq := &requestReader{
 		client:    client,
 		request:   req,
 		lastRange: 1,
 	}
 	defer resreq.Close()
 
-	expectedError := "the server doesn't support byte ranges"
 	buf := make([]byte, 2)
 	_, err = resreq.Read(buf)
-	if err == nil || err.Error() != expectedError {
-		t.Fatalf("Expected an error '%s', got %v", expectedError, err)
-	}
+	assert.EqualError(t, err, "the server doesn't support byte ranges")
 }
 
 func TestResumableRequestReaderWithZeroTotalSize(t *testing.T) {
-
 	srvtxt := "some response text data"
 
 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -214,30 +185,22 @@ func TestResumableRequestReaderWithZeroTotalSize(t *testing.T) {
 
 	var req *http.Request
 	req, err := http.NewRequest("GET", ts.URL, nil)
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 
 	client := &http.Client{}
 	retries := uint32(5)
 
-	resreq := ResumableRequestReader(client, req, retries, 0)
+	resreq := NewRequestReader(client, req, retries, 0)
 	defer resreq.Close()
 
 	data, err := ioutil.ReadAll(resreq)
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 
 	resstr := strings.TrimSuffix(string(data), "\n")
-
-	if resstr != srvtxt {
-		t.Error("resstr != srvtxt")
-	}
+	assert.Equal(t, srvtxt, resstr)
 }
 
 func TestResumableRequestReader(t *testing.T) {
-
 	srvtxt := "some response text data"
 
 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -247,31 +210,23 @@ func TestResumableRequestReader(t *testing.T) {
 
 	var req *http.Request
 	req, err := http.NewRequest("GET", ts.URL, nil)
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 
 	client := &http.Client{}
 	retries := uint32(5)
 	imgSize := int64(len(srvtxt))
 
-	resreq := ResumableRequestReader(client, req, retries, imgSize)
+	resreq := NewRequestReader(client, req, retries, imgSize)
 	defer resreq.Close()
 
 	data, err := ioutil.ReadAll(resreq)
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 
 	resstr := strings.TrimSuffix(string(data), "\n")
-
-	if resstr != srvtxt {
-		t.Error("resstr != srvtxt")
-	}
+	assert.Equal(t, srvtxt, resstr)
 }
 
 func TestResumableRequestReaderWithInitialResponse(t *testing.T) {
-
 	srvtxt := "some response text data"
 
 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -281,30 +236,21 @@ func TestResumableRequestReaderWithInitialResponse(t *testing.T) {
 
 	var req *http.Request
 	req, err := http.NewRequest("GET", ts.URL, nil)
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 
 	client := &http.Client{}
 	retries := uint32(5)
 	imgSize := int64(len(srvtxt))
 
 	res, err := client.Do(req)
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 
-	resreq := ResumableRequestReaderWithInitialResponse(client, req, retries, imgSize, res)
+	resreq := NewRequestReaderWithInitialResponse(client, req, retries, imgSize, res)
 	defer resreq.Close()
 
 	data, err := ioutil.ReadAll(resreq)
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 
 	resstr := strings.TrimSuffix(string(data), "\n")
-
-	if resstr != srvtxt {
-		t.Error("resstr != srvtxt")
-	}
+	assert.Equal(t, srvtxt, resstr)
 }

+ 2 - 1
registry/session.go

@@ -27,6 +27,7 @@ import (
 	"github.com/docker/docker/pkg/ioutils"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/docker/docker/pkg/tarsum"
+	"github.com/docker/docker/registry/resumable"
 )
 
 var (
@@ -313,7 +314,7 @@ func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io
 
 	if res.Header.Get("Accept-Ranges") == "bytes" && imgSize > 0 {
 		logrus.Debug("server supports resume")
-		return httputils.ResumableRequestReaderWithInitialResponse(r.client, req, 5, imgSize, res), nil
+		return resumable.NewRequestReaderWithInitialResponse(r.client, req, 5, imgSize, res), nil
 	}
 	logrus.Debug("server doesn't support resume")
 	return res.Body, nil