Procházet zdrojové kódy

Merge pull request #18353 from aaronlehmann/transfer-manager

Improved push and pull with upload manager and download manager
Alexander Morozov před 9 roky
rodič
revize
ac453a310b
41 změnil soubory, kde provedl 2686 přidání a 1129 odebrání
  1. 6 20
      api/client/build.go
  2. 6 12
      api/server/router/local/image.go
  3. 5 11
      builder/dockerfile/internals.go
  4. 70 9
      daemon/daemon.go
  5. 2 4
      daemon/daemonbuilder/builder.go
  6. 3 11
      daemon/import.go
  7. 5 5
      distribution/metadata/v1_id_service.go
  8. 4 4
      distribution/metadata/v1_id_service_test.go
  9. 0 51
      distribution/pool.go
  10. 0 28
      distribution/pool_test.go
  11. 36 21
      distribution/pull.go
  12. 117 228
      distribution/pull_v1.go
  13. 92 205
      distribution/pull_v2.go
  14. 22 14
      distribution/push.go
  15. 20 32
      distribution/push_v1.go
  16. 118 105
      distribution/push_v2.go
  17. 4 4
      distribution/push_v2_test.go
  18. 23 1
      distribution/registry.go
  19. 3 4
      distribution/registry_unit_test.go
  20. 420 0
      distribution/xfer/download.go
  21. 332 0
      distribution/xfer/download_test.go
  22. 343 0
      distribution/xfer/transfer.go
  23. 385 0
      distribution/xfer/transfer_test.go
  24. 159 0
      distribution/xfer/upload.go
  25. 153 0
      distribution/xfer/upload_test.go
  26. 3 0
      docs/reference/api/docker_remote_api.md
  27. 3 0
      docs/reference/api/docker_remote_api_v1.22.md
  28. 3 6
      docs/reference/commandline/build.md
  29. 3 0
      docs/reference/commandline/pull.md
  30. 3 0
      docs/reference/commandline/push.md
  31. 4 10
      integration-cli/docker_cli_pull_test.go
  32. 0 167
      pkg/broadcaster/buffered.go
  33. 71 0
      pkg/ioutils/readers.go
  34. 27 0
      pkg/ioutils/readers_test.go
  35. 63 0
      pkg/progress/progress.go
  36. 59 0
      pkg/progress/progressreader.go
  37. 75 0
      pkg/progress/progressreader_test.go
  38. 0 68
      pkg/progressreader/progressreader.go
  39. 0 94
      pkg/progressreader/progressreader_test.go
  40. 39 0
      pkg/streamformatter/streamformatter.go
  41. 5 15
      registry/session.go

+ 6 - 20
api/client/build.go

@@ -23,7 +23,7 @@ import (
 	"github.com/docker/docker/pkg/httputils"
 	"github.com/docker/docker/pkg/jsonmessage"
 	flag "github.com/docker/docker/pkg/mflag"
-	"github.com/docker/docker/pkg/progressreader"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/pkg/streamformatter"
 	"github.com/docker/docker/pkg/ulimit"
 	"github.com/docker/docker/pkg/units"
@@ -169,16 +169,9 @@ func (cli *DockerCli) CmdBuild(args ...string) error {
 	context = replaceDockerfileTarWrapper(context, newDockerfile, relDockerfile)
 
 	// Setup an upload progress bar
-	// FIXME: ProgressReader shouldn't be this annoying to use
-	sf := streamformatter.NewStreamFormatter()
-	var body io.Reader = progressreader.New(progressreader.Config{
-		In:        context,
-		Out:       cli.out,
-		Formatter: sf,
-		NewLines:  true,
-		ID:        "",
-		Action:    "Sending build context to Docker daemon",
-	})
+	progressOutput := streamformatter.NewStreamFormatter().NewProgressOutput(cli.out, true)
+
+	var body io.Reader = progress.NewProgressReader(context, progressOutput, 0, "", "Sending build context to Docker daemon")
 
 	var memory int64
 	if *flMemoryString != "" {
@@ -447,17 +440,10 @@ func getContextFromURL(out io.Writer, remoteURL, dockerfileName string) (absCont
 		return "", "", fmt.Errorf("unable to download remote context %s: %v", remoteURL, err)
 	}
 	defer response.Body.Close()
+	progressOutput := streamformatter.NewStreamFormatter().NewProgressOutput(out, true)
 
 	// Pass the response body through a progress reader.
-	progReader := &progressreader.Config{
-		In:        response.Body,
-		Out:       out,
-		Formatter: streamformatter.NewStreamFormatter(),
-		Size:      response.ContentLength,
-		NewLines:  true,
-		ID:        "",
-		Action:    fmt.Sprintf("Downloading build context from remote url: %s", remoteURL),
-	}
+	progReader := progress.NewProgressReader(response.Body, progressOutput, response.ContentLength, "", fmt.Sprintf("Downloading build context from remote url: %s", remoteURL))
 
 	return getContextFromReader(progReader, dockerfileName)
 }

+ 6 - 12
api/server/router/local/image.go

@@ -23,7 +23,7 @@ import (
 	"github.com/docker/docker/pkg/archive"
 	"github.com/docker/docker/pkg/chrootarchive"
 	"github.com/docker/docker/pkg/ioutils"
-	"github.com/docker/docker/pkg/progressreader"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/pkg/streamformatter"
 	"github.com/docker/docker/pkg/ulimit"
 	"github.com/docker/docker/runconfig"
@@ -325,7 +325,7 @@ func (s *router) postBuild(ctx context.Context, w http.ResponseWriter, r *http.R
 	sf := streamformatter.NewJSONStreamFormatter()
 	errf := func(err error) error {
 		// Do not write the error in the http output if it's still empty.
-		// This prevents from writing a 200(OK) when there is an interal error.
+		// This prevents from writing a 200(OK) when there is an internal error.
 		if !output.Flushed() {
 			return err
 		}
@@ -401,23 +401,17 @@ func (s *router) postBuild(ctx context.Context, w http.ResponseWriter, r *http.R
 	remoteURL := r.FormValue("remote")
 
 	// Currently, only used if context is from a remote url.
-	// The field `In` is set by DetectContextFromRemoteURL.
 	// Look at code in DetectContextFromRemoteURL for more information.
-	pReader := &progressreader.Config{
-		// TODO: make progressreader streamformatter-agnostic
-		Out:       output,
-		Formatter: sf,
-		Size:      r.ContentLength,
-		NewLines:  true,
-		ID:        "Downloading context",
-		Action:    remoteURL,
+	createProgressReader := func(in io.ReadCloser) io.ReadCloser {
+		progressOutput := sf.NewProgressOutput(output, true)
+		return progress.NewProgressReader(in, progressOutput, r.ContentLength, "Downloading context", remoteURL)
 	}
 
 	var (
 		context        builder.ModifiableContext
 		dockerfileName string
 	)
-	context, dockerfileName, err = daemonbuilder.DetectContextFromRemoteURL(r.Body, remoteURL, pReader)
+	context, dockerfileName, err = daemonbuilder.DetectContextFromRemoteURL(r.Body, remoteURL, createProgressReader)
 	if err != nil {
 		return errf(err)
 	}

+ 5 - 11
builder/dockerfile/internals.go

@@ -29,7 +29,7 @@ import (
 	"github.com/docker/docker/pkg/httputils"
 	"github.com/docker/docker/pkg/ioutils"
 	"github.com/docker/docker/pkg/jsonmessage"
-	"github.com/docker/docker/pkg/progressreader"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/pkg/streamformatter"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/docker/docker/pkg/stringutils"
@@ -264,17 +264,11 @@ func (b *Builder) download(srcURL string) (fi builder.FileInfo, err error) {
 		return
 	}
 
+	stdoutFormatter := b.Stdout.(*streamformatter.StdoutFormatter)
+	progressOutput := stdoutFormatter.StreamFormatter.NewProgressOutput(stdoutFormatter.Writer, true)
+	progressReader := progress.NewProgressReader(resp.Body, progressOutput, resp.ContentLength, "", "Downloading")
 	// Download and dump result to tmp file
-	if _, err = io.Copy(tmpFile, progressreader.New(progressreader.Config{
-		In: resp.Body,
-		// TODO: make progressreader streamformatter agnostic
-		Out:       b.Stdout.(*streamformatter.StdoutFormatter).Writer,
-		Formatter: b.Stdout.(*streamformatter.StdoutFormatter).StreamFormatter,
-		Size:      resp.ContentLength,
-		NewLines:  true,
-		ID:        "",
-		Action:    "Downloading",
-	})); err != nil {
+	if _, err = io.Copy(tmpFile, progressReader); err != nil {
 		tmpFile.Close()
 		return
 	}

+ 70 - 9
daemon/daemon.go

@@ -34,6 +34,7 @@ import (
 	"github.com/docker/docker/daemon/network"
 	"github.com/docker/docker/distribution"
 	dmetadata "github.com/docker/docker/distribution/metadata"
+	"github.com/docker/docker/distribution/xfer"
 	derr "github.com/docker/docker/errors"
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/image/tarexport"
@@ -49,7 +50,9 @@ import (
 	"github.com/docker/docker/pkg/namesgenerator"
 	"github.com/docker/docker/pkg/nat"
 	"github.com/docker/docker/pkg/parsers/filters"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/pkg/signal"
+	"github.com/docker/docker/pkg/streamformatter"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/docker/docker/pkg/stringutils"
 	"github.com/docker/docker/pkg/sysinfo"
@@ -66,6 +69,16 @@ import (
 	lntypes "github.com/docker/libnetwork/types"
 	"github.com/docker/libtrust"
 	"github.com/opencontainers/runc/libcontainer"
+	"golang.org/x/net/context"
+)
+
+const (
+	// maxDownloadConcurrency is the maximum number of downloads that
+	// may take place at a time for each pull.
+	maxDownloadConcurrency = 3
+	// maxUploadConcurrency is the maximum number of uploads that
+	// may take place at a time for each push.
+	maxUploadConcurrency = 5
 )
 
 var (
@@ -126,7 +139,8 @@ type Daemon struct {
 	containers                *contStore
 	execCommands              *exec.Store
 	tagStore                  tag.Store
-	distributionPool          *distribution.Pool
+	downloadManager           *xfer.LayerDownloadManager
+	uploadManager             *xfer.LayerUploadManager
 	distributionMetadataStore dmetadata.Store
 	trustKey                  libtrust.PrivateKey
 	idIndex                   *truncindex.TruncIndex
@@ -738,7 +752,8 @@ func NewDaemon(config *Config, registryService *registry.Service) (daemon *Daemo
 		return nil, err
 	}
 
-	distributionPool := distribution.NewPool()
+	d.downloadManager = xfer.NewLayerDownloadManager(d.layerStore, maxDownloadConcurrency)
+	d.uploadManager = xfer.NewLayerUploadManager(maxUploadConcurrency)
 
 	ifs, err := image.NewFSStoreBackend(filepath.Join(imageRoot, "imagedb"))
 	if err != nil {
@@ -834,7 +849,6 @@ func NewDaemon(config *Config, registryService *registry.Service) (daemon *Daemo
 	d.containers = &contStore{s: make(map[string]*container.Container)}
 	d.execCommands = exec.NewStore()
 	d.tagStore = tagStore
-	d.distributionPool = distributionPool
 	d.distributionMetadataStore = distributionMetadataStore
 	d.trustKey = trustKey
 	d.idIndex = truncindex.NewTruncIndex([]string{})
@@ -1038,23 +1052,53 @@ func (daemon *Daemon) TagImage(newTag reference.Named, imageName string) error {
 	return nil
 }
 
+func writeDistributionProgress(cancelFunc func(), outStream io.Writer, progressChan <-chan progress.Progress) {
+	progressOutput := streamformatter.NewJSONStreamFormatter().NewProgressOutput(outStream, false)
+	operationCancelled := false
+
+	for prog := range progressChan {
+		if err := progressOutput.WriteProgress(prog); err != nil && !operationCancelled {
+			logrus.Errorf("error writing progress to client: %v", err)
+			cancelFunc()
+			operationCancelled = true
+			// Don't return, because we need to continue draining
+			// progressChan until it's closed to avoid a deadlock.
+		}
+	}
+}
+
 // PullImage initiates a pull operation. image is the repository name to pull, and
 // tag may be either empty, or indicate a specific tag to pull.
 func (daemon *Daemon) PullImage(ref reference.Named, metaHeaders map[string][]string, authConfig *cliconfig.AuthConfig, outStream io.Writer) error {
+	// Include a buffer so that slow client connections don't affect
+	// transfer performance.
+	progressChan := make(chan progress.Progress, 100)
+
+	writesDone := make(chan struct{})
+
+	ctx, cancelFunc := context.WithCancel(context.Background())
+
+	go func() {
+		writeDistributionProgress(cancelFunc, outStream, progressChan)
+		close(writesDone)
+	}()
+
 	imagePullConfig := &distribution.ImagePullConfig{
 		MetaHeaders:     metaHeaders,
 		AuthConfig:      authConfig,
-		OutStream:       outStream,
+		ProgressOutput:  progress.ChanOutput(progressChan),
 		RegistryService: daemon.RegistryService,
 		EventsService:   daemon.EventsService,
 		MetadataStore:   daemon.distributionMetadataStore,
-		LayerStore:      daemon.layerStore,
 		ImageStore:      daemon.imageStore,
 		TagStore:        daemon.tagStore,
-		Pool:            daemon.distributionPool,
+		DownloadManager: daemon.downloadManager,
 	}
 
-	return distribution.Pull(ref, imagePullConfig)
+	err := distribution.Pull(ctx, ref, imagePullConfig)
+	close(progressChan)
+	<-writesDone
+	return err
 }
 
 // ExportImage exports a list of images to the given output stream. The
@@ -1069,10 +1113,23 @@ func (daemon *Daemon) ExportImage(names []string, outStream io.Writer) error {
 
 // PushImage initiates a push operation on the repository named localName.
 func (daemon *Daemon) PushImage(ref reference.Named, metaHeaders map[string][]string, authConfig *cliconfig.AuthConfig, outStream io.Writer) error {
+	// Include a buffer so that slow client connections don't affect
+	// transfer performance.
+	progressChan := make(chan progress.Progress, 100)
+
+	writesDone := make(chan struct{})
+
+	ctx, cancelFunc := context.WithCancel(context.Background())
+
+	go func() {
+		writeDistributionProgress(cancelFunc, outStream, progressChan)
+		close(writesDone)
+	}()
+
 	imagePushConfig := &distribution.ImagePushConfig{
 		MetaHeaders:     metaHeaders,
 		AuthConfig:      authConfig,
-		OutStream:       outStream,
+		ProgressOutput:  progress.ChanOutput(progressChan),
 		RegistryService: daemon.RegistryService,
 		EventsService:   daemon.EventsService,
 		MetadataStore:   daemon.distributionMetadataStore,
@@ -1080,9 +1137,13 @@ func (daemon *Daemon) PushImage(ref reference.Named, metaHeaders map[string][]st
 		ImageStore:      daemon.imageStore,
 		TagStore:        daemon.tagStore,
 		TrustKey:        daemon.trustKey,
+		UploadManager:   daemon.uploadManager,
 	}
 
-	return distribution.Push(ref, imagePushConfig)
+	err := distribution.Push(ctx, ref, imagePushConfig)
+	close(progressChan)
+	<-writesDone
+	return err
 }
 
 // LookupImage looks up an image by name and returns it as an ImageInspect

+ 2 - 4
daemon/daemonbuilder/builder.go

@@ -21,7 +21,6 @@ import (
 	"github.com/docker/docker/pkg/httputils"
 	"github.com/docker/docker/pkg/idtools"
 	"github.com/docker/docker/pkg/ioutils"
-	"github.com/docker/docker/pkg/progressreader"
 	"github.com/docker/docker/pkg/urlutil"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/runconfig"
@@ -239,7 +238,7 @@ func (d Docker) Start(c *container.Container) error {
 // DetectContextFromRemoteURL returns a context and in certain cases the name of the dockerfile to be used
 // irrespective of user input.
 // progressReader is only used if remoteURL is actually a URL (not empty, and not a Git endpoint).
-func DetectContextFromRemoteURL(r io.ReadCloser, remoteURL string, progressReader *progressreader.Config) (context builder.ModifiableContext, dockerfileName string, err error) {
+func DetectContextFromRemoteURL(r io.ReadCloser, remoteURL string, createProgressReader func(in io.ReadCloser) io.ReadCloser) (context builder.ModifiableContext, dockerfileName string, err error) {
 	switch {
 	case remoteURL == "":
 		context, err = builder.MakeTarSumContext(r)
@@ -262,8 +261,7 @@ func DetectContextFromRemoteURL(r io.ReadCloser, remoteURL string, progressReade
 			},
 			// fallback handler (tar context)
 			"": func(rc io.ReadCloser) (io.ReadCloser, error) {
-				progressReader.In = rc
-				return progressReader, nil
+				return createProgressReader(rc), nil
 			},
 		})
 	default:

+ 3 - 11
daemon/import.go

@@ -13,7 +13,7 @@ import (
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/layer"
 	"github.com/docker/docker/pkg/httputils"
-	"github.com/docker/docker/pkg/progressreader"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/pkg/streamformatter"
 	"github.com/docker/docker/runconfig"
 )
@@ -47,16 +47,8 @@ func (daemon *Daemon) ImportImage(src string, newRef reference.Named, msg string
 		if err != nil {
 			return err
 		}
-		progressReader := progressreader.New(progressreader.Config{
-			In:        resp.Body,
-			Out:       outStream,
-			Formatter: sf,
-			Size:      resp.ContentLength,
-			NewLines:  true,
-			ID:        "",
-			Action:    "Importing",
-		})
-		archive = progressReader
+		progressOutput := sf.NewProgressOutput(outStream, true)
+		archive = progress.NewProgressReader(resp.Body, progressOutput, resp.ContentLength, "", "Importing")
 	}
 
 	defer archive.Close()

+ 5 - 5
distribution/metadata/v1_id_service.go

@@ -23,20 +23,20 @@ func (idserv *V1IDService) namespace() string {
 }
 
 // Get finds a layer by its V1 ID.
-func (idserv *V1IDService) Get(v1ID, registry string) (layer.ChainID, error) {
+func (idserv *V1IDService) Get(v1ID, registry string) (layer.DiffID, error) {
 	if err := v1.ValidateID(v1ID); err != nil {
-		return layer.ChainID(""), err
+		return layer.DiffID(""), err
 	}
 
 	idBytes, err := idserv.store.Get(idserv.namespace(), registry+","+v1ID)
 	if err != nil {
-		return layer.ChainID(""), err
+		return layer.DiffID(""), err
 	}
-	return layer.ChainID(idBytes), nil
+	return layer.DiffID(idBytes), nil
 }
 
 // Set associates an image with a V1 ID.
-func (idserv *V1IDService) Set(v1ID, registry string, id layer.ChainID) error {
+func (idserv *V1IDService) Set(v1ID, registry string, id layer.DiffID) error {
 	if err := v1.ValidateID(v1ID); err != nil {
 		return err
 	}

+ 4 - 4
distribution/metadata/v1_id_service_test.go

@@ -24,22 +24,22 @@ func TestV1IDService(t *testing.T) {
 	testVectors := []struct {
 		registry string
 		v1ID     string
-		layerID  layer.ChainID
+		layerID  layer.DiffID
 	}{
 		{
 			registry: "registry1",
 			v1ID:     "f0cd5ca10b07f35512fc2f1cbf9a6cefbdb5cba70ac6b0c9e5988f4497f71937",
-			layerID:  layer.ChainID("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"),
+			layerID:  layer.DiffID("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"),
 		},
 		{
 			registry: "registry2",
 			v1ID:     "9e3447ca24cb96d86ebd5960cb34d1299b07e0a0e03801d90b9969a2c187dd6e",
-			layerID:  layer.ChainID("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa"),
+			layerID:  layer.DiffID("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa"),
 		},
 		{
 			registry: "registry1",
 			v1ID:     "9e3447ca24cb96d86ebd5960cb34d1299b07e0a0e03801d90b9969a2c187dd6e",
-			layerID:  layer.ChainID("sha256:03f4658f8b782e12230c1783426bd3bacce651ce582a4ffb6fbbfa2079428ecb"),
+			layerID:  layer.DiffID("sha256:03f4658f8b782e12230c1783426bd3bacce651ce582a4ffb6fbbfa2079428ecb"),
 		},
 	}
 

+ 0 - 51
distribution/pool.go

@@ -1,51 +0,0 @@
-package distribution
-
-import (
-	"sync"
-
-	"github.com/docker/docker/pkg/broadcaster"
-)
-
-// A Pool manages concurrent pulls. It deduplicates in-progress downloads.
-type Pool struct {
-	sync.Mutex
-	pullingPool map[string]*broadcaster.Buffered
-}
-
-// NewPool creates a new Pool.
-func NewPool() *Pool {
-	return &Pool{
-		pullingPool: make(map[string]*broadcaster.Buffered),
-	}
-}
-
-// add checks if a pull is already running, and returns (broadcaster, true)
-// if a running operation is found. Otherwise, it creates a new one and returns
-// (broadcaster, false).
-func (pool *Pool) add(key string) (*broadcaster.Buffered, bool) {
-	pool.Lock()
-	defer pool.Unlock()
-
-	if p, exists := pool.pullingPool[key]; exists {
-		return p, true
-	}
-
-	broadcaster := broadcaster.NewBuffered()
-	pool.pullingPool[key] = broadcaster
-
-	return broadcaster, false
-}
-
-func (pool *Pool) removeWithError(key string, broadcasterResult error) error {
-	pool.Lock()
-	defer pool.Unlock()
-	if broadcaster, exists := pool.pullingPool[key]; exists {
-		broadcaster.CloseWithError(broadcasterResult)
-		delete(pool.pullingPool, key)
-	}
-	return nil
-}
-
-func (pool *Pool) remove(key string) error {
-	return pool.removeWithError(key, nil)
-}

+ 0 - 28
distribution/pool_test.go

@@ -1,28 +0,0 @@
-package distribution
-
-import (
-	"testing"
-)
-
-func TestPools(t *testing.T) {
-	p := NewPool()
-
-	if _, found := p.add("test1"); found {
-		t.Fatal("Expected pull test1 not to be in progress")
-	}
-	if _, found := p.add("test2"); found {
-		t.Fatal("Expected pull test2 not to be in progress")
-	}
-	if _, found := p.add("test1"); !found {
-		t.Fatalf("Expected pull test1 to be in progress`")
-	}
-	if err := p.remove("test2"); err != nil {
-		t.Fatal(err)
-	}
-	if err := p.remove("test2"); err != nil {
-		t.Fatal(err)
-	}
-	if err := p.remove("test1"); err != nil {
-		t.Fatal(err)
-	}
-}

+ 36 - 21
distribution/pull.go

@@ -2,7 +2,7 @@ package distribution
 
 import (
 	"fmt"
-	"io"
+	"os"
 	"strings"
 
 	"github.com/Sirupsen/logrus"
@@ -10,11 +10,12 @@ import (
 	"github.com/docker/docker/cliconfig"
 	"github.com/docker/docker/daemon/events"
 	"github.com/docker/docker/distribution/metadata"
+	"github.com/docker/docker/distribution/xfer"
 	"github.com/docker/docker/image"
-	"github.com/docker/docker/layer"
-	"github.com/docker/docker/pkg/streamformatter"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/tag"
+	"golang.org/x/net/context"
 )
 
 // ImagePullConfig stores pull configuration.
@@ -25,9 +26,9 @@ type ImagePullConfig struct {
 	// AuthConfig holds authentication credentials for authenticating with
 	// the registry.
 	AuthConfig *cliconfig.AuthConfig
-	// OutStream is the output writer for showing the status of the pull
+	// ProgressOutput is the interface for showing the status of the pull
 	// operation.
-	OutStream io.Writer
+	ProgressOutput progress.Output
 	// RegistryService is the registry service to use for TLS configuration
 	// and endpoint lookup.
 	RegistryService *registry.Service
@@ -36,14 +37,12 @@ type ImagePullConfig struct {
 	// MetadataStore is the storage backend for distribution-specific
 	// metadata.
 	MetadataStore metadata.Store
-	// LayerStore manages layers.
-	LayerStore layer.Store
 	// ImageStore manages images.
 	ImageStore image.Store
 	// TagStore manages tags.
 	TagStore tag.Store
-	// Pool manages concurrent pulls.
-	Pool *Pool
+	// DownloadManager manages concurrent pulls.
+	DownloadManager *xfer.LayerDownloadManager
 }
 
 // Puller is an interface that abstracts pulling for different API versions.
@@ -51,7 +50,7 @@ type Puller interface {
 	// Pull tries to pull the image referenced by `tag`
 	// Pull returns an error if any, as well as a boolean that determines whether to retry Pull on the next configured endpoint.
 	//
-	Pull(ref reference.Named) (fallback bool, err error)
+	Pull(ctx context.Context, ref reference.Named) (fallback bool, err error)
 }
 
 // newPuller returns a Puller interface that will pull from either a v1 or v2
@@ -59,14 +58,13 @@ type Puller interface {
 // whether a v1 or v2 puller will be created. The other parameters are passed
 // through to the underlying puller implementation for use during the actual
 // pull operation.
-func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePullConfig *ImagePullConfig, sf *streamformatter.StreamFormatter) (Puller, error) {
+func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePullConfig *ImagePullConfig) (Puller, error) {
 	switch endpoint.Version {
 	case registry.APIVersion2:
 		return &v2Puller{
 			blobSumService: metadata.NewBlobSumService(imagePullConfig.MetadataStore),
 			endpoint:       endpoint,
 			config:         imagePullConfig,
-			sf:             sf,
 			repoInfo:       repoInfo,
 		}, nil
 	case registry.APIVersion1:
@@ -74,7 +72,6 @@ func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo,
 			v1IDService: metadata.NewV1IDService(imagePullConfig.MetadataStore),
 			endpoint:    endpoint,
 			config:      imagePullConfig,
-			sf:          sf,
 			repoInfo:    repoInfo,
 		}, nil
 	}
@@ -83,9 +80,7 @@ func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo,
 
 // Pull initiates a pull operation. image is the repository name to pull, and
 // tag may be either empty, or indicate a specific tag to pull.
-func Pull(ref reference.Named, imagePullConfig *ImagePullConfig) error {
-	var sf = streamformatter.NewJSONStreamFormatter()
-
+func Pull(ctx context.Context, ref reference.Named, imagePullConfig *ImagePullConfig) error {
 	// Resolve the Repository name from fqn to RepositoryInfo
 	repoInfo, err := imagePullConfig.RegistryService.ResolveRepository(ref)
 	if err != nil {
@@ -120,12 +115,19 @@ func Pull(ref reference.Named, imagePullConfig *ImagePullConfig) error {
 	for _, endpoint := range endpoints {
 		logrus.Debugf("Trying to pull %s from %s %s", repoInfo.LocalName, endpoint.URL, endpoint.Version)
 
-		puller, err := newPuller(endpoint, repoInfo, imagePullConfig, sf)
+		puller, err := newPuller(endpoint, repoInfo, imagePullConfig)
 		if err != nil {
 			errors = append(errors, err.Error())
 			continue
 		}
-		if fallback, err := puller.Pull(ref); err != nil {
+		if fallback, err := puller.Pull(ctx, ref); err != nil {
+			// Was this pull cancelled? If so, don't try to fall
+			// back.
+			select {
+			case <-ctx.Done():
+				fallback = false
+			default:
+			}
 			if fallback {
 				if _, ok := err.(registry.ErrNoSupport); !ok {
 					// Because we found an error that's not ErrNoSupport, discard all subsequent ErrNoSupport errors.
@@ -165,11 +167,11 @@ func Pull(ref reference.Named, imagePullConfig *ImagePullConfig) error {
 // status message indicates that a newer image was downloaded. Otherwise, it
 // indicates that the image is up to date. requestedTag is the tag the message
 // will refer to.
-func writeStatus(requestedTag string, out io.Writer, sf *streamformatter.StreamFormatter, layersDownloaded bool) {
+func writeStatus(requestedTag string, out progress.Output, layersDownloaded bool) {
 	if layersDownloaded {
-		out.Write(sf.FormatStatus("", "Status: Downloaded newer image for %s", requestedTag))
+		progress.Message(out, "", "Status: Downloaded newer image for "+requestedTag)
 	} else {
-		out.Write(sf.FormatStatus("", "Status: Image is up to date for %s", requestedTag))
+		progress.Message(out, "", "Status: Image is up to date for "+requestedTag)
 	}
 }
 
@@ -183,3 +185,16 @@ func validateRepoName(name string) error {
 	}
 	return nil
 }
+
+// tmpFileClose creates a closer function for a temporary file that closes the file
+// and also deletes it.
+func tmpFileCloser(tmpFile *os.File) func() error {
+	return func() error {
+		tmpFile.Close()
+		if err := os.RemoveAll(tmpFile.Name()); err != nil {
+			logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name())
+		}
+
+		return nil
+	}
+}

+ 117 - 228
distribution/pull_v1.go

@@ -1,43 +1,42 @@
 package distribution
 
 import (
-	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
+	"io/ioutil"
 	"net"
 	"net/url"
 	"strings"
-	"sync"
 	"time"
 
 	"github.com/Sirupsen/logrus"
 	"github.com/docker/distribution/reference"
 	"github.com/docker/distribution/registry/client/transport"
 	"github.com/docker/docker/distribution/metadata"
+	"github.com/docker/docker/distribution/xfer"
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/image/v1"
 	"github.com/docker/docker/layer"
-	"github.com/docker/docker/pkg/archive"
-	"github.com/docker/docker/pkg/progressreader"
-	"github.com/docker/docker/pkg/streamformatter"
+	"github.com/docker/docker/pkg/ioutils"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/docker/docker/registry"
+	"golang.org/x/net/context"
 )
 
 type v1Puller struct {
 	v1IDService *metadata.V1IDService
 	endpoint    registry.APIEndpoint
 	config      *ImagePullConfig
-	sf          *streamformatter.StreamFormatter
 	repoInfo    *registry.RepositoryInfo
 	session     *registry.Session
 }
 
-func (p *v1Puller) Pull(ref reference.Named) (fallback bool, err error) {
+func (p *v1Puller) Pull(ctx context.Context, ref reference.Named) (fallback bool, err error) {
 	if _, isDigested := ref.(reference.Digested); isDigested {
 		// Allowing fallback, because HTTPS v1 is before HTTP v2
-		return true, registry.ErrNoSupport{errors.New("Cannot pull by digest with v1 registry")}
+		return true, registry.ErrNoSupport{Err: errors.New("Cannot pull by digest with v1 registry")}
 	}
 
 	tlsConfig, err := p.config.RegistryService.TLSConfig(p.repoInfo.Index.Name)
@@ -62,19 +61,17 @@ func (p *v1Puller) Pull(ref reference.Named) (fallback bool, err error) {
 		logrus.Debugf("Fallback from error: %s", err)
 		return true, err
 	}
-	if err := p.pullRepository(ref); err != nil {
+	if err := p.pullRepository(ctx, ref); err != nil {
 		// TODO(dmcgowan): Check if should fallback
 		return false, err
 	}
-	out := p.config.OutStream
-	out.Write(p.sf.FormatStatus("", "%s: this image was pulled from a legacy registry.  Important: This registry version will not be supported in future versions of docker.", p.repoInfo.CanonicalName.Name()))
+	progress.Message(p.config.ProgressOutput, "", p.repoInfo.CanonicalName.Name()+": this image was pulled from a legacy registry.  Important: This registry version will not be supported in future versions of docker.")
 
 	return false, nil
 }
 
-func (p *v1Puller) pullRepository(ref reference.Named) error {
-	out := p.config.OutStream
-	out.Write(p.sf.FormatStatus("", "Pulling repository %s", p.repoInfo.CanonicalName.Name()))
+func (p *v1Puller) pullRepository(ctx context.Context, ref reference.Named) error {
+	progress.Message(p.config.ProgressOutput, "", "Pulling repository "+p.repoInfo.CanonicalName.Name())
 
 	repoData, err := p.session.GetRepositoryData(p.repoInfo.RemoteName)
 	if err != nil {
@@ -112,46 +109,18 @@ func (p *v1Puller) pullRepository(ref reference.Named) error {
 		}
 	}
 
-	errors := make(chan error)
-	layerDownloaded := make(chan struct{})
-
 	layersDownloaded := false
-	var wg sync.WaitGroup
 	for _, imgData := range repoData.ImgList {
 		if isTagged && imgData.Tag != tagged.Tag() {
 			continue
 		}
 
-		wg.Add(1)
-		go func(img *registry.ImgData) {
-			p.downloadImage(out, repoData, img, layerDownloaded, errors)
-			wg.Done()
-		}(imgData)
-	}
-
-	go func() {
-		wg.Wait()
-		close(errors)
-	}()
-
-	var lastError error
-selectLoop:
-	for {
-		select {
-		case err, ok := <-errors:
-			if !ok {
-				break selectLoop
-			}
-			lastError = err
-		case <-layerDownloaded:
-			layersDownloaded = true
+		err := p.downloadImage(ctx, repoData, imgData, &layersDownloaded)
+		if err != nil {
+			return err
 		}
 	}
 
-	if lastError != nil {
-		return lastError
-	}
-
 	localNameRef := p.repoInfo.LocalName
 	if isTagged {
 		localNameRef, err = reference.WithTag(localNameRef, tagged.Tag())
@@ -159,194 +128,143 @@ selectLoop:
 			localNameRef = p.repoInfo.LocalName
 		}
 	}
-	writeStatus(localNameRef.String(), out, p.sf, layersDownloaded)
+	writeStatus(localNameRef.String(), p.config.ProgressOutput, layersDownloaded)
 	return nil
 }
 
-func (p *v1Puller) downloadImage(out io.Writer, repoData *registry.RepositoryData, img *registry.ImgData, layerDownloaded chan struct{}, errors chan error) {
+func (p *v1Puller) downloadImage(ctx context.Context, repoData *registry.RepositoryData, img *registry.ImgData, layersDownloaded *bool) error {
 	if img.Tag == "" {
 		logrus.Debugf("Image (id: %s) present in this repository but untagged, skipping", img.ID)
-		return
+		return nil
 	}
 
 	localNameRef, err := reference.WithTag(p.repoInfo.LocalName, img.Tag)
 	if err != nil {
 		retErr := fmt.Errorf("Image (id: %s) has invalid tag: %s", img.ID, img.Tag)
 		logrus.Debug(retErr.Error())
-		errors <- retErr
+		return retErr
 	}
 
 	if err := v1.ValidateID(img.ID); err != nil {
-		errors <- err
-		return
+		return err
 	}
 
-	out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName.Name()), nil))
+	progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName.Name())
 	success := false
 	var lastErr error
-	var isDownloaded bool
 	for _, ep := range p.repoInfo.Index.Mirrors {
 		ep += "v1/"
-		out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep), nil))
-		if isDownloaded, err = p.pullImage(out, img.ID, ep, localNameRef); err != nil {
+		progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep))
+		if err = p.pullImage(ctx, img.ID, ep, localNameRef, layersDownloaded); err != nil {
 			// Don't report errors when pulling from mirrors.
 			logrus.Debugf("Error pulling image (%s) from %s, mirror: %s, %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep, err)
 			continue
 		}
-		if isDownloaded {
-			layerDownloaded <- struct{}{}
-		}
 		success = true
 		break
 	}
 	if !success {
 		for _, ep := range repoData.Endpoints {
-			out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep), nil))
-			if isDownloaded, err = p.pullImage(out, img.ID, ep, localNameRef); err != nil {
+			progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep)
+			if err = p.pullImage(ctx, img.ID, ep, localNameRef, layersDownloaded); err != nil {
 				// It's not ideal that only the last error is returned, it would be better to concatenate the errors.
 				// As the error is also given to the output stream the user will see the error.
 				lastErr = err
-				out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep, err), nil))
+				progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep, err)
 				continue
 			}
-			if isDownloaded {
-				layerDownloaded <- struct{}{}
-			}
 			success = true
 			break
 		}
 	}
 	if !success {
 		err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName.Name(), lastErr)
-		out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil))
-		errors <- err
-		return
+		progress.Update(p.config.ProgressOutput, stringid.TruncateID(img.ID), err.Error())
+		return err
 	}
-	out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil))
+	progress.Update(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Download complete")
+	return nil
 }
 
-func (p *v1Puller) pullImage(out io.Writer, v1ID, endpoint string, localNameRef reference.Named) (layersDownloaded bool, err error) {
+func (p *v1Puller) pullImage(ctx context.Context, v1ID, endpoint string, localNameRef reference.Named, layersDownloaded *bool) (err error) {
 	var history []string
 	history, err = p.session.GetRemoteHistory(v1ID, endpoint)
 	if err != nil {
-		return false, err
+		return err
 	}
 	if len(history) < 1 {
-		return false, fmt.Errorf("empty history for image %s", v1ID)
+		return fmt.Errorf("empty history for image %s", v1ID)
 	}
-	out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Pulling dependent layers", nil))
-	// FIXME: Try to stream the images?
-	// FIXME: Launch the getRemoteImage() in goroutines
+	progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Pulling dependent layers")
 
 	var (
-		referencedLayers []layer.Layer
-		parentID         layer.ChainID
-		newHistory       []image.History
-		img              *image.V1Image
-		imgJSON          []byte
-		imgSize          int64
+		descriptors []xfer.DownloadDescriptor
+		newHistory  []image.History
+		imgJSON     []byte
+		imgSize     int64
 	)
 
-	defer func() {
-		for _, l := range referencedLayers {
-			layer.ReleaseAndLog(p.config.LayerStore, l)
-		}
-	}()
-
-	layersDownloaded = false
-
-	// Iterate over layers from top-most to bottom-most, checking if any
-	// already exist on disk.
-	var i int
-	for i = 0; i != len(history); i++ {
-		v1LayerID := history[i]
-		// Do we have a mapping for this particular v1 ID on this
-		// registry?
-		if layerID, err := p.v1IDService.Get(v1LayerID, p.repoInfo.Index.Name); err == nil {
-			// Does the layer actually exist
-			if l, err := p.config.LayerStore.Get(layerID); err == nil {
-				for j := i; j >= 0; j-- {
-					logrus.Debugf("Layer already exists: %s", history[j])
-					out.Write(p.sf.FormatProgress(stringid.TruncateID(history[j]), "Already exists", nil))
-				}
-				referencedLayers = append(referencedLayers, l)
-				parentID = layerID
-				break
-			}
-		}
-	}
-
-	needsDownload := i
-
 	// Iterate over layers, in order from bottom-most to top-most. Download
-	// config for all layers, and download actual layer data if needed.
-	for i = len(history) - 1; i >= 0; i-- {
+	// config for all layers and create descriptors.
+	for i := len(history) - 1; i >= 0; i-- {
 		v1LayerID := history[i]
-		imgJSON, imgSize, err = p.downloadLayerConfig(out, v1LayerID, endpoint)
+		imgJSON, imgSize, err = p.downloadLayerConfig(v1LayerID, endpoint)
 		if err != nil {
-			return layersDownloaded, err
-		}
-
-		img = &image.V1Image{}
-		if err := json.Unmarshal(imgJSON, img); err != nil {
-			return layersDownloaded, err
-		}
-
-		if i < needsDownload {
-			l, err := p.downloadLayer(out, v1LayerID, endpoint, parentID, imgSize, &layersDownloaded)
-
-			// Note: This needs to be done even in the error case to avoid
-			// stale references to the layer.
-			if l != nil {
-				referencedLayers = append(referencedLayers, l)
-			}
-			if err != nil {
-				return layersDownloaded, err
-			}
-
-			parentID = l.ChainID()
+			return err
 		}
 
 		// Create a new-style config from the legacy configs
 		h, err := v1.HistoryFromConfig(imgJSON, false)
 		if err != nil {
-			return layersDownloaded, err
+			return err
 		}
 		newHistory = append(newHistory, h)
+
+		layerDescriptor := &v1LayerDescriptor{
+			v1LayerID:        v1LayerID,
+			indexName:        p.repoInfo.Index.Name,
+			endpoint:         endpoint,
+			v1IDService:      p.v1IDService,
+			layersDownloaded: layersDownloaded,
+			layerSize:        imgSize,
+			session:          p.session,
+		}
+
+		descriptors = append(descriptors, layerDescriptor)
 	}
 
 	rootFS := image.NewRootFS()
-	l := referencedLayers[len(referencedLayers)-1]
-	for l != nil {
-		rootFS.DiffIDs = append([]layer.DiffID{l.DiffID()}, rootFS.DiffIDs...)
-		l = l.Parent()
+	resultRootFS, release, err := p.config.DownloadManager.Download(ctx, *rootFS, descriptors, p.config.ProgressOutput)
+	if err != nil {
+		return err
 	}
+	defer release()
 
-	config, err := v1.MakeConfigFromV1Config(imgJSON, rootFS, newHistory)
+	config, err := v1.MakeConfigFromV1Config(imgJSON, &resultRootFS, newHistory)
 	if err != nil {
-		return layersDownloaded, err
+		return err
 	}
 
 	imageID, err := p.config.ImageStore.Create(config)
 	if err != nil {
-		return layersDownloaded, err
+		return err
 	}
 
 	if err := p.config.TagStore.AddTag(localNameRef, imageID, true); err != nil {
-		return layersDownloaded, err
+		return err
 	}
 
-	return layersDownloaded, nil
+	return nil
 }
 
-func (p *v1Puller) downloadLayerConfig(out io.Writer, v1LayerID, endpoint string) (imgJSON []byte, imgSize int64, err error) {
-	out.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Pulling metadata", nil))
+func (p *v1Puller) downloadLayerConfig(v1LayerID, endpoint string) (imgJSON []byte, imgSize int64, err error) {
+	progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1LayerID), "Pulling metadata")
 
 	retries := 5
 	for j := 1; j <= retries; j++ {
 		imgJSON, imgSize, err := p.session.GetRemoteImageJSON(v1LayerID, endpoint)
 		if err != nil && j == retries {
-			out.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Error pulling layer metadata", nil))
+			progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1LayerID), "Error pulling layer metadata")
 			return nil, 0, err
 		} else if err != nil {
 			time.Sleep(time.Duration(j) * 500 * time.Millisecond)
@@ -360,95 +278,66 @@ func (p *v1Puller) downloadLayerConfig(out io.Writer, v1LayerID, endpoint string
 	return nil, 0, nil
 }
 
-func (p *v1Puller) downloadLayer(out io.Writer, v1LayerID, endpoint string, parentID layer.ChainID, layerSize int64, layersDownloaded *bool) (l layer.Layer, err error) {
-	// ensure no two downloads of the same layer happen at the same time
-	poolKey := "layer:" + v1LayerID
-	broadcaster, found := p.config.Pool.add(poolKey)
-	broadcaster.Add(out)
-	if found {
-		logrus.Debugf("Image (id: %s) pull is already running, skipping", v1LayerID)
-		if err = broadcaster.Wait(); err != nil {
-			return nil, err
-		}
-		layerID, err := p.v1IDService.Get(v1LayerID, p.repoInfo.Index.Name)
-		if err != nil {
-			return nil, err
-		}
-		// Does the layer actually exist
-		l, err := p.config.LayerStore.Get(layerID)
-		if err != nil {
-			return nil, err
-		}
-		return l, nil
-	}
+type v1LayerDescriptor struct {
+	v1LayerID        string
+	indexName        string
+	endpoint         string
+	v1IDService      *metadata.V1IDService
+	layersDownloaded *bool
+	layerSize        int64
+	session          *registry.Session
+}
 
-	// This must use a closure so it captures the value of err when
-	// the function returns, not when the 'defer' is evaluated.
-	defer func() {
-		p.config.Pool.removeWithError(poolKey, err)
-	}()
+func (ld *v1LayerDescriptor) Key() string {
+	return "v1:" + ld.v1LayerID
+}
 
-	retries := 5
-	for j := 1; j <= retries; j++ {
-		// Get the layer
-		status := "Pulling fs layer"
-		if j > 1 {
-			status = fmt.Sprintf("Pulling fs layer [retries: %d]", j)
-		}
-		broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), status, nil))
-		layerReader, err := p.session.GetRemoteImageLayer(v1LayerID, endpoint, layerSize)
+func (ld *v1LayerDescriptor) ID() string {
+	return stringid.TruncateID(ld.v1LayerID)
+}
+
+func (ld *v1LayerDescriptor) DiffID() (layer.DiffID, error) {
+	return ld.v1IDService.Get(ld.v1LayerID, ld.indexName)
+}
+
+func (ld *v1LayerDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) {
+	progress.Update(progressOutput, ld.ID(), "Pulling fs layer")
+	layerReader, err := ld.session.GetRemoteImageLayer(ld.v1LayerID, ld.endpoint, ld.layerSize)
+	if err != nil {
+		progress.Update(progressOutput, ld.ID(), "Error pulling dependent layers")
 		if uerr, ok := err.(*url.Error); ok {
 			err = uerr.Err
 		}
-		if terr, ok := err.(net.Error); ok && terr.Timeout() && j < retries {
-			time.Sleep(time.Duration(j) * 500 * time.Millisecond)
-			continue
-		} else if err != nil {
-			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Error pulling dependent layers", nil))
-			return nil, err
-		}
-		*layersDownloaded = true
-		defer layerReader.Close()
-
-		reader := progressreader.New(progressreader.Config{
-			In:        layerReader,
-			Out:       broadcaster,
-			Formatter: p.sf,
-			Size:      layerSize,
-			NewLines:  false,
-			ID:        stringid.TruncateID(v1LayerID),
-			Action:    "Downloading",
-		})
-
-		inflatedLayerData, err := archive.DecompressStream(reader)
-		if err != nil {
-			return nil, fmt.Errorf("could not get decompression stream: %v", err)
-		}
-
-		l, err := p.config.LayerStore.Register(inflatedLayerData, parentID)
-		if err != nil {
-			return nil, fmt.Errorf("failed to register layer: %v", err)
+		if terr, ok := err.(net.Error); ok && terr.Timeout() {
+			return nil, 0, err
 		}
-		logrus.Debugf("layer %s registered successfully", l.DiffID())
+		return nil, 0, xfer.DoNotRetry{Err: err}
+	}
+	*ld.layersDownloaded = true
 
-		if terr, ok := err.(net.Error); ok && terr.Timeout() && j < retries {
-			time.Sleep(time.Duration(j) * 500 * time.Millisecond)
-			continue
-		} else if err != nil {
-			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Error downloading dependent layers", nil))
-			return nil, err
-		}
+	tmpFile, err := ioutil.TempFile("", "GetImageBlob")
+	if err != nil {
+		layerReader.Close()
+		return nil, 0, err
+	}
 
-		// Cache mapping from this v1 ID to content-addressable layer ID
-		if err := p.v1IDService.Set(v1LayerID, p.repoInfo.Index.Name, l.ChainID()); err != nil {
-			return nil, err
-		}
+	reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, layerReader), progressOutput, ld.layerSize, ld.ID(), "Downloading")
+	defer reader.Close()
 
-		broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Download complete", nil))
-		broadcaster.Close()
-		return l, nil
+	_, err = io.Copy(tmpFile, reader)
+	if err != nil {
+		return nil, 0, err
 	}
 
-	// not reached
-	return nil, nil
+	progress.Update(progressOutput, ld.ID(), "Download complete")
+
+	logrus.Debugf("Downloaded %s to tempfile %s", ld.ID(), tmpFile.Name())
+
+	tmpFile.Seek(0, 0)
+	return ioutils.NewReadCloserWrapper(tmpFile, tmpFileCloser(tmpFile)), ld.layerSize, nil
+}
+
+func (ld *v1LayerDescriptor) Registered(diffID layer.DiffID) {
+	// Cache mapping from this layer's DiffID to the blobsum
+	ld.v1IDService.Set(ld.v1LayerID, ld.indexName, diffID)
 }

+ 92 - 205
distribution/pull_v2.go

@@ -15,13 +15,12 @@ import (
 	"github.com/docker/distribution/manifest/schema1"
 	"github.com/docker/distribution/reference"
 	"github.com/docker/docker/distribution/metadata"
+	"github.com/docker/docker/distribution/xfer"
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/image/v1"
 	"github.com/docker/docker/layer"
-	"github.com/docker/docker/pkg/archive"
-	"github.com/docker/docker/pkg/broadcaster"
-	"github.com/docker/docker/pkg/progressreader"
-	"github.com/docker/docker/pkg/streamformatter"
+	"github.com/docker/docker/pkg/ioutils"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/docker/docker/registry"
 	"golang.org/x/net/context"
@@ -31,23 +30,19 @@ type v2Puller struct {
 	blobSumService *metadata.BlobSumService
 	endpoint       registry.APIEndpoint
 	config         *ImagePullConfig
-	sf             *streamformatter.StreamFormatter
 	repoInfo       *registry.RepositoryInfo
 	repo           distribution.Repository
-	sessionID      string
 }
 
-func (p *v2Puller) Pull(ref reference.Named) (fallback bool, err error) {
+func (p *v2Puller) Pull(ctx context.Context, ref reference.Named) (fallback bool, err error) {
 	// TODO(tiborvass): was ReceiveTimeout
 	p.repo, err = NewV2Repository(p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "pull")
 	if err != nil {
-		logrus.Debugf("Error getting v2 registry: %v", err)
+		logrus.Warnf("Error getting v2 registry: %v", err)
 		return true, err
 	}
 
-	p.sessionID = stringid.GenerateRandomID()
-
-	if err := p.pullV2Repository(ref); err != nil {
+	if err := p.pullV2Repository(ctx, ref); err != nil {
 		if registry.ContinueOnError(err) {
 			logrus.Debugf("Error trying v2 registry: %v", err)
 			return true, err
@@ -57,7 +52,7 @@ func (p *v2Puller) Pull(ref reference.Named) (fallback bool, err error) {
 	return false, nil
 }
 
-func (p *v2Puller) pullV2Repository(ref reference.Named) (err error) {
+func (p *v2Puller) pullV2Repository(ctx context.Context, ref reference.Named) (err error) {
 	var refs []reference.Named
 	taggedName := p.repoInfo.LocalName
 	if tagged, isTagged := ref.(reference.Tagged); isTagged {
@@ -73,7 +68,7 @@ func (p *v2Puller) pullV2Repository(ref reference.Named) (err error) {
 		}
 		refs = []reference.Named{taggedName}
 	} else {
-		manSvc, err := p.repo.Manifests(context.Background())
+		manSvc, err := p.repo.Manifests(ctx)
 		if err != nil {
 			return err
 		}
@@ -98,98 +93,109 @@ func (p *v2Puller) pullV2Repository(ref reference.Named) (err error) {
 	for _, pullRef := range refs {
 		// pulledNew is true if either new layers were downloaded OR if existing images were newly tagged
 		// TODO(tiborvass): should we change the name of `layersDownload`? What about message in WriteStatus?
-		pulledNew, err := p.pullV2Tag(p.config.OutStream, pullRef)
+		pulledNew, err := p.pullV2Tag(ctx, pullRef)
 		if err != nil {
 			return err
 		}
 		layersDownloaded = layersDownloaded || pulledNew
 	}
 
-	writeStatus(taggedName.String(), p.config.OutStream, p.sf, layersDownloaded)
+	writeStatus(taggedName.String(), p.config.ProgressOutput, layersDownloaded)
 
 	return nil
 }
 
-// downloadInfo is used to pass information from download to extractor
-type downloadInfo struct {
-	tmpFile     *os.File
-	digest      digest.Digest
-	layer       distribution.ReadSeekCloser
-	size        int64
-	err         chan error
-	poolKey     string
-	broadcaster *broadcaster.Buffered
+type v2LayerDescriptor struct {
+	digest         digest.Digest
+	repo           distribution.Repository
+	blobSumService *metadata.BlobSumService
 }
 
-type errVerification struct{}
+func (ld *v2LayerDescriptor) Key() string {
+	return "v2:" + ld.digest.String()
+}
 
-func (errVerification) Error() string { return "verification failed" }
+func (ld *v2LayerDescriptor) ID() string {
+	return stringid.TruncateID(ld.digest.String())
+}
 
-func (p *v2Puller) download(di *downloadInfo) {
-	logrus.Debugf("pulling blob %q", di.digest)
+func (ld *v2LayerDescriptor) DiffID() (layer.DiffID, error) {
+	return ld.blobSumService.GetDiffID(ld.digest)
+}
+
+func (ld *v2LayerDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) {
+	logrus.Debugf("pulling blob %q", ld.digest)
 
-	blobs := p.repo.Blobs(context.Background())
+	blobs := ld.repo.Blobs(ctx)
 
-	layerDownload, err := blobs.Open(context.Background(), di.digest)
+	layerDownload, err := blobs.Open(ctx, ld.digest)
 	if err != nil {
-		logrus.Debugf("Error fetching layer: %v", err)
-		di.err <- err
-		return
+		logrus.Debugf("Error statting layer: %v", err)
+		if err == distribution.ErrBlobUnknown {
+			return nil, 0, xfer.DoNotRetry{Err: err}
+		}
+		return nil, 0, retryOnError(err)
 	}
-	defer layerDownload.Close()
 
-	di.size, err = layerDownload.Seek(0, os.SEEK_END)
+	size, err := layerDownload.Seek(0, os.SEEK_END)
 	if err != nil {
 		// Seek failed, perhaps because there was no Content-Length
 		// header. This shouldn't fail the download, because we can
 		// still continue without a progress bar.
-		di.size = 0
+		size = 0
 	} else {
 		// Restore the seek offset at the beginning of the stream.
 		_, err = layerDownload.Seek(0, os.SEEK_SET)
 		if err != nil {
-			di.err <- err
-			return
+			return nil, 0, err
 		}
 	}
 
-	verifier, err := digest.NewDigestVerifier(di.digest)
+	reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, layerDownload), progressOutput, size, ld.ID(), "Downloading")
+	defer reader.Close()
+
+	verifier, err := digest.NewDigestVerifier(ld.digest)
 	if err != nil {
-		di.err <- err
-		return
+		return nil, 0, xfer.DoNotRetry{Err: err}
 	}
 
-	digestStr := di.digest.String()
+	tmpFile, err := ioutil.TempFile("", "GetImageBlob")
+	if err != nil {
+		return nil, 0, xfer.DoNotRetry{Err: err}
+	}
 
-	reader := progressreader.New(progressreader.Config{
-		In:        ioutil.NopCloser(io.TeeReader(layerDownload, verifier)),
-		Out:       di.broadcaster,
-		Formatter: p.sf,
-		Size:      di.size,
-		NewLines:  false,
-		ID:        stringid.TruncateID(digestStr),
-		Action:    "Downloading",
-	})
-	io.Copy(di.tmpFile, reader)
+	_, err = io.Copy(tmpFile, io.TeeReader(reader, verifier))
+	if err != nil {
+		return nil, 0, retryOnError(err)
+	}
 
-	di.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(digestStr), "Verifying Checksum", nil))
+	progress.Update(progressOutput, ld.ID(), "Verifying Checksum")
 
 	if !verifier.Verified() {
-		err = fmt.Errorf("filesystem layer verification failed for digest %s", di.digest)
+		err = fmt.Errorf("filesystem layer verification failed for digest %s", ld.digest)
 		logrus.Error(err)
-		di.err <- err
-		return
+		tmpFile.Close()
+		if err := os.RemoveAll(tmpFile.Name()); err != nil {
+			logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name())
+		}
+
+		return nil, 0, xfer.DoNotRetry{Err: err}
 	}
 
-	di.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(digestStr), "Download complete", nil))
+	progress.Update(progressOutput, ld.ID(), "Download complete")
 
-	logrus.Debugf("Downloaded %s to tempfile %s", digestStr, di.tmpFile.Name())
-	di.layer = layerDownload
+	logrus.Debugf("Downloaded %s to tempfile %s", ld.ID(), tmpFile.Name())
+
+	tmpFile.Seek(0, 0)
+	return ioutils.NewReadCloserWrapper(tmpFile, tmpFileCloser(tmpFile)), size, nil
+}
 
-	di.err <- nil
+func (ld *v2LayerDescriptor) Registered(diffID layer.DiffID) {
+	// Cache mapping from this layer's DiffID to the blobsum
+	ld.blobSumService.Add(diffID, ld.digest)
 }
 
-func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated bool, err error) {
+func (p *v2Puller) pullV2Tag(ctx context.Context, ref reference.Named) (tagUpdated bool, err error) {
 	tagOrDigest := ""
 	if tagged, isTagged := ref.(reference.Tagged); isTagged {
 		tagOrDigest = tagged.Tag()
@@ -201,7 +207,7 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo
 
 	logrus.Debugf("Pulling ref from V2 registry: %q", tagOrDigest)
 
-	manSvc, err := p.repo.Manifests(context.Background())
+	manSvc, err := p.repo.Manifests(ctx)
 	if err != nil {
 		return false, err
 	}
@@ -231,33 +237,17 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo
 		return false, err
 	}
 
-	out.Write(p.sf.FormatStatus(tagOrDigest, "Pulling from %s", p.repo.Name()))
+	progress.Message(p.config.ProgressOutput, tagOrDigest, "Pulling from "+p.repo.Name())
 
-	var downloads []*downloadInfo
-
-	defer func() {
-		for _, d := range downloads {
-			p.config.Pool.removeWithError(d.poolKey, err)
-			if d.tmpFile != nil {
-				d.tmpFile.Close()
-				if err := os.RemoveAll(d.tmpFile.Name()); err != nil {
-					logrus.Errorf("Failed to remove temp file: %s", d.tmpFile.Name())
-				}
-			}
-		}
-	}()
+	var descriptors []xfer.DownloadDescriptor
 
 	// Image history converted to the new format
 	var history []image.History
 
-	poolKey := "v2layer:"
-	notFoundLocally := false
-
 	// Note that the order of this loop is in the direction of bottom-most
 	// to top-most, so that the downloads slice gets ordered correctly.
 	for i := len(verifiedManifest.FSLayers) - 1; i >= 0; i-- {
 		blobSum := verifiedManifest.FSLayers[i].BlobSum
-		poolKey += blobSum.String()
 
 		var throwAway struct {
 			ThrowAway bool `json:"throwaway,omitempty"`
@@ -276,119 +266,22 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo
 			continue
 		}
 
-		// Do we have a layer on disk corresponding to the set of
-		// blobsums up to this point?
-		if !notFoundLocally {
-			notFoundLocally = true
-			diffID, err := p.blobSumService.GetDiffID(blobSum)
-			if err == nil {
-				rootFS.Append(diffID)
-				if l, err := p.config.LayerStore.Get(rootFS.ChainID()); err == nil {
-					notFoundLocally = false
-					logrus.Debugf("Layer already exists: %s", blobSum.String())
-					out.Write(p.sf.FormatProgress(stringid.TruncateID(blobSum.String()), "Already exists", nil))
-					defer layer.ReleaseAndLog(p.config.LayerStore, l)
-					continue
-				} else {
-					rootFS.DiffIDs = rootFS.DiffIDs[:len(rootFS.DiffIDs)-1]
-				}
-			}
+		layerDescriptor := &v2LayerDescriptor{
+			digest:         blobSum,
+			repo:           p.repo,
+			blobSumService: p.blobSumService,
 		}
 
-		out.Write(p.sf.FormatProgress(stringid.TruncateID(blobSum.String()), "Pulling fs layer", nil))
-
-		tmpFile, err := ioutil.TempFile("", "GetImageBlob")
-		if err != nil {
-			return false, err
-		}
-
-		d := &downloadInfo{
-			poolKey: poolKey,
-			digest:  blobSum,
-			tmpFile: tmpFile,
-			// TODO: seems like this chan buffer solved hanging problem in go1.5,
-			// this can indicate some deeper problem that somehow we never take
-			// error from channel in loop below
-			err: make(chan error, 1),
-		}
-
-		downloads = append(downloads, d)
-
-		broadcaster, found := p.config.Pool.add(d.poolKey)
-		broadcaster.Add(out)
-		d.broadcaster = broadcaster
-		if found {
-			d.err <- nil
-		} else {
-			go p.download(d)
-		}
+		descriptors = append(descriptors, layerDescriptor)
 	}
 
-	for _, d := range downloads {
-		if err := <-d.err; err != nil {
-			return false, err
-		}
-
-		if d.layer == nil {
-			// Wait for a different pull to download and extract
-			// this layer.
-			err = d.broadcaster.Wait()
-			if err != nil {
-				return false, err
-			}
-
-			diffID, err := p.blobSumService.GetDiffID(d.digest)
-			if err != nil {
-				return false, err
-			}
-			rootFS.Append(diffID)
-
-			l, err := p.config.LayerStore.Get(rootFS.ChainID())
-			if err != nil {
-				return false, err
-			}
-
-			defer layer.ReleaseAndLog(p.config.LayerStore, l)
-
-			continue
-		}
-
-		d.tmpFile.Seek(0, 0)
-		reader := progressreader.New(progressreader.Config{
-			In:        d.tmpFile,
-			Out:       d.broadcaster,
-			Formatter: p.sf,
-			Size:      d.size,
-			NewLines:  false,
-			ID:        stringid.TruncateID(d.digest.String()),
-			Action:    "Extracting",
-		})
-
-		inflatedLayerData, err := archive.DecompressStream(reader)
-		if err != nil {
-			return false, fmt.Errorf("could not get decompression stream: %v", err)
-		}
-
-		l, err := p.config.LayerStore.Register(inflatedLayerData, rootFS.ChainID())
-		if err != nil {
-			return false, fmt.Errorf("failed to register layer: %v", err)
-		}
-		logrus.Debugf("layer %s registered successfully", l.DiffID())
-		rootFS.Append(l.DiffID())
-
-		// Cache mapping from this layer's DiffID to the blobsum
-		if err := p.blobSumService.Add(l.DiffID(), d.digest); err != nil {
-			return false, err
-		}
-
-		defer layer.ReleaseAndLog(p.config.LayerStore, l)
-
-		d.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(d.digest.String()), "Pull complete", nil))
-		d.broadcaster.Close()
-		tagUpdated = true
+	resultRootFS, release, err := p.config.DownloadManager.Download(ctx, *rootFS, descriptors, p.config.ProgressOutput)
+	if err != nil {
+		return false, err
 	}
+	defer release()
 
-	config, err := v1.MakeConfigFromV1Config([]byte(verifiedManifest.History[0].V1Compatibility), rootFS, history)
+	config, err := v1.MakeConfigFromV1Config([]byte(verifiedManifest.History[0].V1Compatibility), &resultRootFS, history)
 	if err != nil {
 		return false, err
 	}
@@ -403,30 +296,24 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo
 		return false, err
 	}
 
-	// Check for new tag if no layers downloaded
-	var oldTagImageID image.ID
-	if !tagUpdated {
-		oldTagImageID, err = p.config.TagStore.Get(ref)
-		if err != nil || oldTagImageID != imageID {
-			tagUpdated = true
-		}
+	if manifestDigest != "" {
+		progress.Message(p.config.ProgressOutput, "", "Digest: "+manifestDigest.String())
 	}
 
-	if tagUpdated {
-		if canonical, ok := ref.(reference.Canonical); ok {
-			if err = p.config.TagStore.AddDigest(canonical, imageID, true); err != nil {
-				return false, err
-			}
-		} else if err = p.config.TagStore.AddTag(ref, imageID, true); err != nil {
-			return false, err
-		}
+	oldTagImageID, err := p.config.TagStore.Get(ref)
+	if err == nil && oldTagImageID == imageID {
+		return false, nil
 	}
 
-	if manifestDigest != "" {
-		out.Write(p.sf.FormatStatus("", "Digest: %s", manifestDigest))
+	if canonical, ok := ref.(reference.Canonical); ok {
+		if err = p.config.TagStore.AddDigest(canonical, imageID, true); err != nil {
+			return false, err
+		}
+	} else if err = p.config.TagStore.AddTag(ref, imageID, true); err != nil {
+		return false, err
 	}
 
-	return tagUpdated, nil
+	return true, nil
 }
 
 func verifyManifest(signedManifest *schema1.SignedManifest, ref reference.Reference) (m *schema1.Manifest, err error) {

+ 22 - 14
distribution/push.go

@@ -12,12 +12,14 @@ import (
 	"github.com/docker/docker/cliconfig"
 	"github.com/docker/docker/daemon/events"
 	"github.com/docker/docker/distribution/metadata"
+	"github.com/docker/docker/distribution/xfer"
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/layer"
-	"github.com/docker/docker/pkg/streamformatter"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/tag"
 	"github.com/docker/libtrust"
+	"golang.org/x/net/context"
 )
 
 // ImagePushConfig stores push configuration.
@@ -28,9 +30,9 @@ type ImagePushConfig struct {
 	// AuthConfig holds authentication credentials for authenticating with
 	// the registry.
 	AuthConfig *cliconfig.AuthConfig
-	// OutStream is the output writer for showing the status of the push
+	// ProgressOutput is the interface for showing the status of the push
 	// operation.
-	OutStream io.Writer
+	ProgressOutput progress.Output
 	// RegistryService is the registry service to use for TLS configuration
 	// and endpoint lookup.
 	RegistryService *registry.Service
@@ -48,6 +50,8 @@ type ImagePushConfig struct {
 	// TrustKey is the private key for legacy signatures. This is typically
 	// an ephemeral key, since these signatures are no longer verified.
 	TrustKey libtrust.PrivateKey
+	// UploadManager dispatches uploads.
+	UploadManager *xfer.LayerUploadManager
 }
 
 // Pusher is an interface that abstracts pushing for different API versions.
@@ -56,7 +60,7 @@ type Pusher interface {
 	// Push returns an error if any, as well as a boolean that determines whether to retry Push on the next configured endpoint.
 	//
 	// TODO(tiborvass): have Push() take a reference to repository + tag, so that the pusher itself is repository-agnostic.
-	Push() (fallback bool, err error)
+	Push(ctx context.Context) (fallback bool, err error)
 }
 
 const compressionBufSize = 32768
@@ -66,7 +70,7 @@ const compressionBufSize = 32768
 // whether a v1 or v2 pusher will be created. The other parameters are passed
 // through to the underlying pusher implementation for use during the actual
 // push operation.
-func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePushConfig *ImagePushConfig, sf *streamformatter.StreamFormatter) (Pusher, error) {
+func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePushConfig *ImagePushConfig) (Pusher, error) {
 	switch endpoint.Version {
 	case registry.APIVersion2:
 		return &v2Pusher{
@@ -75,8 +79,7 @@ func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *reg
 			endpoint:       endpoint,
 			repoInfo:       repoInfo,
 			config:         imagePushConfig,
-			sf:             sf,
-			layersPushed:   make(map[digest.Digest]bool),
+			layersPushed:   pushMap{layersPushed: make(map[digest.Digest]bool)},
 		}, nil
 	case registry.APIVersion1:
 		return &v1Pusher{
@@ -85,7 +88,6 @@ func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *reg
 			endpoint:    endpoint,
 			repoInfo:    repoInfo,
 			config:      imagePushConfig,
-			sf:          sf,
 		}, nil
 	}
 	return nil, fmt.Errorf("unknown version %d for registry %s", endpoint.Version, endpoint.URL)
@@ -94,11 +96,9 @@ func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *reg
 // Push initiates a push operation on the repository named localName.
 // ref is the specific variant of the image to be pushed.
 // If no tag is provided, all tags will be pushed.
-func Push(ref reference.Named, imagePushConfig *ImagePushConfig) error {
+func Push(ctx context.Context, ref reference.Named, imagePushConfig *ImagePushConfig) error {
 	// FIXME: Allow to interrupt current push when new push of same image is done.
 
-	var sf = streamformatter.NewJSONStreamFormatter()
-
 	// Resolve the Repository name from fqn to RepositoryInfo
 	repoInfo, err := imagePushConfig.RegistryService.ResolveRepository(ref)
 	if err != nil {
@@ -110,7 +110,7 @@ func Push(ref reference.Named, imagePushConfig *ImagePushConfig) error {
 		return err
 	}
 
-	imagePushConfig.OutStream.Write(sf.FormatStatus("", "The push refers to a repository [%s]", repoInfo.CanonicalName))
+	progress.Messagef(imagePushConfig.ProgressOutput, "", "The push refers to a repository [%s]", repoInfo.CanonicalName.String())
 
 	associations := imagePushConfig.TagStore.ReferencesByName(repoInfo.LocalName)
 	if len(associations) == 0 {
@@ -121,12 +121,20 @@ func Push(ref reference.Named, imagePushConfig *ImagePushConfig) error {
 	for _, endpoint := range endpoints {
 		logrus.Debugf("Trying to push %s to %s %s", repoInfo.CanonicalName, endpoint.URL, endpoint.Version)
 
-		pusher, err := NewPusher(ref, endpoint, repoInfo, imagePushConfig, sf)
+		pusher, err := NewPusher(ref, endpoint, repoInfo, imagePushConfig)
 		if err != nil {
 			lastErr = err
 			continue
 		}
-		if fallback, err := pusher.Push(); err != nil {
+		if fallback, err := pusher.Push(ctx); err != nil {
+			// Was this push cancelled? If so, don't try to fall
+			// back.
+			select {
+			case <-ctx.Done():
+				fallback = false
+			default:
+			}
+
 			if fallback {
 				lastErr = err
 				continue

+ 20 - 32
distribution/push_v1.go

@@ -2,8 +2,6 @@ package distribution
 
 import (
 	"fmt"
-	"io"
-	"io/ioutil"
 	"sync"
 
 	"github.com/Sirupsen/logrus"
@@ -15,25 +13,23 @@ import (
 	"github.com/docker/docker/image/v1"
 	"github.com/docker/docker/layer"
 	"github.com/docker/docker/pkg/ioutils"
-	"github.com/docker/docker/pkg/progressreader"
-	"github.com/docker/docker/pkg/streamformatter"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/docker/docker/registry"
+	"golang.org/x/net/context"
 )
 
 type v1Pusher struct {
+	ctx         context.Context
 	v1IDService *metadata.V1IDService
 	endpoint    registry.APIEndpoint
 	ref         reference.Named
 	repoInfo    *registry.RepositoryInfo
 	config      *ImagePushConfig
-	sf          *streamformatter.StreamFormatter
 	session     *registry.Session
-
-	out io.Writer
 }
 
-func (p *v1Pusher) Push() (fallback bool, err error) {
+func (p *v1Pusher) Push(ctx context.Context) (fallback bool, err error) {
 	tlsConfig, err := p.config.RegistryService.TLSConfig(p.repoInfo.Index.Name)
 	if err != nil {
 		return false, err
@@ -55,7 +51,7 @@ func (p *v1Pusher) Push() (fallback bool, err error) {
 		// TODO(dmcgowan): Check if should fallback
 		return true, err
 	}
-	if err := p.pushRepository(); err != nil {
+	if err := p.pushRepository(ctx); err != nil {
 		// TODO(dmcgowan): Check if should fallback
 		return false, err
 	}
@@ -306,12 +302,12 @@ func (p *v1Pusher) lookupImageOnEndpoint(wg *sync.WaitGroup, endpoint string, im
 			logrus.Errorf("Error in LookupRemoteImage: %s", err)
 			imagesToPush <- v1ID
 		} else {
-			p.out.Write(p.sf.FormatStatus("", "Image %s already pushed, skipping", stringid.TruncateID(v1ID)))
+			progress.Messagef(p.config.ProgressOutput, "", "Image %s already pushed, skipping", stringid.TruncateID(v1ID))
 		}
 	}
 }
 
-func (p *v1Pusher) pushImageToEndpoint(endpoint string, imageList []v1Image, tags map[image.ID][]string, repo *registry.RepositoryData) error {
+func (p *v1Pusher) pushImageToEndpoint(ctx context.Context, endpoint string, imageList []v1Image, tags map[image.ID][]string, repo *registry.RepositoryData) error {
 	workerCount := len(imageList)
 	// start a maximum of 5 workers to check if images exist on the specified endpoint.
 	if workerCount > 5 {
@@ -349,14 +345,14 @@ func (p *v1Pusher) pushImageToEndpoint(endpoint string, imageList []v1Image, tag
 	for _, img := range imageList {
 		v1ID := img.V1ID()
 		if _, push := shouldPush[v1ID]; push {
-			if _, err := p.pushImage(img, endpoint); err != nil {
+			if _, err := p.pushImage(ctx, img, endpoint); err != nil {
 				// FIXME: Continue on error?
 				return err
 			}
 		}
 		if topImage, isTopImage := img.(*v1TopImage); isTopImage {
 			for _, tag := range tags[topImage.imageID] {
-				p.out.Write(p.sf.FormatStatus("", "Pushing tag for rev [%s] on {%s}", stringid.TruncateID(v1ID), endpoint+"repositories/"+p.repoInfo.RemoteName.Name()+"/tags/"+tag))
+				progress.Messagef(p.config.ProgressOutput, "", "Pushing tag for rev [%s] on {%s}", stringid.TruncateID(v1ID), endpoint+"repositories/"+p.repoInfo.RemoteName.Name()+"/tags/"+tag)
 				if err := p.session.PushRegistryTag(p.repoInfo.RemoteName, v1ID, tag, endpoint); err != nil {
 					return err
 				}
@@ -367,8 +363,7 @@ func (p *v1Pusher) pushImageToEndpoint(endpoint string, imageList []v1Image, tag
 }
 
 // pushRepository pushes layers that do not already exist on the registry.
-func (p *v1Pusher) pushRepository() error {
-	p.out = ioutils.NewWriteFlusher(p.config.OutStream)
+func (p *v1Pusher) pushRepository(ctx context.Context) error {
 	imgList, tags, referencedLayers, err := p.getImageList()
 	defer func() {
 		for _, l := range referencedLayers {
@@ -378,7 +373,7 @@ func (p *v1Pusher) pushRepository() error {
 	if err != nil {
 		return err
 	}
-	p.out.Write(p.sf.FormatStatus("", "Sending image list"))
+	progress.Message(p.config.ProgressOutput, "", "Sending image list")
 
 	imageIndex := createImageIndex(imgList, tags)
 	for _, data := range imageIndex {
@@ -391,10 +386,10 @@ func (p *v1Pusher) pushRepository() error {
 	if err != nil {
 		return err
 	}
-	p.out.Write(p.sf.FormatStatus("", "Pushing repository %s", p.repoInfo.CanonicalName))
+	progress.Message(p.config.ProgressOutput, "", "Pushing repository "+p.repoInfo.CanonicalName.String())
 	// push the repository to each of the endpoints only if it does not exist.
 	for _, endpoint := range repoData.Endpoints {
-		if err := p.pushImageToEndpoint(endpoint, imgList, tags, repoData); err != nil {
+		if err := p.pushImageToEndpoint(ctx, endpoint, imgList, tags, repoData); err != nil {
 			return err
 		}
 	}
@@ -402,11 +397,11 @@ func (p *v1Pusher) pushRepository() error {
 	return err
 }
 
-func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err error) {
+func (p *v1Pusher) pushImage(ctx context.Context, v1Image v1Image, ep string) (checksum string, err error) {
 	v1ID := v1Image.V1ID()
 
 	jsonRaw := v1Image.Config()
-	p.out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Pushing", nil))
+	progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Pushing")
 
 	// General rule is to use ID for graph accesses and compatibilityID for
 	// calls to session.registry()
@@ -417,7 +412,7 @@ func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err e
 	// Send the json
 	if err := p.session.PushImageJSONRegistry(imgData, jsonRaw, ep); err != nil {
 		if err == registry.ErrAlreadyExists {
-			p.out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Image already pushed, skipping", nil))
+			progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Image already pushed, skipping")
 			return "", nil
 		}
 		return "", err
@@ -437,15 +432,8 @@ func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err e
 	// Send the layer
 	logrus.Debugf("rendered layer for %s of [%d] size", v1ID, size)
 
-	reader := progressreader.New(progressreader.Config{
-		In:        ioutil.NopCloser(arch),
-		Out:       p.out,
-		Formatter: p.sf,
-		Size:      size,
-		NewLines:  false,
-		ID:        stringid.TruncateID(v1ID),
-		Action:    "Pushing",
-	})
+	reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, arch), p.config.ProgressOutput, size, stringid.TruncateID(v1ID), "Pushing")
+	defer reader.Close()
 
 	checksum, checksumPayload, err := p.session.PushImageLayerRegistry(v1ID, reader, ep, jsonRaw)
 	if err != nil {
@@ -458,10 +446,10 @@ func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err e
 		return "", err
 	}
 
-	if err := p.v1IDService.Set(v1ID, p.repoInfo.Index.Name, l.ChainID()); err != nil {
+	if err := p.v1IDService.Set(v1ID, p.repoInfo.Index.Name, l.DiffID()); err != nil {
 		logrus.Warnf("Could not set v1 ID mapping: %v", err)
 	}
 
-	p.out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Image successfully pushed", nil))
+	progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Image successfully pushed")
 	return imgData.Checksum, nil
 }

+ 118 - 105
distribution/push_v2.go

@@ -5,7 +5,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"io/ioutil"
+	"sync"
 	"time"
 
 	"github.com/Sirupsen/logrus"
@@ -15,11 +15,12 @@ import (
 	"github.com/docker/distribution/manifest/schema1"
 	"github.com/docker/distribution/reference"
 	"github.com/docker/docker/distribution/metadata"
+	"github.com/docker/docker/distribution/xfer"
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/image/v1"
 	"github.com/docker/docker/layer"
-	"github.com/docker/docker/pkg/progressreader"
-	"github.com/docker/docker/pkg/streamformatter"
+	"github.com/docker/docker/pkg/ioutils"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/tag"
@@ -32,16 +33,20 @@ type v2Pusher struct {
 	endpoint       registry.APIEndpoint
 	repoInfo       *registry.RepositoryInfo
 	config         *ImagePushConfig
-	sf             *streamformatter.StreamFormatter
 	repo           distribution.Repository
 
 	// layersPushed is the set of layers known to exist on the remote side.
 	// This avoids redundant queries when pushing multiple tags that
 	// involve the same layers.
+	layersPushed pushMap
+}
+
+type pushMap struct {
+	sync.Mutex
 	layersPushed map[digest.Digest]bool
 }
 
-func (p *v2Pusher) Push() (fallback bool, err error) {
+func (p *v2Pusher) Push(ctx context.Context) (fallback bool, err error) {
 	p.repo, err = NewV2Repository(p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "push", "pull")
 	if err != nil {
 		logrus.Debugf("Error getting v2 registry: %v", err)
@@ -75,7 +80,7 @@ func (p *v2Pusher) Push() (fallback bool, err error) {
 	}
 
 	for _, association := range associations {
-		if err := p.pushV2Tag(association); err != nil {
+		if err := p.pushV2Tag(ctx, association); err != nil {
 			return false, err
 		}
 	}
@@ -83,7 +88,7 @@ func (p *v2Pusher) Push() (fallback bool, err error) {
 	return false, nil
 }
 
-func (p *v2Pusher) pushV2Tag(association tag.Association) error {
+func (p *v2Pusher) pushV2Tag(ctx context.Context, association tag.Association) error {
 	ref := association.Ref
 	logrus.Debugf("Pushing repository: %s", ref.String())
 
@@ -92,8 +97,6 @@ func (p *v2Pusher) pushV2Tag(association tag.Association) error {
 		return fmt.Errorf("could not find image from tag %s: %v", ref.String(), err)
 	}
 
-	out := p.config.OutStream
-
 	var l layer.Layer
 
 	topLayerID := img.RootFS.ChainID()
@@ -107,33 +110,41 @@ func (p *v2Pusher) pushV2Tag(association tag.Association) error {
 		defer layer.ReleaseAndLog(p.config.LayerStore, l)
 	}
 
-	fsLayers := make(map[layer.DiffID]schema1.FSLayer)
+	var descriptors []xfer.UploadDescriptor
 
 	// Push empty layer if necessary
 	for _, h := range img.History {
 		if h.EmptyLayer {
-			dgst, err := p.pushLayerIfNecessary(out, layer.EmptyLayer)
-			if err != nil {
-				return err
+			descriptors = []xfer.UploadDescriptor{
+				&v2PushDescriptor{
+					layer:          layer.EmptyLayer,
+					blobSumService: p.blobSumService,
+					repo:           p.repo,
+					layersPushed:   &p.layersPushed,
+				},
 			}
-			p.layersPushed[dgst] = true
-			fsLayers[layer.EmptyLayer.DiffID()] = schema1.FSLayer{BlobSum: dgst}
 			break
 		}
 	}
 
+	// Loop bounds condition is to avoid pushing the base layer on Windows.
 	for i := 0; i < len(img.RootFS.DiffIDs); i++ {
-		dgst, err := p.pushLayerIfNecessary(out, l)
-		if err != nil {
-			return err
+		descriptor := &v2PushDescriptor{
+			layer:          l,
+			blobSumService: p.blobSumService,
+			repo:           p.repo,
+			layersPushed:   &p.layersPushed,
 		}
-
-		p.layersPushed[dgst] = true
-		fsLayers[l.DiffID()] = schema1.FSLayer{BlobSum: dgst}
+		descriptors = append(descriptors, descriptor)
 
 		l = l.Parent()
 	}
 
+	fsLayers, err := p.config.UploadManager.Upload(ctx, descriptors, p.config.ProgressOutput)
+	if err != nil {
+		return err
+	}
+
 	var tag string
 	if tagged, isTagged := ref.(reference.Tagged); isTagged {
 		tag = tagged.Tag()
@@ -157,59 +168,124 @@ func (p *v2Pusher) pushV2Tag(association tag.Association) error {
 		if tagged, isTagged := ref.(reference.Tagged); isTagged {
 			// NOTE: do not change this format without first changing the trust client
 			// code. This information is used to determine what was pushed and should be signed.
-			out.Write(p.sf.FormatStatus("", "%s: digest: %s size: %d", tagged.Tag(), manifestDigest, manifestSize))
+			progress.Messagef(p.config.ProgressOutput, "", "%s: digest: %s size: %d", tagged.Tag(), manifestDigest, manifestSize)
 		}
 	}
 
-	manSvc, err := p.repo.Manifests(context.Background())
+	manSvc, err := p.repo.Manifests(ctx)
 	if err != nil {
 		return err
 	}
 	return manSvc.Put(signed)
 }
 
-func (p *v2Pusher) pushLayerIfNecessary(out io.Writer, l layer.Layer) (digest.Digest, error) {
-	logrus.Debugf("Pushing layer: %s", l.DiffID())
+type v2PushDescriptor struct {
+	layer          layer.Layer
+	blobSumService *metadata.BlobSumService
+	repo           distribution.Repository
+	layersPushed   *pushMap
+}
+
+func (pd *v2PushDescriptor) Key() string {
+	return "v2push:" + pd.repo.Name() + " " + pd.layer.DiffID().String()
+}
+
+func (pd *v2PushDescriptor) ID() string {
+	return stringid.TruncateID(pd.layer.DiffID().String())
+}
+
+func (pd *v2PushDescriptor) DiffID() layer.DiffID {
+	return pd.layer.DiffID()
+}
+
+func (pd *v2PushDescriptor) Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error) {
+	diffID := pd.DiffID()
+
+	logrus.Debugf("Pushing layer: %s", diffID)
 
 	// Do we have any blobsums associated with this layer's DiffID?
-	possibleBlobsums, err := p.blobSumService.GetBlobSums(l.DiffID())
+	possibleBlobsums, err := pd.blobSumService.GetBlobSums(diffID)
 	if err == nil {
-		dgst, exists, err := p.blobSumAlreadyExists(possibleBlobsums)
+		dgst, exists, err := blobSumAlreadyExists(ctx, possibleBlobsums, pd.repo, pd.layersPushed)
 		if err != nil {
-			out.Write(p.sf.FormatProgress(stringid.TruncateID(string(l.DiffID())), "Image push failed", nil))
-			return "", err
+			progress.Update(progressOutput, pd.ID(), "Image push failed")
+			return "", retryOnError(err)
 		}
 		if exists {
-			out.Write(p.sf.FormatProgress(stringid.TruncateID(string(l.DiffID())), "Layer already exists", nil))
+			progress.Update(progressOutput, pd.ID(), "Layer already exists")
 			return dgst, nil
 		}
 	}
 
 	// if digest was empty or not saved, or if blob does not exist on the remote repository,
 	// then push the blob.
-	pushDigest, err := p.pushV2Layer(p.repo.Blobs(context.Background()), l)
+	bs := pd.repo.Blobs(ctx)
+
+	// Send the layer
+	layerUpload, err := bs.Create(ctx)
+	if err != nil {
+		return "", retryOnError(err)
+	}
+	defer layerUpload.Close()
+
+	arch, err := pd.layer.TarStream()
 	if err != nil {
-		return "", err
+		return "", xfer.DoNotRetry{Err: err}
 	}
+
+	// don't care if this fails; best effort
+	size, _ := pd.layer.DiffSize()
+
+	reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, arch), progressOutput, size, pd.ID(), "Pushing")
+	defer reader.Close()
+	compressedReader := compress(reader)
+
+	digester := digest.Canonical.New()
+	tee := io.TeeReader(compressedReader, digester.Hash())
+
+	nn, err := layerUpload.ReadFrom(tee)
+	compressedReader.Close()
+	if err != nil {
+		return "", retryOnError(err)
+	}
+
+	pushDigest := digester.Digest()
+	if _, err := layerUpload.Commit(ctx, distribution.Descriptor{Digest: pushDigest}); err != nil {
+		return "", retryOnError(err)
+	}
+
+	logrus.Debugf("uploaded layer %s (%s), %d bytes", diffID, pushDigest, nn)
+	progress.Update(progressOutput, pd.ID(), "Pushed")
+
 	// Cache mapping from this layer's DiffID to the blobsum
-	if err := p.blobSumService.Add(l.DiffID(), pushDigest); err != nil {
-		return "", err
+	if err := pd.blobSumService.Add(diffID, pushDigest); err != nil {
+		return "", xfer.DoNotRetry{Err: err}
 	}
 
+	pd.layersPushed.Lock()
+	pd.layersPushed.layersPushed[pushDigest] = true
+	pd.layersPushed.Unlock()
+
 	return pushDigest, nil
 }
 
 // blobSumAlreadyExists checks if the registry already know about any of the
 // blobsums passed in the "blobsums" slice. If it finds one that the registry
 // knows about, it returns the known digest and "true".
-func (p *v2Pusher) blobSumAlreadyExists(blobsums []digest.Digest) (digest.Digest, bool, error) {
+func blobSumAlreadyExists(ctx context.Context, blobsums []digest.Digest, repo distribution.Repository, layersPushed *pushMap) (digest.Digest, bool, error) {
+	layersPushed.Lock()
 	for _, dgst := range blobsums {
-		if p.layersPushed[dgst] {
+		if layersPushed.layersPushed[dgst] {
 			// it is already known that the push is not needed and
 			// therefore doing a stat is unnecessary
+			layersPushed.Unlock()
 			return dgst, true, nil
 		}
-		_, err := p.repo.Blobs(context.Background()).Stat(context.Background(), dgst)
+	}
+	layersPushed.Unlock()
+
+	for _, dgst := range blobsums {
+		_, err := repo.Blobs(ctx).Stat(ctx, dgst)
 		switch err {
 		case nil:
 			return dgst, true, nil
@@ -226,7 +302,7 @@ func (p *v2Pusher) blobSumAlreadyExists(blobsums []digest.Digest) (digest.Digest
 // FSLayer digests.
 // FIXME: This should be moved to the distribution repo, since it will also
 // be useful for converting new manifests to the old format.
-func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.DiffID]schema1.FSLayer) (*schema1.Manifest, error) {
+func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.DiffID]digest.Digest) (*schema1.Manifest, error) {
 	if len(img.History) == 0 {
 		return nil, errors.New("empty history when trying to create V2 manifest")
 	}
@@ -271,7 +347,7 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif
 		if !present {
 			return nil, fmt.Errorf("missing layer in CreateV2Manifest: %s", diffID.String())
 		}
-		dgst, err := digest.FromBytes([]byte(fsLayer.BlobSum.Hex() + " " + parent))
+		dgst, err := digest.FromBytes([]byte(fsLayer.Hex() + " " + parent))
 		if err != nil {
 			return nil, err
 		}
@@ -294,7 +370,7 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif
 
 		reversedIndex := len(img.History) - i - 1
 		history[reversedIndex].V1Compatibility = string(jsonBytes)
-		fsLayerList[reversedIndex] = fsLayer
+		fsLayerList[reversedIndex] = schema1.FSLayer{BlobSum: fsLayer}
 
 		parent = v1ID
 	}
@@ -315,11 +391,11 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif
 		return nil, fmt.Errorf("missing layer in CreateV2Manifest: %s", diffID.String())
 	}
 
-	dgst, err := digest.FromBytes([]byte(fsLayer.BlobSum.Hex() + " " + parent + " " + string(img.RawJSON())))
+	dgst, err := digest.FromBytes([]byte(fsLayer.Hex() + " " + parent + " " + string(img.RawJSON())))
 	if err != nil {
 		return nil, err
 	}
-	fsLayerList[0] = fsLayer
+	fsLayerList[0] = schema1.FSLayer{BlobSum: fsLayer}
 
 	// Top-level v1compatibility string should be a modified version of the
 	// image config.
@@ -346,66 +422,3 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif
 		History:      history,
 	}, nil
 }
-
-func rawJSON(value interface{}) *json.RawMessage {
-	jsonval, err := json.Marshal(value)
-	if err != nil {
-		return nil
-	}
-	return (*json.RawMessage)(&jsonval)
-}
-
-func (p *v2Pusher) pushV2Layer(bs distribution.BlobService, l layer.Layer) (digest.Digest, error) {
-	out := p.config.OutStream
-	displayID := stringid.TruncateID(string(l.DiffID()))
-
-	out.Write(p.sf.FormatProgress(displayID, "Preparing", nil))
-
-	arch, err := l.TarStream()
-	if err != nil {
-		return "", err
-	}
-	defer arch.Close()
-
-	// Send the layer
-	layerUpload, err := bs.Create(context.Background())
-	if err != nil {
-		return "", err
-	}
-	defer layerUpload.Close()
-
-	// don't care if this fails; best effort
-	size, _ := l.DiffSize()
-
-	reader := progressreader.New(progressreader.Config{
-		In:        ioutil.NopCloser(arch), // we'll take care of close here.
-		Out:       out,
-		Formatter: p.sf,
-		Size:      size,
-		NewLines:  false,
-		ID:        displayID,
-		Action:    "Pushing",
-	})
-
-	compressedReader := compress(reader)
-
-	digester := digest.Canonical.New()
-	tee := io.TeeReader(compressedReader, digester.Hash())
-
-	out.Write(p.sf.FormatProgress(displayID, "Pushing", nil))
-	nn, err := layerUpload.ReadFrom(tee)
-	compressedReader.Close()
-	if err != nil {
-		return "", err
-	}
-
-	dgst := digester.Digest()
-	if _, err := layerUpload.Commit(context.Background(), distribution.Descriptor{Digest: dgst}); err != nil {
-		return "", err
-	}
-
-	logrus.Debugf("uploaded layer %s (%s), %d bytes", l.DiffID(), dgst, nn)
-	out.Write(p.sf.FormatProgress(displayID, "Pushed", nil))
-
-	return dgst, nil
-}

+ 4 - 4
distribution/push_v2_test.go

@@ -116,10 +116,10 @@ func TestCreateV2Manifest(t *testing.T) {
 		t.Fatalf("json decoding failed: %v", err)
 	}
 
-	fsLayers := map[layer.DiffID]schema1.FSLayer{
-		layer.DiffID("sha256:c6f988f4874bb0add23a778f753c65efe992244e148a1d2ec2a8b664fb66bbd1"): {BlobSum: digest.Digest("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4")},
-		layer.DiffID("sha256:5f70bf18a086007016e948b04aed3b82103a36bea41755b6cddfaf10ace3c6ef"): {BlobSum: digest.Digest("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa")},
-		layer.DiffID("sha256:13f53e08df5a220ab6d13c58b2bf83a59cbdc2e04d0a3f041ddf4b0ba4112d49"): {BlobSum: digest.Digest("sha256:b4ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4")},
+	fsLayers := map[layer.DiffID]digest.Digest{
+		layer.DiffID("sha256:c6f988f4874bb0add23a778f753c65efe992244e148a1d2ec2a8b664fb66bbd1"): digest.Digest("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"),
+		layer.DiffID("sha256:5f70bf18a086007016e948b04aed3b82103a36bea41755b6cddfaf10ace3c6ef"): digest.Digest("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa"),
+		layer.DiffID("sha256:13f53e08df5a220ab6d13c58b2bf83a59cbdc2e04d0a3f041ddf4b0ba4112d49"): digest.Digest("sha256:b4ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"),
 	}
 
 	manifest, err := CreateV2Manifest("testrepo", "testtag", img, fsLayers)

+ 23 - 1
distribution/registry.go

@@ -13,10 +13,12 @@ import (
 	"github.com/docker/distribution"
 	"github.com/docker/distribution/digest"
 	"github.com/docker/distribution/manifest/schema1"
+	"github.com/docker/distribution/registry/api/errcode"
 	"github.com/docker/distribution/registry/client"
 	"github.com/docker/distribution/registry/client/auth"
 	"github.com/docker/distribution/registry/client/transport"
 	"github.com/docker/docker/cliconfig"
+	"github.com/docker/docker/distribution/xfer"
 	"github.com/docker/docker/registry"
 	"golang.org/x/net/context"
 )
@@ -59,7 +61,7 @@ func NewV2Repository(repoInfo *registry.RepositoryInfo, endpoint registry.APIEnd
 	authTransport := transport.NewTransport(base, modifiers...)
 	pingClient := &http.Client{
 		Transport: authTransport,
-		Timeout:   5 * time.Second,
+		Timeout:   15 * time.Second,
 	}
 	endpointStr := strings.TrimRight(endpoint.URL, "/") + "/v2/"
 	req, err := http.NewRequest("GET", endpointStr, nil)
@@ -132,3 +134,23 @@ func (th *existingTokenHandler) AuthorizeRequest(req *http.Request, params map[s
 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", th.token))
 	return nil
 }
+
+// retryOnError wraps the error in xfer.DoNotRetry if we should not retry the
+// operation after this error.
+func retryOnError(err error) error {
+	switch v := err.(type) {
+	case errcode.Errors:
+		return retryOnError(v[0])
+	case errcode.Error:
+		switch v.Code {
+		case errcode.ErrorCodeUnauthorized, errcode.ErrorCodeUnsupported, errcode.ErrorCodeDenied:
+			return xfer.DoNotRetry{Err: err}
+		}
+
+	}
+	// let's be nice and fallback if the error is a completely
+	// unexpected one.
+	// If new errors have to be handled in some way, please
+	// add them to the switch above.
+	return err
+}

+ 3 - 4
distribution/registry_unit_test.go

@@ -11,9 +11,9 @@ import (
 	"github.com/docker/distribution/reference"
 	"github.com/docker/distribution/registry/client/auth"
 	"github.com/docker/docker/cliconfig"
-	"github.com/docker/docker/pkg/streamformatter"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/utils"
+	"golang.org/x/net/context"
 )
 
 func TestTokenPassThru(t *testing.T) {
@@ -72,8 +72,7 @@ func TestTokenPassThru(t *testing.T) {
 		MetaHeaders: http.Header{},
 		AuthConfig:  authConfig,
 	}
-	sf := streamformatter.NewJSONStreamFormatter()
-	puller, err := newPuller(endpoint, repoInfo, imagePullConfig, sf)
+	puller, err := newPuller(endpoint, repoInfo, imagePullConfig)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -86,7 +85,7 @@ func TestTokenPassThru(t *testing.T) {
 	logrus.Debug("About to pull")
 	// We expect it to fail, since we haven't mock'd the full registry exchange in our handler above
 	tag, _ := reference.WithTag(n, "tag_goes_here")
-	_ = p.pullV2Repository(tag)
+	_ = p.pullV2Repository(context.Background(), tag)
 
 	if !gotToken {
 		t.Fatal("Failed to receive registry token")

+ 420 - 0
distribution/xfer/download.go

@@ -0,0 +1,420 @@
+package xfer
+
+import (
+	"errors"
+	"fmt"
+	"io"
+	"time"
+
+	"github.com/Sirupsen/logrus"
+	"github.com/docker/docker/image"
+	"github.com/docker/docker/layer"
+	"github.com/docker/docker/pkg/archive"
+	"github.com/docker/docker/pkg/ioutils"
+	"github.com/docker/docker/pkg/progress"
+	"golang.org/x/net/context"
+)
+
+const maxDownloadAttempts = 5
+
+// LayerDownloadManager figures out which layers need to be downloaded, then
+// registers and downloads those, taking into account dependencies between
+// layers.
+type LayerDownloadManager struct {
+	layerStore layer.Store
+	tm         TransferManager
+}
+
+// NewLayerDownloadManager returns a new LayerDownloadManager.
+func NewLayerDownloadManager(layerStore layer.Store, concurrencyLimit int) *LayerDownloadManager {
+	return &LayerDownloadManager{
+		layerStore: layerStore,
+		tm:         NewTransferManager(concurrencyLimit),
+	}
+}
+
+type downloadTransfer struct {
+	Transfer
+
+	layerStore layer.Store
+	layer      layer.Layer
+	err        error
+}
+
+// result returns the layer resulting from the download, if the download
+// and registration were successful.
+func (d *downloadTransfer) result() (layer.Layer, error) {
+	return d.layer, d.err
+}
+
+// A DownloadDescriptor references a layer that may need to be downloaded.
+type DownloadDescriptor interface {
+	// Key returns the key used to deduplicate downloads.
+	Key() string
+	// ID returns the ID for display purposes.
+	ID() string
+	// DiffID should return the DiffID for this layer, or an error
+	// if it is unknown (for example, if it has not been downloaded
+	// before).
+	DiffID() (layer.DiffID, error)
+	// Download is called to perform the download.
+	Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error)
+}
+
+// DownloadDescriptorWithRegistered is a DownloadDescriptor that has an
+// additional Registered method which gets called after a downloaded layer is
+// registered. This allows the user of the download manager to know the DiffID
+// of each registered layer. This method is called if a cast to
+// DownloadDescriptorWithRegistered is successful.
+type DownloadDescriptorWithRegistered interface {
+	DownloadDescriptor
+	Registered(diffID layer.DiffID)
+}
+
+// Download is a blocking function which ensures the requested layers are
+// present in the layer store. It uses the string returned by the Key method to
+// deduplicate downloads. If a given layer is not already known to present in
+// the layer store, and the key is not used by an in-progress download, the
+// Download method is called to get the layer tar data. Layers are then
+// registered in the appropriate order.  The caller must call the returned
+// release function once it is is done with the returned RootFS object.
+func (ldm *LayerDownloadManager) Download(ctx context.Context, initialRootFS image.RootFS, layers []DownloadDescriptor, progressOutput progress.Output) (image.RootFS, func(), error) {
+	var (
+		topLayer       layer.Layer
+		topDownload    *downloadTransfer
+		watcher        *Watcher
+		missingLayer   bool
+		transferKey    = ""
+		downloadsByKey = make(map[string]*downloadTransfer)
+	)
+
+	rootFS := initialRootFS
+	for _, descriptor := range layers {
+		key := descriptor.Key()
+		transferKey += key
+
+		if !missingLayer {
+			missingLayer = true
+			diffID, err := descriptor.DiffID()
+			if err == nil {
+				getRootFS := rootFS
+				getRootFS.Append(diffID)
+				l, err := ldm.layerStore.Get(getRootFS.ChainID())
+				if err == nil {
+					// Layer already exists.
+					logrus.Debugf("Layer already exists: %s", descriptor.ID())
+					progress.Update(progressOutput, descriptor.ID(), "Already exists")
+					if topLayer != nil {
+						layer.ReleaseAndLog(ldm.layerStore, topLayer)
+					}
+					topLayer = l
+					missingLayer = false
+					rootFS.Append(diffID)
+					continue
+				}
+			}
+		}
+
+		// Does this layer have the same data as a previous layer in
+		// the stack? If so, avoid downloading it more than once.
+		var topDownloadUncasted Transfer
+		if existingDownload, ok := downloadsByKey[key]; ok {
+			xferFunc := ldm.makeDownloadFuncFromDownload(descriptor, existingDownload, topDownload)
+			defer topDownload.Transfer.Release(watcher)
+			topDownloadUncasted, watcher = ldm.tm.Transfer(transferKey, xferFunc, progressOutput)
+			topDownload = topDownloadUncasted.(*downloadTransfer)
+			continue
+		}
+
+		// Layer is not known to exist - download and register it.
+		progress.Update(progressOutput, descriptor.ID(), "Pulling fs layer")
+
+		var xferFunc DoFunc
+		if topDownload != nil {
+			xferFunc = ldm.makeDownloadFunc(descriptor, "", topDownload)
+			defer topDownload.Transfer.Release(watcher)
+		} else {
+			xferFunc = ldm.makeDownloadFunc(descriptor, rootFS.ChainID(), nil)
+		}
+		topDownloadUncasted, watcher = ldm.tm.Transfer(transferKey, xferFunc, progressOutput)
+		topDownload = topDownloadUncasted.(*downloadTransfer)
+		downloadsByKey[key] = topDownload
+	}
+
+	if topDownload == nil {
+		return rootFS, func() { layer.ReleaseAndLog(ldm.layerStore, topLayer) }, nil
+	}
+
+	// Won't be using the list built up so far - will generate it
+	// from downloaded layers instead.
+	rootFS.DiffIDs = []layer.DiffID{}
+
+	defer func() {
+		if topLayer != nil {
+			layer.ReleaseAndLog(ldm.layerStore, topLayer)
+		}
+	}()
+
+	select {
+	case <-ctx.Done():
+		topDownload.Transfer.Release(watcher)
+		return rootFS, func() {}, ctx.Err()
+	case <-topDownload.Done():
+		break
+	}
+
+	l, err := topDownload.result()
+	if err != nil {
+		topDownload.Transfer.Release(watcher)
+		return rootFS, func() {}, err
+	}
+
+	// Must do this exactly len(layers) times, so we don't include the
+	// base layer on Windows.
+	for range layers {
+		if l == nil {
+			topDownload.Transfer.Release(watcher)
+			return rootFS, func() {}, errors.New("internal error: too few parent layers")
+		}
+		rootFS.DiffIDs = append([]layer.DiffID{l.DiffID()}, rootFS.DiffIDs...)
+		l = l.Parent()
+	}
+	return rootFS, func() { topDownload.Transfer.Release(watcher) }, err
+}
+
+// makeDownloadFunc returns a function that performs the layer download and
+// registration. If parentDownload is non-nil, it waits for that download to
+// complete before the registration step, and registers the downloaded data
+// on top of parentDownload's resulting layer. Otherwise, it registers the
+// layer on top of the ChainID given by parentLayer.
+func (ldm *LayerDownloadManager) makeDownloadFunc(descriptor DownloadDescriptor, parentLayer layer.ChainID, parentDownload *downloadTransfer) DoFunc {
+	return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
+		d := &downloadTransfer{
+			Transfer:   NewTransfer(),
+			layerStore: ldm.layerStore,
+		}
+
+		go func() {
+			defer func() {
+				close(progressChan)
+			}()
+
+			progressOutput := progress.ChanOutput(progressChan)
+
+			select {
+			case <-start:
+			default:
+				progress.Update(progressOutput, descriptor.ID(), "Waiting")
+				<-start
+			}
+
+			if parentDownload != nil {
+				// Did the parent download already fail or get
+				// cancelled?
+				select {
+				case <-parentDownload.Done():
+					_, err := parentDownload.result()
+					if err != nil {
+						d.err = err
+						return
+					}
+				default:
+				}
+			}
+
+			var (
+				downloadReader io.ReadCloser
+				size           int64
+				err            error
+				retries        int
+			)
+
+			for {
+				downloadReader, size, err = descriptor.Download(d.Transfer.Context(), progressOutput)
+				if err == nil {
+					break
+				}
+
+				// If an error was returned because the context
+				// was cancelled, we shouldn't retry.
+				select {
+				case <-d.Transfer.Context().Done():
+					d.err = err
+					return
+				default:
+				}
+
+				retries++
+				if _, isDNR := err.(DoNotRetry); isDNR || retries == maxDownloadAttempts {
+					logrus.Errorf("Download failed: %v", err)
+					d.err = err
+					return
+				}
+
+				logrus.Errorf("Download failed, retrying: %v", err)
+				delay := retries * 5
+				ticker := time.NewTicker(time.Second)
+
+			selectLoop:
+				for {
+					progress.Updatef(progressOutput, descriptor.ID(), "Retrying in %d seconds", delay)
+					select {
+					case <-ticker.C:
+						delay--
+						if delay == 0 {
+							ticker.Stop()
+							break selectLoop
+						}
+					case <-d.Transfer.Context().Done():
+						ticker.Stop()
+						d.err = errors.New("download cancelled during retry delay")
+						return
+					}
+
+				}
+			}
+
+			close(inactive)
+
+			if parentDownload != nil {
+				select {
+				case <-d.Transfer.Context().Done():
+					d.err = errors.New("layer registration cancelled")
+					downloadReader.Close()
+					return
+				case <-parentDownload.Done():
+				}
+
+				l, err := parentDownload.result()
+				if err != nil {
+					d.err = err
+					downloadReader.Close()
+					return
+				}
+				parentLayer = l.ChainID()
+			}
+
+			reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(d.Transfer.Context(), downloadReader), progressOutput, size, descriptor.ID(), "Extracting")
+			defer reader.Close()
+
+			inflatedLayerData, err := archive.DecompressStream(reader)
+			if err != nil {
+				d.err = fmt.Errorf("could not get decompression stream: %v", err)
+				return
+			}
+
+			d.layer, err = d.layerStore.Register(inflatedLayerData, parentLayer)
+			if err != nil {
+				select {
+				case <-d.Transfer.Context().Done():
+					d.err = errors.New("layer registration cancelled")
+				default:
+					d.err = fmt.Errorf("failed to register layer: %v", err)
+				}
+				return
+			}
+
+			progress.Update(progressOutput, descriptor.ID(), "Pull complete")
+			withRegistered, hasRegistered := descriptor.(DownloadDescriptorWithRegistered)
+			if hasRegistered {
+				withRegistered.Registered(d.layer.DiffID())
+			}
+
+			// Doesn't actually need to be its own goroutine, but
+			// done like this so we can defer close(c).
+			go func() {
+				<-d.Transfer.Released()
+				if d.layer != nil {
+					layer.ReleaseAndLog(d.layerStore, d.layer)
+				}
+			}()
+		}()
+
+		return d
+	}
+}
+
+// makeDownloadFuncFromDownload returns a function that performs the layer
+// registration when the layer data is coming from an existing download. It
+// waits for sourceDownload and parentDownload to complete, and then
+// reregisters the data from sourceDownload's top layer on top of
+// parentDownload. This function does not log progress output because it would
+// interfere with the progress reporting for sourceDownload, which has the same
+// Key.
+func (ldm *LayerDownloadManager) makeDownloadFuncFromDownload(descriptor DownloadDescriptor, sourceDownload *downloadTransfer, parentDownload *downloadTransfer) DoFunc {
+	return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
+		d := &downloadTransfer{
+			Transfer:   NewTransfer(),
+			layerStore: ldm.layerStore,
+		}
+
+		go func() {
+			defer func() {
+				close(progressChan)
+			}()
+
+			<-start
+
+			close(inactive)
+
+			select {
+			case <-d.Transfer.Context().Done():
+				d.err = errors.New("layer registration cancelled")
+				return
+			case <-parentDownload.Done():
+			}
+
+			l, err := parentDownload.result()
+			if err != nil {
+				d.err = err
+				return
+			}
+			parentLayer := l.ChainID()
+
+			// sourceDownload should have already finished if
+			// parentDownload finished, but wait for it explicitly
+			// to be sure.
+			select {
+			case <-d.Transfer.Context().Done():
+				d.err = errors.New("layer registration cancelled")
+				return
+			case <-sourceDownload.Done():
+			}
+
+			l, err = sourceDownload.result()
+			if err != nil {
+				d.err = err
+				return
+			}
+
+			layerReader, err := l.TarStream()
+			if err != nil {
+				d.err = err
+				return
+			}
+			defer layerReader.Close()
+
+			d.layer, err = d.layerStore.Register(layerReader, parentLayer)
+			if err != nil {
+				d.err = fmt.Errorf("failed to register layer: %v", err)
+				return
+			}
+
+			withRegistered, hasRegistered := descriptor.(DownloadDescriptorWithRegistered)
+			if hasRegistered {
+				withRegistered.Registered(d.layer.DiffID())
+			}
+
+			// Doesn't actually need to be its own goroutine, but
+			// done like this so we can defer close(c).
+			go func() {
+				<-d.Transfer.Released()
+				if d.layer != nil {
+					layer.ReleaseAndLog(d.layerStore, d.layer)
+				}
+			}()
+		}()
+
+		return d
+	}
+}

+ 332 - 0
distribution/xfer/download_test.go

@@ -0,0 +1,332 @@
+package xfer
+
+import (
+	"bytes"
+	"errors"
+	"io"
+	"io/ioutil"
+	"sync/atomic"
+	"testing"
+	"time"
+
+	"github.com/docker/distribution/digest"
+	"github.com/docker/docker/image"
+	"github.com/docker/docker/layer"
+	"github.com/docker/docker/pkg/archive"
+	"github.com/docker/docker/pkg/progress"
+	"golang.org/x/net/context"
+)
+
+const maxDownloadConcurrency = 3
+
+type mockLayer struct {
+	layerData bytes.Buffer
+	diffID    layer.DiffID
+	chainID   layer.ChainID
+	parent    layer.Layer
+}
+
+func (ml *mockLayer) TarStream() (io.ReadCloser, error) {
+	return ioutil.NopCloser(bytes.NewBuffer(ml.layerData.Bytes())), nil
+}
+
+func (ml *mockLayer) ChainID() layer.ChainID {
+	return ml.chainID
+}
+
+func (ml *mockLayer) DiffID() layer.DiffID {
+	return ml.diffID
+}
+
+func (ml *mockLayer) Parent() layer.Layer {
+	return ml.parent
+}
+
+func (ml *mockLayer) Size() (size int64, err error) {
+	return 0, nil
+}
+
+func (ml *mockLayer) DiffSize() (size int64, err error) {
+	return 0, nil
+}
+
+func (ml *mockLayer) Metadata() (map[string]string, error) {
+	return make(map[string]string), nil
+}
+
+type mockLayerStore struct {
+	layers map[layer.ChainID]*mockLayer
+}
+
+func createChainIDFromParent(parent layer.ChainID, dgsts ...layer.DiffID) layer.ChainID {
+	if len(dgsts) == 0 {
+		return parent
+	}
+	if parent == "" {
+		return createChainIDFromParent(layer.ChainID(dgsts[0]), dgsts[1:]...)
+	}
+	// H = "H(n-1) SHA256(n)"
+	dgst, err := digest.FromBytes([]byte(string(parent) + " " + string(dgsts[0])))
+	if err != nil {
+		// Digest calculation is not expected to throw an error,
+		// any error at this point is a program error
+		panic(err)
+	}
+	return createChainIDFromParent(layer.ChainID(dgst), dgsts[1:]...)
+}
+
+func (ls *mockLayerStore) Register(reader io.Reader, parentID layer.ChainID) (layer.Layer, error) {
+	var (
+		parent layer.Layer
+		err    error
+	)
+
+	if parentID != "" {
+		parent, err = ls.Get(parentID)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	l := &mockLayer{parent: parent}
+	_, err = l.layerData.ReadFrom(reader)
+	if err != nil {
+		return nil, err
+	}
+	diffID, err := digest.FromBytes(l.layerData.Bytes())
+	if err != nil {
+		return nil, err
+	}
+	l.diffID = layer.DiffID(diffID)
+	l.chainID = createChainIDFromParent(parentID, l.diffID)
+
+	ls.layers[l.chainID] = l
+	return l, nil
+}
+
+func (ls *mockLayerStore) Get(chainID layer.ChainID) (layer.Layer, error) {
+	l, ok := ls.layers[chainID]
+	if !ok {
+		return nil, layer.ErrLayerDoesNotExist
+	}
+	return l, nil
+}
+
+func (ls *mockLayerStore) Release(l layer.Layer) ([]layer.Metadata, error) {
+	return []layer.Metadata{}, nil
+}
+
+func (ls *mockLayerStore) Mount(id string, parent layer.ChainID, label string, init layer.MountInit) (layer.RWLayer, error) {
+	return nil, errors.New("not implemented")
+}
+
+func (ls *mockLayerStore) Unmount(id string) error {
+	return errors.New("not implemented")
+}
+
+func (ls *mockLayerStore) DeleteMount(id string) ([]layer.Metadata, error) {
+	return nil, errors.New("not implemented")
+}
+
+func (ls *mockLayerStore) Changes(id string) ([]archive.Change, error) {
+	return nil, errors.New("not implemented")
+}
+
+type mockDownloadDescriptor struct {
+	currentDownloads *int32
+	id               string
+	diffID           layer.DiffID
+	registeredDiffID layer.DiffID
+	expectedDiffID   layer.DiffID
+	simulateRetries  int
+}
+
+// Key returns the key used to deduplicate downloads.
+func (d *mockDownloadDescriptor) Key() string {
+	return d.id
+}
+
+// ID returns the ID for display purposes.
+func (d *mockDownloadDescriptor) ID() string {
+	return d.id
+}
+
+// DiffID should return the DiffID for this layer, or an error
+// if it is unknown (for example, if it has not been downloaded
+// before).
+func (d *mockDownloadDescriptor) DiffID() (layer.DiffID, error) {
+	if d.diffID != "" {
+		return d.diffID, nil
+	}
+	return "", errors.New("no diffID available")
+}
+
+func (d *mockDownloadDescriptor) Registered(diffID layer.DiffID) {
+	d.registeredDiffID = diffID
+}
+
+func (d *mockDownloadDescriptor) mockTarStream() io.ReadCloser {
+	// The mock implementation returns the ID repeated 5 times as a tar
+	// stream instead of actual tar data. The data is ignored except for
+	// computing IDs.
+	return ioutil.NopCloser(bytes.NewBuffer([]byte(d.id + d.id + d.id + d.id + d.id)))
+}
+
+// Download is called to perform the download.
+func (d *mockDownloadDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) {
+	if d.currentDownloads != nil {
+		defer atomic.AddInt32(d.currentDownloads, -1)
+
+		if atomic.AddInt32(d.currentDownloads, 1) > maxDownloadConcurrency {
+			return nil, 0, errors.New("concurrency limit exceeded")
+		}
+	}
+
+	// Sleep a bit to simulate a time-consuming download.
+	for i := int64(0); i <= 10; i++ {
+		select {
+		case <-ctx.Done():
+			return nil, 0, ctx.Err()
+		case <-time.After(10 * time.Millisecond):
+			progressOutput.WriteProgress(progress.Progress{ID: d.ID(), Action: "Downloading", Current: i, Total: 10})
+		}
+	}
+
+	if d.simulateRetries != 0 {
+		d.simulateRetries--
+		return nil, 0, errors.New("simulating retry")
+	}
+
+	return d.mockTarStream(), 0, nil
+}
+
+func downloadDescriptors(currentDownloads *int32) []DownloadDescriptor {
+	return []DownloadDescriptor{
+		&mockDownloadDescriptor{
+			currentDownloads: currentDownloads,
+			id:               "id1",
+			expectedDiffID:   layer.DiffID("sha256:68e2c75dc5c78ea9240689c60d7599766c213ae210434c53af18470ae8c53ec1"),
+		},
+		&mockDownloadDescriptor{
+			currentDownloads: currentDownloads,
+			id:               "id2",
+			expectedDiffID:   layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"),
+		},
+		&mockDownloadDescriptor{
+			currentDownloads: currentDownloads,
+			id:               "id3",
+			expectedDiffID:   layer.DiffID("sha256:58745a8bbd669c25213e9de578c4da5c8ee1c836b3581432c2b50e38a6753300"),
+		},
+		&mockDownloadDescriptor{
+			currentDownloads: currentDownloads,
+			id:               "id2",
+			expectedDiffID:   layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"),
+		},
+		&mockDownloadDescriptor{
+			currentDownloads: currentDownloads,
+			id:               "id4",
+			expectedDiffID:   layer.DiffID("sha256:0dfb5b9577716cc173e95af7c10289322c29a6453a1718addc00c0c5b1330936"),
+			simulateRetries:  1,
+		},
+		&mockDownloadDescriptor{
+			currentDownloads: currentDownloads,
+			id:               "id5",
+			expectedDiffID:   layer.DiffID("sha256:0a5f25fa1acbc647f6112a6276735d0fa01e4ee2aa7ec33015e337350e1ea23d"),
+		},
+	}
+}
+
+func TestSuccessfulDownload(t *testing.T) {
+	layerStore := &mockLayerStore{make(map[layer.ChainID]*mockLayer)}
+	ldm := NewLayerDownloadManager(layerStore, maxDownloadConcurrency)
+
+	progressChan := make(chan progress.Progress)
+	progressDone := make(chan struct{})
+	receivedProgress := make(map[string]int64)
+
+	go func() {
+		for p := range progressChan {
+			if p.Action == "Downloading" {
+				receivedProgress[p.ID] = p.Current
+			} else if p.Action == "Already exists" {
+				receivedProgress[p.ID] = -1
+			}
+		}
+		close(progressDone)
+	}()
+
+	var currentDownloads int32
+	descriptors := downloadDescriptors(&currentDownloads)
+
+	firstDescriptor := descriptors[0].(*mockDownloadDescriptor)
+
+	// Pre-register the first layer to simulate an already-existing layer
+	l, err := layerStore.Register(firstDescriptor.mockTarStream(), "")
+	if err != nil {
+		t.Fatal(err)
+	}
+	firstDescriptor.diffID = l.DiffID()
+
+	rootFS, releaseFunc, err := ldm.Download(context.Background(), *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan))
+	if err != nil {
+		t.Fatalf("download error: %v", err)
+	}
+
+	releaseFunc()
+
+	close(progressChan)
+	<-progressDone
+
+	if len(rootFS.DiffIDs) != len(descriptors) {
+		t.Fatal("got wrong number of diffIDs in rootfs")
+	}
+
+	for i, d := range descriptors {
+		descriptor := d.(*mockDownloadDescriptor)
+
+		if descriptor.diffID != "" {
+			if receivedProgress[d.ID()] != -1 {
+				t.Fatalf("did not get 'already exists' message for %v", d.ID())
+			}
+		} else if receivedProgress[d.ID()] != 10 {
+			t.Fatalf("missing or wrong progress output for %v (got: %d)", d.ID(), receivedProgress[d.ID()])
+		}
+
+		if rootFS.DiffIDs[i] != descriptor.expectedDiffID {
+			t.Fatalf("rootFS item %d has the wrong diffID (expected: %v got: %v)", i, descriptor.expectedDiffID, rootFS.DiffIDs[i])
+		}
+
+		if descriptor.diffID == "" && descriptor.registeredDiffID != rootFS.DiffIDs[i] {
+			t.Fatal("diffID mismatch between rootFS and Registered callback")
+		}
+	}
+}
+
+func TestCancelledDownload(t *testing.T) {
+	ldm := NewLayerDownloadManager(&mockLayerStore{make(map[layer.ChainID]*mockLayer)}, maxDownloadConcurrency)
+
+	progressChan := make(chan progress.Progress)
+	progressDone := make(chan struct{})
+
+	go func() {
+		for range progressChan {
+		}
+		close(progressDone)
+	}()
+
+	ctx, cancel := context.WithCancel(context.Background())
+
+	go func() {
+		<-time.After(time.Millisecond)
+		cancel()
+	}()
+
+	descriptors := downloadDescriptors(nil)
+	_, _, err := ldm.Download(ctx, *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan))
+	if err != context.Canceled {
+		t.Fatal("expected download to be cancelled")
+	}
+
+	close(progressChan)
+	<-progressDone
+}

+ 343 - 0
distribution/xfer/transfer.go

@@ -0,0 +1,343 @@
+package xfer
+
+import (
+	"sync"
+
+	"github.com/docker/docker/pkg/progress"
+	"golang.org/x/net/context"
+)
+
+// DoNotRetry is an error wrapper indicating that the error cannot be resolved
+// with a retry.
+type DoNotRetry struct {
+	Err error
+}
+
+// Error returns the stringified representation of the encapsulated error.
+func (e DoNotRetry) Error() string {
+	return e.Err.Error()
+}
+
+// Watcher is returned by Watch and can be passed to Release to stop watching.
+type Watcher struct {
+	// signalChan is used to signal to the watcher goroutine that
+	// new progress information is available, or that the transfer
+	// has finished.
+	signalChan chan struct{}
+	// releaseChan signals to the watcher goroutine that the watcher
+	// should be detached.
+	releaseChan chan struct{}
+	// running remains open as long as the watcher is watching the
+	// transfer. It gets closed if the transfer finishes or the
+	// watcher is detached.
+	running chan struct{}
+}
+
+// Transfer represents an in-progress transfer.
+type Transfer interface {
+	Watch(progressOutput progress.Output) *Watcher
+	Release(*Watcher)
+	Context() context.Context
+	Cancel()
+	Done() <-chan struct{}
+	Released() <-chan struct{}
+	Broadcast(masterProgressChan <-chan progress.Progress)
+}
+
+type transfer struct {
+	mu sync.Mutex
+
+	ctx    context.Context
+	cancel context.CancelFunc
+
+	// watchers keeps track of the goroutines monitoring progress output,
+	// indexed by the channels that release them.
+	watchers map[chan struct{}]*Watcher
+
+	// lastProgress is the most recently received progress event.
+	lastProgress progress.Progress
+	// hasLastProgress is true when lastProgress has been set.
+	hasLastProgress bool
+
+	// running remains open as long as the transfer is in progress.
+	running chan struct{}
+	// hasWatchers stays open until all watchers release the trasnfer.
+	hasWatchers chan struct{}
+
+	// broadcastDone is true if the master progress channel has closed.
+	broadcastDone bool
+	// broadcastSyncChan allows watchers to "ping" the broadcasting
+	// goroutine to wait for it for deplete its input channel. This ensures
+	// a detaching watcher won't miss an event that was sent before it
+	// started detaching.
+	broadcastSyncChan chan struct{}
+}
+
+// NewTransfer creates a new transfer.
+func NewTransfer() Transfer {
+	t := &transfer{
+		watchers:          make(map[chan struct{}]*Watcher),
+		running:           make(chan struct{}),
+		hasWatchers:       make(chan struct{}),
+		broadcastSyncChan: make(chan struct{}),
+	}
+
+	// This uses context.Background instead of a caller-supplied context
+	// so that a transfer won't be cancelled automatically if the client
+	// which requested it is ^C'd (there could be other viewers).
+	t.ctx, t.cancel = context.WithCancel(context.Background())
+
+	return t
+}
+
+// Broadcast copies the progress and error output to all viewers.
+func (t *transfer) Broadcast(masterProgressChan <-chan progress.Progress) {
+	for {
+		var (
+			p  progress.Progress
+			ok bool
+		)
+		select {
+		case p, ok = <-masterProgressChan:
+		default:
+			// We've depleted the channel, so now we can handle
+			// reads on broadcastSyncChan to let detaching watchers
+			// know we're caught up.
+			select {
+			case <-t.broadcastSyncChan:
+				continue
+			case p, ok = <-masterProgressChan:
+			}
+		}
+
+		t.mu.Lock()
+		if ok {
+			t.lastProgress = p
+			t.hasLastProgress = true
+			for _, w := range t.watchers {
+				select {
+				case w.signalChan <- struct{}{}:
+				default:
+				}
+			}
+
+		} else {
+			t.broadcastDone = true
+		}
+		t.mu.Unlock()
+		if !ok {
+			close(t.running)
+			return
+		}
+	}
+}
+
+// Watch adds a watcher to the transfer. The supplied channel gets progress
+// updates and is closed when the transfer finishes.
+func (t *transfer) Watch(progressOutput progress.Output) *Watcher {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+
+	w := &Watcher{
+		releaseChan: make(chan struct{}),
+		signalChan:  make(chan struct{}),
+		running:     make(chan struct{}),
+	}
+
+	if t.broadcastDone {
+		close(w.running)
+		return w
+	}
+
+	t.watchers[w.releaseChan] = w
+
+	go func() {
+		defer func() {
+			close(w.running)
+		}()
+		done := false
+		for {
+			t.mu.Lock()
+			hasLastProgress := t.hasLastProgress
+			lastProgress := t.lastProgress
+			t.mu.Unlock()
+
+			// This might write the last progress item a
+			// second time (since channel closure also gets
+			// us here), but that's fine.
+			if hasLastProgress {
+				progressOutput.WriteProgress(lastProgress)
+			}
+
+			if done {
+				return
+			}
+
+			select {
+			case <-w.signalChan:
+			case <-w.releaseChan:
+				done = true
+				// Since the watcher is going to detach, make
+				// sure the broadcaster is caught up so we
+				// don't miss anything.
+				select {
+				case t.broadcastSyncChan <- struct{}{}:
+				case <-t.running:
+				}
+			case <-t.running:
+				done = true
+			}
+		}
+	}()
+
+	return w
+}
+
+// Release is the inverse of Watch; indicating that the watcher no longer wants
+// to be notified about the progress of the transfer. All calls to Watch must
+// be paired with later calls to Release so that the lifecycle of the transfer
+// is properly managed.
+func (t *transfer) Release(watcher *Watcher) {
+	t.mu.Lock()
+	delete(t.watchers, watcher.releaseChan)
+
+	if len(t.watchers) == 0 {
+		close(t.hasWatchers)
+		t.cancel()
+	}
+	t.mu.Unlock()
+
+	close(watcher.releaseChan)
+	// Block until the watcher goroutine completes
+	<-watcher.running
+}
+
+// Done returns a channel which is closed if the transfer completes or is
+// cancelled. Note that having 0 watchers causes a transfer to be cancelled.
+func (t *transfer) Done() <-chan struct{} {
+	// Note that this doesn't return t.ctx.Done() because that channel will
+	// be closed the moment Cancel is called, and we need to return a
+	// channel that blocks until a cancellation is actually acknowledged by
+	// the transfer function.
+	return t.running
+}
+
+// Released returns a channel which is closed once all watchers release the
+// transfer.
+func (t *transfer) Released() <-chan struct{} {
+	return t.hasWatchers
+}
+
+// Context returns the context associated with the transfer.
+func (t *transfer) Context() context.Context {
+	return t.ctx
+}
+
+// Cancel cancels the context associated with the transfer.
+func (t *transfer) Cancel() {
+	t.cancel()
+}
+
+// DoFunc is a function called by the transfer manager to actually perform
+// a transfer. It should be non-blocking. It should wait until the start channel
+// is closed before transfering any data. If the function closes inactive, that
+// signals to the transfer manager that the job is no longer actively moving
+// data - for example, it may be waiting for a dependent tranfer to finish.
+// This prevents it from taking up a slot.
+type DoFunc func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer
+
+// TransferManager is used by LayerDownloadManager and LayerUploadManager to
+// schedule and deduplicate transfers. It is up to the TransferManager
+// implementation to make the scheduling and concurrency decisions.
+type TransferManager interface {
+	// Transfer checks if a transfer with the given key is in progress. If
+	// so, it returns progress and error output from that transfer.
+	// Otherwise, it will call xferFunc to initiate the transfer.
+	Transfer(key string, xferFunc DoFunc, progressOutput progress.Output) (Transfer, *Watcher)
+}
+
+type transferManager struct {
+	mu sync.Mutex
+
+	concurrencyLimit int
+	activeTransfers  int
+	transfers        map[string]Transfer
+	waitingTransfers []chan struct{}
+}
+
+// NewTransferManager returns a new TransferManager.
+func NewTransferManager(concurrencyLimit int) TransferManager {
+	return &transferManager{
+		concurrencyLimit: concurrencyLimit,
+		transfers:        make(map[string]Transfer),
+	}
+}
+
+// Transfer checks if a transfer matching the given key is in progress. If not,
+// it starts one by calling xferFunc. The caller supplies a channel which
+// receives progress output from the transfer.
+func (tm *transferManager) Transfer(key string, xferFunc DoFunc, progressOutput progress.Output) (Transfer, *Watcher) {
+	tm.mu.Lock()
+	defer tm.mu.Unlock()
+
+	if xfer, present := tm.transfers[key]; present {
+		// Transfer is already in progress.
+		watcher := xfer.Watch(progressOutput)
+		return xfer, watcher
+	}
+
+	start := make(chan struct{})
+	inactive := make(chan struct{})
+
+	if tm.activeTransfers < tm.concurrencyLimit {
+		close(start)
+		tm.activeTransfers++
+	} else {
+		tm.waitingTransfers = append(tm.waitingTransfers, start)
+	}
+
+	masterProgressChan := make(chan progress.Progress)
+	xfer := xferFunc(masterProgressChan, start, inactive)
+	watcher := xfer.Watch(progressOutput)
+	go xfer.Broadcast(masterProgressChan)
+	tm.transfers[key] = xfer
+
+	// When the transfer is finished, remove from the map.
+	go func() {
+		for {
+			select {
+			case <-inactive:
+				tm.mu.Lock()
+				tm.inactivate(start)
+				tm.mu.Unlock()
+				inactive = nil
+			case <-xfer.Done():
+				tm.mu.Lock()
+				if inactive != nil {
+					tm.inactivate(start)
+				}
+				delete(tm.transfers, key)
+				tm.mu.Unlock()
+				return
+			}
+		}
+	}()
+
+	return xfer, watcher
+}
+
+func (tm *transferManager) inactivate(start chan struct{}) {
+	// If the transfer was started, remove it from the activeTransfers
+	// count.
+	select {
+	case <-start:
+		// Start next transfer if any are waiting
+		if len(tm.waitingTransfers) != 0 {
+			close(tm.waitingTransfers[0])
+			tm.waitingTransfers = tm.waitingTransfers[1:]
+		} else {
+			tm.activeTransfers--
+		}
+	default:
+	}
+}

+ 385 - 0
distribution/xfer/transfer_test.go

@@ -0,0 +1,385 @@
+package xfer
+
+import (
+	"sync/atomic"
+	"testing"
+	"time"
+
+	"github.com/docker/docker/pkg/progress"
+)
+
+func TestTransfer(t *testing.T) {
+	makeXferFunc := func(id string) DoFunc {
+		return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
+			select {
+			case <-start:
+			default:
+				t.Fatalf("transfer function not started even though concurrency limit not reached")
+			}
+
+			xfer := NewTransfer()
+			go func() {
+				for i := 0; i <= 10; i++ {
+					progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10}
+					time.Sleep(10 * time.Millisecond)
+				}
+				close(progressChan)
+			}()
+			return xfer
+		}
+	}
+
+	tm := NewTransferManager(5)
+	progressChan := make(chan progress.Progress)
+	progressDone := make(chan struct{})
+	receivedProgress := make(map[string]int64)
+
+	go func() {
+		for p := range progressChan {
+			val, present := receivedProgress[p.ID]
+			if !present {
+				if p.Current != 0 {
+					t.Fatalf("got unexpected progress value: %d (expected 0)", p.Current)
+				}
+			} else if p.Current == 10 {
+				// Special case: last progress output may be
+				// repeated because the transfer finishing
+				// causes the latest progress output to be
+				// written to the channel (in case the watcher
+				// missed it).
+				if p.Current != 9 && p.Current != 10 {
+					t.Fatalf("got unexpected progress value: %d (expected %d)", p.Current, val+1)
+				}
+			} else if p.Current != val+1 {
+				t.Fatalf("got unexpected progress value: %d (expected %d)", p.Current, val+1)
+			}
+			receivedProgress[p.ID] = p.Current
+		}
+		close(progressDone)
+	}()
+
+	// Start a few transfers
+	ids := []string{"id1", "id2", "id3"}
+	xfers := make([]Transfer, len(ids))
+	watchers := make([]*Watcher, len(ids))
+	for i, id := range ids {
+		xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan))
+	}
+
+	for i, xfer := range xfers {
+		<-xfer.Done()
+		xfer.Release(watchers[i])
+	}
+	close(progressChan)
+	<-progressDone
+
+	for _, id := range ids {
+		if receivedProgress[id] != 10 {
+			t.Fatalf("final progress value %d instead of 10", receivedProgress[id])
+		}
+	}
+}
+
+func TestConcurrencyLimit(t *testing.T) {
+	concurrencyLimit := 3
+	var runningJobs int32
+
+	makeXferFunc := func(id string) DoFunc {
+		return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
+			xfer := NewTransfer()
+			go func() {
+				<-start
+				totalJobs := atomic.AddInt32(&runningJobs, 1)
+				if int(totalJobs) > concurrencyLimit {
+					t.Fatalf("too many jobs running")
+				}
+				for i := 0; i <= 10; i++ {
+					progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10}
+					time.Sleep(10 * time.Millisecond)
+				}
+				atomic.AddInt32(&runningJobs, -1)
+				close(progressChan)
+			}()
+			return xfer
+		}
+	}
+
+	tm := NewTransferManager(concurrencyLimit)
+	progressChan := make(chan progress.Progress)
+	progressDone := make(chan struct{})
+	receivedProgress := make(map[string]int64)
+
+	go func() {
+		for p := range progressChan {
+			receivedProgress[p.ID] = p.Current
+		}
+		close(progressDone)
+	}()
+
+	// Start more transfers than the concurrency limit
+	ids := []string{"id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"}
+	xfers := make([]Transfer, len(ids))
+	watchers := make([]*Watcher, len(ids))
+	for i, id := range ids {
+		xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan))
+	}
+
+	for i, xfer := range xfers {
+		<-xfer.Done()
+		xfer.Release(watchers[i])
+	}
+	close(progressChan)
+	<-progressDone
+
+	for _, id := range ids {
+		if receivedProgress[id] != 10 {
+			t.Fatalf("final progress value %d instead of 10", receivedProgress[id])
+		}
+	}
+}
+
+func TestInactiveJobs(t *testing.T) {
+	concurrencyLimit := 3
+	var runningJobs int32
+	testDone := make(chan struct{})
+
+	makeXferFunc := func(id string) DoFunc {
+		return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
+			xfer := NewTransfer()
+			go func() {
+				<-start
+				totalJobs := atomic.AddInt32(&runningJobs, 1)
+				if int(totalJobs) > concurrencyLimit {
+					t.Fatalf("too many jobs running")
+				}
+				for i := 0; i <= 10; i++ {
+					progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10}
+					time.Sleep(10 * time.Millisecond)
+				}
+				atomic.AddInt32(&runningJobs, -1)
+				close(inactive)
+				<-testDone
+				close(progressChan)
+			}()
+			return xfer
+		}
+	}
+
+	tm := NewTransferManager(concurrencyLimit)
+	progressChan := make(chan progress.Progress)
+	progressDone := make(chan struct{})
+	receivedProgress := make(map[string]int64)
+
+	go func() {
+		for p := range progressChan {
+			receivedProgress[p.ID] = p.Current
+		}
+		close(progressDone)
+	}()
+
+	// Start more transfers than the concurrency limit
+	ids := []string{"id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"}
+	xfers := make([]Transfer, len(ids))
+	watchers := make([]*Watcher, len(ids))
+	for i, id := range ids {
+		xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan))
+	}
+
+	close(testDone)
+	for i, xfer := range xfers {
+		<-xfer.Done()
+		xfer.Release(watchers[i])
+	}
+	close(progressChan)
+	<-progressDone
+
+	for _, id := range ids {
+		if receivedProgress[id] != 10 {
+			t.Fatalf("final progress value %d instead of 10", receivedProgress[id])
+		}
+	}
+}
+
+func TestWatchRelease(t *testing.T) {
+	ready := make(chan struct{})
+
+	makeXferFunc := func(id string) DoFunc {
+		return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
+			xfer := NewTransfer()
+			go func() {
+				defer func() {
+					close(progressChan)
+				}()
+				<-ready
+				for i := int64(0); ; i++ {
+					select {
+					case <-time.After(10 * time.Millisecond):
+					case <-xfer.Context().Done():
+						return
+					}
+					progressChan <- progress.Progress{ID: id, Action: "testing", Current: i, Total: 10}
+				}
+			}()
+			return xfer
+		}
+	}
+
+	tm := NewTransferManager(5)
+
+	type watcherInfo struct {
+		watcher               *Watcher
+		progressChan          chan progress.Progress
+		progressDone          chan struct{}
+		receivedFirstProgress chan struct{}
+	}
+
+	progressConsumer := func(w watcherInfo) {
+		first := true
+		for range w.progressChan {
+			if first {
+				close(w.receivedFirstProgress)
+			}
+			first = false
+		}
+		close(w.progressDone)
+	}
+
+	// Start a transfer
+	watchers := make([]watcherInfo, 5)
+	var xfer Transfer
+	watchers[0].progressChan = make(chan progress.Progress)
+	watchers[0].progressDone = make(chan struct{})
+	watchers[0].receivedFirstProgress = make(chan struct{})
+	xfer, watchers[0].watcher = tm.Transfer("id1", makeXferFunc("id1"), progress.ChanOutput(watchers[0].progressChan))
+	go progressConsumer(watchers[0])
+
+	// Give it multiple watchers
+	for i := 1; i != len(watchers); i++ {
+		watchers[i].progressChan = make(chan progress.Progress)
+		watchers[i].progressDone = make(chan struct{})
+		watchers[i].receivedFirstProgress = make(chan struct{})
+		watchers[i].watcher = xfer.Watch(progress.ChanOutput(watchers[i].progressChan))
+		go progressConsumer(watchers[i])
+	}
+
+	// Now that the watchers are set up, allow the transfer goroutine to
+	// proceed.
+	close(ready)
+
+	// Confirm that each watcher gets progress output.
+	for _, w := range watchers {
+		<-w.receivedFirstProgress
+	}
+
+	// Release one watcher every 5ms
+	for _, w := range watchers {
+		xfer.Release(w.watcher)
+		<-time.After(5 * time.Millisecond)
+	}
+
+	// Now that all watchers have been released, Released() should
+	// return a closed channel.
+	<-xfer.Released()
+
+	// Done() should return a closed channel because the xfer func returned
+	// due to cancellation.
+	<-xfer.Done()
+
+	for _, w := range watchers {
+		close(w.progressChan)
+		<-w.progressDone
+	}
+}
+
+func TestDuplicateTransfer(t *testing.T) {
+	ready := make(chan struct{})
+
+	var xferFuncCalls int32
+
+	makeXferFunc := func(id string) DoFunc {
+		return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
+			atomic.AddInt32(&xferFuncCalls, 1)
+			xfer := NewTransfer()
+			go func() {
+				defer func() {
+					close(progressChan)
+				}()
+				<-ready
+				for i := int64(0); ; i++ {
+					select {
+					case <-time.After(10 * time.Millisecond):
+					case <-xfer.Context().Done():
+						return
+					}
+					progressChan <- progress.Progress{ID: id, Action: "testing", Current: i, Total: 10}
+				}
+			}()
+			return xfer
+		}
+	}
+
+	tm := NewTransferManager(5)
+
+	type transferInfo struct {
+		xfer                  Transfer
+		watcher               *Watcher
+		progressChan          chan progress.Progress
+		progressDone          chan struct{}
+		receivedFirstProgress chan struct{}
+	}
+
+	progressConsumer := func(t transferInfo) {
+		first := true
+		for range t.progressChan {
+			if first {
+				close(t.receivedFirstProgress)
+			}
+			first = false
+		}
+		close(t.progressDone)
+	}
+
+	// Try to start multiple transfers with the same ID
+	transfers := make([]transferInfo, 5)
+	for i := range transfers {
+		t := &transfers[i]
+		t.progressChan = make(chan progress.Progress)
+		t.progressDone = make(chan struct{})
+		t.receivedFirstProgress = make(chan struct{})
+		t.xfer, t.watcher = tm.Transfer("id1", makeXferFunc("id1"), progress.ChanOutput(t.progressChan))
+		go progressConsumer(*t)
+	}
+
+	// Allow the transfer goroutine to proceed.
+	close(ready)
+
+	// Confirm that each watcher gets progress output.
+	for _, t := range transfers {
+		<-t.receivedFirstProgress
+	}
+
+	// Confirm that the transfer function was called exactly once.
+	if xferFuncCalls != 1 {
+		t.Fatal("transfer function wasn't called exactly once")
+	}
+
+	// Release one watcher every 5ms
+	for _, t := range transfers {
+		t.xfer.Release(t.watcher)
+		<-time.After(5 * time.Millisecond)
+	}
+
+	for _, t := range transfers {
+		// Now that all watchers have been released, Released() should
+		// return a closed channel.
+		<-t.xfer.Released()
+		// Done() should return a closed channel because the xfer func returned
+		// due to cancellation.
+		<-t.xfer.Done()
+	}
+
+	for _, t := range transfers {
+		close(t.progressChan)
+		<-t.progressDone
+	}
+}

+ 159 - 0
distribution/xfer/upload.go

@@ -0,0 +1,159 @@
+package xfer
+
+import (
+	"errors"
+	"time"
+
+	"github.com/Sirupsen/logrus"
+	"github.com/docker/distribution/digest"
+	"github.com/docker/docker/layer"
+	"github.com/docker/docker/pkg/progress"
+	"golang.org/x/net/context"
+)
+
+const maxUploadAttempts = 5
+
+// LayerUploadManager provides task management and progress reporting for
+// uploads.
+type LayerUploadManager struct {
+	tm TransferManager
+}
+
+// NewLayerUploadManager returns a new LayerUploadManager.
+func NewLayerUploadManager(concurrencyLimit int) *LayerUploadManager {
+	return &LayerUploadManager{
+		tm: NewTransferManager(concurrencyLimit),
+	}
+}
+
+type uploadTransfer struct {
+	Transfer
+
+	diffID layer.DiffID
+	digest digest.Digest
+	err    error
+}
+
+// An UploadDescriptor references a layer that may need to be uploaded.
+type UploadDescriptor interface {
+	// Key returns the key used to deduplicate uploads.
+	Key() string
+	// ID returns the ID for display purposes.
+	ID() string
+	// DiffID should return the DiffID for this layer.
+	DiffID() layer.DiffID
+	// Upload is called to perform the Upload.
+	Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error)
+}
+
+// Upload is a blocking function which ensures the listed layers are present on
+// the remote registry. It uses the string returned by the Key method to
+// deduplicate uploads.
+func (lum *LayerUploadManager) Upload(ctx context.Context, layers []UploadDescriptor, progressOutput progress.Output) (map[layer.DiffID]digest.Digest, error) {
+	var (
+		uploads          []*uploadTransfer
+		digests          = make(map[layer.DiffID]digest.Digest)
+		dedupDescriptors = make(map[string]struct{})
+	)
+
+	for _, descriptor := range layers {
+		progress.Update(progressOutput, descriptor.ID(), "Preparing")
+
+		key := descriptor.Key()
+		if _, present := dedupDescriptors[key]; present {
+			continue
+		}
+		dedupDescriptors[key] = struct{}{}
+
+		xferFunc := lum.makeUploadFunc(descriptor)
+		upload, watcher := lum.tm.Transfer(descriptor.Key(), xferFunc, progressOutput)
+		defer upload.Release(watcher)
+		uploads = append(uploads, upload.(*uploadTransfer))
+	}
+
+	for _, upload := range uploads {
+		select {
+		case <-ctx.Done():
+			return nil, ctx.Err()
+		case <-upload.Transfer.Done():
+			if upload.err != nil {
+				return nil, upload.err
+			}
+			digests[upload.diffID] = upload.digest
+		}
+	}
+
+	return digests, nil
+}
+
+func (lum *LayerUploadManager) makeUploadFunc(descriptor UploadDescriptor) DoFunc {
+	return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
+		u := &uploadTransfer{
+			Transfer: NewTransfer(),
+			diffID:   descriptor.DiffID(),
+		}
+
+		go func() {
+			defer func() {
+				close(progressChan)
+			}()
+
+			progressOutput := progress.ChanOutput(progressChan)
+
+			select {
+			case <-start:
+			default:
+				progress.Update(progressOutput, descriptor.ID(), "Waiting")
+				<-start
+			}
+
+			retries := 0
+			for {
+				digest, err := descriptor.Upload(u.Transfer.Context(), progressOutput)
+				if err == nil {
+					u.digest = digest
+					break
+				}
+
+				// If an error was returned because the context
+				// was cancelled, we shouldn't retry.
+				select {
+				case <-u.Transfer.Context().Done():
+					u.err = err
+					return
+				default:
+				}
+
+				retries++
+				if _, isDNR := err.(DoNotRetry); isDNR || retries == maxUploadAttempts {
+					logrus.Errorf("Upload failed: %v", err)
+					u.err = err
+					return
+				}
+
+				logrus.Errorf("Upload failed, retrying: %v", err)
+				delay := retries * 5
+				ticker := time.NewTicker(time.Second)
+
+			selectLoop:
+				for {
+					progress.Updatef(progressOutput, descriptor.ID(), "Retrying in %d seconds", delay)
+					select {
+					case <-ticker.C:
+						delay--
+						if delay == 0 {
+							ticker.Stop()
+							break selectLoop
+						}
+					case <-u.Transfer.Context().Done():
+						ticker.Stop()
+						u.err = errors.New("upload cancelled during retry delay")
+						return
+					}
+				}
+			}
+		}()
+
+		return u
+	}
+}

+ 153 - 0
distribution/xfer/upload_test.go

@@ -0,0 +1,153 @@
+package xfer
+
+import (
+	"errors"
+	"sync/atomic"
+	"testing"
+	"time"
+
+	"github.com/docker/distribution/digest"
+	"github.com/docker/docker/layer"
+	"github.com/docker/docker/pkg/progress"
+	"golang.org/x/net/context"
+)
+
+const maxUploadConcurrency = 3
+
+type mockUploadDescriptor struct {
+	currentUploads  *int32
+	diffID          layer.DiffID
+	simulateRetries int
+}
+
+// Key returns the key used to deduplicate downloads.
+func (u *mockUploadDescriptor) Key() string {
+	return u.diffID.String()
+}
+
+// ID returns the ID for display purposes.
+func (u *mockUploadDescriptor) ID() string {
+	return u.diffID.String()
+}
+
+// DiffID should return the DiffID for this layer.
+func (u *mockUploadDescriptor) DiffID() layer.DiffID {
+	return u.diffID
+}
+
+// Upload is called to perform the upload.
+func (u *mockUploadDescriptor) Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error) {
+	if u.currentUploads != nil {
+		defer atomic.AddInt32(u.currentUploads, -1)
+
+		if atomic.AddInt32(u.currentUploads, 1) > maxUploadConcurrency {
+			return "", errors.New("concurrency limit exceeded")
+		}
+	}
+
+	// Sleep a bit to simulate a time-consuming upload.
+	for i := int64(0); i <= 10; i++ {
+		select {
+		case <-ctx.Done():
+			return "", ctx.Err()
+		case <-time.After(10 * time.Millisecond):
+			progressOutput.WriteProgress(progress.Progress{ID: u.ID(), Current: i, Total: 10})
+		}
+	}
+
+	if u.simulateRetries != 0 {
+		u.simulateRetries--
+		return "", errors.New("simulating retry")
+	}
+
+	// For the mock implementation, use SHA256(DiffID) as the returned
+	// digest.
+	return digest.FromBytes([]byte(u.diffID.String()))
+}
+
+func uploadDescriptors(currentUploads *int32) []UploadDescriptor {
+	return []UploadDescriptor{
+		&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"), 0},
+		&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:1515325234325236634634608943609283523908626098235490238423902343"), 0},
+		&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:6929356290463485374960346430698374523437683470934634534953453453"), 0},
+		&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"), 0},
+		&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:8159352387436803946235346346368745389534789534897538734598734987"), 1},
+		&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:4637863963478346897346987346987346789346789364879364897364987346"), 0},
+	}
+}
+
+var expectedDigests = map[layer.DiffID]digest.Digest{
+	layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"): digest.Digest("sha256:c5095d6cf7ee42b7b064371dcc1dc3fb4af197f04d01a60009d484bd432724fc"),
+	layer.DiffID("sha256:1515325234325236634634608943609283523908626098235490238423902343"): digest.Digest("sha256:968cbfe2ff5269ea1729b3804767a1f57ffbc442d3bc86f47edbf7e688a4f36e"),
+	layer.DiffID("sha256:6929356290463485374960346430698374523437683470934634534953453453"): digest.Digest("sha256:8a5e56ab4b477a400470a7d5d4c1ca0c91235fd723ab19cc862636a06f3a735d"),
+	layer.DiffID("sha256:8159352387436803946235346346368745389534789534897538734598734987"): digest.Digest("sha256:5e733e5cd3688512fc240bd5c178e72671c9915947d17bb8451750d827944cb2"),
+	layer.DiffID("sha256:4637863963478346897346987346987346789346789364879364897364987346"): digest.Digest("sha256:ec4bb98d15e554a9f66c3ef9296cf46772c0ded3b1592bd8324d96e2f60f460c"),
+}
+
+func TestSuccessfulUpload(t *testing.T) {
+	lum := NewLayerUploadManager(maxUploadConcurrency)
+
+	progressChan := make(chan progress.Progress)
+	progressDone := make(chan struct{})
+	receivedProgress := make(map[string]int64)
+
+	go func() {
+		for p := range progressChan {
+			receivedProgress[p.ID] = p.Current
+		}
+		close(progressDone)
+	}()
+
+	var currentUploads int32
+	descriptors := uploadDescriptors(&currentUploads)
+
+	digests, err := lum.Upload(context.Background(), descriptors, progress.ChanOutput(progressChan))
+	if err != nil {
+		t.Fatalf("upload error: %v", err)
+	}
+
+	close(progressChan)
+	<-progressDone
+
+	if len(digests) != len(expectedDigests) {
+		t.Fatal("wrong number of keys in digests map")
+	}
+
+	for key, val := range expectedDigests {
+		if digests[key] != val {
+			t.Fatalf("mismatch in digest array for key %v (expected %v, got %v)", key, val, digests[key])
+		}
+		if receivedProgress[key.String()] != 10 {
+			t.Fatalf("missing or wrong progress output for %v", key)
+		}
+	}
+}
+
+func TestCancelledUpload(t *testing.T) {
+	lum := NewLayerUploadManager(maxUploadConcurrency)
+
+	progressChan := make(chan progress.Progress)
+	progressDone := make(chan struct{})
+
+	go func() {
+		for range progressChan {
+		}
+		close(progressDone)
+	}()
+
+	ctx, cancel := context.WithCancel(context.Background())
+
+	go func() {
+		<-time.After(time.Millisecond)
+		cancel()
+	}()
+
+	descriptors := uploadDescriptors(nil)
+	_, err := lum.Upload(ctx, descriptors, progress.ChanOutput(progressChan))
+	if err != context.Canceled {
+		t.Fatal("expected upload to be cancelled")
+	}
+
+	close(progressChan)
+	<-progressDone
+}

+ 3 - 0
docs/reference/api/docker_remote_api.md

@@ -103,6 +103,9 @@ This section lists each version from latest to oldest.  Each listing includes a
   consistent with other date/time values returned by the API.
 * `AuthConfig` now supports a `registrytoken` for token based authentication
 * `POST /containers/create` now has a 4M minimum value limit for `HostConfig.KernelMemory`
+* Pushes initated with `POST /images/(name)/push` and pulls initiated with `POST /images/create`
+  will be cancelled if the HTTP connection making the API request is closed before
+  the push or pull completes.
 
 ### v1.21 API changes
 

+ 3 - 0
docs/reference/api/docker_remote_api_v1.22.md

@@ -1530,6 +1530,7 @@ Query Parameters:
 
 -   **fromImage** – Name of the image to pull. The name may include a tag or
         digest. This parameter may only be used when pulling an image.
+        The pull is cancelled if the HTTP connection is closed.
 -   **fromSrc** – Source to import.  The value may be a URL from which the image
         can be retrieved or `-` to read the image from the request body.
         This parameter may only be used when importing an image.
@@ -1755,6 +1756,8 @@ If you wish to push an image on to a private registry, that image must already h
 into a repository which references that registry `hostname` and `port`.  This repository name should
 then be used in the URL. This duplicates the command line's flow.
 
+The push is cancelled if the HTTP connection is closed.
+
 **Example request**:
 
     POST /images/registry.acme.com:5000/test/push HTTP/1.1

+ 3 - 6
docs/reference/commandline/build.md

@@ -98,12 +98,9 @@ adding a `.dockerignore` file to that directory as well. For information on
 creating one, see the [.dockerignore file](../builder.md#dockerignore-file).
 
 If the Docker client loses connection to the daemon, the build is canceled.
-This happens if you interrupt the Docker client with `ctrl-c` or if the Docker
-client is killed for any reason.
-
-> **Note:**
-> Currently only the "run" phase of the build can be canceled until pull
-> cancellation is implemented).
+This happens if you interrupt the Docker client with `CTRL-c` or if the Docker
+client is killed for any reason. If the build initiated a pull which is still
+running at the time the build is cancelled, the pull is cancelled as well.
 
 ## Return code
 

+ 3 - 0
docs/reference/commandline/pull.md

@@ -49,3 +49,6 @@ use `docker pull`:
     # manually specifies the path to the default Docker registry. This could
     # be replaced with the path to a local registry to pull from another source.
     # sudo docker pull myhub.com:8080/test-image
+
+Killing the `docker pull` process, for example by pressing `CTRL-c` while it is
+running in a terminal, will terminate the pull operation.

+ 3 - 0
docs/reference/commandline/push.md

@@ -19,3 +19,6 @@ parent = "smn_cli"
 
 Use `docker push` to share your images to the [Docker Hub](https://hub.docker.com)
 registry or to a self-hosted one.
+
+Killing the `docker push` process, for example by pressing `CTRL-c` while it is
+running in a terminal, will terminate the push operation.

+ 4 - 10
integration-cli/docker_cli_pull_test.go

@@ -140,7 +140,7 @@ func (s *DockerHubPullSuite) TestPullAllTagsFromCentralRegistry(c *check.C) {
 }
 
 // TestPullClientDisconnect kills the client during a pull operation and verifies that the operation
-// still succesfully completes on the daemon side.
+// gets cancelled.
 //
 // Ref: docker/docker#15589
 func (s *DockerHubPullSuite) TestPullClientDisconnect(c *check.C) {
@@ -161,14 +161,8 @@ func (s *DockerHubPullSuite) TestPullClientDisconnect(c *check.C) {
 	err = pullCmd.Process.Kill()
 	c.Assert(err, checker.IsNil)
 
-	maxAttempts := 20
-	for i := 0; ; i++ {
-		if _, err := s.CmdWithError("inspect", repoName); err == nil {
-			break
-		}
-		if i >= maxAttempts {
-			c.Fatal("timeout reached: image was not pulled after client disconnected")
-		}
-		time.Sleep(500 * time.Millisecond)
+	time.Sleep(2 * time.Second)
+	if _, err := s.CmdWithError("inspect", repoName); err == nil {
+		c.Fatal("image was pulled after client disconnected")
 	}
 }

+ 0 - 167
pkg/broadcaster/buffered.go

@@ -1,167 +0,0 @@
-package broadcaster
-
-import (
-	"errors"
-	"io"
-	"sync"
-)
-
-// Buffered keeps track of one or more observers watching the progress
-// of an operation. For example, if multiple clients are trying to pull an
-// image, they share a Buffered struct for the download operation.
-type Buffered struct {
-	sync.Mutex
-	// c is a channel that observers block on, waiting for the operation
-	// to finish.
-	c chan struct{}
-	// cond is a condition variable used to wake up observers when there's
-	// new data available.
-	cond *sync.Cond
-	// history is a buffer of the progress output so far, so a new observer
-	// can catch up. The history is stored as a slice of separate byte
-	// slices, so that if the writer is a WriteFlusher, the flushes will
-	// happen in the right places.
-	history [][]byte
-	// wg is a WaitGroup used to wait for all writes to finish on Close
-	wg sync.WaitGroup
-	// result is the argument passed to the first call of Close, and
-	// returned to callers of Wait
-	result error
-}
-
-// NewBuffered returns an initialized Buffered structure.
-func NewBuffered() *Buffered {
-	b := &Buffered{
-		c: make(chan struct{}),
-	}
-	b.cond = sync.NewCond(b)
-	return b
-}
-
-// closed returns true if and only if the broadcaster has been closed
-func (broadcaster *Buffered) closed() bool {
-	select {
-	case <-broadcaster.c:
-		return true
-	default:
-		return false
-	}
-}
-
-// receiveWrites runs as a goroutine so that writes don't block the Write
-// function. It writes the new data in broadcaster.history each time there's
-// activity on the broadcaster.cond condition variable.
-func (broadcaster *Buffered) receiveWrites(observer io.Writer) {
-	n := 0
-
-	broadcaster.Lock()
-
-	// The condition variable wait is at the end of this loop, so that the
-	// first iteration will write the history so far.
-	for {
-		newData := broadcaster.history[n:]
-		// Make a copy of newData so we can release the lock
-		sendData := make([][]byte, len(newData), len(newData))
-		copy(sendData, newData)
-		broadcaster.Unlock()
-
-		for len(sendData) > 0 {
-			_, err := observer.Write(sendData[0])
-			if err != nil {
-				broadcaster.wg.Done()
-				return
-			}
-			n++
-			sendData = sendData[1:]
-		}
-
-		broadcaster.Lock()
-
-		// If we are behind, we need to catch up instead of waiting
-		// or handling a closure.
-		if len(broadcaster.history) != n {
-			continue
-		}
-
-		// detect closure of the broadcast writer
-		if broadcaster.closed() {
-			broadcaster.Unlock()
-			broadcaster.wg.Done()
-			return
-		}
-
-		broadcaster.cond.Wait()
-
-		// Mutex is still locked as the loop continues
-	}
-}
-
-// Write adds data to the history buffer, and also writes it to all current
-// observers.
-func (broadcaster *Buffered) Write(p []byte) (n int, err error) {
-	broadcaster.Lock()
-	defer broadcaster.Unlock()
-
-	// Is the broadcaster closed? If so, the write should fail.
-	if broadcaster.closed() {
-		return 0, errors.New("attempted write to a closed broadcaster.Buffered")
-	}
-
-	// Add message in p to the history slice
-	newEntry := make([]byte, len(p), len(p))
-	copy(newEntry, p)
-	broadcaster.history = append(broadcaster.history, newEntry)
-
-	broadcaster.cond.Broadcast()
-
-	return len(p), nil
-}
-
-// Add adds an observer to the broadcaster. The new observer receives the
-// data from the history buffer, and also all subsequent data.
-func (broadcaster *Buffered) Add(w io.Writer) error {
-	// The lock is acquired here so that Add can't race with Close
-	broadcaster.Lock()
-	defer broadcaster.Unlock()
-
-	if broadcaster.closed() {
-		return errors.New("attempted to add observer to a closed broadcaster.Buffered")
-	}
-
-	broadcaster.wg.Add(1)
-	go broadcaster.receiveWrites(w)
-
-	return nil
-}
-
-// CloseWithError signals to all observers that the operation has finished. Its
-// argument is a result that should be returned to waiters blocking on Wait.
-func (broadcaster *Buffered) CloseWithError(result error) {
-	broadcaster.Lock()
-	if broadcaster.closed() {
-		broadcaster.Unlock()
-		return
-	}
-	broadcaster.result = result
-	close(broadcaster.c)
-	broadcaster.cond.Broadcast()
-	broadcaster.Unlock()
-
-	// Don't return until all writers have caught up.
-	broadcaster.wg.Wait()
-}
-
-// Close signals to all observers that the operation has finished. It causes
-// all calls to Wait to return nil.
-func (broadcaster *Buffered) Close() {
-	broadcaster.CloseWithError(nil)
-}
-
-// Wait blocks until the operation is marked as completed by the Close method,
-// and all writer goroutines have completed. It returns the argument that was
-// passed to Close.
-func (broadcaster *Buffered) Wait() error {
-	<-broadcaster.c
-	broadcaster.wg.Wait()
-	return broadcaster.result
-}

+ 71 - 0
pkg/ioutils/readers.go

@@ -4,6 +4,8 @@ import (
 	"crypto/sha256"
 	"encoding/hex"
 	"io"
+
+	"golang.org/x/net/context"
 )
 
 type readCloserWrapper struct {
@@ -81,3 +83,72 @@ func (r *OnEOFReader) runFunc() {
 		r.Fn = nil
 	}
 }
+
+// cancelReadCloser wraps an io.ReadCloser with a context for cancelling read
+// operations.
+type cancelReadCloser struct {
+	cancel func()
+	pR     *io.PipeReader // Stream to read from
+	pW     *io.PipeWriter
+}
+
+// NewCancelReadCloser creates a wrapper that closes the ReadCloser when the
+// context is cancelled. The returned io.ReadCloser must be closed when it is
+// no longer needed.
+func NewCancelReadCloser(ctx context.Context, in io.ReadCloser) io.ReadCloser {
+	pR, pW := io.Pipe()
+
+	// Create a context used to signal when the pipe is closed
+	doneCtx, cancel := context.WithCancel(context.Background())
+
+	p := &cancelReadCloser{
+		cancel: cancel,
+		pR:     pR,
+		pW:     pW,
+	}
+
+	go func() {
+		_, err := io.Copy(pW, in)
+		select {
+		case <-ctx.Done():
+			// If the context was closed, p.closeWithError
+			// was already called. Calling it again would
+			// change the error that Read returns.
+		default:
+			p.closeWithError(err)
+		}
+		in.Close()
+	}()
+	go func() {
+		for {
+			select {
+			case <-ctx.Done():
+				p.closeWithError(ctx.Err())
+			case <-doneCtx.Done():
+				return
+			}
+		}
+	}()
+
+	return p
+}
+
+// Read wraps the Read method of the pipe that provides data from the wrapped
+// ReadCloser.
+func (p *cancelReadCloser) Read(buf []byte) (n int, err error) {
+	return p.pR.Read(buf)
+}
+
+// closeWithError closes the wrapper and its underlying reader. It will
+// cause future calls to Read to return err.
+func (p *cancelReadCloser) closeWithError(err error) {
+	p.pW.CloseWithError(err)
+	p.cancel()
+}
+
+// Close closes the wrapper its underlying reader. It will cause
+// future calls to Read to return io.EOF.
+func (p *cancelReadCloser) Close() error {
+	p.closeWithError(io.EOF)
+	return nil
+}

+ 27 - 0
pkg/ioutils/readers_test.go

@@ -2,8 +2,12 @@ package ioutils
 
 import (
 	"fmt"
+	"io/ioutil"
 	"strings"
 	"testing"
+	"time"
+
+	"golang.org/x/net/context"
 )
 
 // Implement io.Reader
@@ -65,3 +69,26 @@ func TestHashData(t *testing.T) {
 		t.Fatalf("Expecting %s, got %s", expected, actual)
 	}
 }
+
+type perpetualReader struct{}
+
+func (p *perpetualReader) Read(buf []byte) (n int, err error) {
+	for i := 0; i != len(buf); i++ {
+		buf[i] = 'a'
+	}
+	return len(buf), nil
+}
+
+func TestCancelReadCloser(t *testing.T) {
+	ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
+	cancelReadCloser := NewCancelReadCloser(ctx, ioutil.NopCloser(&perpetualReader{}))
+	for {
+		var buf [128]byte
+		_, err := cancelReadCloser.Read(buf[:])
+		if err == context.DeadlineExceeded {
+			break
+		} else if err != nil {
+			t.Fatalf("got unexpected error: %v", err)
+		}
+	}
+}

+ 63 - 0
pkg/progress/progress.go

@@ -0,0 +1,63 @@
+package progress
+
+import (
+	"fmt"
+)
+
+// Progress represents the progress of a transfer.
+type Progress struct {
+	ID string
+
+	// Progress contains a Message or...
+	Message string
+
+	// ...progress of an action
+	Action  string
+	Current int64
+	Total   int64
+
+	LastUpdate bool
+}
+
+// Output is an interface for writing progress information. It's
+// like a writer for progress, but we don't call it Writer because
+// that would be confusing next to ProgressReader (also, because it
+// doesn't implement the io.Writer interface).
+type Output interface {
+	WriteProgress(Progress) error
+}
+
+type chanOutput chan<- Progress
+
+func (out chanOutput) WriteProgress(p Progress) error {
+	out <- p
+	return nil
+}
+
+// ChanOutput returns a Output that writes progress updates to the
+// supplied channel.
+func ChanOutput(progressChan chan<- Progress) Output {
+	return chanOutput(progressChan)
+}
+
+// Update is a convenience function to write a progress update to the channel.
+func Update(out Output, id, action string) {
+	out.WriteProgress(Progress{ID: id, Action: action})
+}
+
+// Updatef is a convenience function to write a printf-formatted progress update
+// to the channel.
+func Updatef(out Output, id, format string, a ...interface{}) {
+	Update(out, id, fmt.Sprintf(format, a...))
+}
+
+// Message is a convenience function to write a progress message to the channel.
+func Message(out Output, id, message string) {
+	out.WriteProgress(Progress{ID: id, Message: message})
+}
+
+// Messagef is a convenience function to write a printf-formatted progress
+// message to the channel.
+func Messagef(out Output, id, format string, a ...interface{}) {
+	Message(out, id, fmt.Sprintf(format, a...))
+}

+ 59 - 0
pkg/progress/progressreader.go

@@ -0,0 +1,59 @@
+package progress
+
+import (
+	"io"
+)
+
+// Reader is a Reader with progress bar.
+type Reader struct {
+	in         io.ReadCloser // Stream to read from
+	out        Output        // Where to send progress bar to
+	size       int64
+	current    int64
+	lastUpdate int64
+	id         string
+	action     string
+}
+
+// NewProgressReader creates a new ProgressReader.
+func NewProgressReader(in io.ReadCloser, out Output, size int64, id, action string) *Reader {
+	return &Reader{
+		in:     in,
+		out:    out,
+		size:   size,
+		id:     id,
+		action: action,
+	}
+}
+
+func (p *Reader) Read(buf []byte) (n int, err error) {
+	read, err := p.in.Read(buf)
+	p.current += int64(read)
+	updateEvery := int64(1024 * 512) //512kB
+	if p.size > 0 {
+		// Update progress for every 1% read if 1% < 512kB
+		if increment := int64(0.01 * float64(p.size)); increment < updateEvery {
+			updateEvery = increment
+		}
+	}
+	if p.current-p.lastUpdate > updateEvery || err != nil {
+		p.updateProgress(err != nil && read == 0)
+		p.lastUpdate = p.current
+	}
+
+	return read, err
+}
+
+// Close closes the progress reader and its underlying reader.
+func (p *Reader) Close() error {
+	if p.current < p.size {
+		// print a full progress bar when closing prematurely
+		p.current = p.size
+		p.updateProgress(false)
+	}
+	return p.in.Close()
+}
+
+func (p *Reader) updateProgress(last bool) {
+	p.out.WriteProgress(Progress{ID: p.id, Action: p.action, Current: p.current, Total: p.size, LastUpdate: last})
+}

+ 75 - 0
pkg/progress/progressreader_test.go

@@ -0,0 +1,75 @@
+package progress
+
+import (
+	"bytes"
+	"io"
+	"io/ioutil"
+	"testing"
+)
+
+func TestOutputOnPrematureClose(t *testing.T) {
+	content := []byte("TESTING")
+	reader := ioutil.NopCloser(bytes.NewReader(content))
+	progressChan := make(chan Progress, 10)
+
+	pr := NewProgressReader(reader, ChanOutput(progressChan), int64(len(content)), "Test", "Read")
+
+	part := make([]byte, 4, 4)
+	_, err := io.ReadFull(pr, part)
+	if err != nil {
+		pr.Close()
+		t.Fatal(err)
+	}
+
+drainLoop:
+	for {
+		select {
+		case <-progressChan:
+		default:
+			break drainLoop
+		}
+	}
+
+	pr.Close()
+
+	select {
+	case <-progressChan:
+	default:
+		t.Fatalf("Expected some output when closing prematurely")
+	}
+}
+
+func TestCompleteSilently(t *testing.T) {
+	content := []byte("TESTING")
+	reader := ioutil.NopCloser(bytes.NewReader(content))
+	progressChan := make(chan Progress, 10)
+
+	pr := NewProgressReader(reader, ChanOutput(progressChan), int64(len(content)), "Test", "Read")
+
+	out, err := ioutil.ReadAll(pr)
+	if err != nil {
+		pr.Close()
+		t.Fatal(err)
+	}
+	if string(out) != "TESTING" {
+		pr.Close()
+		t.Fatalf("Unexpected output %q from reader", string(out))
+	}
+
+drainLoop:
+	for {
+		select {
+		case <-progressChan:
+		default:
+			break drainLoop
+		}
+	}
+
+	pr.Close()
+
+	select {
+	case <-progressChan:
+		t.Fatalf("Should have closed silently when read is complete")
+	default:
+	}
+}

+ 0 - 68
pkg/progressreader/progressreader.go

@@ -1,68 +0,0 @@
-// Package progressreader provides a Reader with a progress bar that can be
-// printed out using the streamformatter package.
-package progressreader
-
-import (
-	"io"
-
-	"github.com/docker/docker/pkg/jsonmessage"
-	"github.com/docker/docker/pkg/streamformatter"
-)
-
-// Config contains the configuration for a Reader with progress bar.
-type Config struct {
-	In         io.ReadCloser // Stream to read from
-	Out        io.Writer     // Where to send progress bar to
-	Formatter  *streamformatter.StreamFormatter
-	Size       int64
-	Current    int64
-	LastUpdate int64
-	NewLines   bool
-	ID         string
-	Action     string
-}
-
-// New creates a new Config.
-func New(newReader Config) *Config {
-	return &newReader
-}
-
-func (config *Config) Read(p []byte) (n int, err error) {
-	read, err := config.In.Read(p)
-	config.Current += int64(read)
-	updateEvery := int64(1024 * 512) //512kB
-	if config.Size > 0 {
-		// Update progress for every 1% read if 1% < 512kB
-		if increment := int64(0.01 * float64(config.Size)); increment < updateEvery {
-			updateEvery = increment
-		}
-	}
-	if config.Current-config.LastUpdate > updateEvery || err != nil {
-		updateProgress(config)
-		config.LastUpdate = config.Current
-	}
-
-	if err != nil && read == 0 {
-		updateProgress(config)
-		if config.NewLines {
-			config.Out.Write(config.Formatter.FormatStatus("", ""))
-		}
-	}
-	return read, err
-}
-
-// Close closes the reader (Config).
-func (config *Config) Close() error {
-	if config.Current < config.Size {
-		//print a full progress bar when closing prematurely
-		config.Current = config.Size
-		updateProgress(config)
-	}
-	return config.In.Close()
-}
-
-func updateProgress(config *Config) {
-	progress := jsonmessage.JSONProgress{Current: config.Current, Total: config.Size}
-	fmtMessage := config.Formatter.FormatProgress(config.ID, config.Action, &progress)
-	config.Out.Write(fmtMessage)
-}

+ 0 - 94
pkg/progressreader/progressreader_test.go

@@ -1,94 +0,0 @@
-package progressreader
-
-import (
-	"bufio"
-	"bytes"
-	"io"
-	"io/ioutil"
-	"testing"
-
-	"github.com/docker/docker/pkg/streamformatter"
-)
-
-func TestOutputOnPrematureClose(t *testing.T) {
-	var outBuf bytes.Buffer
-	content := []byte("TESTING")
-	reader := ioutil.NopCloser(bytes.NewReader(content))
-	writer := bufio.NewWriter(&outBuf)
-
-	prCfg := Config{
-		In:        reader,
-		Out:       writer,
-		Formatter: streamformatter.NewStreamFormatter(),
-		Size:      int64(len(content)),
-		NewLines:  true,
-		ID:        "Test",
-		Action:    "Read",
-	}
-	pr := New(prCfg)
-
-	part := make([]byte, 4, 4)
-	_, err := io.ReadFull(pr, part)
-	if err != nil {
-		pr.Close()
-		t.Fatal(err)
-	}
-
-	if err := writer.Flush(); err != nil {
-		pr.Close()
-		t.Fatal(err)
-	}
-
-	tlen := outBuf.Len()
-	pr.Close()
-	if err := writer.Flush(); err != nil {
-		t.Fatal(err)
-	}
-
-	if outBuf.Len() == tlen {
-		t.Fatalf("Expected some output when closing prematurely")
-	}
-}
-
-func TestCompleteSilently(t *testing.T) {
-	var outBuf bytes.Buffer
-	content := []byte("TESTING")
-	reader := ioutil.NopCloser(bytes.NewReader(content))
-	writer := bufio.NewWriter(&outBuf)
-
-	prCfg := Config{
-		In:        reader,
-		Out:       writer,
-		Formatter: streamformatter.NewStreamFormatter(),
-		Size:      int64(len(content)),
-		NewLines:  true,
-		ID:        "Test",
-		Action:    "Read",
-	}
-	pr := New(prCfg)
-
-	out, err := ioutil.ReadAll(pr)
-	if err != nil {
-		pr.Close()
-		t.Fatal(err)
-	}
-	if string(out) != "TESTING" {
-		pr.Close()
-		t.Fatalf("Unexpected output %q from reader", string(out))
-	}
-
-	if err := writer.Flush(); err != nil {
-		pr.Close()
-		t.Fatal(err)
-	}
-
-	tlen := outBuf.Len()
-	pr.Close()
-	if err := writer.Flush(); err != nil {
-		t.Fatal(err)
-	}
-
-	if outBuf.Len() > tlen {
-		t.Fatalf("Should have closed silently when read is complete")
-	}
-}

+ 39 - 0
pkg/streamformatter/streamformatter.go

@@ -7,6 +7,7 @@ import (
 	"io"
 
 	"github.com/docker/docker/pkg/jsonmessage"
+	"github.com/docker/docker/pkg/progress"
 )
 
 // StreamFormatter formats a stream, optionally using JSON.
@@ -92,6 +93,44 @@ func (sf *StreamFormatter) FormatProgress(id, action string, progress *jsonmessa
 	return []byte(action + " " + progress.String() + endl)
 }
 
+// NewProgressOutput returns a progress.Output object that can be passed to
+// progress.NewProgressReader.
+func (sf *StreamFormatter) NewProgressOutput(out io.Writer, newLines bool) progress.Output {
+	return &progressOutput{
+		sf:       sf,
+		out:      out,
+		newLines: newLines,
+	}
+}
+
+type progressOutput struct {
+	sf       *StreamFormatter
+	out      io.Writer
+	newLines bool
+}
+
+// WriteProgress formats progress information from a ProgressReader.
+func (out *progressOutput) WriteProgress(prog progress.Progress) error {
+	var formatted []byte
+	if prog.Message != "" {
+		formatted = out.sf.FormatStatus(prog.ID, prog.Message)
+	} else {
+		jsonProgress := jsonmessage.JSONProgress{Current: prog.Current, Total: prog.Total}
+		formatted = out.sf.FormatProgress(prog.ID, prog.Action, &jsonProgress)
+	}
+	_, err := out.out.Write(formatted)
+	if err != nil {
+		return err
+	}
+
+	if out.newLines && prog.LastUpdate {
+		_, err = out.out.Write(out.sf.FormatStatus("", ""))
+		return err
+	}
+
+	return nil
+}
+
 // StdoutFormatter is a streamFormatter that writes to the standard output.
 type StdoutFormatter struct {
 	io.Writer

+ 5 - 15
registry/session.go

@@ -17,7 +17,6 @@ import (
 	"net/url"
 	"strconv"
 	"strings"
-	"time"
 
 	"github.com/Sirupsen/logrus"
 	"github.com/docker/distribution/reference"
@@ -270,7 +269,6 @@ func (r *Session) GetRemoteImageJSON(imgID, registry string) ([]byte, int64, err
 // GetRemoteImageLayer retrieves an image layer from the registry
 func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io.ReadCloser, error) {
 	var (
-		retries    = 5
 		statusCode = 0
 		res        *http.Response
 		err        error
@@ -281,14 +279,9 @@ func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io
 	if err != nil {
 		return nil, fmt.Errorf("Error while getting from the server: %v", err)
 	}
-	// TODO(tiborvass): why are we doing retries at this level?
-	// These retries should be generic to both v1 and v2
-	for i := 1; i <= retries; i++ {
-		statusCode = 0
-		res, err = r.client.Do(req)
-		if err == nil {
-			break
-		}
+	statusCode = 0
+	res, err = r.client.Do(req)
+	if err != nil {
 		logrus.Debugf("Error contacting registry %s: %v", registry, err)
 		if res != nil {
 			if res.Body != nil {
@@ -296,11 +289,8 @@ func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io
 			}
 			statusCode = res.StatusCode
 		}
-		if i == retries {
-			return nil, fmt.Errorf("Server error: Status %d while fetching image layer (%s)",
-				statusCode, imgID)
-		}
-		time.Sleep(time.Duration(i) * 5 * time.Second)
+		return nil, fmt.Errorf("Server error: Status %d while fetching image layer (%s)",
+			statusCode, imgID)
 	}
 
 	if res.StatusCode != 200 {