Browse Source

Improved push and pull with upload manager and download manager

This commit adds a transfer manager which deduplicates and schedules
transfers, and also an upload manager and download manager that build on
top of the transfer manager to provide high-level interfaces for uploads
and downloads. The push and pull code is modified to use these building
blocks.

Some benefits of the changes:

- Simplification of push/pull code
- Pushes can upload layers concurrently
- Failed downloads and uploads are retried after backoff delays
- Cancellation is supported, but individual transfers will only be
  cancelled if all pushes or pulls using them are cancelled.
- The distribution code is decoupled from Docker Engine packages and API
  conventions (i.e. streamformatter), which will make it easier to split
  out.

This commit also includes unit tests for the new distribution/xfer
package. The tests cover 87.8% of the statements in the package.

Signed-off-by: Aaron Lehmann <aaron.lehmann@docker.com>
Aaron Lehmann 9 years ago
parent
commit
572ce80230

+ 6 - 20
api/client/build.go

@@ -23,7 +23,7 @@ import (
 	"github.com/docker/docker/pkg/httputils"
 	"github.com/docker/docker/pkg/httputils"
 	"github.com/docker/docker/pkg/jsonmessage"
 	"github.com/docker/docker/pkg/jsonmessage"
 	flag "github.com/docker/docker/pkg/mflag"
 	flag "github.com/docker/docker/pkg/mflag"
-	"github.com/docker/docker/pkg/progressreader"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/pkg/streamformatter"
 	"github.com/docker/docker/pkg/streamformatter"
 	"github.com/docker/docker/pkg/ulimit"
 	"github.com/docker/docker/pkg/ulimit"
 	"github.com/docker/docker/pkg/units"
 	"github.com/docker/docker/pkg/units"
@@ -169,16 +169,9 @@ func (cli *DockerCli) CmdBuild(args ...string) error {
 	context = replaceDockerfileTarWrapper(context, newDockerfile, relDockerfile)
 	context = replaceDockerfileTarWrapper(context, newDockerfile, relDockerfile)
 
 
 	// Setup an upload progress bar
 	// Setup an upload progress bar
-	// FIXME: ProgressReader shouldn't be this annoying to use
-	sf := streamformatter.NewStreamFormatter()
-	var body io.Reader = progressreader.New(progressreader.Config{
-		In:        context,
-		Out:       cli.out,
-		Formatter: sf,
-		NewLines:  true,
-		ID:        "",
-		Action:    "Sending build context to Docker daemon",
-	})
+	progressOutput := streamformatter.NewStreamFormatter().NewProgressOutput(cli.out, true)
+
+	var body io.Reader = progress.NewProgressReader(context, progressOutput, 0, "", "Sending build context to Docker daemon")
 
 
 	var memory int64
 	var memory int64
 	if *flMemoryString != "" {
 	if *flMemoryString != "" {
@@ -447,17 +440,10 @@ func getContextFromURL(out io.Writer, remoteURL, dockerfileName string) (absCont
 		return "", "", fmt.Errorf("unable to download remote context %s: %v", remoteURL, err)
 		return "", "", fmt.Errorf("unable to download remote context %s: %v", remoteURL, err)
 	}
 	}
 	defer response.Body.Close()
 	defer response.Body.Close()
+	progressOutput := streamformatter.NewStreamFormatter().NewProgressOutput(out, true)
 
 
 	// Pass the response body through a progress reader.
 	// Pass the response body through a progress reader.
-	progReader := &progressreader.Config{
-		In:        response.Body,
-		Out:       out,
-		Formatter: streamformatter.NewStreamFormatter(),
-		Size:      response.ContentLength,
-		NewLines:  true,
-		ID:        "",
-		Action:    fmt.Sprintf("Downloading build context from remote url: %s", remoteURL),
-	}
+	progReader := progress.NewProgressReader(response.Body, progressOutput, response.ContentLength, "", fmt.Sprintf("Downloading build context from remote url: %s", remoteURL))
 
 
 	return getContextFromReader(progReader, dockerfileName)
 	return getContextFromReader(progReader, dockerfileName)
 }
 }

+ 6 - 12
api/server/router/local/image.go

@@ -23,7 +23,7 @@ import (
 	"github.com/docker/docker/pkg/archive"
 	"github.com/docker/docker/pkg/archive"
 	"github.com/docker/docker/pkg/chrootarchive"
 	"github.com/docker/docker/pkg/chrootarchive"
 	"github.com/docker/docker/pkg/ioutils"
 	"github.com/docker/docker/pkg/ioutils"
-	"github.com/docker/docker/pkg/progressreader"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/pkg/streamformatter"
 	"github.com/docker/docker/pkg/streamformatter"
 	"github.com/docker/docker/pkg/ulimit"
 	"github.com/docker/docker/pkg/ulimit"
 	"github.com/docker/docker/runconfig"
 	"github.com/docker/docker/runconfig"
@@ -325,7 +325,7 @@ func (s *router) postBuild(ctx context.Context, w http.ResponseWriter, r *http.R
 	sf := streamformatter.NewJSONStreamFormatter()
 	sf := streamformatter.NewJSONStreamFormatter()
 	errf := func(err error) error {
 	errf := func(err error) error {
 		// Do not write the error in the http output if it's still empty.
 		// Do not write the error in the http output if it's still empty.
-		// This prevents from writing a 200(OK) when there is an interal error.
+		// This prevents from writing a 200(OK) when there is an internal error.
 		if !output.Flushed() {
 		if !output.Flushed() {
 			return err
 			return err
 		}
 		}
@@ -401,23 +401,17 @@ func (s *router) postBuild(ctx context.Context, w http.ResponseWriter, r *http.R
 	remoteURL := r.FormValue("remote")
 	remoteURL := r.FormValue("remote")
 
 
 	// Currently, only used if context is from a remote url.
 	// Currently, only used if context is from a remote url.
-	// The field `In` is set by DetectContextFromRemoteURL.
 	// Look at code in DetectContextFromRemoteURL for more information.
 	// Look at code in DetectContextFromRemoteURL for more information.
-	pReader := &progressreader.Config{
-		// TODO: make progressreader streamformatter-agnostic
-		Out:       output,
-		Formatter: sf,
-		Size:      r.ContentLength,
-		NewLines:  true,
-		ID:        "Downloading context",
-		Action:    remoteURL,
+	createProgressReader := func(in io.ReadCloser) io.ReadCloser {
+		progressOutput := sf.NewProgressOutput(output, true)
+		return progress.NewProgressReader(in, progressOutput, r.ContentLength, "Downloading context", remoteURL)
 	}
 	}
 
 
 	var (
 	var (
 		context        builder.ModifiableContext
 		context        builder.ModifiableContext
 		dockerfileName string
 		dockerfileName string
 	)
 	)
-	context, dockerfileName, err = daemonbuilder.DetectContextFromRemoteURL(r.Body, remoteURL, pReader)
+	context, dockerfileName, err = daemonbuilder.DetectContextFromRemoteURL(r.Body, remoteURL, createProgressReader)
 	if err != nil {
 	if err != nil {
 		return errf(err)
 		return errf(err)
 	}
 	}

+ 5 - 11
builder/dockerfile/internals.go

@@ -29,7 +29,7 @@ import (
 	"github.com/docker/docker/pkg/httputils"
 	"github.com/docker/docker/pkg/httputils"
 	"github.com/docker/docker/pkg/ioutils"
 	"github.com/docker/docker/pkg/ioutils"
 	"github.com/docker/docker/pkg/jsonmessage"
 	"github.com/docker/docker/pkg/jsonmessage"
-	"github.com/docker/docker/pkg/progressreader"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/pkg/streamformatter"
 	"github.com/docker/docker/pkg/streamformatter"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/docker/docker/pkg/stringutils"
 	"github.com/docker/docker/pkg/stringutils"
@@ -264,17 +264,11 @@ func (b *Builder) download(srcURL string) (fi builder.FileInfo, err error) {
 		return
 		return
 	}
 	}
 
 
+	stdoutFormatter := b.Stdout.(*streamformatter.StdoutFormatter)
+	progressOutput := stdoutFormatter.StreamFormatter.NewProgressOutput(stdoutFormatter.Writer, true)
+	progressReader := progress.NewProgressReader(resp.Body, progressOutput, resp.ContentLength, "", "Downloading")
 	// Download and dump result to tmp file
 	// Download and dump result to tmp file
-	if _, err = io.Copy(tmpFile, progressreader.New(progressreader.Config{
-		In: resp.Body,
-		// TODO: make progressreader streamformatter agnostic
-		Out:       b.Stdout.(*streamformatter.StdoutFormatter).Writer,
-		Formatter: b.Stdout.(*streamformatter.StdoutFormatter).StreamFormatter,
-		Size:      resp.ContentLength,
-		NewLines:  true,
-		ID:        "",
-		Action:    "Downloading",
-	})); err != nil {
+	if _, err = io.Copy(tmpFile, progressReader); err != nil {
 		tmpFile.Close()
 		tmpFile.Close()
 		return
 		return
 	}
 	}

+ 70 - 9
daemon/daemon.go

@@ -34,6 +34,7 @@ import (
 	"github.com/docker/docker/daemon/network"
 	"github.com/docker/docker/daemon/network"
 	"github.com/docker/docker/distribution"
 	"github.com/docker/docker/distribution"
 	dmetadata "github.com/docker/docker/distribution/metadata"
 	dmetadata "github.com/docker/docker/distribution/metadata"
+	"github.com/docker/docker/distribution/xfer"
 	derr "github.com/docker/docker/errors"
 	derr "github.com/docker/docker/errors"
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/image/tarexport"
 	"github.com/docker/docker/image/tarexport"
@@ -49,7 +50,9 @@ import (
 	"github.com/docker/docker/pkg/namesgenerator"
 	"github.com/docker/docker/pkg/namesgenerator"
 	"github.com/docker/docker/pkg/nat"
 	"github.com/docker/docker/pkg/nat"
 	"github.com/docker/docker/pkg/parsers/filters"
 	"github.com/docker/docker/pkg/parsers/filters"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/pkg/signal"
 	"github.com/docker/docker/pkg/signal"
+	"github.com/docker/docker/pkg/streamformatter"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/docker/docker/pkg/stringutils"
 	"github.com/docker/docker/pkg/stringutils"
 	"github.com/docker/docker/pkg/sysinfo"
 	"github.com/docker/docker/pkg/sysinfo"
@@ -66,6 +69,16 @@ import (
 	lntypes "github.com/docker/libnetwork/types"
 	lntypes "github.com/docker/libnetwork/types"
 	"github.com/docker/libtrust"
 	"github.com/docker/libtrust"
 	"github.com/opencontainers/runc/libcontainer"
 	"github.com/opencontainers/runc/libcontainer"
+	"golang.org/x/net/context"
+)
+
+const (
+	// maxDownloadConcurrency is the maximum number of downloads that
+	// may take place at a time for each pull.
+	maxDownloadConcurrency = 3
+	// maxUploadConcurrency is the maximum number of uploads that
+	// may take place at a time for each push.
+	maxUploadConcurrency = 5
 )
 )
 
 
 var (
 var (
@@ -126,7 +139,8 @@ type Daemon struct {
 	containers                *contStore
 	containers                *contStore
 	execCommands              *exec.Store
 	execCommands              *exec.Store
 	tagStore                  tag.Store
 	tagStore                  tag.Store
-	distributionPool          *distribution.Pool
+	downloadManager           *xfer.LayerDownloadManager
+	uploadManager             *xfer.LayerUploadManager
 	distributionMetadataStore dmetadata.Store
 	distributionMetadataStore dmetadata.Store
 	trustKey                  libtrust.PrivateKey
 	trustKey                  libtrust.PrivateKey
 	idIndex                   *truncindex.TruncIndex
 	idIndex                   *truncindex.TruncIndex
@@ -738,7 +752,8 @@ func NewDaemon(config *Config, registryService *registry.Service) (daemon *Daemo
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	distributionPool := distribution.NewPool()
+	d.downloadManager = xfer.NewLayerDownloadManager(d.layerStore, maxDownloadConcurrency)
+	d.uploadManager = xfer.NewLayerUploadManager(maxUploadConcurrency)
 
 
 	ifs, err := image.NewFSStoreBackend(filepath.Join(imageRoot, "imagedb"))
 	ifs, err := image.NewFSStoreBackend(filepath.Join(imageRoot, "imagedb"))
 	if err != nil {
 	if err != nil {
@@ -834,7 +849,6 @@ func NewDaemon(config *Config, registryService *registry.Service) (daemon *Daemo
 	d.containers = &contStore{s: make(map[string]*container.Container)}
 	d.containers = &contStore{s: make(map[string]*container.Container)}
 	d.execCommands = exec.NewStore()
 	d.execCommands = exec.NewStore()
 	d.tagStore = tagStore
 	d.tagStore = tagStore
-	d.distributionPool = distributionPool
 	d.distributionMetadataStore = distributionMetadataStore
 	d.distributionMetadataStore = distributionMetadataStore
 	d.trustKey = trustKey
 	d.trustKey = trustKey
 	d.idIndex = truncindex.NewTruncIndex([]string{})
 	d.idIndex = truncindex.NewTruncIndex([]string{})
@@ -1038,23 +1052,53 @@ func (daemon *Daemon) TagImage(newTag reference.Named, imageName string) error {
 	return nil
 	return nil
 }
 }
 
 
+func writeDistributionProgress(cancelFunc func(), outStream io.Writer, progressChan <-chan progress.Progress) {
+	progressOutput := streamformatter.NewJSONStreamFormatter().NewProgressOutput(outStream, false)
+	operationCancelled := false
+
+	for prog := range progressChan {
+		if err := progressOutput.WriteProgress(prog); err != nil && !operationCancelled {
+			logrus.Errorf("error writing progress to client: %v", err)
+			cancelFunc()
+			operationCancelled = true
+			// Don't return, because we need to continue draining
+			// progressChan until it's closed to avoid a deadlock.
+		}
+	}
+}
+
 // PullImage initiates a pull operation. image is the repository name to pull, and
 // PullImage initiates a pull operation. image is the repository name to pull, and
 // tag may be either empty, or indicate a specific tag to pull.
 // tag may be either empty, or indicate a specific tag to pull.
 func (daemon *Daemon) PullImage(ref reference.Named, metaHeaders map[string][]string, authConfig *cliconfig.AuthConfig, outStream io.Writer) error {
 func (daemon *Daemon) PullImage(ref reference.Named, metaHeaders map[string][]string, authConfig *cliconfig.AuthConfig, outStream io.Writer) error {
+	// Include a buffer so that slow client connections don't affect
+	// transfer performance.
+	progressChan := make(chan progress.Progress, 100)
+
+	writesDone := make(chan struct{})
+
+	ctx, cancelFunc := context.WithCancel(context.Background())
+
+	go func() {
+		writeDistributionProgress(cancelFunc, outStream, progressChan)
+		close(writesDone)
+	}()
+
 	imagePullConfig := &distribution.ImagePullConfig{
 	imagePullConfig := &distribution.ImagePullConfig{
 		MetaHeaders:     metaHeaders,
 		MetaHeaders:     metaHeaders,
 		AuthConfig:      authConfig,
 		AuthConfig:      authConfig,
-		OutStream:       outStream,
+		ProgressOutput:  progress.ChanOutput(progressChan),
 		RegistryService: daemon.RegistryService,
 		RegistryService: daemon.RegistryService,
 		EventsService:   daemon.EventsService,
 		EventsService:   daemon.EventsService,
 		MetadataStore:   daemon.distributionMetadataStore,
 		MetadataStore:   daemon.distributionMetadataStore,
-		LayerStore:      daemon.layerStore,
 		ImageStore:      daemon.imageStore,
 		ImageStore:      daemon.imageStore,
 		TagStore:        daemon.tagStore,
 		TagStore:        daemon.tagStore,
-		Pool:            daemon.distributionPool,
+		DownloadManager: daemon.downloadManager,
 	}
 	}
 
 
-	return distribution.Pull(ref, imagePullConfig)
+	err := distribution.Pull(ctx, ref, imagePullConfig)
+	close(progressChan)
+	<-writesDone
+	return err
 }
 }
 
 
 // ExportImage exports a list of images to the given output stream. The
 // ExportImage exports a list of images to the given output stream. The
@@ -1069,10 +1113,23 @@ func (daemon *Daemon) ExportImage(names []string, outStream io.Writer) error {
 
 
 // PushImage initiates a push operation on the repository named localName.
 // PushImage initiates a push operation on the repository named localName.
 func (daemon *Daemon) PushImage(ref reference.Named, metaHeaders map[string][]string, authConfig *cliconfig.AuthConfig, outStream io.Writer) error {
 func (daemon *Daemon) PushImage(ref reference.Named, metaHeaders map[string][]string, authConfig *cliconfig.AuthConfig, outStream io.Writer) error {
+	// Include a buffer so that slow client connections don't affect
+	// transfer performance.
+	progressChan := make(chan progress.Progress, 100)
+
+	writesDone := make(chan struct{})
+
+	ctx, cancelFunc := context.WithCancel(context.Background())
+
+	go func() {
+		writeDistributionProgress(cancelFunc, outStream, progressChan)
+		close(writesDone)
+	}()
+
 	imagePushConfig := &distribution.ImagePushConfig{
 	imagePushConfig := &distribution.ImagePushConfig{
 		MetaHeaders:     metaHeaders,
 		MetaHeaders:     metaHeaders,
 		AuthConfig:      authConfig,
 		AuthConfig:      authConfig,
-		OutStream:       outStream,
+		ProgressOutput:  progress.ChanOutput(progressChan),
 		RegistryService: daemon.RegistryService,
 		RegistryService: daemon.RegistryService,
 		EventsService:   daemon.EventsService,
 		EventsService:   daemon.EventsService,
 		MetadataStore:   daemon.distributionMetadataStore,
 		MetadataStore:   daemon.distributionMetadataStore,
@@ -1080,9 +1137,13 @@ func (daemon *Daemon) PushImage(ref reference.Named, metaHeaders map[string][]st
 		ImageStore:      daemon.imageStore,
 		ImageStore:      daemon.imageStore,
 		TagStore:        daemon.tagStore,
 		TagStore:        daemon.tagStore,
 		TrustKey:        daemon.trustKey,
 		TrustKey:        daemon.trustKey,
+		UploadManager:   daemon.uploadManager,
 	}
 	}
 
 
-	return distribution.Push(ref, imagePushConfig)
+	err := distribution.Push(ctx, ref, imagePushConfig)
+	close(progressChan)
+	<-writesDone
+	return err
 }
 }
 
 
 // LookupImage looks up an image by name and returns it as an ImageInspect
 // LookupImage looks up an image by name and returns it as an ImageInspect

+ 2 - 4
daemon/daemonbuilder/builder.go

@@ -21,7 +21,6 @@ import (
 	"github.com/docker/docker/pkg/httputils"
 	"github.com/docker/docker/pkg/httputils"
 	"github.com/docker/docker/pkg/idtools"
 	"github.com/docker/docker/pkg/idtools"
 	"github.com/docker/docker/pkg/ioutils"
 	"github.com/docker/docker/pkg/ioutils"
-	"github.com/docker/docker/pkg/progressreader"
 	"github.com/docker/docker/pkg/urlutil"
 	"github.com/docker/docker/pkg/urlutil"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/runconfig"
 	"github.com/docker/docker/runconfig"
@@ -239,7 +238,7 @@ func (d Docker) Start(c *container.Container) error {
 // DetectContextFromRemoteURL returns a context and in certain cases the name of the dockerfile to be used
 // DetectContextFromRemoteURL returns a context and in certain cases the name of the dockerfile to be used
 // irrespective of user input.
 // irrespective of user input.
 // progressReader is only used if remoteURL is actually a URL (not empty, and not a Git endpoint).
 // progressReader is only used if remoteURL is actually a URL (not empty, and not a Git endpoint).
-func DetectContextFromRemoteURL(r io.ReadCloser, remoteURL string, progressReader *progressreader.Config) (context builder.ModifiableContext, dockerfileName string, err error) {
+func DetectContextFromRemoteURL(r io.ReadCloser, remoteURL string, createProgressReader func(in io.ReadCloser) io.ReadCloser) (context builder.ModifiableContext, dockerfileName string, err error) {
 	switch {
 	switch {
 	case remoteURL == "":
 	case remoteURL == "":
 		context, err = builder.MakeTarSumContext(r)
 		context, err = builder.MakeTarSumContext(r)
@@ -262,8 +261,7 @@ func DetectContextFromRemoteURL(r io.ReadCloser, remoteURL string, progressReade
 			},
 			},
 			// fallback handler (tar context)
 			// fallback handler (tar context)
 			"": func(rc io.ReadCloser) (io.ReadCloser, error) {
 			"": func(rc io.ReadCloser) (io.ReadCloser, error) {
-				progressReader.In = rc
-				return progressReader, nil
+				return createProgressReader(rc), nil
 			},
 			},
 		})
 		})
 	default:
 	default:

+ 3 - 11
daemon/import.go

@@ -13,7 +13,7 @@ import (
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/layer"
 	"github.com/docker/docker/layer"
 	"github.com/docker/docker/pkg/httputils"
 	"github.com/docker/docker/pkg/httputils"
-	"github.com/docker/docker/pkg/progressreader"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/pkg/streamformatter"
 	"github.com/docker/docker/pkg/streamformatter"
 	"github.com/docker/docker/runconfig"
 	"github.com/docker/docker/runconfig"
 )
 )
@@ -47,16 +47,8 @@ func (daemon *Daemon) ImportImage(src string, newRef reference.Named, msg string
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-		progressReader := progressreader.New(progressreader.Config{
-			In:        resp.Body,
-			Out:       outStream,
-			Formatter: sf,
-			Size:      resp.ContentLength,
-			NewLines:  true,
-			ID:        "",
-			Action:    "Importing",
-		})
-		archive = progressReader
+		progressOutput := sf.NewProgressOutput(outStream, true)
+		archive = progress.NewProgressReader(resp.Body, progressOutput, resp.ContentLength, "", "Importing")
 	}
 	}
 
 
 	defer archive.Close()
 	defer archive.Close()

+ 5 - 5
distribution/metadata/v1_id_service.go

@@ -23,20 +23,20 @@ func (idserv *V1IDService) namespace() string {
 }
 }
 
 
 // Get finds a layer by its V1 ID.
 // Get finds a layer by its V1 ID.
-func (idserv *V1IDService) Get(v1ID, registry string) (layer.ChainID, error) {
+func (idserv *V1IDService) Get(v1ID, registry string) (layer.DiffID, error) {
 	if err := v1.ValidateID(v1ID); err != nil {
 	if err := v1.ValidateID(v1ID); err != nil {
-		return layer.ChainID(""), err
+		return layer.DiffID(""), err
 	}
 	}
 
 
 	idBytes, err := idserv.store.Get(idserv.namespace(), registry+","+v1ID)
 	idBytes, err := idserv.store.Get(idserv.namespace(), registry+","+v1ID)
 	if err != nil {
 	if err != nil {
-		return layer.ChainID(""), err
+		return layer.DiffID(""), err
 	}
 	}
-	return layer.ChainID(idBytes), nil
+	return layer.DiffID(idBytes), nil
 }
 }
 
 
 // Set associates an image with a V1 ID.
 // Set associates an image with a V1 ID.
-func (idserv *V1IDService) Set(v1ID, registry string, id layer.ChainID) error {
+func (idserv *V1IDService) Set(v1ID, registry string, id layer.DiffID) error {
 	if err := v1.ValidateID(v1ID); err != nil {
 	if err := v1.ValidateID(v1ID); err != nil {
 		return err
 		return err
 	}
 	}

+ 4 - 4
distribution/metadata/v1_id_service_test.go

@@ -24,22 +24,22 @@ func TestV1IDService(t *testing.T) {
 	testVectors := []struct {
 	testVectors := []struct {
 		registry string
 		registry string
 		v1ID     string
 		v1ID     string
-		layerID  layer.ChainID
+		layerID  layer.DiffID
 	}{
 	}{
 		{
 		{
 			registry: "registry1",
 			registry: "registry1",
 			v1ID:     "f0cd5ca10b07f35512fc2f1cbf9a6cefbdb5cba70ac6b0c9e5988f4497f71937",
 			v1ID:     "f0cd5ca10b07f35512fc2f1cbf9a6cefbdb5cba70ac6b0c9e5988f4497f71937",
-			layerID:  layer.ChainID("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"),
+			layerID:  layer.DiffID("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"),
 		},
 		},
 		{
 		{
 			registry: "registry2",
 			registry: "registry2",
 			v1ID:     "9e3447ca24cb96d86ebd5960cb34d1299b07e0a0e03801d90b9969a2c187dd6e",
 			v1ID:     "9e3447ca24cb96d86ebd5960cb34d1299b07e0a0e03801d90b9969a2c187dd6e",
-			layerID:  layer.ChainID("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa"),
+			layerID:  layer.DiffID("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa"),
 		},
 		},
 		{
 		{
 			registry: "registry1",
 			registry: "registry1",
 			v1ID:     "9e3447ca24cb96d86ebd5960cb34d1299b07e0a0e03801d90b9969a2c187dd6e",
 			v1ID:     "9e3447ca24cb96d86ebd5960cb34d1299b07e0a0e03801d90b9969a2c187dd6e",
-			layerID:  layer.ChainID("sha256:03f4658f8b782e12230c1783426bd3bacce651ce582a4ffb6fbbfa2079428ecb"),
+			layerID:  layer.DiffID("sha256:03f4658f8b782e12230c1783426bd3bacce651ce582a4ffb6fbbfa2079428ecb"),
 		},
 		},
 	}
 	}
 
 

+ 0 - 51
distribution/pool.go

@@ -1,51 +0,0 @@
-package distribution
-
-import (
-	"sync"
-
-	"github.com/docker/docker/pkg/broadcaster"
-)
-
-// A Pool manages concurrent pulls. It deduplicates in-progress downloads.
-type Pool struct {
-	sync.Mutex
-	pullingPool map[string]*broadcaster.Buffered
-}
-
-// NewPool creates a new Pool.
-func NewPool() *Pool {
-	return &Pool{
-		pullingPool: make(map[string]*broadcaster.Buffered),
-	}
-}
-
-// add checks if a pull is already running, and returns (broadcaster, true)
-// if a running operation is found. Otherwise, it creates a new one and returns
-// (broadcaster, false).
-func (pool *Pool) add(key string) (*broadcaster.Buffered, bool) {
-	pool.Lock()
-	defer pool.Unlock()
-
-	if p, exists := pool.pullingPool[key]; exists {
-		return p, true
-	}
-
-	broadcaster := broadcaster.NewBuffered()
-	pool.pullingPool[key] = broadcaster
-
-	return broadcaster, false
-}
-
-func (pool *Pool) removeWithError(key string, broadcasterResult error) error {
-	pool.Lock()
-	defer pool.Unlock()
-	if broadcaster, exists := pool.pullingPool[key]; exists {
-		broadcaster.CloseWithError(broadcasterResult)
-		delete(pool.pullingPool, key)
-	}
-	return nil
-}
-
-func (pool *Pool) remove(key string) error {
-	return pool.removeWithError(key, nil)
-}

+ 0 - 28
distribution/pool_test.go

@@ -1,28 +0,0 @@
-package distribution
-
-import (
-	"testing"
-)
-
-func TestPools(t *testing.T) {
-	p := NewPool()
-
-	if _, found := p.add("test1"); found {
-		t.Fatal("Expected pull test1 not to be in progress")
-	}
-	if _, found := p.add("test2"); found {
-		t.Fatal("Expected pull test2 not to be in progress")
-	}
-	if _, found := p.add("test1"); !found {
-		t.Fatalf("Expected pull test1 to be in progress`")
-	}
-	if err := p.remove("test2"); err != nil {
-		t.Fatal(err)
-	}
-	if err := p.remove("test2"); err != nil {
-		t.Fatal(err)
-	}
-	if err := p.remove("test1"); err != nil {
-		t.Fatal(err)
-	}
-}

+ 36 - 21
distribution/pull.go

@@ -2,7 +2,7 @@ package distribution
 
 
 import (
 import (
 	"fmt"
 	"fmt"
-	"io"
+	"os"
 	"strings"
 	"strings"
 
 
 	"github.com/Sirupsen/logrus"
 	"github.com/Sirupsen/logrus"
@@ -10,11 +10,12 @@ import (
 	"github.com/docker/docker/cliconfig"
 	"github.com/docker/docker/cliconfig"
 	"github.com/docker/docker/daemon/events"
 	"github.com/docker/docker/daemon/events"
 	"github.com/docker/docker/distribution/metadata"
 	"github.com/docker/docker/distribution/metadata"
+	"github.com/docker/docker/distribution/xfer"
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/image"
-	"github.com/docker/docker/layer"
-	"github.com/docker/docker/pkg/streamformatter"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/tag"
 	"github.com/docker/docker/tag"
+	"golang.org/x/net/context"
 )
 )
 
 
 // ImagePullConfig stores pull configuration.
 // ImagePullConfig stores pull configuration.
@@ -25,9 +26,9 @@ type ImagePullConfig struct {
 	// AuthConfig holds authentication credentials for authenticating with
 	// AuthConfig holds authentication credentials for authenticating with
 	// the registry.
 	// the registry.
 	AuthConfig *cliconfig.AuthConfig
 	AuthConfig *cliconfig.AuthConfig
-	// OutStream is the output writer for showing the status of the pull
+	// ProgressOutput is the interface for showing the status of the pull
 	// operation.
 	// operation.
-	OutStream io.Writer
+	ProgressOutput progress.Output
 	// RegistryService is the registry service to use for TLS configuration
 	// RegistryService is the registry service to use for TLS configuration
 	// and endpoint lookup.
 	// and endpoint lookup.
 	RegistryService *registry.Service
 	RegistryService *registry.Service
@@ -36,14 +37,12 @@ type ImagePullConfig struct {
 	// MetadataStore is the storage backend for distribution-specific
 	// MetadataStore is the storage backend for distribution-specific
 	// metadata.
 	// metadata.
 	MetadataStore metadata.Store
 	MetadataStore metadata.Store
-	// LayerStore manages layers.
-	LayerStore layer.Store
 	// ImageStore manages images.
 	// ImageStore manages images.
 	ImageStore image.Store
 	ImageStore image.Store
 	// TagStore manages tags.
 	// TagStore manages tags.
 	TagStore tag.Store
 	TagStore tag.Store
-	// Pool manages concurrent pulls.
-	Pool *Pool
+	// DownloadManager manages concurrent pulls.
+	DownloadManager *xfer.LayerDownloadManager
 }
 }
 
 
 // Puller is an interface that abstracts pulling for different API versions.
 // Puller is an interface that abstracts pulling for different API versions.
@@ -51,7 +50,7 @@ type Puller interface {
 	// Pull tries to pull the image referenced by `tag`
 	// Pull tries to pull the image referenced by `tag`
 	// Pull returns an error if any, as well as a boolean that determines whether to retry Pull on the next configured endpoint.
 	// Pull returns an error if any, as well as a boolean that determines whether to retry Pull on the next configured endpoint.
 	//
 	//
-	Pull(ref reference.Named) (fallback bool, err error)
+	Pull(ctx context.Context, ref reference.Named) (fallback bool, err error)
 }
 }
 
 
 // newPuller returns a Puller interface that will pull from either a v1 or v2
 // newPuller returns a Puller interface that will pull from either a v1 or v2
@@ -59,14 +58,13 @@ type Puller interface {
 // whether a v1 or v2 puller will be created. The other parameters are passed
 // whether a v1 or v2 puller will be created. The other parameters are passed
 // through to the underlying puller implementation for use during the actual
 // through to the underlying puller implementation for use during the actual
 // pull operation.
 // pull operation.
-func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePullConfig *ImagePullConfig, sf *streamformatter.StreamFormatter) (Puller, error) {
+func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePullConfig *ImagePullConfig) (Puller, error) {
 	switch endpoint.Version {
 	switch endpoint.Version {
 	case registry.APIVersion2:
 	case registry.APIVersion2:
 		return &v2Puller{
 		return &v2Puller{
 			blobSumService: metadata.NewBlobSumService(imagePullConfig.MetadataStore),
 			blobSumService: metadata.NewBlobSumService(imagePullConfig.MetadataStore),
 			endpoint:       endpoint,
 			endpoint:       endpoint,
 			config:         imagePullConfig,
 			config:         imagePullConfig,
-			sf:             sf,
 			repoInfo:       repoInfo,
 			repoInfo:       repoInfo,
 		}, nil
 		}, nil
 	case registry.APIVersion1:
 	case registry.APIVersion1:
@@ -74,7 +72,6 @@ func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo,
 			v1IDService: metadata.NewV1IDService(imagePullConfig.MetadataStore),
 			v1IDService: metadata.NewV1IDService(imagePullConfig.MetadataStore),
 			endpoint:    endpoint,
 			endpoint:    endpoint,
 			config:      imagePullConfig,
 			config:      imagePullConfig,
-			sf:          sf,
 			repoInfo:    repoInfo,
 			repoInfo:    repoInfo,
 		}, nil
 		}, nil
 	}
 	}
@@ -83,9 +80,7 @@ func newPuller(endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo,
 
 
 // Pull initiates a pull operation. image is the repository name to pull, and
 // Pull initiates a pull operation. image is the repository name to pull, and
 // tag may be either empty, or indicate a specific tag to pull.
 // tag may be either empty, or indicate a specific tag to pull.
-func Pull(ref reference.Named, imagePullConfig *ImagePullConfig) error {
-	var sf = streamformatter.NewJSONStreamFormatter()
-
+func Pull(ctx context.Context, ref reference.Named, imagePullConfig *ImagePullConfig) error {
 	// Resolve the Repository name from fqn to RepositoryInfo
 	// Resolve the Repository name from fqn to RepositoryInfo
 	repoInfo, err := imagePullConfig.RegistryService.ResolveRepository(ref)
 	repoInfo, err := imagePullConfig.RegistryService.ResolveRepository(ref)
 	if err != nil {
 	if err != nil {
@@ -120,12 +115,19 @@ func Pull(ref reference.Named, imagePullConfig *ImagePullConfig) error {
 	for _, endpoint := range endpoints {
 	for _, endpoint := range endpoints {
 		logrus.Debugf("Trying to pull %s from %s %s", repoInfo.LocalName, endpoint.URL, endpoint.Version)
 		logrus.Debugf("Trying to pull %s from %s %s", repoInfo.LocalName, endpoint.URL, endpoint.Version)
 
 
-		puller, err := newPuller(endpoint, repoInfo, imagePullConfig, sf)
+		puller, err := newPuller(endpoint, repoInfo, imagePullConfig)
 		if err != nil {
 		if err != nil {
 			errors = append(errors, err.Error())
 			errors = append(errors, err.Error())
 			continue
 			continue
 		}
 		}
-		if fallback, err := puller.Pull(ref); err != nil {
+		if fallback, err := puller.Pull(ctx, ref); err != nil {
+			// Was this pull cancelled? If so, don't try to fall
+			// back.
+			select {
+			case <-ctx.Done():
+				fallback = false
+			default:
+			}
 			if fallback {
 			if fallback {
 				if _, ok := err.(registry.ErrNoSupport); !ok {
 				if _, ok := err.(registry.ErrNoSupport); !ok {
 					// Because we found an error that's not ErrNoSupport, discard all subsequent ErrNoSupport errors.
 					// Because we found an error that's not ErrNoSupport, discard all subsequent ErrNoSupport errors.
@@ -165,11 +167,11 @@ func Pull(ref reference.Named, imagePullConfig *ImagePullConfig) error {
 // status message indicates that a newer image was downloaded. Otherwise, it
 // status message indicates that a newer image was downloaded. Otherwise, it
 // indicates that the image is up to date. requestedTag is the tag the message
 // indicates that the image is up to date. requestedTag is the tag the message
 // will refer to.
 // will refer to.
-func writeStatus(requestedTag string, out io.Writer, sf *streamformatter.StreamFormatter, layersDownloaded bool) {
+func writeStatus(requestedTag string, out progress.Output, layersDownloaded bool) {
 	if layersDownloaded {
 	if layersDownloaded {
-		out.Write(sf.FormatStatus("", "Status: Downloaded newer image for %s", requestedTag))
+		progress.Message(out, "", "Status: Downloaded newer image for "+requestedTag)
 	} else {
 	} else {
-		out.Write(sf.FormatStatus("", "Status: Image is up to date for %s", requestedTag))
+		progress.Message(out, "", "Status: Image is up to date for "+requestedTag)
 	}
 	}
 }
 }
 
 
@@ -183,3 +185,16 @@ func validateRepoName(name string) error {
 	}
 	}
 	return nil
 	return nil
 }
 }
+
+// tmpFileClose creates a closer function for a temporary file that closes the file
+// and also deletes it.
+func tmpFileCloser(tmpFile *os.File) func() error {
+	return func() error {
+		tmpFile.Close()
+		if err := os.RemoveAll(tmpFile.Name()); err != nil {
+			logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name())
+		}
+
+		return nil
+	}
+}

+ 117 - 228
distribution/pull_v1.go

@@ -1,43 +1,42 @@
 package distribution
 package distribution
 
 
 import (
 import (
-	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
+	"io/ioutil"
 	"net"
 	"net"
 	"net/url"
 	"net/url"
 	"strings"
 	"strings"
-	"sync"
 	"time"
 	"time"
 
 
 	"github.com/Sirupsen/logrus"
 	"github.com/Sirupsen/logrus"
 	"github.com/docker/distribution/reference"
 	"github.com/docker/distribution/reference"
 	"github.com/docker/distribution/registry/client/transport"
 	"github.com/docker/distribution/registry/client/transport"
 	"github.com/docker/docker/distribution/metadata"
 	"github.com/docker/docker/distribution/metadata"
+	"github.com/docker/docker/distribution/xfer"
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/image/v1"
 	"github.com/docker/docker/image/v1"
 	"github.com/docker/docker/layer"
 	"github.com/docker/docker/layer"
-	"github.com/docker/docker/pkg/archive"
-	"github.com/docker/docker/pkg/progressreader"
-	"github.com/docker/docker/pkg/streamformatter"
+	"github.com/docker/docker/pkg/ioutils"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/registry"
+	"golang.org/x/net/context"
 )
 )
 
 
 type v1Puller struct {
 type v1Puller struct {
 	v1IDService *metadata.V1IDService
 	v1IDService *metadata.V1IDService
 	endpoint    registry.APIEndpoint
 	endpoint    registry.APIEndpoint
 	config      *ImagePullConfig
 	config      *ImagePullConfig
-	sf          *streamformatter.StreamFormatter
 	repoInfo    *registry.RepositoryInfo
 	repoInfo    *registry.RepositoryInfo
 	session     *registry.Session
 	session     *registry.Session
 }
 }
 
 
-func (p *v1Puller) Pull(ref reference.Named) (fallback bool, err error) {
+func (p *v1Puller) Pull(ctx context.Context, ref reference.Named) (fallback bool, err error) {
 	if _, isDigested := ref.(reference.Digested); isDigested {
 	if _, isDigested := ref.(reference.Digested); isDigested {
 		// Allowing fallback, because HTTPS v1 is before HTTP v2
 		// Allowing fallback, because HTTPS v1 is before HTTP v2
-		return true, registry.ErrNoSupport{errors.New("Cannot pull by digest with v1 registry")}
+		return true, registry.ErrNoSupport{Err: errors.New("Cannot pull by digest with v1 registry")}
 	}
 	}
 
 
 	tlsConfig, err := p.config.RegistryService.TLSConfig(p.repoInfo.Index.Name)
 	tlsConfig, err := p.config.RegistryService.TLSConfig(p.repoInfo.Index.Name)
@@ -62,19 +61,17 @@ func (p *v1Puller) Pull(ref reference.Named) (fallback bool, err error) {
 		logrus.Debugf("Fallback from error: %s", err)
 		logrus.Debugf("Fallback from error: %s", err)
 		return true, err
 		return true, err
 	}
 	}
-	if err := p.pullRepository(ref); err != nil {
+	if err := p.pullRepository(ctx, ref); err != nil {
 		// TODO(dmcgowan): Check if should fallback
 		// TODO(dmcgowan): Check if should fallback
 		return false, err
 		return false, err
 	}
 	}
-	out := p.config.OutStream
-	out.Write(p.sf.FormatStatus("", "%s: this image was pulled from a legacy registry.  Important: This registry version will not be supported in future versions of docker.", p.repoInfo.CanonicalName.Name()))
+	progress.Message(p.config.ProgressOutput, "", p.repoInfo.CanonicalName.Name()+": this image was pulled from a legacy registry.  Important: This registry version will not be supported in future versions of docker.")
 
 
 	return false, nil
 	return false, nil
 }
 }
 
 
-func (p *v1Puller) pullRepository(ref reference.Named) error {
-	out := p.config.OutStream
-	out.Write(p.sf.FormatStatus("", "Pulling repository %s", p.repoInfo.CanonicalName.Name()))
+func (p *v1Puller) pullRepository(ctx context.Context, ref reference.Named) error {
+	progress.Message(p.config.ProgressOutput, "", "Pulling repository "+p.repoInfo.CanonicalName.Name())
 
 
 	repoData, err := p.session.GetRepositoryData(p.repoInfo.RemoteName)
 	repoData, err := p.session.GetRepositoryData(p.repoInfo.RemoteName)
 	if err != nil {
 	if err != nil {
@@ -112,46 +109,18 @@ func (p *v1Puller) pullRepository(ref reference.Named) error {
 		}
 		}
 	}
 	}
 
 
-	errors := make(chan error)
-	layerDownloaded := make(chan struct{})
-
 	layersDownloaded := false
 	layersDownloaded := false
-	var wg sync.WaitGroup
 	for _, imgData := range repoData.ImgList {
 	for _, imgData := range repoData.ImgList {
 		if isTagged && imgData.Tag != tagged.Tag() {
 		if isTagged && imgData.Tag != tagged.Tag() {
 			continue
 			continue
 		}
 		}
 
 
-		wg.Add(1)
-		go func(img *registry.ImgData) {
-			p.downloadImage(out, repoData, img, layerDownloaded, errors)
-			wg.Done()
-		}(imgData)
-	}
-
-	go func() {
-		wg.Wait()
-		close(errors)
-	}()
-
-	var lastError error
-selectLoop:
-	for {
-		select {
-		case err, ok := <-errors:
-			if !ok {
-				break selectLoop
-			}
-			lastError = err
-		case <-layerDownloaded:
-			layersDownloaded = true
+		err := p.downloadImage(ctx, repoData, imgData, &layersDownloaded)
+		if err != nil {
+			return err
 		}
 		}
 	}
 	}
 
 
-	if lastError != nil {
-		return lastError
-	}
-
 	localNameRef := p.repoInfo.LocalName
 	localNameRef := p.repoInfo.LocalName
 	if isTagged {
 	if isTagged {
 		localNameRef, err = reference.WithTag(localNameRef, tagged.Tag())
 		localNameRef, err = reference.WithTag(localNameRef, tagged.Tag())
@@ -159,194 +128,143 @@ selectLoop:
 			localNameRef = p.repoInfo.LocalName
 			localNameRef = p.repoInfo.LocalName
 		}
 		}
 	}
 	}
-	writeStatus(localNameRef.String(), out, p.sf, layersDownloaded)
+	writeStatus(localNameRef.String(), p.config.ProgressOutput, layersDownloaded)
 	return nil
 	return nil
 }
 }
 
 
-func (p *v1Puller) downloadImage(out io.Writer, repoData *registry.RepositoryData, img *registry.ImgData, layerDownloaded chan struct{}, errors chan error) {
+func (p *v1Puller) downloadImage(ctx context.Context, repoData *registry.RepositoryData, img *registry.ImgData, layersDownloaded *bool) error {
 	if img.Tag == "" {
 	if img.Tag == "" {
 		logrus.Debugf("Image (id: %s) present in this repository but untagged, skipping", img.ID)
 		logrus.Debugf("Image (id: %s) present in this repository but untagged, skipping", img.ID)
-		return
+		return nil
 	}
 	}
 
 
 	localNameRef, err := reference.WithTag(p.repoInfo.LocalName, img.Tag)
 	localNameRef, err := reference.WithTag(p.repoInfo.LocalName, img.Tag)
 	if err != nil {
 	if err != nil {
 		retErr := fmt.Errorf("Image (id: %s) has invalid tag: %s", img.ID, img.Tag)
 		retErr := fmt.Errorf("Image (id: %s) has invalid tag: %s", img.ID, img.Tag)
 		logrus.Debug(retErr.Error())
 		logrus.Debug(retErr.Error())
-		errors <- retErr
+		return retErr
 	}
 	}
 
 
 	if err := v1.ValidateID(img.ID); err != nil {
 	if err := v1.ValidateID(img.ID); err != nil {
-		errors <- err
-		return
+		return err
 	}
 	}
 
 
-	out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName.Name()), nil))
+	progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Pulling image (%s) from %s", img.Tag, p.repoInfo.CanonicalName.Name())
 	success := false
 	success := false
 	var lastErr error
 	var lastErr error
-	var isDownloaded bool
 	for _, ep := range p.repoInfo.Index.Mirrors {
 	for _, ep := range p.repoInfo.Index.Mirrors {
 		ep += "v1/"
 		ep += "v1/"
-		out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep), nil))
-		if isDownloaded, err = p.pullImage(out, img.ID, ep, localNameRef); err != nil {
+		progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, mirror: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep))
+		if err = p.pullImage(ctx, img.ID, ep, localNameRef, layersDownloaded); err != nil {
 			// Don't report errors when pulling from mirrors.
 			// Don't report errors when pulling from mirrors.
 			logrus.Debugf("Error pulling image (%s) from %s, mirror: %s, %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep, err)
 			logrus.Debugf("Error pulling image (%s) from %s, mirror: %s, %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep, err)
 			continue
 			continue
 		}
 		}
-		if isDownloaded {
-			layerDownloaded <- struct{}{}
-		}
 		success = true
 		success = true
 		break
 		break
 	}
 	}
 	if !success {
 	if !success {
 		for _, ep := range repoData.Endpoints {
 		for _, ep := range repoData.Endpoints {
-			out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep), nil))
-			if isDownloaded, err = p.pullImage(out, img.ID, ep, localNameRef); err != nil {
+			progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Pulling image (%s) from %s, endpoint: %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep)
+			if err = p.pullImage(ctx, img.ID, ep, localNameRef, layersDownloaded); err != nil {
 				// It's not ideal that only the last error is returned, it would be better to concatenate the errors.
 				// It's not ideal that only the last error is returned, it would be better to concatenate the errors.
 				// As the error is also given to the output stream the user will see the error.
 				// As the error is also given to the output stream the user will see the error.
 				lastErr = err
 				lastErr = err
-				out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), fmt.Sprintf("Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep, err), nil))
+				progress.Updatef(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Error pulling image (%s) from %s, endpoint: %s, %s", img.Tag, p.repoInfo.CanonicalName.Name(), ep, err)
 				continue
 				continue
 			}
 			}
-			if isDownloaded {
-				layerDownloaded <- struct{}{}
-			}
 			success = true
 			success = true
 			break
 			break
 		}
 		}
 	}
 	}
 	if !success {
 	if !success {
 		err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName.Name(), lastErr)
 		err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName.Name(), lastErr)
-		out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil))
-		errors <- err
-		return
+		progress.Update(p.config.ProgressOutput, stringid.TruncateID(img.ID), err.Error())
+		return err
 	}
 	}
-	out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil))
+	progress.Update(p.config.ProgressOutput, stringid.TruncateID(img.ID), "Download complete")
+	return nil
 }
 }
 
 
-func (p *v1Puller) pullImage(out io.Writer, v1ID, endpoint string, localNameRef reference.Named) (layersDownloaded bool, err error) {
+func (p *v1Puller) pullImage(ctx context.Context, v1ID, endpoint string, localNameRef reference.Named, layersDownloaded *bool) (err error) {
 	var history []string
 	var history []string
 	history, err = p.session.GetRemoteHistory(v1ID, endpoint)
 	history, err = p.session.GetRemoteHistory(v1ID, endpoint)
 	if err != nil {
 	if err != nil {
-		return false, err
+		return err
 	}
 	}
 	if len(history) < 1 {
 	if len(history) < 1 {
-		return false, fmt.Errorf("empty history for image %s", v1ID)
+		return fmt.Errorf("empty history for image %s", v1ID)
 	}
 	}
-	out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Pulling dependent layers", nil))
-	// FIXME: Try to stream the images?
-	// FIXME: Launch the getRemoteImage() in goroutines
+	progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Pulling dependent layers")
 
 
 	var (
 	var (
-		referencedLayers []layer.Layer
-		parentID         layer.ChainID
-		newHistory       []image.History
-		img              *image.V1Image
-		imgJSON          []byte
-		imgSize          int64
+		descriptors []xfer.DownloadDescriptor
+		newHistory  []image.History
+		imgJSON     []byte
+		imgSize     int64
 	)
 	)
 
 
-	defer func() {
-		for _, l := range referencedLayers {
-			layer.ReleaseAndLog(p.config.LayerStore, l)
-		}
-	}()
-
-	layersDownloaded = false
-
-	// Iterate over layers from top-most to bottom-most, checking if any
-	// already exist on disk.
-	var i int
-	for i = 0; i != len(history); i++ {
-		v1LayerID := history[i]
-		// Do we have a mapping for this particular v1 ID on this
-		// registry?
-		if layerID, err := p.v1IDService.Get(v1LayerID, p.repoInfo.Index.Name); err == nil {
-			// Does the layer actually exist
-			if l, err := p.config.LayerStore.Get(layerID); err == nil {
-				for j := i; j >= 0; j-- {
-					logrus.Debugf("Layer already exists: %s", history[j])
-					out.Write(p.sf.FormatProgress(stringid.TruncateID(history[j]), "Already exists", nil))
-				}
-				referencedLayers = append(referencedLayers, l)
-				parentID = layerID
-				break
-			}
-		}
-	}
-
-	needsDownload := i
-
 	// Iterate over layers, in order from bottom-most to top-most. Download
 	// Iterate over layers, in order from bottom-most to top-most. Download
-	// config for all layers, and download actual layer data if needed.
-	for i = len(history) - 1; i >= 0; i-- {
+	// config for all layers and create descriptors.
+	for i := len(history) - 1; i >= 0; i-- {
 		v1LayerID := history[i]
 		v1LayerID := history[i]
-		imgJSON, imgSize, err = p.downloadLayerConfig(out, v1LayerID, endpoint)
+		imgJSON, imgSize, err = p.downloadLayerConfig(v1LayerID, endpoint)
 		if err != nil {
 		if err != nil {
-			return layersDownloaded, err
-		}
-
-		img = &image.V1Image{}
-		if err := json.Unmarshal(imgJSON, img); err != nil {
-			return layersDownloaded, err
-		}
-
-		if i < needsDownload {
-			l, err := p.downloadLayer(out, v1LayerID, endpoint, parentID, imgSize, &layersDownloaded)
-
-			// Note: This needs to be done even in the error case to avoid
-			// stale references to the layer.
-			if l != nil {
-				referencedLayers = append(referencedLayers, l)
-			}
-			if err != nil {
-				return layersDownloaded, err
-			}
-
-			parentID = l.ChainID()
+			return err
 		}
 		}
 
 
 		// Create a new-style config from the legacy configs
 		// Create a new-style config from the legacy configs
 		h, err := v1.HistoryFromConfig(imgJSON, false)
 		h, err := v1.HistoryFromConfig(imgJSON, false)
 		if err != nil {
 		if err != nil {
-			return layersDownloaded, err
+			return err
 		}
 		}
 		newHistory = append(newHistory, h)
 		newHistory = append(newHistory, h)
+
+		layerDescriptor := &v1LayerDescriptor{
+			v1LayerID:        v1LayerID,
+			indexName:        p.repoInfo.Index.Name,
+			endpoint:         endpoint,
+			v1IDService:      p.v1IDService,
+			layersDownloaded: layersDownloaded,
+			layerSize:        imgSize,
+			session:          p.session,
+		}
+
+		descriptors = append(descriptors, layerDescriptor)
 	}
 	}
 
 
 	rootFS := image.NewRootFS()
 	rootFS := image.NewRootFS()
-	l := referencedLayers[len(referencedLayers)-1]
-	for l != nil {
-		rootFS.DiffIDs = append([]layer.DiffID{l.DiffID()}, rootFS.DiffIDs...)
-		l = l.Parent()
+	resultRootFS, release, err := p.config.DownloadManager.Download(ctx, *rootFS, descriptors, p.config.ProgressOutput)
+	if err != nil {
+		return err
 	}
 	}
+	defer release()
 
 
-	config, err := v1.MakeConfigFromV1Config(imgJSON, rootFS, newHistory)
+	config, err := v1.MakeConfigFromV1Config(imgJSON, &resultRootFS, newHistory)
 	if err != nil {
 	if err != nil {
-		return layersDownloaded, err
+		return err
 	}
 	}
 
 
 	imageID, err := p.config.ImageStore.Create(config)
 	imageID, err := p.config.ImageStore.Create(config)
 	if err != nil {
 	if err != nil {
-		return layersDownloaded, err
+		return err
 	}
 	}
 
 
 	if err := p.config.TagStore.AddTag(localNameRef, imageID, true); err != nil {
 	if err := p.config.TagStore.AddTag(localNameRef, imageID, true); err != nil {
-		return layersDownloaded, err
+		return err
 	}
 	}
 
 
-	return layersDownloaded, nil
+	return nil
 }
 }
 
 
-func (p *v1Puller) downloadLayerConfig(out io.Writer, v1LayerID, endpoint string) (imgJSON []byte, imgSize int64, err error) {
-	out.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Pulling metadata", nil))
+func (p *v1Puller) downloadLayerConfig(v1LayerID, endpoint string) (imgJSON []byte, imgSize int64, err error) {
+	progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1LayerID), "Pulling metadata")
 
 
 	retries := 5
 	retries := 5
 	for j := 1; j <= retries; j++ {
 	for j := 1; j <= retries; j++ {
 		imgJSON, imgSize, err := p.session.GetRemoteImageJSON(v1LayerID, endpoint)
 		imgJSON, imgSize, err := p.session.GetRemoteImageJSON(v1LayerID, endpoint)
 		if err != nil && j == retries {
 		if err != nil && j == retries {
-			out.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Error pulling layer metadata", nil))
+			progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1LayerID), "Error pulling layer metadata")
 			return nil, 0, err
 			return nil, 0, err
 		} else if err != nil {
 		} else if err != nil {
 			time.Sleep(time.Duration(j) * 500 * time.Millisecond)
 			time.Sleep(time.Duration(j) * 500 * time.Millisecond)
@@ -360,95 +278,66 @@ func (p *v1Puller) downloadLayerConfig(out io.Writer, v1LayerID, endpoint string
 	return nil, 0, nil
 	return nil, 0, nil
 }
 }
 
 
-func (p *v1Puller) downloadLayer(out io.Writer, v1LayerID, endpoint string, parentID layer.ChainID, layerSize int64, layersDownloaded *bool) (l layer.Layer, err error) {
-	// ensure no two downloads of the same layer happen at the same time
-	poolKey := "layer:" + v1LayerID
-	broadcaster, found := p.config.Pool.add(poolKey)
-	broadcaster.Add(out)
-	if found {
-		logrus.Debugf("Image (id: %s) pull is already running, skipping", v1LayerID)
-		if err = broadcaster.Wait(); err != nil {
-			return nil, err
-		}
-		layerID, err := p.v1IDService.Get(v1LayerID, p.repoInfo.Index.Name)
-		if err != nil {
-			return nil, err
-		}
-		// Does the layer actually exist
-		l, err := p.config.LayerStore.Get(layerID)
-		if err != nil {
-			return nil, err
-		}
-		return l, nil
-	}
+type v1LayerDescriptor struct {
+	v1LayerID        string
+	indexName        string
+	endpoint         string
+	v1IDService      *metadata.V1IDService
+	layersDownloaded *bool
+	layerSize        int64
+	session          *registry.Session
+}
 
 
-	// This must use a closure so it captures the value of err when
-	// the function returns, not when the 'defer' is evaluated.
-	defer func() {
-		p.config.Pool.removeWithError(poolKey, err)
-	}()
+func (ld *v1LayerDescriptor) Key() string {
+	return "v1:" + ld.v1LayerID
+}
 
 
-	retries := 5
-	for j := 1; j <= retries; j++ {
-		// Get the layer
-		status := "Pulling fs layer"
-		if j > 1 {
-			status = fmt.Sprintf("Pulling fs layer [retries: %d]", j)
-		}
-		broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), status, nil))
-		layerReader, err := p.session.GetRemoteImageLayer(v1LayerID, endpoint, layerSize)
+func (ld *v1LayerDescriptor) ID() string {
+	return stringid.TruncateID(ld.v1LayerID)
+}
+
+func (ld *v1LayerDescriptor) DiffID() (layer.DiffID, error) {
+	return ld.v1IDService.Get(ld.v1LayerID, ld.indexName)
+}
+
+func (ld *v1LayerDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) {
+	progress.Update(progressOutput, ld.ID(), "Pulling fs layer")
+	layerReader, err := ld.session.GetRemoteImageLayer(ld.v1LayerID, ld.endpoint, ld.layerSize)
+	if err != nil {
+		progress.Update(progressOutput, ld.ID(), "Error pulling dependent layers")
 		if uerr, ok := err.(*url.Error); ok {
 		if uerr, ok := err.(*url.Error); ok {
 			err = uerr.Err
 			err = uerr.Err
 		}
 		}
-		if terr, ok := err.(net.Error); ok && terr.Timeout() && j < retries {
-			time.Sleep(time.Duration(j) * 500 * time.Millisecond)
-			continue
-		} else if err != nil {
-			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Error pulling dependent layers", nil))
-			return nil, err
-		}
-		*layersDownloaded = true
-		defer layerReader.Close()
-
-		reader := progressreader.New(progressreader.Config{
-			In:        layerReader,
-			Out:       broadcaster,
-			Formatter: p.sf,
-			Size:      layerSize,
-			NewLines:  false,
-			ID:        stringid.TruncateID(v1LayerID),
-			Action:    "Downloading",
-		})
-
-		inflatedLayerData, err := archive.DecompressStream(reader)
-		if err != nil {
-			return nil, fmt.Errorf("could not get decompression stream: %v", err)
-		}
-
-		l, err := p.config.LayerStore.Register(inflatedLayerData, parentID)
-		if err != nil {
-			return nil, fmt.Errorf("failed to register layer: %v", err)
+		if terr, ok := err.(net.Error); ok && terr.Timeout() {
+			return nil, 0, err
 		}
 		}
-		logrus.Debugf("layer %s registered successfully", l.DiffID())
+		return nil, 0, xfer.DoNotRetry{Err: err}
+	}
+	*ld.layersDownloaded = true
 
 
-		if terr, ok := err.(net.Error); ok && terr.Timeout() && j < retries {
-			time.Sleep(time.Duration(j) * 500 * time.Millisecond)
-			continue
-		} else if err != nil {
-			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Error downloading dependent layers", nil))
-			return nil, err
-		}
+	tmpFile, err := ioutil.TempFile("", "GetImageBlob")
+	if err != nil {
+		layerReader.Close()
+		return nil, 0, err
+	}
 
 
-		// Cache mapping from this v1 ID to content-addressable layer ID
-		if err := p.v1IDService.Set(v1LayerID, p.repoInfo.Index.Name, l.ChainID()); err != nil {
-			return nil, err
-		}
+	reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, layerReader), progressOutput, ld.layerSize, ld.ID(), "Downloading")
+	defer reader.Close()
 
 
-		broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(v1LayerID), "Download complete", nil))
-		broadcaster.Close()
-		return l, nil
+	_, err = io.Copy(tmpFile, reader)
+	if err != nil {
+		return nil, 0, err
 	}
 	}
 
 
-	// not reached
-	return nil, nil
+	progress.Update(progressOutput, ld.ID(), "Download complete")
+
+	logrus.Debugf("Downloaded %s to tempfile %s", ld.ID(), tmpFile.Name())
+
+	tmpFile.Seek(0, 0)
+	return ioutils.NewReadCloserWrapper(tmpFile, tmpFileCloser(tmpFile)), ld.layerSize, nil
+}
+
+func (ld *v1LayerDescriptor) Registered(diffID layer.DiffID) {
+	// Cache mapping from this layer's DiffID to the blobsum
+	ld.v1IDService.Set(ld.v1LayerID, ld.indexName, diffID)
 }
 }

+ 92 - 205
distribution/pull_v2.go

@@ -15,13 +15,12 @@ import (
 	"github.com/docker/distribution/manifest/schema1"
 	"github.com/docker/distribution/manifest/schema1"
 	"github.com/docker/distribution/reference"
 	"github.com/docker/distribution/reference"
 	"github.com/docker/docker/distribution/metadata"
 	"github.com/docker/docker/distribution/metadata"
+	"github.com/docker/docker/distribution/xfer"
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/image/v1"
 	"github.com/docker/docker/image/v1"
 	"github.com/docker/docker/layer"
 	"github.com/docker/docker/layer"
-	"github.com/docker/docker/pkg/archive"
-	"github.com/docker/docker/pkg/broadcaster"
-	"github.com/docker/docker/pkg/progressreader"
-	"github.com/docker/docker/pkg/streamformatter"
+	"github.com/docker/docker/pkg/ioutils"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/registry"
 	"golang.org/x/net/context"
 	"golang.org/x/net/context"
@@ -31,23 +30,19 @@ type v2Puller struct {
 	blobSumService *metadata.BlobSumService
 	blobSumService *metadata.BlobSumService
 	endpoint       registry.APIEndpoint
 	endpoint       registry.APIEndpoint
 	config         *ImagePullConfig
 	config         *ImagePullConfig
-	sf             *streamformatter.StreamFormatter
 	repoInfo       *registry.RepositoryInfo
 	repoInfo       *registry.RepositoryInfo
 	repo           distribution.Repository
 	repo           distribution.Repository
-	sessionID      string
 }
 }
 
 
-func (p *v2Puller) Pull(ref reference.Named) (fallback bool, err error) {
+func (p *v2Puller) Pull(ctx context.Context, ref reference.Named) (fallback bool, err error) {
 	// TODO(tiborvass): was ReceiveTimeout
 	// TODO(tiborvass): was ReceiveTimeout
 	p.repo, err = NewV2Repository(p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "pull")
 	p.repo, err = NewV2Repository(p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "pull")
 	if err != nil {
 	if err != nil {
-		logrus.Debugf("Error getting v2 registry: %v", err)
+		logrus.Warnf("Error getting v2 registry: %v", err)
 		return true, err
 		return true, err
 	}
 	}
 
 
-	p.sessionID = stringid.GenerateRandomID()
-
-	if err := p.pullV2Repository(ref); err != nil {
+	if err := p.pullV2Repository(ctx, ref); err != nil {
 		if registry.ContinueOnError(err) {
 		if registry.ContinueOnError(err) {
 			logrus.Debugf("Error trying v2 registry: %v", err)
 			logrus.Debugf("Error trying v2 registry: %v", err)
 			return true, err
 			return true, err
@@ -57,7 +52,7 @@ func (p *v2Puller) Pull(ref reference.Named) (fallback bool, err error) {
 	return false, nil
 	return false, nil
 }
 }
 
 
-func (p *v2Puller) pullV2Repository(ref reference.Named) (err error) {
+func (p *v2Puller) pullV2Repository(ctx context.Context, ref reference.Named) (err error) {
 	var refs []reference.Named
 	var refs []reference.Named
 	taggedName := p.repoInfo.LocalName
 	taggedName := p.repoInfo.LocalName
 	if tagged, isTagged := ref.(reference.Tagged); isTagged {
 	if tagged, isTagged := ref.(reference.Tagged); isTagged {
@@ -73,7 +68,7 @@ func (p *v2Puller) pullV2Repository(ref reference.Named) (err error) {
 		}
 		}
 		refs = []reference.Named{taggedName}
 		refs = []reference.Named{taggedName}
 	} else {
 	} else {
-		manSvc, err := p.repo.Manifests(context.Background())
+		manSvc, err := p.repo.Manifests(ctx)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -98,98 +93,109 @@ func (p *v2Puller) pullV2Repository(ref reference.Named) (err error) {
 	for _, pullRef := range refs {
 	for _, pullRef := range refs {
 		// pulledNew is true if either new layers were downloaded OR if existing images were newly tagged
 		// pulledNew is true if either new layers were downloaded OR if existing images were newly tagged
 		// TODO(tiborvass): should we change the name of `layersDownload`? What about message in WriteStatus?
 		// TODO(tiborvass): should we change the name of `layersDownload`? What about message in WriteStatus?
-		pulledNew, err := p.pullV2Tag(p.config.OutStream, pullRef)
+		pulledNew, err := p.pullV2Tag(ctx, pullRef)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
 		layersDownloaded = layersDownloaded || pulledNew
 		layersDownloaded = layersDownloaded || pulledNew
 	}
 	}
 
 
-	writeStatus(taggedName.String(), p.config.OutStream, p.sf, layersDownloaded)
+	writeStatus(taggedName.String(), p.config.ProgressOutput, layersDownloaded)
 
 
 	return nil
 	return nil
 }
 }
 
 
-// downloadInfo is used to pass information from download to extractor
-type downloadInfo struct {
-	tmpFile     *os.File
-	digest      digest.Digest
-	layer       distribution.ReadSeekCloser
-	size        int64
-	err         chan error
-	poolKey     string
-	broadcaster *broadcaster.Buffered
+type v2LayerDescriptor struct {
+	digest         digest.Digest
+	repo           distribution.Repository
+	blobSumService *metadata.BlobSumService
 }
 }
 
 
-type errVerification struct{}
+func (ld *v2LayerDescriptor) Key() string {
+	return "v2:" + ld.digest.String()
+}
 
 
-func (errVerification) Error() string { return "verification failed" }
+func (ld *v2LayerDescriptor) ID() string {
+	return stringid.TruncateID(ld.digest.String())
+}
 
 
-func (p *v2Puller) download(di *downloadInfo) {
-	logrus.Debugf("pulling blob %q", di.digest)
+func (ld *v2LayerDescriptor) DiffID() (layer.DiffID, error) {
+	return ld.blobSumService.GetDiffID(ld.digest)
+}
+
+func (ld *v2LayerDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) {
+	logrus.Debugf("pulling blob %q", ld.digest)
 
 
-	blobs := p.repo.Blobs(context.Background())
+	blobs := ld.repo.Blobs(ctx)
 
 
-	layerDownload, err := blobs.Open(context.Background(), di.digest)
+	layerDownload, err := blobs.Open(ctx, ld.digest)
 	if err != nil {
 	if err != nil {
-		logrus.Debugf("Error fetching layer: %v", err)
-		di.err <- err
-		return
+		logrus.Debugf("Error statting layer: %v", err)
+		if err == distribution.ErrBlobUnknown {
+			return nil, 0, xfer.DoNotRetry{Err: err}
+		}
+		return nil, 0, retryOnError(err)
 	}
 	}
-	defer layerDownload.Close()
 
 
-	di.size, err = layerDownload.Seek(0, os.SEEK_END)
+	size, err := layerDownload.Seek(0, os.SEEK_END)
 	if err != nil {
 	if err != nil {
 		// Seek failed, perhaps because there was no Content-Length
 		// Seek failed, perhaps because there was no Content-Length
 		// header. This shouldn't fail the download, because we can
 		// header. This shouldn't fail the download, because we can
 		// still continue without a progress bar.
 		// still continue without a progress bar.
-		di.size = 0
+		size = 0
 	} else {
 	} else {
 		// Restore the seek offset at the beginning of the stream.
 		// Restore the seek offset at the beginning of the stream.
 		_, err = layerDownload.Seek(0, os.SEEK_SET)
 		_, err = layerDownload.Seek(0, os.SEEK_SET)
 		if err != nil {
 		if err != nil {
-			di.err <- err
-			return
+			return nil, 0, err
 		}
 		}
 	}
 	}
 
 
-	verifier, err := digest.NewDigestVerifier(di.digest)
+	reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, layerDownload), progressOutput, size, ld.ID(), "Downloading")
+	defer reader.Close()
+
+	verifier, err := digest.NewDigestVerifier(ld.digest)
 	if err != nil {
 	if err != nil {
-		di.err <- err
-		return
+		return nil, 0, xfer.DoNotRetry{Err: err}
 	}
 	}
 
 
-	digestStr := di.digest.String()
+	tmpFile, err := ioutil.TempFile("", "GetImageBlob")
+	if err != nil {
+		return nil, 0, xfer.DoNotRetry{Err: err}
+	}
 
 
-	reader := progressreader.New(progressreader.Config{
-		In:        ioutil.NopCloser(io.TeeReader(layerDownload, verifier)),
-		Out:       di.broadcaster,
-		Formatter: p.sf,
-		Size:      di.size,
-		NewLines:  false,
-		ID:        stringid.TruncateID(digestStr),
-		Action:    "Downloading",
-	})
-	io.Copy(di.tmpFile, reader)
+	_, err = io.Copy(tmpFile, io.TeeReader(reader, verifier))
+	if err != nil {
+		return nil, 0, retryOnError(err)
+	}
 
 
-	di.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(digestStr), "Verifying Checksum", nil))
+	progress.Update(progressOutput, ld.ID(), "Verifying Checksum")
 
 
 	if !verifier.Verified() {
 	if !verifier.Verified() {
-		err = fmt.Errorf("filesystem layer verification failed for digest %s", di.digest)
+		err = fmt.Errorf("filesystem layer verification failed for digest %s", ld.digest)
 		logrus.Error(err)
 		logrus.Error(err)
-		di.err <- err
-		return
+		tmpFile.Close()
+		if err := os.RemoveAll(tmpFile.Name()); err != nil {
+			logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name())
+		}
+
+		return nil, 0, xfer.DoNotRetry{Err: err}
 	}
 	}
 
 
-	di.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(digestStr), "Download complete", nil))
+	progress.Update(progressOutput, ld.ID(), "Download complete")
 
 
-	logrus.Debugf("Downloaded %s to tempfile %s", digestStr, di.tmpFile.Name())
-	di.layer = layerDownload
+	logrus.Debugf("Downloaded %s to tempfile %s", ld.ID(), tmpFile.Name())
+
+	tmpFile.Seek(0, 0)
+	return ioutils.NewReadCloserWrapper(tmpFile, tmpFileCloser(tmpFile)), size, nil
+}
 
 
-	di.err <- nil
+func (ld *v2LayerDescriptor) Registered(diffID layer.DiffID) {
+	// Cache mapping from this layer's DiffID to the blobsum
+	ld.blobSumService.Add(diffID, ld.digest)
 }
 }
 
 
-func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated bool, err error) {
+func (p *v2Puller) pullV2Tag(ctx context.Context, ref reference.Named) (tagUpdated bool, err error) {
 	tagOrDigest := ""
 	tagOrDigest := ""
 	if tagged, isTagged := ref.(reference.Tagged); isTagged {
 	if tagged, isTagged := ref.(reference.Tagged); isTagged {
 		tagOrDigest = tagged.Tag()
 		tagOrDigest = tagged.Tag()
@@ -201,7 +207,7 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo
 
 
 	logrus.Debugf("Pulling ref from V2 registry: %q", tagOrDigest)
 	logrus.Debugf("Pulling ref from V2 registry: %q", tagOrDigest)
 
 
-	manSvc, err := p.repo.Manifests(context.Background())
+	manSvc, err := p.repo.Manifests(ctx)
 	if err != nil {
 	if err != nil {
 		return false, err
 		return false, err
 	}
 	}
@@ -231,33 +237,17 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo
 		return false, err
 		return false, err
 	}
 	}
 
 
-	out.Write(p.sf.FormatStatus(tagOrDigest, "Pulling from %s", p.repo.Name()))
+	progress.Message(p.config.ProgressOutput, tagOrDigest, "Pulling from "+p.repo.Name())
 
 
-	var downloads []*downloadInfo
-
-	defer func() {
-		for _, d := range downloads {
-			p.config.Pool.removeWithError(d.poolKey, err)
-			if d.tmpFile != nil {
-				d.tmpFile.Close()
-				if err := os.RemoveAll(d.tmpFile.Name()); err != nil {
-					logrus.Errorf("Failed to remove temp file: %s", d.tmpFile.Name())
-				}
-			}
-		}
-	}()
+	var descriptors []xfer.DownloadDescriptor
 
 
 	// Image history converted to the new format
 	// Image history converted to the new format
 	var history []image.History
 	var history []image.History
 
 
-	poolKey := "v2layer:"
-	notFoundLocally := false
-
 	// Note that the order of this loop is in the direction of bottom-most
 	// Note that the order of this loop is in the direction of bottom-most
 	// to top-most, so that the downloads slice gets ordered correctly.
 	// to top-most, so that the downloads slice gets ordered correctly.
 	for i := len(verifiedManifest.FSLayers) - 1; i >= 0; i-- {
 	for i := len(verifiedManifest.FSLayers) - 1; i >= 0; i-- {
 		blobSum := verifiedManifest.FSLayers[i].BlobSum
 		blobSum := verifiedManifest.FSLayers[i].BlobSum
-		poolKey += blobSum.String()
 
 
 		var throwAway struct {
 		var throwAway struct {
 			ThrowAway bool `json:"throwaway,omitempty"`
 			ThrowAway bool `json:"throwaway,omitempty"`
@@ -276,119 +266,22 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo
 			continue
 			continue
 		}
 		}
 
 
-		// Do we have a layer on disk corresponding to the set of
-		// blobsums up to this point?
-		if !notFoundLocally {
-			notFoundLocally = true
-			diffID, err := p.blobSumService.GetDiffID(blobSum)
-			if err == nil {
-				rootFS.Append(diffID)
-				if l, err := p.config.LayerStore.Get(rootFS.ChainID()); err == nil {
-					notFoundLocally = false
-					logrus.Debugf("Layer already exists: %s", blobSum.String())
-					out.Write(p.sf.FormatProgress(stringid.TruncateID(blobSum.String()), "Already exists", nil))
-					defer layer.ReleaseAndLog(p.config.LayerStore, l)
-					continue
-				} else {
-					rootFS.DiffIDs = rootFS.DiffIDs[:len(rootFS.DiffIDs)-1]
-				}
-			}
+		layerDescriptor := &v2LayerDescriptor{
+			digest:         blobSum,
+			repo:           p.repo,
+			blobSumService: p.blobSumService,
 		}
 		}
 
 
-		out.Write(p.sf.FormatProgress(stringid.TruncateID(blobSum.String()), "Pulling fs layer", nil))
-
-		tmpFile, err := ioutil.TempFile("", "GetImageBlob")
-		if err != nil {
-			return false, err
-		}
-
-		d := &downloadInfo{
-			poolKey: poolKey,
-			digest:  blobSum,
-			tmpFile: tmpFile,
-			// TODO: seems like this chan buffer solved hanging problem in go1.5,
-			// this can indicate some deeper problem that somehow we never take
-			// error from channel in loop below
-			err: make(chan error, 1),
-		}
-
-		downloads = append(downloads, d)
-
-		broadcaster, found := p.config.Pool.add(d.poolKey)
-		broadcaster.Add(out)
-		d.broadcaster = broadcaster
-		if found {
-			d.err <- nil
-		} else {
-			go p.download(d)
-		}
+		descriptors = append(descriptors, layerDescriptor)
 	}
 	}
 
 
-	for _, d := range downloads {
-		if err := <-d.err; err != nil {
-			return false, err
-		}
-
-		if d.layer == nil {
-			// Wait for a different pull to download and extract
-			// this layer.
-			err = d.broadcaster.Wait()
-			if err != nil {
-				return false, err
-			}
-
-			diffID, err := p.blobSumService.GetDiffID(d.digest)
-			if err != nil {
-				return false, err
-			}
-			rootFS.Append(diffID)
-
-			l, err := p.config.LayerStore.Get(rootFS.ChainID())
-			if err != nil {
-				return false, err
-			}
-
-			defer layer.ReleaseAndLog(p.config.LayerStore, l)
-
-			continue
-		}
-
-		d.tmpFile.Seek(0, 0)
-		reader := progressreader.New(progressreader.Config{
-			In:        d.tmpFile,
-			Out:       d.broadcaster,
-			Formatter: p.sf,
-			Size:      d.size,
-			NewLines:  false,
-			ID:        stringid.TruncateID(d.digest.String()),
-			Action:    "Extracting",
-		})
-
-		inflatedLayerData, err := archive.DecompressStream(reader)
-		if err != nil {
-			return false, fmt.Errorf("could not get decompression stream: %v", err)
-		}
-
-		l, err := p.config.LayerStore.Register(inflatedLayerData, rootFS.ChainID())
-		if err != nil {
-			return false, fmt.Errorf("failed to register layer: %v", err)
-		}
-		logrus.Debugf("layer %s registered successfully", l.DiffID())
-		rootFS.Append(l.DiffID())
-
-		// Cache mapping from this layer's DiffID to the blobsum
-		if err := p.blobSumService.Add(l.DiffID(), d.digest); err != nil {
-			return false, err
-		}
-
-		defer layer.ReleaseAndLog(p.config.LayerStore, l)
-
-		d.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(d.digest.String()), "Pull complete", nil))
-		d.broadcaster.Close()
-		tagUpdated = true
+	resultRootFS, release, err := p.config.DownloadManager.Download(ctx, *rootFS, descriptors, p.config.ProgressOutput)
+	if err != nil {
+		return false, err
 	}
 	}
+	defer release()
 
 
-	config, err := v1.MakeConfigFromV1Config([]byte(verifiedManifest.History[0].V1Compatibility), rootFS, history)
+	config, err := v1.MakeConfigFromV1Config([]byte(verifiedManifest.History[0].V1Compatibility), &resultRootFS, history)
 	if err != nil {
 	if err != nil {
 		return false, err
 		return false, err
 	}
 	}
@@ -403,30 +296,24 @@ func (p *v2Puller) pullV2Tag(out io.Writer, ref reference.Named) (tagUpdated boo
 		return false, err
 		return false, err
 	}
 	}
 
 
-	// Check for new tag if no layers downloaded
-	var oldTagImageID image.ID
-	if !tagUpdated {
-		oldTagImageID, err = p.config.TagStore.Get(ref)
-		if err != nil || oldTagImageID != imageID {
-			tagUpdated = true
-		}
+	if manifestDigest != "" {
+		progress.Message(p.config.ProgressOutput, "", "Digest: "+manifestDigest.String())
 	}
 	}
 
 
-	if tagUpdated {
-		if canonical, ok := ref.(reference.Canonical); ok {
-			if err = p.config.TagStore.AddDigest(canonical, imageID, true); err != nil {
-				return false, err
-			}
-		} else if err = p.config.TagStore.AddTag(ref, imageID, true); err != nil {
-			return false, err
-		}
+	oldTagImageID, err := p.config.TagStore.Get(ref)
+	if err == nil && oldTagImageID == imageID {
+		return false, nil
 	}
 	}
 
 
-	if manifestDigest != "" {
-		out.Write(p.sf.FormatStatus("", "Digest: %s", manifestDigest))
+	if canonical, ok := ref.(reference.Canonical); ok {
+		if err = p.config.TagStore.AddDigest(canonical, imageID, true); err != nil {
+			return false, err
+		}
+	} else if err = p.config.TagStore.AddTag(ref, imageID, true); err != nil {
+		return false, err
 	}
 	}
 
 
-	return tagUpdated, nil
+	return true, nil
 }
 }
 
 
 func verifyManifest(signedManifest *schema1.SignedManifest, ref reference.Reference) (m *schema1.Manifest, err error) {
 func verifyManifest(signedManifest *schema1.SignedManifest, ref reference.Reference) (m *schema1.Manifest, err error) {

+ 22 - 14
distribution/push.go

@@ -12,12 +12,14 @@ import (
 	"github.com/docker/docker/cliconfig"
 	"github.com/docker/docker/cliconfig"
 	"github.com/docker/docker/daemon/events"
 	"github.com/docker/docker/daemon/events"
 	"github.com/docker/docker/distribution/metadata"
 	"github.com/docker/docker/distribution/metadata"
+	"github.com/docker/docker/distribution/xfer"
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/layer"
 	"github.com/docker/docker/layer"
-	"github.com/docker/docker/pkg/streamformatter"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/tag"
 	"github.com/docker/docker/tag"
 	"github.com/docker/libtrust"
 	"github.com/docker/libtrust"
+	"golang.org/x/net/context"
 )
 )
 
 
 // ImagePushConfig stores push configuration.
 // ImagePushConfig stores push configuration.
@@ -28,9 +30,9 @@ type ImagePushConfig struct {
 	// AuthConfig holds authentication credentials for authenticating with
 	// AuthConfig holds authentication credentials for authenticating with
 	// the registry.
 	// the registry.
 	AuthConfig *cliconfig.AuthConfig
 	AuthConfig *cliconfig.AuthConfig
-	// OutStream is the output writer for showing the status of the push
+	// ProgressOutput is the interface for showing the status of the push
 	// operation.
 	// operation.
-	OutStream io.Writer
+	ProgressOutput progress.Output
 	// RegistryService is the registry service to use for TLS configuration
 	// RegistryService is the registry service to use for TLS configuration
 	// and endpoint lookup.
 	// and endpoint lookup.
 	RegistryService *registry.Service
 	RegistryService *registry.Service
@@ -48,6 +50,8 @@ type ImagePushConfig struct {
 	// TrustKey is the private key for legacy signatures. This is typically
 	// TrustKey is the private key for legacy signatures. This is typically
 	// an ephemeral key, since these signatures are no longer verified.
 	// an ephemeral key, since these signatures are no longer verified.
 	TrustKey libtrust.PrivateKey
 	TrustKey libtrust.PrivateKey
+	// UploadManager dispatches uploads.
+	UploadManager *xfer.LayerUploadManager
 }
 }
 
 
 // Pusher is an interface that abstracts pushing for different API versions.
 // Pusher is an interface that abstracts pushing for different API versions.
@@ -56,7 +60,7 @@ type Pusher interface {
 	// Push returns an error if any, as well as a boolean that determines whether to retry Push on the next configured endpoint.
 	// Push returns an error if any, as well as a boolean that determines whether to retry Push on the next configured endpoint.
 	//
 	//
 	// TODO(tiborvass): have Push() take a reference to repository + tag, so that the pusher itself is repository-agnostic.
 	// TODO(tiborvass): have Push() take a reference to repository + tag, so that the pusher itself is repository-agnostic.
-	Push() (fallback bool, err error)
+	Push(ctx context.Context) (fallback bool, err error)
 }
 }
 
 
 const compressionBufSize = 32768
 const compressionBufSize = 32768
@@ -66,7 +70,7 @@ const compressionBufSize = 32768
 // whether a v1 or v2 pusher will be created. The other parameters are passed
 // whether a v1 or v2 pusher will be created. The other parameters are passed
 // through to the underlying pusher implementation for use during the actual
 // through to the underlying pusher implementation for use during the actual
 // push operation.
 // push operation.
-func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePushConfig *ImagePushConfig, sf *streamformatter.StreamFormatter) (Pusher, error) {
+func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *registry.RepositoryInfo, imagePushConfig *ImagePushConfig) (Pusher, error) {
 	switch endpoint.Version {
 	switch endpoint.Version {
 	case registry.APIVersion2:
 	case registry.APIVersion2:
 		return &v2Pusher{
 		return &v2Pusher{
@@ -75,8 +79,7 @@ func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *reg
 			endpoint:       endpoint,
 			endpoint:       endpoint,
 			repoInfo:       repoInfo,
 			repoInfo:       repoInfo,
 			config:         imagePushConfig,
 			config:         imagePushConfig,
-			sf:             sf,
-			layersPushed:   make(map[digest.Digest]bool),
+			layersPushed:   pushMap{layersPushed: make(map[digest.Digest]bool)},
 		}, nil
 		}, nil
 	case registry.APIVersion1:
 	case registry.APIVersion1:
 		return &v1Pusher{
 		return &v1Pusher{
@@ -85,7 +88,6 @@ func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *reg
 			endpoint:    endpoint,
 			endpoint:    endpoint,
 			repoInfo:    repoInfo,
 			repoInfo:    repoInfo,
 			config:      imagePushConfig,
 			config:      imagePushConfig,
-			sf:          sf,
 		}, nil
 		}, nil
 	}
 	}
 	return nil, fmt.Errorf("unknown version %d for registry %s", endpoint.Version, endpoint.URL)
 	return nil, fmt.Errorf("unknown version %d for registry %s", endpoint.Version, endpoint.URL)
@@ -94,11 +96,9 @@ func NewPusher(ref reference.Named, endpoint registry.APIEndpoint, repoInfo *reg
 // Push initiates a push operation on the repository named localName.
 // Push initiates a push operation on the repository named localName.
 // ref is the specific variant of the image to be pushed.
 // ref is the specific variant of the image to be pushed.
 // If no tag is provided, all tags will be pushed.
 // If no tag is provided, all tags will be pushed.
-func Push(ref reference.Named, imagePushConfig *ImagePushConfig) error {
+func Push(ctx context.Context, ref reference.Named, imagePushConfig *ImagePushConfig) error {
 	// FIXME: Allow to interrupt current push when new push of same image is done.
 	// FIXME: Allow to interrupt current push when new push of same image is done.
 
 
-	var sf = streamformatter.NewJSONStreamFormatter()
-
 	// Resolve the Repository name from fqn to RepositoryInfo
 	// Resolve the Repository name from fqn to RepositoryInfo
 	repoInfo, err := imagePushConfig.RegistryService.ResolveRepository(ref)
 	repoInfo, err := imagePushConfig.RegistryService.ResolveRepository(ref)
 	if err != nil {
 	if err != nil {
@@ -110,7 +110,7 @@ func Push(ref reference.Named, imagePushConfig *ImagePushConfig) error {
 		return err
 		return err
 	}
 	}
 
 
-	imagePushConfig.OutStream.Write(sf.FormatStatus("", "The push refers to a repository [%s]", repoInfo.CanonicalName))
+	progress.Messagef(imagePushConfig.ProgressOutput, "", "The push refers to a repository [%s]", repoInfo.CanonicalName.String())
 
 
 	associations := imagePushConfig.TagStore.ReferencesByName(repoInfo.LocalName)
 	associations := imagePushConfig.TagStore.ReferencesByName(repoInfo.LocalName)
 	if len(associations) == 0 {
 	if len(associations) == 0 {
@@ -121,12 +121,20 @@ func Push(ref reference.Named, imagePushConfig *ImagePushConfig) error {
 	for _, endpoint := range endpoints {
 	for _, endpoint := range endpoints {
 		logrus.Debugf("Trying to push %s to %s %s", repoInfo.CanonicalName, endpoint.URL, endpoint.Version)
 		logrus.Debugf("Trying to push %s to %s %s", repoInfo.CanonicalName, endpoint.URL, endpoint.Version)
 
 
-		pusher, err := NewPusher(ref, endpoint, repoInfo, imagePushConfig, sf)
+		pusher, err := NewPusher(ref, endpoint, repoInfo, imagePushConfig)
 		if err != nil {
 		if err != nil {
 			lastErr = err
 			lastErr = err
 			continue
 			continue
 		}
 		}
-		if fallback, err := pusher.Push(); err != nil {
+		if fallback, err := pusher.Push(ctx); err != nil {
+			// Was this push cancelled? If so, don't try to fall
+			// back.
+			select {
+			case <-ctx.Done():
+				fallback = false
+			default:
+			}
+
 			if fallback {
 			if fallback {
 				lastErr = err
 				lastErr = err
 				continue
 				continue

+ 20 - 32
distribution/push_v1.go

@@ -2,8 +2,6 @@ package distribution
 
 
 import (
 import (
 	"fmt"
 	"fmt"
-	"io"
-	"io/ioutil"
 	"sync"
 	"sync"
 
 
 	"github.com/Sirupsen/logrus"
 	"github.com/Sirupsen/logrus"
@@ -15,25 +13,23 @@ import (
 	"github.com/docker/docker/image/v1"
 	"github.com/docker/docker/image/v1"
 	"github.com/docker/docker/layer"
 	"github.com/docker/docker/layer"
 	"github.com/docker/docker/pkg/ioutils"
 	"github.com/docker/docker/pkg/ioutils"
-	"github.com/docker/docker/pkg/progressreader"
-	"github.com/docker/docker/pkg/streamformatter"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/registry"
+	"golang.org/x/net/context"
 )
 )
 
 
 type v1Pusher struct {
 type v1Pusher struct {
+	ctx         context.Context
 	v1IDService *metadata.V1IDService
 	v1IDService *metadata.V1IDService
 	endpoint    registry.APIEndpoint
 	endpoint    registry.APIEndpoint
 	ref         reference.Named
 	ref         reference.Named
 	repoInfo    *registry.RepositoryInfo
 	repoInfo    *registry.RepositoryInfo
 	config      *ImagePushConfig
 	config      *ImagePushConfig
-	sf          *streamformatter.StreamFormatter
 	session     *registry.Session
 	session     *registry.Session
-
-	out io.Writer
 }
 }
 
 
-func (p *v1Pusher) Push() (fallback bool, err error) {
+func (p *v1Pusher) Push(ctx context.Context) (fallback bool, err error) {
 	tlsConfig, err := p.config.RegistryService.TLSConfig(p.repoInfo.Index.Name)
 	tlsConfig, err := p.config.RegistryService.TLSConfig(p.repoInfo.Index.Name)
 	if err != nil {
 	if err != nil {
 		return false, err
 		return false, err
@@ -55,7 +51,7 @@ func (p *v1Pusher) Push() (fallback bool, err error) {
 		// TODO(dmcgowan): Check if should fallback
 		// TODO(dmcgowan): Check if should fallback
 		return true, err
 		return true, err
 	}
 	}
-	if err := p.pushRepository(); err != nil {
+	if err := p.pushRepository(ctx); err != nil {
 		// TODO(dmcgowan): Check if should fallback
 		// TODO(dmcgowan): Check if should fallback
 		return false, err
 		return false, err
 	}
 	}
@@ -306,12 +302,12 @@ func (p *v1Pusher) lookupImageOnEndpoint(wg *sync.WaitGroup, endpoint string, im
 			logrus.Errorf("Error in LookupRemoteImage: %s", err)
 			logrus.Errorf("Error in LookupRemoteImage: %s", err)
 			imagesToPush <- v1ID
 			imagesToPush <- v1ID
 		} else {
 		} else {
-			p.out.Write(p.sf.FormatStatus("", "Image %s already pushed, skipping", stringid.TruncateID(v1ID)))
+			progress.Messagef(p.config.ProgressOutput, "", "Image %s already pushed, skipping", stringid.TruncateID(v1ID))
 		}
 		}
 	}
 	}
 }
 }
 
 
-func (p *v1Pusher) pushImageToEndpoint(endpoint string, imageList []v1Image, tags map[image.ID][]string, repo *registry.RepositoryData) error {
+func (p *v1Pusher) pushImageToEndpoint(ctx context.Context, endpoint string, imageList []v1Image, tags map[image.ID][]string, repo *registry.RepositoryData) error {
 	workerCount := len(imageList)
 	workerCount := len(imageList)
 	// start a maximum of 5 workers to check if images exist on the specified endpoint.
 	// start a maximum of 5 workers to check if images exist on the specified endpoint.
 	if workerCount > 5 {
 	if workerCount > 5 {
@@ -349,14 +345,14 @@ func (p *v1Pusher) pushImageToEndpoint(endpoint string, imageList []v1Image, tag
 	for _, img := range imageList {
 	for _, img := range imageList {
 		v1ID := img.V1ID()
 		v1ID := img.V1ID()
 		if _, push := shouldPush[v1ID]; push {
 		if _, push := shouldPush[v1ID]; push {
-			if _, err := p.pushImage(img, endpoint); err != nil {
+			if _, err := p.pushImage(ctx, img, endpoint); err != nil {
 				// FIXME: Continue on error?
 				// FIXME: Continue on error?
 				return err
 				return err
 			}
 			}
 		}
 		}
 		if topImage, isTopImage := img.(*v1TopImage); isTopImage {
 		if topImage, isTopImage := img.(*v1TopImage); isTopImage {
 			for _, tag := range tags[topImage.imageID] {
 			for _, tag := range tags[topImage.imageID] {
-				p.out.Write(p.sf.FormatStatus("", "Pushing tag for rev [%s] on {%s}", stringid.TruncateID(v1ID), endpoint+"repositories/"+p.repoInfo.RemoteName.Name()+"/tags/"+tag))
+				progress.Messagef(p.config.ProgressOutput, "", "Pushing tag for rev [%s] on {%s}", stringid.TruncateID(v1ID), endpoint+"repositories/"+p.repoInfo.RemoteName.Name()+"/tags/"+tag)
 				if err := p.session.PushRegistryTag(p.repoInfo.RemoteName, v1ID, tag, endpoint); err != nil {
 				if err := p.session.PushRegistryTag(p.repoInfo.RemoteName, v1ID, tag, endpoint); err != nil {
 					return err
 					return err
 				}
 				}
@@ -367,8 +363,7 @@ func (p *v1Pusher) pushImageToEndpoint(endpoint string, imageList []v1Image, tag
 }
 }
 
 
 // pushRepository pushes layers that do not already exist on the registry.
 // pushRepository pushes layers that do not already exist on the registry.
-func (p *v1Pusher) pushRepository() error {
-	p.out = ioutils.NewWriteFlusher(p.config.OutStream)
+func (p *v1Pusher) pushRepository(ctx context.Context) error {
 	imgList, tags, referencedLayers, err := p.getImageList()
 	imgList, tags, referencedLayers, err := p.getImageList()
 	defer func() {
 	defer func() {
 		for _, l := range referencedLayers {
 		for _, l := range referencedLayers {
@@ -378,7 +373,7 @@ func (p *v1Pusher) pushRepository() error {
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	p.out.Write(p.sf.FormatStatus("", "Sending image list"))
+	progress.Message(p.config.ProgressOutput, "", "Sending image list")
 
 
 	imageIndex := createImageIndex(imgList, tags)
 	imageIndex := createImageIndex(imgList, tags)
 	for _, data := range imageIndex {
 	for _, data := range imageIndex {
@@ -391,10 +386,10 @@ func (p *v1Pusher) pushRepository() error {
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	p.out.Write(p.sf.FormatStatus("", "Pushing repository %s", p.repoInfo.CanonicalName))
+	progress.Message(p.config.ProgressOutput, "", "Pushing repository "+p.repoInfo.CanonicalName.String())
 	// push the repository to each of the endpoints only if it does not exist.
 	// push the repository to each of the endpoints only if it does not exist.
 	for _, endpoint := range repoData.Endpoints {
 	for _, endpoint := range repoData.Endpoints {
-		if err := p.pushImageToEndpoint(endpoint, imgList, tags, repoData); err != nil {
+		if err := p.pushImageToEndpoint(ctx, endpoint, imgList, tags, repoData); err != nil {
 			return err
 			return err
 		}
 		}
 	}
 	}
@@ -402,11 +397,11 @@ func (p *v1Pusher) pushRepository() error {
 	return err
 	return err
 }
 }
 
 
-func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err error) {
+func (p *v1Pusher) pushImage(ctx context.Context, v1Image v1Image, ep string) (checksum string, err error) {
 	v1ID := v1Image.V1ID()
 	v1ID := v1Image.V1ID()
 
 
 	jsonRaw := v1Image.Config()
 	jsonRaw := v1Image.Config()
-	p.out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Pushing", nil))
+	progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Pushing")
 
 
 	// General rule is to use ID for graph accesses and compatibilityID for
 	// General rule is to use ID for graph accesses and compatibilityID for
 	// calls to session.registry()
 	// calls to session.registry()
@@ -417,7 +412,7 @@ func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err e
 	// Send the json
 	// Send the json
 	if err := p.session.PushImageJSONRegistry(imgData, jsonRaw, ep); err != nil {
 	if err := p.session.PushImageJSONRegistry(imgData, jsonRaw, ep); err != nil {
 		if err == registry.ErrAlreadyExists {
 		if err == registry.ErrAlreadyExists {
-			p.out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Image already pushed, skipping", nil))
+			progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Image already pushed, skipping")
 			return "", nil
 			return "", nil
 		}
 		}
 		return "", err
 		return "", err
@@ -437,15 +432,8 @@ func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err e
 	// Send the layer
 	// Send the layer
 	logrus.Debugf("rendered layer for %s of [%d] size", v1ID, size)
 	logrus.Debugf("rendered layer for %s of [%d] size", v1ID, size)
 
 
-	reader := progressreader.New(progressreader.Config{
-		In:        ioutil.NopCloser(arch),
-		Out:       p.out,
-		Formatter: p.sf,
-		Size:      size,
-		NewLines:  false,
-		ID:        stringid.TruncateID(v1ID),
-		Action:    "Pushing",
-	})
+	reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, arch), p.config.ProgressOutput, size, stringid.TruncateID(v1ID), "Pushing")
+	defer reader.Close()
 
 
 	checksum, checksumPayload, err := p.session.PushImageLayerRegistry(v1ID, reader, ep, jsonRaw)
 	checksum, checksumPayload, err := p.session.PushImageLayerRegistry(v1ID, reader, ep, jsonRaw)
 	if err != nil {
 	if err != nil {
@@ -458,10 +446,10 @@ func (p *v1Pusher) pushImage(v1Image v1Image, ep string) (checksum string, err e
 		return "", err
 		return "", err
 	}
 	}
 
 
-	if err := p.v1IDService.Set(v1ID, p.repoInfo.Index.Name, l.ChainID()); err != nil {
+	if err := p.v1IDService.Set(v1ID, p.repoInfo.Index.Name, l.DiffID()); err != nil {
 		logrus.Warnf("Could not set v1 ID mapping: %v", err)
 		logrus.Warnf("Could not set v1 ID mapping: %v", err)
 	}
 	}
 
 
-	p.out.Write(p.sf.FormatProgress(stringid.TruncateID(v1ID), "Image successfully pushed", nil))
+	progress.Update(p.config.ProgressOutput, stringid.TruncateID(v1ID), "Image successfully pushed")
 	return imgData.Checksum, nil
 	return imgData.Checksum, nil
 }
 }

+ 118 - 105
distribution/push_v2.go

@@ -5,7 +5,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"io/ioutil"
+	"sync"
 	"time"
 	"time"
 
 
 	"github.com/Sirupsen/logrus"
 	"github.com/Sirupsen/logrus"
@@ -15,11 +15,12 @@ import (
 	"github.com/docker/distribution/manifest/schema1"
 	"github.com/docker/distribution/manifest/schema1"
 	"github.com/docker/distribution/reference"
 	"github.com/docker/distribution/reference"
 	"github.com/docker/docker/distribution/metadata"
 	"github.com/docker/docker/distribution/metadata"
+	"github.com/docker/docker/distribution/xfer"
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/image/v1"
 	"github.com/docker/docker/image/v1"
 	"github.com/docker/docker/layer"
 	"github.com/docker/docker/layer"
-	"github.com/docker/docker/pkg/progressreader"
-	"github.com/docker/docker/pkg/streamformatter"
+	"github.com/docker/docker/pkg/ioutils"
+	"github.com/docker/docker/pkg/progress"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/docker/docker/pkg/stringid"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/tag"
 	"github.com/docker/docker/tag"
@@ -32,16 +33,20 @@ type v2Pusher struct {
 	endpoint       registry.APIEndpoint
 	endpoint       registry.APIEndpoint
 	repoInfo       *registry.RepositoryInfo
 	repoInfo       *registry.RepositoryInfo
 	config         *ImagePushConfig
 	config         *ImagePushConfig
-	sf             *streamformatter.StreamFormatter
 	repo           distribution.Repository
 	repo           distribution.Repository
 
 
 	// layersPushed is the set of layers known to exist on the remote side.
 	// layersPushed is the set of layers known to exist on the remote side.
 	// This avoids redundant queries when pushing multiple tags that
 	// This avoids redundant queries when pushing multiple tags that
 	// involve the same layers.
 	// involve the same layers.
+	layersPushed pushMap
+}
+
+type pushMap struct {
+	sync.Mutex
 	layersPushed map[digest.Digest]bool
 	layersPushed map[digest.Digest]bool
 }
 }
 
 
-func (p *v2Pusher) Push() (fallback bool, err error) {
+func (p *v2Pusher) Push(ctx context.Context) (fallback bool, err error) {
 	p.repo, err = NewV2Repository(p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "push", "pull")
 	p.repo, err = NewV2Repository(p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "push", "pull")
 	if err != nil {
 	if err != nil {
 		logrus.Debugf("Error getting v2 registry: %v", err)
 		logrus.Debugf("Error getting v2 registry: %v", err)
@@ -75,7 +80,7 @@ func (p *v2Pusher) Push() (fallback bool, err error) {
 	}
 	}
 
 
 	for _, association := range associations {
 	for _, association := range associations {
-		if err := p.pushV2Tag(association); err != nil {
+		if err := p.pushV2Tag(ctx, association); err != nil {
 			return false, err
 			return false, err
 		}
 		}
 	}
 	}
@@ -83,7 +88,7 @@ func (p *v2Pusher) Push() (fallback bool, err error) {
 	return false, nil
 	return false, nil
 }
 }
 
 
-func (p *v2Pusher) pushV2Tag(association tag.Association) error {
+func (p *v2Pusher) pushV2Tag(ctx context.Context, association tag.Association) error {
 	ref := association.Ref
 	ref := association.Ref
 	logrus.Debugf("Pushing repository: %s", ref.String())
 	logrus.Debugf("Pushing repository: %s", ref.String())
 
 
@@ -92,8 +97,6 @@ func (p *v2Pusher) pushV2Tag(association tag.Association) error {
 		return fmt.Errorf("could not find image from tag %s: %v", ref.String(), err)
 		return fmt.Errorf("could not find image from tag %s: %v", ref.String(), err)
 	}
 	}
 
 
-	out := p.config.OutStream
-
 	var l layer.Layer
 	var l layer.Layer
 
 
 	topLayerID := img.RootFS.ChainID()
 	topLayerID := img.RootFS.ChainID()
@@ -107,33 +110,41 @@ func (p *v2Pusher) pushV2Tag(association tag.Association) error {
 		defer layer.ReleaseAndLog(p.config.LayerStore, l)
 		defer layer.ReleaseAndLog(p.config.LayerStore, l)
 	}
 	}
 
 
-	fsLayers := make(map[layer.DiffID]schema1.FSLayer)
+	var descriptors []xfer.UploadDescriptor
 
 
 	// Push empty layer if necessary
 	// Push empty layer if necessary
 	for _, h := range img.History {
 	for _, h := range img.History {
 		if h.EmptyLayer {
 		if h.EmptyLayer {
-			dgst, err := p.pushLayerIfNecessary(out, layer.EmptyLayer)
-			if err != nil {
-				return err
+			descriptors = []xfer.UploadDescriptor{
+				&v2PushDescriptor{
+					layer:          layer.EmptyLayer,
+					blobSumService: p.blobSumService,
+					repo:           p.repo,
+					layersPushed:   &p.layersPushed,
+				},
 			}
 			}
-			p.layersPushed[dgst] = true
-			fsLayers[layer.EmptyLayer.DiffID()] = schema1.FSLayer{BlobSum: dgst}
 			break
 			break
 		}
 		}
 	}
 	}
 
 
+	// Loop bounds condition is to avoid pushing the base layer on Windows.
 	for i := 0; i < len(img.RootFS.DiffIDs); i++ {
 	for i := 0; i < len(img.RootFS.DiffIDs); i++ {
-		dgst, err := p.pushLayerIfNecessary(out, l)
-		if err != nil {
-			return err
+		descriptor := &v2PushDescriptor{
+			layer:          l,
+			blobSumService: p.blobSumService,
+			repo:           p.repo,
+			layersPushed:   &p.layersPushed,
 		}
 		}
-
-		p.layersPushed[dgst] = true
-		fsLayers[l.DiffID()] = schema1.FSLayer{BlobSum: dgst}
+		descriptors = append(descriptors, descriptor)
 
 
 		l = l.Parent()
 		l = l.Parent()
 	}
 	}
 
 
+	fsLayers, err := p.config.UploadManager.Upload(ctx, descriptors, p.config.ProgressOutput)
+	if err != nil {
+		return err
+	}
+
 	var tag string
 	var tag string
 	if tagged, isTagged := ref.(reference.Tagged); isTagged {
 	if tagged, isTagged := ref.(reference.Tagged); isTagged {
 		tag = tagged.Tag()
 		tag = tagged.Tag()
@@ -157,59 +168,124 @@ func (p *v2Pusher) pushV2Tag(association tag.Association) error {
 		if tagged, isTagged := ref.(reference.Tagged); isTagged {
 		if tagged, isTagged := ref.(reference.Tagged); isTagged {
 			// NOTE: do not change this format without first changing the trust client
 			// NOTE: do not change this format without first changing the trust client
 			// code. This information is used to determine what was pushed and should be signed.
 			// code. This information is used to determine what was pushed and should be signed.
-			out.Write(p.sf.FormatStatus("", "%s: digest: %s size: %d", tagged.Tag(), manifestDigest, manifestSize))
+			progress.Messagef(p.config.ProgressOutput, "", "%s: digest: %s size: %d", tagged.Tag(), manifestDigest, manifestSize)
 		}
 		}
 	}
 	}
 
 
-	manSvc, err := p.repo.Manifests(context.Background())
+	manSvc, err := p.repo.Manifests(ctx)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 	return manSvc.Put(signed)
 	return manSvc.Put(signed)
 }
 }
 
 
-func (p *v2Pusher) pushLayerIfNecessary(out io.Writer, l layer.Layer) (digest.Digest, error) {
-	logrus.Debugf("Pushing layer: %s", l.DiffID())
+type v2PushDescriptor struct {
+	layer          layer.Layer
+	blobSumService *metadata.BlobSumService
+	repo           distribution.Repository
+	layersPushed   *pushMap
+}
+
+func (pd *v2PushDescriptor) Key() string {
+	return "v2push:" + pd.repo.Name() + " " + pd.layer.DiffID().String()
+}
+
+func (pd *v2PushDescriptor) ID() string {
+	return stringid.TruncateID(pd.layer.DiffID().String())
+}
+
+func (pd *v2PushDescriptor) DiffID() layer.DiffID {
+	return pd.layer.DiffID()
+}
+
+func (pd *v2PushDescriptor) Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error) {
+	diffID := pd.DiffID()
+
+	logrus.Debugf("Pushing layer: %s", diffID)
 
 
 	// Do we have any blobsums associated with this layer's DiffID?
 	// Do we have any blobsums associated with this layer's DiffID?
-	possibleBlobsums, err := p.blobSumService.GetBlobSums(l.DiffID())
+	possibleBlobsums, err := pd.blobSumService.GetBlobSums(diffID)
 	if err == nil {
 	if err == nil {
-		dgst, exists, err := p.blobSumAlreadyExists(possibleBlobsums)
+		dgst, exists, err := blobSumAlreadyExists(ctx, possibleBlobsums, pd.repo, pd.layersPushed)
 		if err != nil {
 		if err != nil {
-			out.Write(p.sf.FormatProgress(stringid.TruncateID(string(l.DiffID())), "Image push failed", nil))
-			return "", err
+			progress.Update(progressOutput, pd.ID(), "Image push failed")
+			return "", retryOnError(err)
 		}
 		}
 		if exists {
 		if exists {
-			out.Write(p.sf.FormatProgress(stringid.TruncateID(string(l.DiffID())), "Layer already exists", nil))
+			progress.Update(progressOutput, pd.ID(), "Layer already exists")
 			return dgst, nil
 			return dgst, nil
 		}
 		}
 	}
 	}
 
 
 	// if digest was empty or not saved, or if blob does not exist on the remote repository,
 	// if digest was empty or not saved, or if blob does not exist on the remote repository,
 	// then push the blob.
 	// then push the blob.
-	pushDigest, err := p.pushV2Layer(p.repo.Blobs(context.Background()), l)
+	bs := pd.repo.Blobs(ctx)
+
+	// Send the layer
+	layerUpload, err := bs.Create(ctx)
+	if err != nil {
+		return "", retryOnError(err)
+	}
+	defer layerUpload.Close()
+
+	arch, err := pd.layer.TarStream()
 	if err != nil {
 	if err != nil {
-		return "", err
+		return "", xfer.DoNotRetry{Err: err}
 	}
 	}
+
+	// don't care if this fails; best effort
+	size, _ := pd.layer.DiffSize()
+
+	reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, arch), progressOutput, size, pd.ID(), "Pushing")
+	defer reader.Close()
+	compressedReader := compress(reader)
+
+	digester := digest.Canonical.New()
+	tee := io.TeeReader(compressedReader, digester.Hash())
+
+	nn, err := layerUpload.ReadFrom(tee)
+	compressedReader.Close()
+	if err != nil {
+		return "", retryOnError(err)
+	}
+
+	pushDigest := digester.Digest()
+	if _, err := layerUpload.Commit(ctx, distribution.Descriptor{Digest: pushDigest}); err != nil {
+		return "", retryOnError(err)
+	}
+
+	logrus.Debugf("uploaded layer %s (%s), %d bytes", diffID, pushDigest, nn)
+	progress.Update(progressOutput, pd.ID(), "Pushed")
+
 	// Cache mapping from this layer's DiffID to the blobsum
 	// Cache mapping from this layer's DiffID to the blobsum
-	if err := p.blobSumService.Add(l.DiffID(), pushDigest); err != nil {
-		return "", err
+	if err := pd.blobSumService.Add(diffID, pushDigest); err != nil {
+		return "", xfer.DoNotRetry{Err: err}
 	}
 	}
 
 
+	pd.layersPushed.Lock()
+	pd.layersPushed.layersPushed[pushDigest] = true
+	pd.layersPushed.Unlock()
+
 	return pushDigest, nil
 	return pushDigest, nil
 }
 }
 
 
 // blobSumAlreadyExists checks if the registry already know about any of the
 // blobSumAlreadyExists checks if the registry already know about any of the
 // blobsums passed in the "blobsums" slice. If it finds one that the registry
 // blobsums passed in the "blobsums" slice. If it finds one that the registry
 // knows about, it returns the known digest and "true".
 // knows about, it returns the known digest and "true".
-func (p *v2Pusher) blobSumAlreadyExists(blobsums []digest.Digest) (digest.Digest, bool, error) {
+func blobSumAlreadyExists(ctx context.Context, blobsums []digest.Digest, repo distribution.Repository, layersPushed *pushMap) (digest.Digest, bool, error) {
+	layersPushed.Lock()
 	for _, dgst := range blobsums {
 	for _, dgst := range blobsums {
-		if p.layersPushed[dgst] {
+		if layersPushed.layersPushed[dgst] {
 			// it is already known that the push is not needed and
 			// it is already known that the push is not needed and
 			// therefore doing a stat is unnecessary
 			// therefore doing a stat is unnecessary
+			layersPushed.Unlock()
 			return dgst, true, nil
 			return dgst, true, nil
 		}
 		}
-		_, err := p.repo.Blobs(context.Background()).Stat(context.Background(), dgst)
+	}
+	layersPushed.Unlock()
+
+	for _, dgst := range blobsums {
+		_, err := repo.Blobs(ctx).Stat(ctx, dgst)
 		switch err {
 		switch err {
 		case nil:
 		case nil:
 			return dgst, true, nil
 			return dgst, true, nil
@@ -226,7 +302,7 @@ func (p *v2Pusher) blobSumAlreadyExists(blobsums []digest.Digest) (digest.Digest
 // FSLayer digests.
 // FSLayer digests.
 // FIXME: This should be moved to the distribution repo, since it will also
 // FIXME: This should be moved to the distribution repo, since it will also
 // be useful for converting new manifests to the old format.
 // be useful for converting new manifests to the old format.
-func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.DiffID]schema1.FSLayer) (*schema1.Manifest, error) {
+func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.DiffID]digest.Digest) (*schema1.Manifest, error) {
 	if len(img.History) == 0 {
 	if len(img.History) == 0 {
 		return nil, errors.New("empty history when trying to create V2 manifest")
 		return nil, errors.New("empty history when trying to create V2 manifest")
 	}
 	}
@@ -271,7 +347,7 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif
 		if !present {
 		if !present {
 			return nil, fmt.Errorf("missing layer in CreateV2Manifest: %s", diffID.String())
 			return nil, fmt.Errorf("missing layer in CreateV2Manifest: %s", diffID.String())
 		}
 		}
-		dgst, err := digest.FromBytes([]byte(fsLayer.BlobSum.Hex() + " " + parent))
+		dgst, err := digest.FromBytes([]byte(fsLayer.Hex() + " " + parent))
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
@@ -294,7 +370,7 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif
 
 
 		reversedIndex := len(img.History) - i - 1
 		reversedIndex := len(img.History) - i - 1
 		history[reversedIndex].V1Compatibility = string(jsonBytes)
 		history[reversedIndex].V1Compatibility = string(jsonBytes)
-		fsLayerList[reversedIndex] = fsLayer
+		fsLayerList[reversedIndex] = schema1.FSLayer{BlobSum: fsLayer}
 
 
 		parent = v1ID
 		parent = v1ID
 	}
 	}
@@ -315,11 +391,11 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif
 		return nil, fmt.Errorf("missing layer in CreateV2Manifest: %s", diffID.String())
 		return nil, fmt.Errorf("missing layer in CreateV2Manifest: %s", diffID.String())
 	}
 	}
 
 
-	dgst, err := digest.FromBytes([]byte(fsLayer.BlobSum.Hex() + " " + parent + " " + string(img.RawJSON())))
+	dgst, err := digest.FromBytes([]byte(fsLayer.Hex() + " " + parent + " " + string(img.RawJSON())))
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	fsLayerList[0] = fsLayer
+	fsLayerList[0] = schema1.FSLayer{BlobSum: fsLayer}
 
 
 	// Top-level v1compatibility string should be a modified version of the
 	// Top-level v1compatibility string should be a modified version of the
 	// image config.
 	// image config.
@@ -346,66 +422,3 @@ func CreateV2Manifest(name, tag string, img *image.Image, fsLayers map[layer.Dif
 		History:      history,
 		History:      history,
 	}, nil
 	}, nil
 }
 }
-
-func rawJSON(value interface{}) *json.RawMessage {
-	jsonval, err := json.Marshal(value)
-	if err != nil {
-		return nil
-	}
-	return (*json.RawMessage)(&jsonval)
-}
-
-func (p *v2Pusher) pushV2Layer(bs distribution.BlobService, l layer.Layer) (digest.Digest, error) {
-	out := p.config.OutStream
-	displayID := stringid.TruncateID(string(l.DiffID()))
-
-	out.Write(p.sf.FormatProgress(displayID, "Preparing", nil))
-
-	arch, err := l.TarStream()
-	if err != nil {
-		return "", err
-	}
-	defer arch.Close()
-
-	// Send the layer
-	layerUpload, err := bs.Create(context.Background())
-	if err != nil {
-		return "", err
-	}
-	defer layerUpload.Close()
-
-	// don't care if this fails; best effort
-	size, _ := l.DiffSize()
-
-	reader := progressreader.New(progressreader.Config{
-		In:        ioutil.NopCloser(arch), // we'll take care of close here.
-		Out:       out,
-		Formatter: p.sf,
-		Size:      size,
-		NewLines:  false,
-		ID:        displayID,
-		Action:    "Pushing",
-	})
-
-	compressedReader := compress(reader)
-
-	digester := digest.Canonical.New()
-	tee := io.TeeReader(compressedReader, digester.Hash())
-
-	out.Write(p.sf.FormatProgress(displayID, "Pushing", nil))
-	nn, err := layerUpload.ReadFrom(tee)
-	compressedReader.Close()
-	if err != nil {
-		return "", err
-	}
-
-	dgst := digester.Digest()
-	if _, err := layerUpload.Commit(context.Background(), distribution.Descriptor{Digest: dgst}); err != nil {
-		return "", err
-	}
-
-	logrus.Debugf("uploaded layer %s (%s), %d bytes", l.DiffID(), dgst, nn)
-	out.Write(p.sf.FormatProgress(displayID, "Pushed", nil))
-
-	return dgst, nil
-}

+ 4 - 4
distribution/push_v2_test.go

@@ -116,10 +116,10 @@ func TestCreateV2Manifest(t *testing.T) {
 		t.Fatalf("json decoding failed: %v", err)
 		t.Fatalf("json decoding failed: %v", err)
 	}
 	}
 
 
-	fsLayers := map[layer.DiffID]schema1.FSLayer{
-		layer.DiffID("sha256:c6f988f4874bb0add23a778f753c65efe992244e148a1d2ec2a8b664fb66bbd1"): {BlobSum: digest.Digest("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4")},
-		layer.DiffID("sha256:5f70bf18a086007016e948b04aed3b82103a36bea41755b6cddfaf10ace3c6ef"): {BlobSum: digest.Digest("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa")},
-		layer.DiffID("sha256:13f53e08df5a220ab6d13c58b2bf83a59cbdc2e04d0a3f041ddf4b0ba4112d49"): {BlobSum: digest.Digest("sha256:b4ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4")},
+	fsLayers := map[layer.DiffID]digest.Digest{
+		layer.DiffID("sha256:c6f988f4874bb0add23a778f753c65efe992244e148a1d2ec2a8b664fb66bbd1"): digest.Digest("sha256:a3ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"),
+		layer.DiffID("sha256:5f70bf18a086007016e948b04aed3b82103a36bea41755b6cddfaf10ace3c6ef"): digest.Digest("sha256:86e0e091d0da6bde2456dbb48306f3956bbeb2eae1b5b9a43045843f69fe4aaa"),
+		layer.DiffID("sha256:13f53e08df5a220ab6d13c58b2bf83a59cbdc2e04d0a3f041ddf4b0ba4112d49"): digest.Digest("sha256:b4ed95caeb02ffe68cdd9fd84406680ae93d633cb16422d00e8a7c22955b46d4"),
 	}
 	}
 
 
 	manifest, err := CreateV2Manifest("testrepo", "testtag", img, fsLayers)
 	manifest, err := CreateV2Manifest("testrepo", "testtag", img, fsLayers)

+ 23 - 1
distribution/registry.go

@@ -13,10 +13,12 @@ import (
 	"github.com/docker/distribution"
 	"github.com/docker/distribution"
 	"github.com/docker/distribution/digest"
 	"github.com/docker/distribution/digest"
 	"github.com/docker/distribution/manifest/schema1"
 	"github.com/docker/distribution/manifest/schema1"
+	"github.com/docker/distribution/registry/api/errcode"
 	"github.com/docker/distribution/registry/client"
 	"github.com/docker/distribution/registry/client"
 	"github.com/docker/distribution/registry/client/auth"
 	"github.com/docker/distribution/registry/client/auth"
 	"github.com/docker/distribution/registry/client/transport"
 	"github.com/docker/distribution/registry/client/transport"
 	"github.com/docker/docker/cliconfig"
 	"github.com/docker/docker/cliconfig"
+	"github.com/docker/docker/distribution/xfer"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/registry"
 	"golang.org/x/net/context"
 	"golang.org/x/net/context"
 )
 )
@@ -59,7 +61,7 @@ func NewV2Repository(repoInfo *registry.RepositoryInfo, endpoint registry.APIEnd
 	authTransport := transport.NewTransport(base, modifiers...)
 	authTransport := transport.NewTransport(base, modifiers...)
 	pingClient := &http.Client{
 	pingClient := &http.Client{
 		Transport: authTransport,
 		Transport: authTransport,
-		Timeout:   5 * time.Second,
+		Timeout:   15 * time.Second,
 	}
 	}
 	endpointStr := strings.TrimRight(endpoint.URL, "/") + "/v2/"
 	endpointStr := strings.TrimRight(endpoint.URL, "/") + "/v2/"
 	req, err := http.NewRequest("GET", endpointStr, nil)
 	req, err := http.NewRequest("GET", endpointStr, nil)
@@ -132,3 +134,23 @@ func (th *existingTokenHandler) AuthorizeRequest(req *http.Request, params map[s
 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", th.token))
 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", th.token))
 	return nil
 	return nil
 }
 }
+
+// retryOnError wraps the error in xfer.DoNotRetry if we should not retry the
+// operation after this error.
+func retryOnError(err error) error {
+	switch v := err.(type) {
+	case errcode.Errors:
+		return retryOnError(v[0])
+	case errcode.Error:
+		switch v.Code {
+		case errcode.ErrorCodeUnauthorized, errcode.ErrorCodeUnsupported, errcode.ErrorCodeDenied:
+			return xfer.DoNotRetry{Err: err}
+		}
+
+	}
+	// let's be nice and fallback if the error is a completely
+	// unexpected one.
+	// If new errors have to be handled in some way, please
+	// add them to the switch above.
+	return err
+}

+ 3 - 4
distribution/registry_unit_test.go

@@ -11,9 +11,9 @@ import (
 	"github.com/docker/distribution/reference"
 	"github.com/docker/distribution/reference"
 	"github.com/docker/distribution/registry/client/auth"
 	"github.com/docker/distribution/registry/client/auth"
 	"github.com/docker/docker/cliconfig"
 	"github.com/docker/docker/cliconfig"
-	"github.com/docker/docker/pkg/streamformatter"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/registry"
 	"github.com/docker/docker/utils"
 	"github.com/docker/docker/utils"
+	"golang.org/x/net/context"
 )
 )
 
 
 func TestTokenPassThru(t *testing.T) {
 func TestTokenPassThru(t *testing.T) {
@@ -72,8 +72,7 @@ func TestTokenPassThru(t *testing.T) {
 		MetaHeaders: http.Header{},
 		MetaHeaders: http.Header{},
 		AuthConfig:  authConfig,
 		AuthConfig:  authConfig,
 	}
 	}
-	sf := streamformatter.NewJSONStreamFormatter()
-	puller, err := newPuller(endpoint, repoInfo, imagePullConfig, sf)
+	puller, err := newPuller(endpoint, repoInfo, imagePullConfig)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -86,7 +85,7 @@ func TestTokenPassThru(t *testing.T) {
 	logrus.Debug("About to pull")
 	logrus.Debug("About to pull")
 	// We expect it to fail, since we haven't mock'd the full registry exchange in our handler above
 	// We expect it to fail, since we haven't mock'd the full registry exchange in our handler above
 	tag, _ := reference.WithTag(n, "tag_goes_here")
 	tag, _ := reference.WithTag(n, "tag_goes_here")
-	_ = p.pullV2Repository(tag)
+	_ = p.pullV2Repository(context.Background(), tag)
 
 
 	if !gotToken {
 	if !gotToken {
 		t.Fatal("Failed to receive registry token")
 		t.Fatal("Failed to receive registry token")

+ 420 - 0
distribution/xfer/download.go

@@ -0,0 +1,420 @@
+package xfer
+
+import (
+	"errors"
+	"fmt"
+	"io"
+	"time"
+
+	"github.com/Sirupsen/logrus"
+	"github.com/docker/docker/image"
+	"github.com/docker/docker/layer"
+	"github.com/docker/docker/pkg/archive"
+	"github.com/docker/docker/pkg/ioutils"
+	"github.com/docker/docker/pkg/progress"
+	"golang.org/x/net/context"
+)
+
+const maxDownloadAttempts = 5
+
+// LayerDownloadManager figures out which layers need to be downloaded, then
+// registers and downloads those, taking into account dependencies between
+// layers.
+type LayerDownloadManager struct {
+	layerStore layer.Store
+	tm         TransferManager
+}
+
+// NewLayerDownloadManager returns a new LayerDownloadManager.
+func NewLayerDownloadManager(layerStore layer.Store, concurrencyLimit int) *LayerDownloadManager {
+	return &LayerDownloadManager{
+		layerStore: layerStore,
+		tm:         NewTransferManager(concurrencyLimit),
+	}
+}
+
+type downloadTransfer struct {
+	Transfer
+
+	layerStore layer.Store
+	layer      layer.Layer
+	err        error
+}
+
+// result returns the layer resulting from the download, if the download
+// and registration were successful.
+func (d *downloadTransfer) result() (layer.Layer, error) {
+	return d.layer, d.err
+}
+
+// A DownloadDescriptor references a layer that may need to be downloaded.
+type DownloadDescriptor interface {
+	// Key returns the key used to deduplicate downloads.
+	Key() string
+	// ID returns the ID for display purposes.
+	ID() string
+	// DiffID should return the DiffID for this layer, or an error
+	// if it is unknown (for example, if it has not been downloaded
+	// before).
+	DiffID() (layer.DiffID, error)
+	// Download is called to perform the download.
+	Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error)
+}
+
+// DownloadDescriptorWithRegistered is a DownloadDescriptor that has an
+// additional Registered method which gets called after a downloaded layer is
+// registered. This allows the user of the download manager to know the DiffID
+// of each registered layer. This method is called if a cast to
+// DownloadDescriptorWithRegistered is successful.
+type DownloadDescriptorWithRegistered interface {
+	DownloadDescriptor
+	Registered(diffID layer.DiffID)
+}
+
+// Download is a blocking function which ensures the requested layers are
+// present in the layer store. It uses the string returned by the Key method to
+// deduplicate downloads. If a given layer is not already known to present in
+// the layer store, and the key is not used by an in-progress download, the
+// Download method is called to get the layer tar data. Layers are then
+// registered in the appropriate order.  The caller must call the returned
+// release function once it is is done with the returned RootFS object.
+func (ldm *LayerDownloadManager) Download(ctx context.Context, initialRootFS image.RootFS, layers []DownloadDescriptor, progressOutput progress.Output) (image.RootFS, func(), error) {
+	var (
+		topLayer       layer.Layer
+		topDownload    *downloadTransfer
+		watcher        *Watcher
+		missingLayer   bool
+		transferKey    = ""
+		downloadsByKey = make(map[string]*downloadTransfer)
+	)
+
+	rootFS := initialRootFS
+	for _, descriptor := range layers {
+		key := descriptor.Key()
+		transferKey += key
+
+		if !missingLayer {
+			missingLayer = true
+			diffID, err := descriptor.DiffID()
+			if err == nil {
+				getRootFS := rootFS
+				getRootFS.Append(diffID)
+				l, err := ldm.layerStore.Get(getRootFS.ChainID())
+				if err == nil {
+					// Layer already exists.
+					logrus.Debugf("Layer already exists: %s", descriptor.ID())
+					progress.Update(progressOutput, descriptor.ID(), "Already exists")
+					if topLayer != nil {
+						layer.ReleaseAndLog(ldm.layerStore, topLayer)
+					}
+					topLayer = l
+					missingLayer = false
+					rootFS.Append(diffID)
+					continue
+				}
+			}
+		}
+
+		// Does this layer have the same data as a previous layer in
+		// the stack? If so, avoid downloading it more than once.
+		var topDownloadUncasted Transfer
+		if existingDownload, ok := downloadsByKey[key]; ok {
+			xferFunc := ldm.makeDownloadFuncFromDownload(descriptor, existingDownload, topDownload)
+			defer topDownload.Transfer.Release(watcher)
+			topDownloadUncasted, watcher = ldm.tm.Transfer(transferKey, xferFunc, progressOutput)
+			topDownload = topDownloadUncasted.(*downloadTransfer)
+			continue
+		}
+
+		// Layer is not known to exist - download and register it.
+		progress.Update(progressOutput, descriptor.ID(), "Pulling fs layer")
+
+		var xferFunc DoFunc
+		if topDownload != nil {
+			xferFunc = ldm.makeDownloadFunc(descriptor, "", topDownload)
+			defer topDownload.Transfer.Release(watcher)
+		} else {
+			xferFunc = ldm.makeDownloadFunc(descriptor, rootFS.ChainID(), nil)
+		}
+		topDownloadUncasted, watcher = ldm.tm.Transfer(transferKey, xferFunc, progressOutput)
+		topDownload = topDownloadUncasted.(*downloadTransfer)
+		downloadsByKey[key] = topDownload
+	}
+
+	if topDownload == nil {
+		return rootFS, func() { layer.ReleaseAndLog(ldm.layerStore, topLayer) }, nil
+	}
+
+	// Won't be using the list built up so far - will generate it
+	// from downloaded layers instead.
+	rootFS.DiffIDs = []layer.DiffID{}
+
+	defer func() {
+		if topLayer != nil {
+			layer.ReleaseAndLog(ldm.layerStore, topLayer)
+		}
+	}()
+
+	select {
+	case <-ctx.Done():
+		topDownload.Transfer.Release(watcher)
+		return rootFS, func() {}, ctx.Err()
+	case <-topDownload.Done():
+		break
+	}
+
+	l, err := topDownload.result()
+	if err != nil {
+		topDownload.Transfer.Release(watcher)
+		return rootFS, func() {}, err
+	}
+
+	// Must do this exactly len(layers) times, so we don't include the
+	// base layer on Windows.
+	for range layers {
+		if l == nil {
+			topDownload.Transfer.Release(watcher)
+			return rootFS, func() {}, errors.New("internal error: too few parent layers")
+		}
+		rootFS.DiffIDs = append([]layer.DiffID{l.DiffID()}, rootFS.DiffIDs...)
+		l = l.Parent()
+	}
+	return rootFS, func() { topDownload.Transfer.Release(watcher) }, err
+}
+
+// makeDownloadFunc returns a function that performs the layer download and
+// registration. If parentDownload is non-nil, it waits for that download to
+// complete before the registration step, and registers the downloaded data
+// on top of parentDownload's resulting layer. Otherwise, it registers the
+// layer on top of the ChainID given by parentLayer.
+func (ldm *LayerDownloadManager) makeDownloadFunc(descriptor DownloadDescriptor, parentLayer layer.ChainID, parentDownload *downloadTransfer) DoFunc {
+	return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
+		d := &downloadTransfer{
+			Transfer:   NewTransfer(),
+			layerStore: ldm.layerStore,
+		}
+
+		go func() {
+			defer func() {
+				close(progressChan)
+			}()
+
+			progressOutput := progress.ChanOutput(progressChan)
+
+			select {
+			case <-start:
+			default:
+				progress.Update(progressOutput, descriptor.ID(), "Waiting")
+				<-start
+			}
+
+			if parentDownload != nil {
+				// Did the parent download already fail or get
+				// cancelled?
+				select {
+				case <-parentDownload.Done():
+					_, err := parentDownload.result()
+					if err != nil {
+						d.err = err
+						return
+					}
+				default:
+				}
+			}
+
+			var (
+				downloadReader io.ReadCloser
+				size           int64
+				err            error
+				retries        int
+			)
+
+			for {
+				downloadReader, size, err = descriptor.Download(d.Transfer.Context(), progressOutput)
+				if err == nil {
+					break
+				}
+
+				// If an error was returned because the context
+				// was cancelled, we shouldn't retry.
+				select {
+				case <-d.Transfer.Context().Done():
+					d.err = err
+					return
+				default:
+				}
+
+				retries++
+				if _, isDNR := err.(DoNotRetry); isDNR || retries == maxDownloadAttempts {
+					logrus.Errorf("Download failed: %v", err)
+					d.err = err
+					return
+				}
+
+				logrus.Errorf("Download failed, retrying: %v", err)
+				delay := retries * 5
+				ticker := time.NewTicker(time.Second)
+
+			selectLoop:
+				for {
+					progress.Updatef(progressOutput, descriptor.ID(), "Retrying in %d seconds", delay)
+					select {
+					case <-ticker.C:
+						delay--
+						if delay == 0 {
+							ticker.Stop()
+							break selectLoop
+						}
+					case <-d.Transfer.Context().Done():
+						ticker.Stop()
+						d.err = errors.New("download cancelled during retry delay")
+						return
+					}
+
+				}
+			}
+
+			close(inactive)
+
+			if parentDownload != nil {
+				select {
+				case <-d.Transfer.Context().Done():
+					d.err = errors.New("layer registration cancelled")
+					downloadReader.Close()
+					return
+				case <-parentDownload.Done():
+				}
+
+				l, err := parentDownload.result()
+				if err != nil {
+					d.err = err
+					downloadReader.Close()
+					return
+				}
+				parentLayer = l.ChainID()
+			}
+
+			reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(d.Transfer.Context(), downloadReader), progressOutput, size, descriptor.ID(), "Extracting")
+			defer reader.Close()
+
+			inflatedLayerData, err := archive.DecompressStream(reader)
+			if err != nil {
+				d.err = fmt.Errorf("could not get decompression stream: %v", err)
+				return
+			}
+
+			d.layer, err = d.layerStore.Register(inflatedLayerData, parentLayer)
+			if err != nil {
+				select {
+				case <-d.Transfer.Context().Done():
+					d.err = errors.New("layer registration cancelled")
+				default:
+					d.err = fmt.Errorf("failed to register layer: %v", err)
+				}
+				return
+			}
+
+			progress.Update(progressOutput, descriptor.ID(), "Pull complete")
+			withRegistered, hasRegistered := descriptor.(DownloadDescriptorWithRegistered)
+			if hasRegistered {
+				withRegistered.Registered(d.layer.DiffID())
+			}
+
+			// Doesn't actually need to be its own goroutine, but
+			// done like this so we can defer close(c).
+			go func() {
+				<-d.Transfer.Released()
+				if d.layer != nil {
+					layer.ReleaseAndLog(d.layerStore, d.layer)
+				}
+			}()
+		}()
+
+		return d
+	}
+}
+
+// makeDownloadFuncFromDownload returns a function that performs the layer
+// registration when the layer data is coming from an existing download. It
+// waits for sourceDownload and parentDownload to complete, and then
+// reregisters the data from sourceDownload's top layer on top of
+// parentDownload. This function does not log progress output because it would
+// interfere with the progress reporting for sourceDownload, which has the same
+// Key.
+func (ldm *LayerDownloadManager) makeDownloadFuncFromDownload(descriptor DownloadDescriptor, sourceDownload *downloadTransfer, parentDownload *downloadTransfer) DoFunc {
+	return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
+		d := &downloadTransfer{
+			Transfer:   NewTransfer(),
+			layerStore: ldm.layerStore,
+		}
+
+		go func() {
+			defer func() {
+				close(progressChan)
+			}()
+
+			<-start
+
+			close(inactive)
+
+			select {
+			case <-d.Transfer.Context().Done():
+				d.err = errors.New("layer registration cancelled")
+				return
+			case <-parentDownload.Done():
+			}
+
+			l, err := parentDownload.result()
+			if err != nil {
+				d.err = err
+				return
+			}
+			parentLayer := l.ChainID()
+
+			// sourceDownload should have already finished if
+			// parentDownload finished, but wait for it explicitly
+			// to be sure.
+			select {
+			case <-d.Transfer.Context().Done():
+				d.err = errors.New("layer registration cancelled")
+				return
+			case <-sourceDownload.Done():
+			}
+
+			l, err = sourceDownload.result()
+			if err != nil {
+				d.err = err
+				return
+			}
+
+			layerReader, err := l.TarStream()
+			if err != nil {
+				d.err = err
+				return
+			}
+			defer layerReader.Close()
+
+			d.layer, err = d.layerStore.Register(layerReader, parentLayer)
+			if err != nil {
+				d.err = fmt.Errorf("failed to register layer: %v", err)
+				return
+			}
+
+			withRegistered, hasRegistered := descriptor.(DownloadDescriptorWithRegistered)
+			if hasRegistered {
+				withRegistered.Registered(d.layer.DiffID())
+			}
+
+			// Doesn't actually need to be its own goroutine, but
+			// done like this so we can defer close(c).
+			go func() {
+				<-d.Transfer.Released()
+				if d.layer != nil {
+					layer.ReleaseAndLog(d.layerStore, d.layer)
+				}
+			}()
+		}()
+
+		return d
+	}
+}

+ 332 - 0
distribution/xfer/download_test.go

@@ -0,0 +1,332 @@
+package xfer
+
+import (
+	"bytes"
+	"errors"
+	"io"
+	"io/ioutil"
+	"sync/atomic"
+	"testing"
+	"time"
+
+	"github.com/docker/distribution/digest"
+	"github.com/docker/docker/image"
+	"github.com/docker/docker/layer"
+	"github.com/docker/docker/pkg/archive"
+	"github.com/docker/docker/pkg/progress"
+	"golang.org/x/net/context"
+)
+
+const maxDownloadConcurrency = 3
+
+type mockLayer struct {
+	layerData bytes.Buffer
+	diffID    layer.DiffID
+	chainID   layer.ChainID
+	parent    layer.Layer
+}
+
+func (ml *mockLayer) TarStream() (io.ReadCloser, error) {
+	return ioutil.NopCloser(bytes.NewBuffer(ml.layerData.Bytes())), nil
+}
+
+func (ml *mockLayer) ChainID() layer.ChainID {
+	return ml.chainID
+}
+
+func (ml *mockLayer) DiffID() layer.DiffID {
+	return ml.diffID
+}
+
+func (ml *mockLayer) Parent() layer.Layer {
+	return ml.parent
+}
+
+func (ml *mockLayer) Size() (size int64, err error) {
+	return 0, nil
+}
+
+func (ml *mockLayer) DiffSize() (size int64, err error) {
+	return 0, nil
+}
+
+func (ml *mockLayer) Metadata() (map[string]string, error) {
+	return make(map[string]string), nil
+}
+
+type mockLayerStore struct {
+	layers map[layer.ChainID]*mockLayer
+}
+
+func createChainIDFromParent(parent layer.ChainID, dgsts ...layer.DiffID) layer.ChainID {
+	if len(dgsts) == 0 {
+		return parent
+	}
+	if parent == "" {
+		return createChainIDFromParent(layer.ChainID(dgsts[0]), dgsts[1:]...)
+	}
+	// H = "H(n-1) SHA256(n)"
+	dgst, err := digest.FromBytes([]byte(string(parent) + " " + string(dgsts[0])))
+	if err != nil {
+		// Digest calculation is not expected to throw an error,
+		// any error at this point is a program error
+		panic(err)
+	}
+	return createChainIDFromParent(layer.ChainID(dgst), dgsts[1:]...)
+}
+
+func (ls *mockLayerStore) Register(reader io.Reader, parentID layer.ChainID) (layer.Layer, error) {
+	var (
+		parent layer.Layer
+		err    error
+	)
+
+	if parentID != "" {
+		parent, err = ls.Get(parentID)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	l := &mockLayer{parent: parent}
+	_, err = l.layerData.ReadFrom(reader)
+	if err != nil {
+		return nil, err
+	}
+	diffID, err := digest.FromBytes(l.layerData.Bytes())
+	if err != nil {
+		return nil, err
+	}
+	l.diffID = layer.DiffID(diffID)
+	l.chainID = createChainIDFromParent(parentID, l.diffID)
+
+	ls.layers[l.chainID] = l
+	return l, nil
+}
+
+func (ls *mockLayerStore) Get(chainID layer.ChainID) (layer.Layer, error) {
+	l, ok := ls.layers[chainID]
+	if !ok {
+		return nil, layer.ErrLayerDoesNotExist
+	}
+	return l, nil
+}
+
+func (ls *mockLayerStore) Release(l layer.Layer) ([]layer.Metadata, error) {
+	return []layer.Metadata{}, nil
+}
+
+func (ls *mockLayerStore) Mount(id string, parent layer.ChainID, label string, init layer.MountInit) (layer.RWLayer, error) {
+	return nil, errors.New("not implemented")
+}
+
+func (ls *mockLayerStore) Unmount(id string) error {
+	return errors.New("not implemented")
+}
+
+func (ls *mockLayerStore) DeleteMount(id string) ([]layer.Metadata, error) {
+	return nil, errors.New("not implemented")
+}
+
+func (ls *mockLayerStore) Changes(id string) ([]archive.Change, error) {
+	return nil, errors.New("not implemented")
+}
+
+type mockDownloadDescriptor struct {
+	currentDownloads *int32
+	id               string
+	diffID           layer.DiffID
+	registeredDiffID layer.DiffID
+	expectedDiffID   layer.DiffID
+	simulateRetries  int
+}
+
+// Key returns the key used to deduplicate downloads.
+func (d *mockDownloadDescriptor) Key() string {
+	return d.id
+}
+
+// ID returns the ID for display purposes.
+func (d *mockDownloadDescriptor) ID() string {
+	return d.id
+}
+
+// DiffID should return the DiffID for this layer, or an error
+// if it is unknown (for example, if it has not been downloaded
+// before).
+func (d *mockDownloadDescriptor) DiffID() (layer.DiffID, error) {
+	if d.diffID != "" {
+		return d.diffID, nil
+	}
+	return "", errors.New("no diffID available")
+}
+
+func (d *mockDownloadDescriptor) Registered(diffID layer.DiffID) {
+	d.registeredDiffID = diffID
+}
+
+func (d *mockDownloadDescriptor) mockTarStream() io.ReadCloser {
+	// The mock implementation returns the ID repeated 5 times as a tar
+	// stream instead of actual tar data. The data is ignored except for
+	// computing IDs.
+	return ioutil.NopCloser(bytes.NewBuffer([]byte(d.id + d.id + d.id + d.id + d.id)))
+}
+
+// Download is called to perform the download.
+func (d *mockDownloadDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) {
+	if d.currentDownloads != nil {
+		defer atomic.AddInt32(d.currentDownloads, -1)
+
+		if atomic.AddInt32(d.currentDownloads, 1) > maxDownloadConcurrency {
+			return nil, 0, errors.New("concurrency limit exceeded")
+		}
+	}
+
+	// Sleep a bit to simulate a time-consuming download.
+	for i := int64(0); i <= 10; i++ {
+		select {
+		case <-ctx.Done():
+			return nil, 0, ctx.Err()
+		case <-time.After(10 * time.Millisecond):
+			progressOutput.WriteProgress(progress.Progress{ID: d.ID(), Action: "Downloading", Current: i, Total: 10})
+		}
+	}
+
+	if d.simulateRetries != 0 {
+		d.simulateRetries--
+		return nil, 0, errors.New("simulating retry")
+	}
+
+	return d.mockTarStream(), 0, nil
+}
+
+func downloadDescriptors(currentDownloads *int32) []DownloadDescriptor {
+	return []DownloadDescriptor{
+		&mockDownloadDescriptor{
+			currentDownloads: currentDownloads,
+			id:               "id1",
+			expectedDiffID:   layer.DiffID("sha256:68e2c75dc5c78ea9240689c60d7599766c213ae210434c53af18470ae8c53ec1"),
+		},
+		&mockDownloadDescriptor{
+			currentDownloads: currentDownloads,
+			id:               "id2",
+			expectedDiffID:   layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"),
+		},
+		&mockDownloadDescriptor{
+			currentDownloads: currentDownloads,
+			id:               "id3",
+			expectedDiffID:   layer.DiffID("sha256:58745a8bbd669c25213e9de578c4da5c8ee1c836b3581432c2b50e38a6753300"),
+		},
+		&mockDownloadDescriptor{
+			currentDownloads: currentDownloads,
+			id:               "id2",
+			expectedDiffID:   layer.DiffID("sha256:64a636223116aa837973a5d9c2bdd17d9b204e4f95ac423e20e65dfbb3655473"),
+		},
+		&mockDownloadDescriptor{
+			currentDownloads: currentDownloads,
+			id:               "id4",
+			expectedDiffID:   layer.DiffID("sha256:0dfb5b9577716cc173e95af7c10289322c29a6453a1718addc00c0c5b1330936"),
+			simulateRetries:  1,
+		},
+		&mockDownloadDescriptor{
+			currentDownloads: currentDownloads,
+			id:               "id5",
+			expectedDiffID:   layer.DiffID("sha256:0a5f25fa1acbc647f6112a6276735d0fa01e4ee2aa7ec33015e337350e1ea23d"),
+		},
+	}
+}
+
+func TestSuccessfulDownload(t *testing.T) {
+	layerStore := &mockLayerStore{make(map[layer.ChainID]*mockLayer)}
+	ldm := NewLayerDownloadManager(layerStore, maxDownloadConcurrency)
+
+	progressChan := make(chan progress.Progress)
+	progressDone := make(chan struct{})
+	receivedProgress := make(map[string]int64)
+
+	go func() {
+		for p := range progressChan {
+			if p.Action == "Downloading" {
+				receivedProgress[p.ID] = p.Current
+			} else if p.Action == "Already exists" {
+				receivedProgress[p.ID] = -1
+			}
+		}
+		close(progressDone)
+	}()
+
+	var currentDownloads int32
+	descriptors := downloadDescriptors(&currentDownloads)
+
+	firstDescriptor := descriptors[0].(*mockDownloadDescriptor)
+
+	// Pre-register the first layer to simulate an already-existing layer
+	l, err := layerStore.Register(firstDescriptor.mockTarStream(), "")
+	if err != nil {
+		t.Fatal(err)
+	}
+	firstDescriptor.diffID = l.DiffID()
+
+	rootFS, releaseFunc, err := ldm.Download(context.Background(), *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan))
+	if err != nil {
+		t.Fatalf("download error: %v", err)
+	}
+
+	releaseFunc()
+
+	close(progressChan)
+	<-progressDone
+
+	if len(rootFS.DiffIDs) != len(descriptors) {
+		t.Fatal("got wrong number of diffIDs in rootfs")
+	}
+
+	for i, d := range descriptors {
+		descriptor := d.(*mockDownloadDescriptor)
+
+		if descriptor.diffID != "" {
+			if receivedProgress[d.ID()] != -1 {
+				t.Fatalf("did not get 'already exists' message for %v", d.ID())
+			}
+		} else if receivedProgress[d.ID()] != 10 {
+			t.Fatalf("missing or wrong progress output for %v (got: %d)", d.ID(), receivedProgress[d.ID()])
+		}
+
+		if rootFS.DiffIDs[i] != descriptor.expectedDiffID {
+			t.Fatalf("rootFS item %d has the wrong diffID (expected: %v got: %v)", i, descriptor.expectedDiffID, rootFS.DiffIDs[i])
+		}
+
+		if descriptor.diffID == "" && descriptor.registeredDiffID != rootFS.DiffIDs[i] {
+			t.Fatal("diffID mismatch between rootFS and Registered callback")
+		}
+	}
+}
+
+func TestCancelledDownload(t *testing.T) {
+	ldm := NewLayerDownloadManager(&mockLayerStore{make(map[layer.ChainID]*mockLayer)}, maxDownloadConcurrency)
+
+	progressChan := make(chan progress.Progress)
+	progressDone := make(chan struct{})
+
+	go func() {
+		for range progressChan {
+		}
+		close(progressDone)
+	}()
+
+	ctx, cancel := context.WithCancel(context.Background())
+
+	go func() {
+		<-time.After(time.Millisecond)
+		cancel()
+	}()
+
+	descriptors := downloadDescriptors(nil)
+	_, _, err := ldm.Download(ctx, *image.NewRootFS(), descriptors, progress.ChanOutput(progressChan))
+	if err != context.Canceled {
+		t.Fatal("expected download to be cancelled")
+	}
+
+	close(progressChan)
+	<-progressDone
+}

+ 343 - 0
distribution/xfer/transfer.go

@@ -0,0 +1,343 @@
+package xfer
+
+import (
+	"sync"
+
+	"github.com/docker/docker/pkg/progress"
+	"golang.org/x/net/context"
+)
+
+// DoNotRetry is an error wrapper indicating that the error cannot be resolved
+// with a retry.
+type DoNotRetry struct {
+	Err error
+}
+
+// Error returns the stringified representation of the encapsulated error.
+func (e DoNotRetry) Error() string {
+	return e.Err.Error()
+}
+
+// Watcher is returned by Watch and can be passed to Release to stop watching.
+type Watcher struct {
+	// signalChan is used to signal to the watcher goroutine that
+	// new progress information is available, or that the transfer
+	// has finished.
+	signalChan chan struct{}
+	// releaseChan signals to the watcher goroutine that the watcher
+	// should be detached.
+	releaseChan chan struct{}
+	// running remains open as long as the watcher is watching the
+	// transfer. It gets closed if the transfer finishes or the
+	// watcher is detached.
+	running chan struct{}
+}
+
+// Transfer represents an in-progress transfer.
+type Transfer interface {
+	Watch(progressOutput progress.Output) *Watcher
+	Release(*Watcher)
+	Context() context.Context
+	Cancel()
+	Done() <-chan struct{}
+	Released() <-chan struct{}
+	Broadcast(masterProgressChan <-chan progress.Progress)
+}
+
+type transfer struct {
+	mu sync.Mutex
+
+	ctx    context.Context
+	cancel context.CancelFunc
+
+	// watchers keeps track of the goroutines monitoring progress output,
+	// indexed by the channels that release them.
+	watchers map[chan struct{}]*Watcher
+
+	// lastProgress is the most recently received progress event.
+	lastProgress progress.Progress
+	// hasLastProgress is true when lastProgress has been set.
+	hasLastProgress bool
+
+	// running remains open as long as the transfer is in progress.
+	running chan struct{}
+	// hasWatchers stays open until all watchers release the trasnfer.
+	hasWatchers chan struct{}
+
+	// broadcastDone is true if the master progress channel has closed.
+	broadcastDone bool
+	// broadcastSyncChan allows watchers to "ping" the broadcasting
+	// goroutine to wait for it for deplete its input channel. This ensures
+	// a detaching watcher won't miss an event that was sent before it
+	// started detaching.
+	broadcastSyncChan chan struct{}
+}
+
+// NewTransfer creates a new transfer.
+func NewTransfer() Transfer {
+	t := &transfer{
+		watchers:          make(map[chan struct{}]*Watcher),
+		running:           make(chan struct{}),
+		hasWatchers:       make(chan struct{}),
+		broadcastSyncChan: make(chan struct{}),
+	}
+
+	// This uses context.Background instead of a caller-supplied context
+	// so that a transfer won't be cancelled automatically if the client
+	// which requested it is ^C'd (there could be other viewers).
+	t.ctx, t.cancel = context.WithCancel(context.Background())
+
+	return t
+}
+
+// Broadcast copies the progress and error output to all viewers.
+func (t *transfer) Broadcast(masterProgressChan <-chan progress.Progress) {
+	for {
+		var (
+			p  progress.Progress
+			ok bool
+		)
+		select {
+		case p, ok = <-masterProgressChan:
+		default:
+			// We've depleted the channel, so now we can handle
+			// reads on broadcastSyncChan to let detaching watchers
+			// know we're caught up.
+			select {
+			case <-t.broadcastSyncChan:
+				continue
+			case p, ok = <-masterProgressChan:
+			}
+		}
+
+		t.mu.Lock()
+		if ok {
+			t.lastProgress = p
+			t.hasLastProgress = true
+			for _, w := range t.watchers {
+				select {
+				case w.signalChan <- struct{}{}:
+				default:
+				}
+			}
+
+		} else {
+			t.broadcastDone = true
+		}
+		t.mu.Unlock()
+		if !ok {
+			close(t.running)
+			return
+		}
+	}
+}
+
+// Watch adds a watcher to the transfer. The supplied channel gets progress
+// updates and is closed when the transfer finishes.
+func (t *transfer) Watch(progressOutput progress.Output) *Watcher {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+
+	w := &Watcher{
+		releaseChan: make(chan struct{}),
+		signalChan:  make(chan struct{}),
+		running:     make(chan struct{}),
+	}
+
+	if t.broadcastDone {
+		close(w.running)
+		return w
+	}
+
+	t.watchers[w.releaseChan] = w
+
+	go func() {
+		defer func() {
+			close(w.running)
+		}()
+		done := false
+		for {
+			t.mu.Lock()
+			hasLastProgress := t.hasLastProgress
+			lastProgress := t.lastProgress
+			t.mu.Unlock()
+
+			// This might write the last progress item a
+			// second time (since channel closure also gets
+			// us here), but that's fine.
+			if hasLastProgress {
+				progressOutput.WriteProgress(lastProgress)
+			}
+
+			if done {
+				return
+			}
+
+			select {
+			case <-w.signalChan:
+			case <-w.releaseChan:
+				done = true
+				// Since the watcher is going to detach, make
+				// sure the broadcaster is caught up so we
+				// don't miss anything.
+				select {
+				case t.broadcastSyncChan <- struct{}{}:
+				case <-t.running:
+				}
+			case <-t.running:
+				done = true
+			}
+		}
+	}()
+
+	return w
+}
+
+// Release is the inverse of Watch; indicating that the watcher no longer wants
+// to be notified about the progress of the transfer. All calls to Watch must
+// be paired with later calls to Release so that the lifecycle of the transfer
+// is properly managed.
+func (t *transfer) Release(watcher *Watcher) {
+	t.mu.Lock()
+	delete(t.watchers, watcher.releaseChan)
+
+	if len(t.watchers) == 0 {
+		close(t.hasWatchers)
+		t.cancel()
+	}
+	t.mu.Unlock()
+
+	close(watcher.releaseChan)
+	// Block until the watcher goroutine completes
+	<-watcher.running
+}
+
+// Done returns a channel which is closed if the transfer completes or is
+// cancelled. Note that having 0 watchers causes a transfer to be cancelled.
+func (t *transfer) Done() <-chan struct{} {
+	// Note that this doesn't return t.ctx.Done() because that channel will
+	// be closed the moment Cancel is called, and we need to return a
+	// channel that blocks until a cancellation is actually acknowledged by
+	// the transfer function.
+	return t.running
+}
+
+// Released returns a channel which is closed once all watchers release the
+// transfer.
+func (t *transfer) Released() <-chan struct{} {
+	return t.hasWatchers
+}
+
+// Context returns the context associated with the transfer.
+func (t *transfer) Context() context.Context {
+	return t.ctx
+}
+
+// Cancel cancels the context associated with the transfer.
+func (t *transfer) Cancel() {
+	t.cancel()
+}
+
+// DoFunc is a function called by the transfer manager to actually perform
+// a transfer. It should be non-blocking. It should wait until the start channel
+// is closed before transfering any data. If the function closes inactive, that
+// signals to the transfer manager that the job is no longer actively moving
+// data - for example, it may be waiting for a dependent tranfer to finish.
+// This prevents it from taking up a slot.
+type DoFunc func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer
+
+// TransferManager is used by LayerDownloadManager and LayerUploadManager to
+// schedule and deduplicate transfers. It is up to the TransferManager
+// implementation to make the scheduling and concurrency decisions.
+type TransferManager interface {
+	// Transfer checks if a transfer with the given key is in progress. If
+	// so, it returns progress and error output from that transfer.
+	// Otherwise, it will call xferFunc to initiate the transfer.
+	Transfer(key string, xferFunc DoFunc, progressOutput progress.Output) (Transfer, *Watcher)
+}
+
+type transferManager struct {
+	mu sync.Mutex
+
+	concurrencyLimit int
+	activeTransfers  int
+	transfers        map[string]Transfer
+	waitingTransfers []chan struct{}
+}
+
+// NewTransferManager returns a new TransferManager.
+func NewTransferManager(concurrencyLimit int) TransferManager {
+	return &transferManager{
+		concurrencyLimit: concurrencyLimit,
+		transfers:        make(map[string]Transfer),
+	}
+}
+
+// Transfer checks if a transfer matching the given key is in progress. If not,
+// it starts one by calling xferFunc. The caller supplies a channel which
+// receives progress output from the transfer.
+func (tm *transferManager) Transfer(key string, xferFunc DoFunc, progressOutput progress.Output) (Transfer, *Watcher) {
+	tm.mu.Lock()
+	defer tm.mu.Unlock()
+
+	if xfer, present := tm.transfers[key]; present {
+		// Transfer is already in progress.
+		watcher := xfer.Watch(progressOutput)
+		return xfer, watcher
+	}
+
+	start := make(chan struct{})
+	inactive := make(chan struct{})
+
+	if tm.activeTransfers < tm.concurrencyLimit {
+		close(start)
+		tm.activeTransfers++
+	} else {
+		tm.waitingTransfers = append(tm.waitingTransfers, start)
+	}
+
+	masterProgressChan := make(chan progress.Progress)
+	xfer := xferFunc(masterProgressChan, start, inactive)
+	watcher := xfer.Watch(progressOutput)
+	go xfer.Broadcast(masterProgressChan)
+	tm.transfers[key] = xfer
+
+	// When the transfer is finished, remove from the map.
+	go func() {
+		for {
+			select {
+			case <-inactive:
+				tm.mu.Lock()
+				tm.inactivate(start)
+				tm.mu.Unlock()
+				inactive = nil
+			case <-xfer.Done():
+				tm.mu.Lock()
+				if inactive != nil {
+					tm.inactivate(start)
+				}
+				delete(tm.transfers, key)
+				tm.mu.Unlock()
+				return
+			}
+		}
+	}()
+
+	return xfer, watcher
+}
+
+func (tm *transferManager) inactivate(start chan struct{}) {
+	// If the transfer was started, remove it from the activeTransfers
+	// count.
+	select {
+	case <-start:
+		// Start next transfer if any are waiting
+		if len(tm.waitingTransfers) != 0 {
+			close(tm.waitingTransfers[0])
+			tm.waitingTransfers = tm.waitingTransfers[1:]
+		} else {
+			tm.activeTransfers--
+		}
+	default:
+	}
+}

+ 385 - 0
distribution/xfer/transfer_test.go

@@ -0,0 +1,385 @@
+package xfer
+
+import (
+	"sync/atomic"
+	"testing"
+	"time"
+
+	"github.com/docker/docker/pkg/progress"
+)
+
+func TestTransfer(t *testing.T) {
+	makeXferFunc := func(id string) DoFunc {
+		return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
+			select {
+			case <-start:
+			default:
+				t.Fatalf("transfer function not started even though concurrency limit not reached")
+			}
+
+			xfer := NewTransfer()
+			go func() {
+				for i := 0; i <= 10; i++ {
+					progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10}
+					time.Sleep(10 * time.Millisecond)
+				}
+				close(progressChan)
+			}()
+			return xfer
+		}
+	}
+
+	tm := NewTransferManager(5)
+	progressChan := make(chan progress.Progress)
+	progressDone := make(chan struct{})
+	receivedProgress := make(map[string]int64)
+
+	go func() {
+		for p := range progressChan {
+			val, present := receivedProgress[p.ID]
+			if !present {
+				if p.Current != 0 {
+					t.Fatalf("got unexpected progress value: %d (expected 0)", p.Current)
+				}
+			} else if p.Current == 10 {
+				// Special case: last progress output may be
+				// repeated because the transfer finishing
+				// causes the latest progress output to be
+				// written to the channel (in case the watcher
+				// missed it).
+				if p.Current != 9 && p.Current != 10 {
+					t.Fatalf("got unexpected progress value: %d (expected %d)", p.Current, val+1)
+				}
+			} else if p.Current != val+1 {
+				t.Fatalf("got unexpected progress value: %d (expected %d)", p.Current, val+1)
+			}
+			receivedProgress[p.ID] = p.Current
+		}
+		close(progressDone)
+	}()
+
+	// Start a few transfers
+	ids := []string{"id1", "id2", "id3"}
+	xfers := make([]Transfer, len(ids))
+	watchers := make([]*Watcher, len(ids))
+	for i, id := range ids {
+		xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan))
+	}
+
+	for i, xfer := range xfers {
+		<-xfer.Done()
+		xfer.Release(watchers[i])
+	}
+	close(progressChan)
+	<-progressDone
+
+	for _, id := range ids {
+		if receivedProgress[id] != 10 {
+			t.Fatalf("final progress value %d instead of 10", receivedProgress[id])
+		}
+	}
+}
+
+func TestConcurrencyLimit(t *testing.T) {
+	concurrencyLimit := 3
+	var runningJobs int32
+
+	makeXferFunc := func(id string) DoFunc {
+		return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
+			xfer := NewTransfer()
+			go func() {
+				<-start
+				totalJobs := atomic.AddInt32(&runningJobs, 1)
+				if int(totalJobs) > concurrencyLimit {
+					t.Fatalf("too many jobs running")
+				}
+				for i := 0; i <= 10; i++ {
+					progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10}
+					time.Sleep(10 * time.Millisecond)
+				}
+				atomic.AddInt32(&runningJobs, -1)
+				close(progressChan)
+			}()
+			return xfer
+		}
+	}
+
+	tm := NewTransferManager(concurrencyLimit)
+	progressChan := make(chan progress.Progress)
+	progressDone := make(chan struct{})
+	receivedProgress := make(map[string]int64)
+
+	go func() {
+		for p := range progressChan {
+			receivedProgress[p.ID] = p.Current
+		}
+		close(progressDone)
+	}()
+
+	// Start more transfers than the concurrency limit
+	ids := []string{"id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"}
+	xfers := make([]Transfer, len(ids))
+	watchers := make([]*Watcher, len(ids))
+	for i, id := range ids {
+		xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan))
+	}
+
+	for i, xfer := range xfers {
+		<-xfer.Done()
+		xfer.Release(watchers[i])
+	}
+	close(progressChan)
+	<-progressDone
+
+	for _, id := range ids {
+		if receivedProgress[id] != 10 {
+			t.Fatalf("final progress value %d instead of 10", receivedProgress[id])
+		}
+	}
+}
+
+func TestInactiveJobs(t *testing.T) {
+	concurrencyLimit := 3
+	var runningJobs int32
+	testDone := make(chan struct{})
+
+	makeXferFunc := func(id string) DoFunc {
+		return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
+			xfer := NewTransfer()
+			go func() {
+				<-start
+				totalJobs := atomic.AddInt32(&runningJobs, 1)
+				if int(totalJobs) > concurrencyLimit {
+					t.Fatalf("too many jobs running")
+				}
+				for i := 0; i <= 10; i++ {
+					progressChan <- progress.Progress{ID: id, Action: "testing", Current: int64(i), Total: 10}
+					time.Sleep(10 * time.Millisecond)
+				}
+				atomic.AddInt32(&runningJobs, -1)
+				close(inactive)
+				<-testDone
+				close(progressChan)
+			}()
+			return xfer
+		}
+	}
+
+	tm := NewTransferManager(concurrencyLimit)
+	progressChan := make(chan progress.Progress)
+	progressDone := make(chan struct{})
+	receivedProgress := make(map[string]int64)
+
+	go func() {
+		for p := range progressChan {
+			receivedProgress[p.ID] = p.Current
+		}
+		close(progressDone)
+	}()
+
+	// Start more transfers than the concurrency limit
+	ids := []string{"id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"}
+	xfers := make([]Transfer, len(ids))
+	watchers := make([]*Watcher, len(ids))
+	for i, id := range ids {
+		xfers[i], watchers[i] = tm.Transfer(id, makeXferFunc(id), progress.ChanOutput(progressChan))
+	}
+
+	close(testDone)
+	for i, xfer := range xfers {
+		<-xfer.Done()
+		xfer.Release(watchers[i])
+	}
+	close(progressChan)
+	<-progressDone
+
+	for _, id := range ids {
+		if receivedProgress[id] != 10 {
+			t.Fatalf("final progress value %d instead of 10", receivedProgress[id])
+		}
+	}
+}
+
+func TestWatchRelease(t *testing.T) {
+	ready := make(chan struct{})
+
+	makeXferFunc := func(id string) DoFunc {
+		return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
+			xfer := NewTransfer()
+			go func() {
+				defer func() {
+					close(progressChan)
+				}()
+				<-ready
+				for i := int64(0); ; i++ {
+					select {
+					case <-time.After(10 * time.Millisecond):
+					case <-xfer.Context().Done():
+						return
+					}
+					progressChan <- progress.Progress{ID: id, Action: "testing", Current: i, Total: 10}
+				}
+			}()
+			return xfer
+		}
+	}
+
+	tm := NewTransferManager(5)
+
+	type watcherInfo struct {
+		watcher               *Watcher
+		progressChan          chan progress.Progress
+		progressDone          chan struct{}
+		receivedFirstProgress chan struct{}
+	}
+
+	progressConsumer := func(w watcherInfo) {
+		first := true
+		for range w.progressChan {
+			if first {
+				close(w.receivedFirstProgress)
+			}
+			first = false
+		}
+		close(w.progressDone)
+	}
+
+	// Start a transfer
+	watchers := make([]watcherInfo, 5)
+	var xfer Transfer
+	watchers[0].progressChan = make(chan progress.Progress)
+	watchers[0].progressDone = make(chan struct{})
+	watchers[0].receivedFirstProgress = make(chan struct{})
+	xfer, watchers[0].watcher = tm.Transfer("id1", makeXferFunc("id1"), progress.ChanOutput(watchers[0].progressChan))
+	go progressConsumer(watchers[0])
+
+	// Give it multiple watchers
+	for i := 1; i != len(watchers); i++ {
+		watchers[i].progressChan = make(chan progress.Progress)
+		watchers[i].progressDone = make(chan struct{})
+		watchers[i].receivedFirstProgress = make(chan struct{})
+		watchers[i].watcher = xfer.Watch(progress.ChanOutput(watchers[i].progressChan))
+		go progressConsumer(watchers[i])
+	}
+
+	// Now that the watchers are set up, allow the transfer goroutine to
+	// proceed.
+	close(ready)
+
+	// Confirm that each watcher gets progress output.
+	for _, w := range watchers {
+		<-w.receivedFirstProgress
+	}
+
+	// Release one watcher every 5ms
+	for _, w := range watchers {
+		xfer.Release(w.watcher)
+		<-time.After(5 * time.Millisecond)
+	}
+
+	// Now that all watchers have been released, Released() should
+	// return a closed channel.
+	<-xfer.Released()
+
+	// Done() should return a closed channel because the xfer func returned
+	// due to cancellation.
+	<-xfer.Done()
+
+	for _, w := range watchers {
+		close(w.progressChan)
+		<-w.progressDone
+	}
+}
+
+func TestDuplicateTransfer(t *testing.T) {
+	ready := make(chan struct{})
+
+	var xferFuncCalls int32
+
+	makeXferFunc := func(id string) DoFunc {
+		return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
+			atomic.AddInt32(&xferFuncCalls, 1)
+			xfer := NewTransfer()
+			go func() {
+				defer func() {
+					close(progressChan)
+				}()
+				<-ready
+				for i := int64(0); ; i++ {
+					select {
+					case <-time.After(10 * time.Millisecond):
+					case <-xfer.Context().Done():
+						return
+					}
+					progressChan <- progress.Progress{ID: id, Action: "testing", Current: i, Total: 10}
+				}
+			}()
+			return xfer
+		}
+	}
+
+	tm := NewTransferManager(5)
+
+	type transferInfo struct {
+		xfer                  Transfer
+		watcher               *Watcher
+		progressChan          chan progress.Progress
+		progressDone          chan struct{}
+		receivedFirstProgress chan struct{}
+	}
+
+	progressConsumer := func(t transferInfo) {
+		first := true
+		for range t.progressChan {
+			if first {
+				close(t.receivedFirstProgress)
+			}
+			first = false
+		}
+		close(t.progressDone)
+	}
+
+	// Try to start multiple transfers with the same ID
+	transfers := make([]transferInfo, 5)
+	for i := range transfers {
+		t := &transfers[i]
+		t.progressChan = make(chan progress.Progress)
+		t.progressDone = make(chan struct{})
+		t.receivedFirstProgress = make(chan struct{})
+		t.xfer, t.watcher = tm.Transfer("id1", makeXferFunc("id1"), progress.ChanOutput(t.progressChan))
+		go progressConsumer(*t)
+	}
+
+	// Allow the transfer goroutine to proceed.
+	close(ready)
+
+	// Confirm that each watcher gets progress output.
+	for _, t := range transfers {
+		<-t.receivedFirstProgress
+	}
+
+	// Confirm that the transfer function was called exactly once.
+	if xferFuncCalls != 1 {
+		t.Fatal("transfer function wasn't called exactly once")
+	}
+
+	// Release one watcher every 5ms
+	for _, t := range transfers {
+		t.xfer.Release(t.watcher)
+		<-time.After(5 * time.Millisecond)
+	}
+
+	for _, t := range transfers {
+		// Now that all watchers have been released, Released() should
+		// return a closed channel.
+		<-t.xfer.Released()
+		// Done() should return a closed channel because the xfer func returned
+		// due to cancellation.
+		<-t.xfer.Done()
+	}
+
+	for _, t := range transfers {
+		close(t.progressChan)
+		<-t.progressDone
+	}
+}

+ 159 - 0
distribution/xfer/upload.go

@@ -0,0 +1,159 @@
+package xfer
+
+import (
+	"errors"
+	"time"
+
+	"github.com/Sirupsen/logrus"
+	"github.com/docker/distribution/digest"
+	"github.com/docker/docker/layer"
+	"github.com/docker/docker/pkg/progress"
+	"golang.org/x/net/context"
+)
+
+const maxUploadAttempts = 5
+
+// LayerUploadManager provides task management and progress reporting for
+// uploads.
+type LayerUploadManager struct {
+	tm TransferManager
+}
+
+// NewLayerUploadManager returns a new LayerUploadManager.
+func NewLayerUploadManager(concurrencyLimit int) *LayerUploadManager {
+	return &LayerUploadManager{
+		tm: NewTransferManager(concurrencyLimit),
+	}
+}
+
+type uploadTransfer struct {
+	Transfer
+
+	diffID layer.DiffID
+	digest digest.Digest
+	err    error
+}
+
+// An UploadDescriptor references a layer that may need to be uploaded.
+type UploadDescriptor interface {
+	// Key returns the key used to deduplicate uploads.
+	Key() string
+	// ID returns the ID for display purposes.
+	ID() string
+	// DiffID should return the DiffID for this layer.
+	DiffID() layer.DiffID
+	// Upload is called to perform the Upload.
+	Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error)
+}
+
+// Upload is a blocking function which ensures the listed layers are present on
+// the remote registry. It uses the string returned by the Key method to
+// deduplicate uploads.
+func (lum *LayerUploadManager) Upload(ctx context.Context, layers []UploadDescriptor, progressOutput progress.Output) (map[layer.DiffID]digest.Digest, error) {
+	var (
+		uploads          []*uploadTransfer
+		digests          = make(map[layer.DiffID]digest.Digest)
+		dedupDescriptors = make(map[string]struct{})
+	)
+
+	for _, descriptor := range layers {
+		progress.Update(progressOutput, descriptor.ID(), "Preparing")
+
+		key := descriptor.Key()
+		if _, present := dedupDescriptors[key]; present {
+			continue
+		}
+		dedupDescriptors[key] = struct{}{}
+
+		xferFunc := lum.makeUploadFunc(descriptor)
+		upload, watcher := lum.tm.Transfer(descriptor.Key(), xferFunc, progressOutput)
+		defer upload.Release(watcher)
+		uploads = append(uploads, upload.(*uploadTransfer))
+	}
+
+	for _, upload := range uploads {
+		select {
+		case <-ctx.Done():
+			return nil, ctx.Err()
+		case <-upload.Transfer.Done():
+			if upload.err != nil {
+				return nil, upload.err
+			}
+			digests[upload.diffID] = upload.digest
+		}
+	}
+
+	return digests, nil
+}
+
+func (lum *LayerUploadManager) makeUploadFunc(descriptor UploadDescriptor) DoFunc {
+	return func(progressChan chan<- progress.Progress, start <-chan struct{}, inactive chan<- struct{}) Transfer {
+		u := &uploadTransfer{
+			Transfer: NewTransfer(),
+			diffID:   descriptor.DiffID(),
+		}
+
+		go func() {
+			defer func() {
+				close(progressChan)
+			}()
+
+			progressOutput := progress.ChanOutput(progressChan)
+
+			select {
+			case <-start:
+			default:
+				progress.Update(progressOutput, descriptor.ID(), "Waiting")
+				<-start
+			}
+
+			retries := 0
+			for {
+				digest, err := descriptor.Upload(u.Transfer.Context(), progressOutput)
+				if err == nil {
+					u.digest = digest
+					break
+				}
+
+				// If an error was returned because the context
+				// was cancelled, we shouldn't retry.
+				select {
+				case <-u.Transfer.Context().Done():
+					u.err = err
+					return
+				default:
+				}
+
+				retries++
+				if _, isDNR := err.(DoNotRetry); isDNR || retries == maxUploadAttempts {
+					logrus.Errorf("Upload failed: %v", err)
+					u.err = err
+					return
+				}
+
+				logrus.Errorf("Upload failed, retrying: %v", err)
+				delay := retries * 5
+				ticker := time.NewTicker(time.Second)
+
+			selectLoop:
+				for {
+					progress.Updatef(progressOutput, descriptor.ID(), "Retrying in %d seconds", delay)
+					select {
+					case <-ticker.C:
+						delay--
+						if delay == 0 {
+							ticker.Stop()
+							break selectLoop
+						}
+					case <-u.Transfer.Context().Done():
+						ticker.Stop()
+						u.err = errors.New("upload cancelled during retry delay")
+						return
+					}
+				}
+			}
+		}()
+
+		return u
+	}
+}

+ 153 - 0
distribution/xfer/upload_test.go

@@ -0,0 +1,153 @@
+package xfer
+
+import (
+	"errors"
+	"sync/atomic"
+	"testing"
+	"time"
+
+	"github.com/docker/distribution/digest"
+	"github.com/docker/docker/layer"
+	"github.com/docker/docker/pkg/progress"
+	"golang.org/x/net/context"
+)
+
+const maxUploadConcurrency = 3
+
+type mockUploadDescriptor struct {
+	currentUploads  *int32
+	diffID          layer.DiffID
+	simulateRetries int
+}
+
+// Key returns the key used to deduplicate downloads.
+func (u *mockUploadDescriptor) Key() string {
+	return u.diffID.String()
+}
+
+// ID returns the ID for display purposes.
+func (u *mockUploadDescriptor) ID() string {
+	return u.diffID.String()
+}
+
+// DiffID should return the DiffID for this layer.
+func (u *mockUploadDescriptor) DiffID() layer.DiffID {
+	return u.diffID
+}
+
+// Upload is called to perform the upload.
+func (u *mockUploadDescriptor) Upload(ctx context.Context, progressOutput progress.Output) (digest.Digest, error) {
+	if u.currentUploads != nil {
+		defer atomic.AddInt32(u.currentUploads, -1)
+
+		if atomic.AddInt32(u.currentUploads, 1) > maxUploadConcurrency {
+			return "", errors.New("concurrency limit exceeded")
+		}
+	}
+
+	// Sleep a bit to simulate a time-consuming upload.
+	for i := int64(0); i <= 10; i++ {
+		select {
+		case <-ctx.Done():
+			return "", ctx.Err()
+		case <-time.After(10 * time.Millisecond):
+			progressOutput.WriteProgress(progress.Progress{ID: u.ID(), Current: i, Total: 10})
+		}
+	}
+
+	if u.simulateRetries != 0 {
+		u.simulateRetries--
+		return "", errors.New("simulating retry")
+	}
+
+	// For the mock implementation, use SHA256(DiffID) as the returned
+	// digest.
+	return digest.FromBytes([]byte(u.diffID.String()))
+}
+
+func uploadDescriptors(currentUploads *int32) []UploadDescriptor {
+	return []UploadDescriptor{
+		&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"), 0},
+		&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:1515325234325236634634608943609283523908626098235490238423902343"), 0},
+		&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:6929356290463485374960346430698374523437683470934634534953453453"), 0},
+		&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"), 0},
+		&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:8159352387436803946235346346368745389534789534897538734598734987"), 1},
+		&mockUploadDescriptor{currentUploads, layer.DiffID("sha256:4637863963478346897346987346987346789346789364879364897364987346"), 0},
+	}
+}
+
+var expectedDigests = map[layer.DiffID]digest.Digest{
+	layer.DiffID("sha256:cbbf2f9a99b47fc460d422812b6a5adff7dfee951d8fa2e4a98caa0382cfbdbf"): digest.Digest("sha256:c5095d6cf7ee42b7b064371dcc1dc3fb4af197f04d01a60009d484bd432724fc"),
+	layer.DiffID("sha256:1515325234325236634634608943609283523908626098235490238423902343"): digest.Digest("sha256:968cbfe2ff5269ea1729b3804767a1f57ffbc442d3bc86f47edbf7e688a4f36e"),
+	layer.DiffID("sha256:6929356290463485374960346430698374523437683470934634534953453453"): digest.Digest("sha256:8a5e56ab4b477a400470a7d5d4c1ca0c91235fd723ab19cc862636a06f3a735d"),
+	layer.DiffID("sha256:8159352387436803946235346346368745389534789534897538734598734987"): digest.Digest("sha256:5e733e5cd3688512fc240bd5c178e72671c9915947d17bb8451750d827944cb2"),
+	layer.DiffID("sha256:4637863963478346897346987346987346789346789364879364897364987346"): digest.Digest("sha256:ec4bb98d15e554a9f66c3ef9296cf46772c0ded3b1592bd8324d96e2f60f460c"),
+}
+
+func TestSuccessfulUpload(t *testing.T) {
+	lum := NewLayerUploadManager(maxUploadConcurrency)
+
+	progressChan := make(chan progress.Progress)
+	progressDone := make(chan struct{})
+	receivedProgress := make(map[string]int64)
+
+	go func() {
+		for p := range progressChan {
+			receivedProgress[p.ID] = p.Current
+		}
+		close(progressDone)
+	}()
+
+	var currentUploads int32
+	descriptors := uploadDescriptors(&currentUploads)
+
+	digests, err := lum.Upload(context.Background(), descriptors, progress.ChanOutput(progressChan))
+	if err != nil {
+		t.Fatalf("upload error: %v", err)
+	}
+
+	close(progressChan)
+	<-progressDone
+
+	if len(digests) != len(expectedDigests) {
+		t.Fatal("wrong number of keys in digests map")
+	}
+
+	for key, val := range expectedDigests {
+		if digests[key] != val {
+			t.Fatalf("mismatch in digest array for key %v (expected %v, got %v)", key, val, digests[key])
+		}
+		if receivedProgress[key.String()] != 10 {
+			t.Fatalf("missing or wrong progress output for %v", key)
+		}
+	}
+}
+
+func TestCancelledUpload(t *testing.T) {
+	lum := NewLayerUploadManager(maxUploadConcurrency)
+
+	progressChan := make(chan progress.Progress)
+	progressDone := make(chan struct{})
+
+	go func() {
+		for range progressChan {
+		}
+		close(progressDone)
+	}()
+
+	ctx, cancel := context.WithCancel(context.Background())
+
+	go func() {
+		<-time.After(time.Millisecond)
+		cancel()
+	}()
+
+	descriptors := uploadDescriptors(nil)
+	_, err := lum.Upload(ctx, descriptors, progress.ChanOutput(progressChan))
+	if err != context.Canceled {
+		t.Fatal("expected upload to be cancelled")
+	}
+
+	close(progressChan)
+	<-progressDone
+}

+ 4 - 10
integration-cli/docker_cli_pull_test.go

@@ -140,7 +140,7 @@ func (s *DockerHubPullSuite) TestPullAllTagsFromCentralRegistry(c *check.C) {
 }
 }
 
 
 // TestPullClientDisconnect kills the client during a pull operation and verifies that the operation
 // TestPullClientDisconnect kills the client during a pull operation and verifies that the operation
-// still succesfully completes on the daemon side.
+// gets cancelled.
 //
 //
 // Ref: docker/docker#15589
 // Ref: docker/docker#15589
 func (s *DockerHubPullSuite) TestPullClientDisconnect(c *check.C) {
 func (s *DockerHubPullSuite) TestPullClientDisconnect(c *check.C) {
@@ -161,14 +161,8 @@ func (s *DockerHubPullSuite) TestPullClientDisconnect(c *check.C) {
 	err = pullCmd.Process.Kill()
 	err = pullCmd.Process.Kill()
 	c.Assert(err, checker.IsNil)
 	c.Assert(err, checker.IsNil)
 
 
-	maxAttempts := 20
-	for i := 0; ; i++ {
-		if _, err := s.CmdWithError("inspect", repoName); err == nil {
-			break
-		}
-		if i >= maxAttempts {
-			c.Fatal("timeout reached: image was not pulled after client disconnected")
-		}
-		time.Sleep(500 * time.Millisecond)
+	time.Sleep(2 * time.Second)
+	if _, err := s.CmdWithError("inspect", repoName); err == nil {
+		c.Fatal("image was pulled after client disconnected")
 	}
 	}
 }
 }

+ 0 - 167
pkg/broadcaster/buffered.go

@@ -1,167 +0,0 @@
-package broadcaster
-
-import (
-	"errors"
-	"io"
-	"sync"
-)
-
-// Buffered keeps track of one or more observers watching the progress
-// of an operation. For example, if multiple clients are trying to pull an
-// image, they share a Buffered struct for the download operation.
-type Buffered struct {
-	sync.Mutex
-	// c is a channel that observers block on, waiting for the operation
-	// to finish.
-	c chan struct{}
-	// cond is a condition variable used to wake up observers when there's
-	// new data available.
-	cond *sync.Cond
-	// history is a buffer of the progress output so far, so a new observer
-	// can catch up. The history is stored as a slice of separate byte
-	// slices, so that if the writer is a WriteFlusher, the flushes will
-	// happen in the right places.
-	history [][]byte
-	// wg is a WaitGroup used to wait for all writes to finish on Close
-	wg sync.WaitGroup
-	// result is the argument passed to the first call of Close, and
-	// returned to callers of Wait
-	result error
-}
-
-// NewBuffered returns an initialized Buffered structure.
-func NewBuffered() *Buffered {
-	b := &Buffered{
-		c: make(chan struct{}),
-	}
-	b.cond = sync.NewCond(b)
-	return b
-}
-
-// closed returns true if and only if the broadcaster has been closed
-func (broadcaster *Buffered) closed() bool {
-	select {
-	case <-broadcaster.c:
-		return true
-	default:
-		return false
-	}
-}
-
-// receiveWrites runs as a goroutine so that writes don't block the Write
-// function. It writes the new data in broadcaster.history each time there's
-// activity on the broadcaster.cond condition variable.
-func (broadcaster *Buffered) receiveWrites(observer io.Writer) {
-	n := 0
-
-	broadcaster.Lock()
-
-	// The condition variable wait is at the end of this loop, so that the
-	// first iteration will write the history so far.
-	for {
-		newData := broadcaster.history[n:]
-		// Make a copy of newData so we can release the lock
-		sendData := make([][]byte, len(newData), len(newData))
-		copy(sendData, newData)
-		broadcaster.Unlock()
-
-		for len(sendData) > 0 {
-			_, err := observer.Write(sendData[0])
-			if err != nil {
-				broadcaster.wg.Done()
-				return
-			}
-			n++
-			sendData = sendData[1:]
-		}
-
-		broadcaster.Lock()
-
-		// If we are behind, we need to catch up instead of waiting
-		// or handling a closure.
-		if len(broadcaster.history) != n {
-			continue
-		}
-
-		// detect closure of the broadcast writer
-		if broadcaster.closed() {
-			broadcaster.Unlock()
-			broadcaster.wg.Done()
-			return
-		}
-
-		broadcaster.cond.Wait()
-
-		// Mutex is still locked as the loop continues
-	}
-}
-
-// Write adds data to the history buffer, and also writes it to all current
-// observers.
-func (broadcaster *Buffered) Write(p []byte) (n int, err error) {
-	broadcaster.Lock()
-	defer broadcaster.Unlock()
-
-	// Is the broadcaster closed? If so, the write should fail.
-	if broadcaster.closed() {
-		return 0, errors.New("attempted write to a closed broadcaster.Buffered")
-	}
-
-	// Add message in p to the history slice
-	newEntry := make([]byte, len(p), len(p))
-	copy(newEntry, p)
-	broadcaster.history = append(broadcaster.history, newEntry)
-
-	broadcaster.cond.Broadcast()
-
-	return len(p), nil
-}
-
-// Add adds an observer to the broadcaster. The new observer receives the
-// data from the history buffer, and also all subsequent data.
-func (broadcaster *Buffered) Add(w io.Writer) error {
-	// The lock is acquired here so that Add can't race with Close
-	broadcaster.Lock()
-	defer broadcaster.Unlock()
-
-	if broadcaster.closed() {
-		return errors.New("attempted to add observer to a closed broadcaster.Buffered")
-	}
-
-	broadcaster.wg.Add(1)
-	go broadcaster.receiveWrites(w)
-
-	return nil
-}
-
-// CloseWithError signals to all observers that the operation has finished. Its
-// argument is a result that should be returned to waiters blocking on Wait.
-func (broadcaster *Buffered) CloseWithError(result error) {
-	broadcaster.Lock()
-	if broadcaster.closed() {
-		broadcaster.Unlock()
-		return
-	}
-	broadcaster.result = result
-	close(broadcaster.c)
-	broadcaster.cond.Broadcast()
-	broadcaster.Unlock()
-
-	// Don't return until all writers have caught up.
-	broadcaster.wg.Wait()
-}
-
-// Close signals to all observers that the operation has finished. It causes
-// all calls to Wait to return nil.
-func (broadcaster *Buffered) Close() {
-	broadcaster.CloseWithError(nil)
-}
-
-// Wait blocks until the operation is marked as completed by the Close method,
-// and all writer goroutines have completed. It returns the argument that was
-// passed to Close.
-func (broadcaster *Buffered) Wait() error {
-	<-broadcaster.c
-	broadcaster.wg.Wait()
-	return broadcaster.result
-}

+ 71 - 0
pkg/ioutils/readers.go

@@ -4,6 +4,8 @@ import (
 	"crypto/sha256"
 	"crypto/sha256"
 	"encoding/hex"
 	"encoding/hex"
 	"io"
 	"io"
+
+	"golang.org/x/net/context"
 )
 )
 
 
 type readCloserWrapper struct {
 type readCloserWrapper struct {
@@ -81,3 +83,72 @@ func (r *OnEOFReader) runFunc() {
 		r.Fn = nil
 		r.Fn = nil
 	}
 	}
 }
 }
+
+// cancelReadCloser wraps an io.ReadCloser with a context for cancelling read
+// operations.
+type cancelReadCloser struct {
+	cancel func()
+	pR     *io.PipeReader // Stream to read from
+	pW     *io.PipeWriter
+}
+
+// NewCancelReadCloser creates a wrapper that closes the ReadCloser when the
+// context is cancelled. The returned io.ReadCloser must be closed when it is
+// no longer needed.
+func NewCancelReadCloser(ctx context.Context, in io.ReadCloser) io.ReadCloser {
+	pR, pW := io.Pipe()
+
+	// Create a context used to signal when the pipe is closed
+	doneCtx, cancel := context.WithCancel(context.Background())
+
+	p := &cancelReadCloser{
+		cancel: cancel,
+		pR:     pR,
+		pW:     pW,
+	}
+
+	go func() {
+		_, err := io.Copy(pW, in)
+		select {
+		case <-ctx.Done():
+			// If the context was closed, p.closeWithError
+			// was already called. Calling it again would
+			// change the error that Read returns.
+		default:
+			p.closeWithError(err)
+		}
+		in.Close()
+	}()
+	go func() {
+		for {
+			select {
+			case <-ctx.Done():
+				p.closeWithError(ctx.Err())
+			case <-doneCtx.Done():
+				return
+			}
+		}
+	}()
+
+	return p
+}
+
+// Read wraps the Read method of the pipe that provides data from the wrapped
+// ReadCloser.
+func (p *cancelReadCloser) Read(buf []byte) (n int, err error) {
+	return p.pR.Read(buf)
+}
+
+// closeWithError closes the wrapper and its underlying reader. It will
+// cause future calls to Read to return err.
+func (p *cancelReadCloser) closeWithError(err error) {
+	p.pW.CloseWithError(err)
+	p.cancel()
+}
+
+// Close closes the wrapper its underlying reader. It will cause
+// future calls to Read to return io.EOF.
+func (p *cancelReadCloser) Close() error {
+	p.closeWithError(io.EOF)
+	return nil
+}

+ 27 - 0
pkg/ioutils/readers_test.go

@@ -2,8 +2,12 @@ package ioutils
 
 
 import (
 import (
 	"fmt"
 	"fmt"
+	"io/ioutil"
 	"strings"
 	"strings"
 	"testing"
 	"testing"
+	"time"
+
+	"golang.org/x/net/context"
 )
 )
 
 
 // Implement io.Reader
 // Implement io.Reader
@@ -65,3 +69,26 @@ func TestHashData(t *testing.T) {
 		t.Fatalf("Expecting %s, got %s", expected, actual)
 		t.Fatalf("Expecting %s, got %s", expected, actual)
 	}
 	}
 }
 }
+
+type perpetualReader struct{}
+
+func (p *perpetualReader) Read(buf []byte) (n int, err error) {
+	for i := 0; i != len(buf); i++ {
+		buf[i] = 'a'
+	}
+	return len(buf), nil
+}
+
+func TestCancelReadCloser(t *testing.T) {
+	ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
+	cancelReadCloser := NewCancelReadCloser(ctx, ioutil.NopCloser(&perpetualReader{}))
+	for {
+		var buf [128]byte
+		_, err := cancelReadCloser.Read(buf[:])
+		if err == context.DeadlineExceeded {
+			break
+		} else if err != nil {
+			t.Fatalf("got unexpected error: %v", err)
+		}
+	}
+}

+ 63 - 0
pkg/progress/progress.go

@@ -0,0 +1,63 @@
+package progress
+
+import (
+	"fmt"
+)
+
+// Progress represents the progress of a transfer.
+type Progress struct {
+	ID string
+
+	// Progress contains a Message or...
+	Message string
+
+	// ...progress of an action
+	Action  string
+	Current int64
+	Total   int64
+
+	LastUpdate bool
+}
+
+// Output is an interface for writing progress information. It's
+// like a writer for progress, but we don't call it Writer because
+// that would be confusing next to ProgressReader (also, because it
+// doesn't implement the io.Writer interface).
+type Output interface {
+	WriteProgress(Progress) error
+}
+
+type chanOutput chan<- Progress
+
+func (out chanOutput) WriteProgress(p Progress) error {
+	out <- p
+	return nil
+}
+
+// ChanOutput returns a Output that writes progress updates to the
+// supplied channel.
+func ChanOutput(progressChan chan<- Progress) Output {
+	return chanOutput(progressChan)
+}
+
+// Update is a convenience function to write a progress update to the channel.
+func Update(out Output, id, action string) {
+	out.WriteProgress(Progress{ID: id, Action: action})
+}
+
+// Updatef is a convenience function to write a printf-formatted progress update
+// to the channel.
+func Updatef(out Output, id, format string, a ...interface{}) {
+	Update(out, id, fmt.Sprintf(format, a...))
+}
+
+// Message is a convenience function to write a progress message to the channel.
+func Message(out Output, id, message string) {
+	out.WriteProgress(Progress{ID: id, Message: message})
+}
+
+// Messagef is a convenience function to write a printf-formatted progress
+// message to the channel.
+func Messagef(out Output, id, format string, a ...interface{}) {
+	Message(out, id, fmt.Sprintf(format, a...))
+}

+ 59 - 0
pkg/progress/progressreader.go

@@ -0,0 +1,59 @@
+package progress
+
+import (
+	"io"
+)
+
+// Reader is a Reader with progress bar.
+type Reader struct {
+	in         io.ReadCloser // Stream to read from
+	out        Output        // Where to send progress bar to
+	size       int64
+	current    int64
+	lastUpdate int64
+	id         string
+	action     string
+}
+
+// NewProgressReader creates a new ProgressReader.
+func NewProgressReader(in io.ReadCloser, out Output, size int64, id, action string) *Reader {
+	return &Reader{
+		in:     in,
+		out:    out,
+		size:   size,
+		id:     id,
+		action: action,
+	}
+}
+
+func (p *Reader) Read(buf []byte) (n int, err error) {
+	read, err := p.in.Read(buf)
+	p.current += int64(read)
+	updateEvery := int64(1024 * 512) //512kB
+	if p.size > 0 {
+		// Update progress for every 1% read if 1% < 512kB
+		if increment := int64(0.01 * float64(p.size)); increment < updateEvery {
+			updateEvery = increment
+		}
+	}
+	if p.current-p.lastUpdate > updateEvery || err != nil {
+		p.updateProgress(err != nil && read == 0)
+		p.lastUpdate = p.current
+	}
+
+	return read, err
+}
+
+// Close closes the progress reader and its underlying reader.
+func (p *Reader) Close() error {
+	if p.current < p.size {
+		// print a full progress bar when closing prematurely
+		p.current = p.size
+		p.updateProgress(false)
+	}
+	return p.in.Close()
+}
+
+func (p *Reader) updateProgress(last bool) {
+	p.out.WriteProgress(Progress{ID: p.id, Action: p.action, Current: p.current, Total: p.size, LastUpdate: last})
+}

+ 75 - 0
pkg/progress/progressreader_test.go

@@ -0,0 +1,75 @@
+package progress
+
+import (
+	"bytes"
+	"io"
+	"io/ioutil"
+	"testing"
+)
+
+func TestOutputOnPrematureClose(t *testing.T) {
+	content := []byte("TESTING")
+	reader := ioutil.NopCloser(bytes.NewReader(content))
+	progressChan := make(chan Progress, 10)
+
+	pr := NewProgressReader(reader, ChanOutput(progressChan), int64(len(content)), "Test", "Read")
+
+	part := make([]byte, 4, 4)
+	_, err := io.ReadFull(pr, part)
+	if err != nil {
+		pr.Close()
+		t.Fatal(err)
+	}
+
+drainLoop:
+	for {
+		select {
+		case <-progressChan:
+		default:
+			break drainLoop
+		}
+	}
+
+	pr.Close()
+
+	select {
+	case <-progressChan:
+	default:
+		t.Fatalf("Expected some output when closing prematurely")
+	}
+}
+
+func TestCompleteSilently(t *testing.T) {
+	content := []byte("TESTING")
+	reader := ioutil.NopCloser(bytes.NewReader(content))
+	progressChan := make(chan Progress, 10)
+
+	pr := NewProgressReader(reader, ChanOutput(progressChan), int64(len(content)), "Test", "Read")
+
+	out, err := ioutil.ReadAll(pr)
+	if err != nil {
+		pr.Close()
+		t.Fatal(err)
+	}
+	if string(out) != "TESTING" {
+		pr.Close()
+		t.Fatalf("Unexpected output %q from reader", string(out))
+	}
+
+drainLoop:
+	for {
+		select {
+		case <-progressChan:
+		default:
+			break drainLoop
+		}
+	}
+
+	pr.Close()
+
+	select {
+	case <-progressChan:
+		t.Fatalf("Should have closed silently when read is complete")
+	default:
+	}
+}

+ 0 - 68
pkg/progressreader/progressreader.go

@@ -1,68 +0,0 @@
-// Package progressreader provides a Reader with a progress bar that can be
-// printed out using the streamformatter package.
-package progressreader
-
-import (
-	"io"
-
-	"github.com/docker/docker/pkg/jsonmessage"
-	"github.com/docker/docker/pkg/streamformatter"
-)
-
-// Config contains the configuration for a Reader with progress bar.
-type Config struct {
-	In         io.ReadCloser // Stream to read from
-	Out        io.Writer     // Where to send progress bar to
-	Formatter  *streamformatter.StreamFormatter
-	Size       int64
-	Current    int64
-	LastUpdate int64
-	NewLines   bool
-	ID         string
-	Action     string
-}
-
-// New creates a new Config.
-func New(newReader Config) *Config {
-	return &newReader
-}
-
-func (config *Config) Read(p []byte) (n int, err error) {
-	read, err := config.In.Read(p)
-	config.Current += int64(read)
-	updateEvery := int64(1024 * 512) //512kB
-	if config.Size > 0 {
-		// Update progress for every 1% read if 1% < 512kB
-		if increment := int64(0.01 * float64(config.Size)); increment < updateEvery {
-			updateEvery = increment
-		}
-	}
-	if config.Current-config.LastUpdate > updateEvery || err != nil {
-		updateProgress(config)
-		config.LastUpdate = config.Current
-	}
-
-	if err != nil && read == 0 {
-		updateProgress(config)
-		if config.NewLines {
-			config.Out.Write(config.Formatter.FormatStatus("", ""))
-		}
-	}
-	return read, err
-}
-
-// Close closes the reader (Config).
-func (config *Config) Close() error {
-	if config.Current < config.Size {
-		//print a full progress bar when closing prematurely
-		config.Current = config.Size
-		updateProgress(config)
-	}
-	return config.In.Close()
-}
-
-func updateProgress(config *Config) {
-	progress := jsonmessage.JSONProgress{Current: config.Current, Total: config.Size}
-	fmtMessage := config.Formatter.FormatProgress(config.ID, config.Action, &progress)
-	config.Out.Write(fmtMessage)
-}

+ 0 - 94
pkg/progressreader/progressreader_test.go

@@ -1,94 +0,0 @@
-package progressreader
-
-import (
-	"bufio"
-	"bytes"
-	"io"
-	"io/ioutil"
-	"testing"
-
-	"github.com/docker/docker/pkg/streamformatter"
-)
-
-func TestOutputOnPrematureClose(t *testing.T) {
-	var outBuf bytes.Buffer
-	content := []byte("TESTING")
-	reader := ioutil.NopCloser(bytes.NewReader(content))
-	writer := bufio.NewWriter(&outBuf)
-
-	prCfg := Config{
-		In:        reader,
-		Out:       writer,
-		Formatter: streamformatter.NewStreamFormatter(),
-		Size:      int64(len(content)),
-		NewLines:  true,
-		ID:        "Test",
-		Action:    "Read",
-	}
-	pr := New(prCfg)
-
-	part := make([]byte, 4, 4)
-	_, err := io.ReadFull(pr, part)
-	if err != nil {
-		pr.Close()
-		t.Fatal(err)
-	}
-
-	if err := writer.Flush(); err != nil {
-		pr.Close()
-		t.Fatal(err)
-	}
-
-	tlen := outBuf.Len()
-	pr.Close()
-	if err := writer.Flush(); err != nil {
-		t.Fatal(err)
-	}
-
-	if outBuf.Len() == tlen {
-		t.Fatalf("Expected some output when closing prematurely")
-	}
-}
-
-func TestCompleteSilently(t *testing.T) {
-	var outBuf bytes.Buffer
-	content := []byte("TESTING")
-	reader := ioutil.NopCloser(bytes.NewReader(content))
-	writer := bufio.NewWriter(&outBuf)
-
-	prCfg := Config{
-		In:        reader,
-		Out:       writer,
-		Formatter: streamformatter.NewStreamFormatter(),
-		Size:      int64(len(content)),
-		NewLines:  true,
-		ID:        "Test",
-		Action:    "Read",
-	}
-	pr := New(prCfg)
-
-	out, err := ioutil.ReadAll(pr)
-	if err != nil {
-		pr.Close()
-		t.Fatal(err)
-	}
-	if string(out) != "TESTING" {
-		pr.Close()
-		t.Fatalf("Unexpected output %q from reader", string(out))
-	}
-
-	if err := writer.Flush(); err != nil {
-		pr.Close()
-		t.Fatal(err)
-	}
-
-	tlen := outBuf.Len()
-	pr.Close()
-	if err := writer.Flush(); err != nil {
-		t.Fatal(err)
-	}
-
-	if outBuf.Len() > tlen {
-		t.Fatalf("Should have closed silently when read is complete")
-	}
-}

+ 39 - 0
pkg/streamformatter/streamformatter.go

@@ -7,6 +7,7 @@ import (
 	"io"
 	"io"
 
 
 	"github.com/docker/docker/pkg/jsonmessage"
 	"github.com/docker/docker/pkg/jsonmessage"
+	"github.com/docker/docker/pkg/progress"
 )
 )
 
 
 // StreamFormatter formats a stream, optionally using JSON.
 // StreamFormatter formats a stream, optionally using JSON.
@@ -92,6 +93,44 @@ func (sf *StreamFormatter) FormatProgress(id, action string, progress *jsonmessa
 	return []byte(action + " " + progress.String() + endl)
 	return []byte(action + " " + progress.String() + endl)
 }
 }
 
 
+// NewProgressOutput returns a progress.Output object that can be passed to
+// progress.NewProgressReader.
+func (sf *StreamFormatter) NewProgressOutput(out io.Writer, newLines bool) progress.Output {
+	return &progressOutput{
+		sf:       sf,
+		out:      out,
+		newLines: newLines,
+	}
+}
+
+type progressOutput struct {
+	sf       *StreamFormatter
+	out      io.Writer
+	newLines bool
+}
+
+// WriteProgress formats progress information from a ProgressReader.
+func (out *progressOutput) WriteProgress(prog progress.Progress) error {
+	var formatted []byte
+	if prog.Message != "" {
+		formatted = out.sf.FormatStatus(prog.ID, prog.Message)
+	} else {
+		jsonProgress := jsonmessage.JSONProgress{Current: prog.Current, Total: prog.Total}
+		formatted = out.sf.FormatProgress(prog.ID, prog.Action, &jsonProgress)
+	}
+	_, err := out.out.Write(formatted)
+	if err != nil {
+		return err
+	}
+
+	if out.newLines && prog.LastUpdate {
+		_, err = out.out.Write(out.sf.FormatStatus("", ""))
+		return err
+	}
+
+	return nil
+}
+
 // StdoutFormatter is a streamFormatter that writes to the standard output.
 // StdoutFormatter is a streamFormatter that writes to the standard output.
 type StdoutFormatter struct {
 type StdoutFormatter struct {
 	io.Writer
 	io.Writer

+ 5 - 15
registry/session.go

@@ -17,7 +17,6 @@ import (
 	"net/url"
 	"net/url"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
-	"time"
 
 
 	"github.com/Sirupsen/logrus"
 	"github.com/Sirupsen/logrus"
 	"github.com/docker/distribution/reference"
 	"github.com/docker/distribution/reference"
@@ -270,7 +269,6 @@ func (r *Session) GetRemoteImageJSON(imgID, registry string) ([]byte, int64, err
 // GetRemoteImageLayer retrieves an image layer from the registry
 // GetRemoteImageLayer retrieves an image layer from the registry
 func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io.ReadCloser, error) {
 func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io.ReadCloser, error) {
 	var (
 	var (
-		retries    = 5
 		statusCode = 0
 		statusCode = 0
 		res        *http.Response
 		res        *http.Response
 		err        error
 		err        error
@@ -281,14 +279,9 @@ func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("Error while getting from the server: %v", err)
 		return nil, fmt.Errorf("Error while getting from the server: %v", err)
 	}
 	}
-	// TODO(tiborvass): why are we doing retries at this level?
-	// These retries should be generic to both v1 and v2
-	for i := 1; i <= retries; i++ {
-		statusCode = 0
-		res, err = r.client.Do(req)
-		if err == nil {
-			break
-		}
+	statusCode = 0
+	res, err = r.client.Do(req)
+	if err != nil {
 		logrus.Debugf("Error contacting registry %s: %v", registry, err)
 		logrus.Debugf("Error contacting registry %s: %v", registry, err)
 		if res != nil {
 		if res != nil {
 			if res.Body != nil {
 			if res.Body != nil {
@@ -296,11 +289,8 @@ func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io
 			}
 			}
 			statusCode = res.StatusCode
 			statusCode = res.StatusCode
 		}
 		}
-		if i == retries {
-			return nil, fmt.Errorf("Server error: Status %d while fetching image layer (%s)",
-				statusCode, imgID)
-		}
-		time.Sleep(time.Duration(i) * 5 * time.Second)
+		return nil, fmt.Errorf("Server error: Status %d while fetching image layer (%s)",
+			statusCode, imgID)
 	}
 	}
 
 
 	if res.StatusCode != 200 {
 	if res.StatusCode != 200 {