Kaynağa Gözat

Merge pull request #43291 from pete-woods/retry-image-schema-download

distribution: retry downloading schema config on retryable error
Sebastiaan van Stijn 3 yıl önce
ebeveyn
işleme
18e20d3f37

+ 36 - 1
distribution/pull_v2.go

@@ -7,6 +7,7 @@ import (
 	"io"
 	"os"
 	"runtime"
+	"time"
 
 	"github.com/containerd/containerd/log"
 	"github.com/containerd/containerd/platforms"
@@ -856,9 +857,17 @@ func (p *v2Puller) pullManifestList(ctx context.Context, ref reference.Named, mf
 	return id, manifestListDigest, err
 }
 
+const (
+	defaultSchemaPullBackoff     = 250 * time.Millisecond
+	defaultMaxSchemaPullAttempts = 5
+)
+
 func (p *v2Puller) pullSchema2Config(ctx context.Context, dgst digest.Digest) (configJSON []byte, err error) {
 	blobs := p.repo.Blobs(ctx)
-	configJSON, err = blobs.Get(ctx, dgst)
+	err = retry(ctx, defaultMaxSchemaPullAttempts, defaultSchemaPullBackoff, func(ctx context.Context) (err error) {
+		configJSON, err = blobs.Get(ctx, dgst)
+		return err
+	})
 	if err != nil {
 		return nil, err
 	}
@@ -877,6 +886,32 @@ func (p *v2Puller) pullSchema2Config(ctx context.Context, dgst digest.Digest) (c
 	return configJSON, nil
 }
 
+func retry(ctx context.Context, maxAttempts int, sleep time.Duration, f func(ctx context.Context) error) (err error) {
+	attempt := 0
+	for ; attempt < maxAttempts; attempt++ {
+		err = retryOnError(f(ctx))
+		if err == nil {
+			return nil
+		}
+		if xfer.IsDoNotRetryError(err) {
+			break
+		}
+
+		if attempt+1 < maxAttempts {
+			timer := time.NewTimer(sleep)
+			select {
+			case <-ctx.Done():
+				timer.Stop()
+				return ctx.Err()
+			case <-timer.C:
+				logrus.WithError(err).WithField("attempts", attempt+1).Debug("retrying after error")
+				sleep *= 2
+			}
+		}
+	}
+	return errors.Wrapf(err, "download failed after attempts=%d", attempt+1)
+}
+
 // schema2ManifestDigest computes the manifest digest, and, if pulling by
 // digest, ensures that it matches the requested digest.
 func schema2ManifestDigest(ref reference.Named, mfst distribution.Manifest) (digest.Digest, error) {

+ 166 - 0
distribution/pull_v2_test.go

@@ -1,17 +1,26 @@
 package distribution // import "github.com/docker/docker/distribution"
 
 import (
+	"context"
 	"encoding/json"
 	"fmt"
+	"net/http"
+	"net/http/httptest"
+	"net/url"
 	"os"
 	"reflect"
 	"regexp"
 	"runtime"
 	"strings"
+	"sync/atomic"
 	"testing"
 
 	"github.com/docker/distribution/manifest/schema1"
 	"github.com/docker/distribution/reference"
+	"github.com/docker/docker/api/types"
+	registrytypes "github.com/docker/docker/api/types/registry"
+	"github.com/docker/docker/image"
+	"github.com/docker/docker/registry"
 	"github.com/opencontainers/go-digest"
 	specs "github.com/opencontainers/image-spec/specs-go/v1"
 	"gotest.tools/v3/assert"
@@ -205,3 +214,160 @@ func TestFormatPlatform(t *testing.T) {
 		}
 	}
 }
+
+func TestPullSchema2Config(t *testing.T) {
+	ctx := context.Background()
+
+	const imageJSON = `{
+	"architecture": "amd64",
+	"os": "linux",
+	"config": {},
+	"rootfs": {
+		"type": "layers",
+		"diff_ids": []
+	}
+}`
+	expectedDigest := digest.Digest("sha256:66ad98165d38f53ee73868f82bd4eed60556ddfee824810a4062c4f777b20a5b")
+
+	tests := []struct {
+		name           string
+		handler        func(callCount int, w http.ResponseWriter)
+		expectError    string
+		expectAttempts int64
+	}{
+		{
+			name: "success first time",
+			handler: func(callCount int, w http.ResponseWriter) {
+				w.WriteHeader(http.StatusOK)
+				_, _ = w.Write([]byte(imageJSON))
+			},
+			expectAttempts: 1,
+		},
+		{
+			name: "500 status",
+			handler: func(callCount int, w http.ResponseWriter) {
+				if callCount == 1 {
+					w.WriteHeader(http.StatusInternalServerError)
+					return
+				}
+				w.WriteHeader(http.StatusOK)
+				_, _ = w.Write([]byte(imageJSON))
+			},
+			expectAttempts: 2,
+		},
+		{
+			name: "EOF",
+			handler: func(callCount int, w http.ResponseWriter) {
+				if callCount == 1 {
+					panic("intentional panic")
+				}
+				w.WriteHeader(http.StatusOK)
+				_, _ = w.Write([]byte(imageJSON))
+			},
+			expectAttempts: 2,
+		},
+		{
+			name: "unauthorized",
+			handler: func(callCount int, w http.ResponseWriter) {
+				w.WriteHeader(http.StatusUnauthorized)
+			},
+			expectError:    "unauthorized: authentication required",
+			expectAttempts: 1,
+		},
+	}
+
+	for _, tt := range tests {
+		tt := tt
+		t.Run(tt.name, func(t *testing.T) {
+			var callCount int64
+			ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				t.Logf("HTTP %s %s", r.Method, r.URL.Path)
+				defer r.Body.Close()
+				switch {
+				case r.Method == "GET" && r.URL.Path == "/v2":
+					w.WriteHeader(http.StatusOK)
+				case r.Method == "GET" && r.URL.Path == "/v2/docker.io/library/testremotename/blobs/"+expectedDigest.String():
+					tt.handler(int(atomic.AddInt64(&callCount, 1)), w)
+				default:
+					w.WriteHeader(http.StatusNotFound)
+				}
+			}))
+			defer ts.Close()
+
+			p := testNewPuller(t, ts.URL)
+
+			config, err := p.pullSchema2Config(ctx, expectedDigest)
+			if tt.expectError == "" {
+				if err != nil {
+					t.Fatal(err)
+				}
+
+				_, err = image.NewFromJSON(config)
+				if err != nil {
+					t.Fatal(err)
+				}
+			} else {
+				if err == nil {
+					t.Fatalf("expected error to contain %q", tt.expectError)
+				}
+				if !strings.Contains(err.Error(), tt.expectError) {
+					t.Fatalf("expected error=%q to contain %q", err, tt.expectError)
+				}
+			}
+
+			if callCount != tt.expectAttempts {
+				t.Fatalf("got callCount=%d but expected=%d", callCount, tt.expectAttempts)
+			}
+		})
+	}
+}
+
+func testNewPuller(t *testing.T, rawurl string) *v2Puller {
+	t.Helper()
+
+	uri, err := url.Parse(rawurl)
+	if err != nil {
+		t.Fatalf("could not parse url from test server: %v", err)
+	}
+
+	endpoint := registry.APIEndpoint{
+		Mirror:       false,
+		URL:          uri,
+		Version:      2,
+		Official:     false,
+		TrimHostname: false,
+		TLSConfig:    nil,
+	}
+	n, _ := reference.ParseNormalizedNamed("testremotename")
+	repoInfo := &registry.RepositoryInfo{
+		Name: n,
+		Index: &registrytypes.IndexInfo{
+			Name:     "testrepo",
+			Mirrors:  nil,
+			Secure:   false,
+			Official: false,
+		},
+		Official: false,
+	}
+	imagePullConfig := &ImagePullConfig{
+		Config: Config{
+			MetaHeaders: http.Header{},
+			AuthConfig: &types.AuthConfig{
+				RegistryToken: secretRegistryToken,
+			},
+		},
+		Schema2Types: ImageTypes,
+	}
+
+	puller, err := newPuller(endpoint, repoInfo, imagePullConfig, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	p := puller.(*v2Puller)
+
+	p.repo, err = NewV2Repository(context.Background(), p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "pull")
+	if err != nil {
+		t.Fatal(err)
+	}
+	return p
+}

+ 8 - 0
distribution/xfer/transfer.go

@@ -6,6 +6,7 @@ import (
 	"sync"
 
 	"github.com/docker/docker/pkg/progress"
+	"github.com/pkg/errors"
 )
 
 // DoNotRetry is an error wrapper indicating that the error cannot be resolved
@@ -19,6 +20,13 @@ func (e DoNotRetry) Error() string {
 	return e.Err.Error()
 }
 
+// IsDoNotRetryError returns true if the error is caused by DoNotRetry error,
+// and the transfer should not be retried.
+func IsDoNotRetryError(err error) bool {
+	var dnr DoNotRetry
+	return errors.As(err, &dnr)
+}
+
 // watcher is returned by Watch and can be passed to Release to stop watching.
 type watcher struct {
 	// signalChan is used to signal to the watcher goroutine that