Browse Source

Merge pull request #15955 from aaronlehmann/parallel-pull-race-2

Fix race condition when waiting for a concurrent layer pull
Jess Frazelle 9 years ago
parent
commit
c044dadade
6 changed files with 253 additions and 120 deletions
  1. 5 4
      graph/load.go
  2. 25 16
      graph/pull_v1.go
  3. 60 85
      graph/pull_v2.go
  4. 9 5
      graph/tags.go
  5. 133 0
      integration-cli/docker_cli_pull_local_test.go
  6. 21 10
      pkg/progressreader/broadcaster.go

+ 5 - 4
graph/load.go

@@ -106,13 +106,14 @@ func (s *TagStore) recursiveLoad(address, tmpImageDir string) error {
 		}
 		}
 
 
 		// ensure no two downloads of the same layer happen at the same time
 		// ensure no two downloads of the same layer happen at the same time
-		if ps, found := s.poolAdd("pull", "layer:"+img.ID); found {
+		poolKey := "layer:" + img.ID
+		broadcaster, found := s.poolAdd("pull", poolKey)
+		if found {
 			logrus.Debugf("Image (id: %s) load is already running, waiting", img.ID)
 			logrus.Debugf("Image (id: %s) load is already running, waiting", img.ID)
-			ps.Wait()
-			return nil
+			return broadcaster.Wait()
 		}
 		}
 
 
-		defer s.poolRemove("pull", "layer:"+img.ID)
+		defer s.poolRemove("pull", poolKey)
 
 
 		if img.Parent != "" {
 		if img.Parent != "" {
 			if !s.graph.Exists(img.Parent) {
 			if !s.graph.Exists(img.Parent) {

+ 25 - 16
graph/pull_v1.go

@@ -138,16 +138,14 @@ func (p *v1Puller) pullRepository(askedTag string) error {
 			}
 			}
 
 
 			// ensure no two downloads of the same image happen at the same time
 			// ensure no two downloads of the same image happen at the same time
-			broadcaster, found := p.poolAdd("pull", "img:"+img.ID)
+			poolKey := "img:" + img.ID
+			broadcaster, found := p.poolAdd("pull", poolKey)
+			broadcaster.Add(out)
 			if found {
 			if found {
-				broadcaster.Add(out)
-				broadcaster.Wait()
-				out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil))
-				errors <- nil
+				errors <- broadcaster.Wait()
 				return
 				return
 			}
 			}
-			broadcaster.Add(out)
-			defer p.poolRemove("pull", "img:"+img.ID)
+			defer p.poolRemove("pull", poolKey)
 
 
 			// we need to retain it until tagging
 			// we need to retain it until tagging
 			p.graph.Retain(sessionID, img.ID)
 			p.graph.Retain(sessionID, img.ID)
@@ -188,6 +186,7 @@ func (p *v1Puller) pullRepository(askedTag string) error {
 				err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName, lastErr)
 				err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName, lastErr)
 				broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil))
 				broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil))
 				errors <- err
 				errors <- err
+				broadcaster.CloseWithError(err)
 				return
 				return
 			}
 			}
 			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil))
 			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil))
@@ -225,8 +224,9 @@ func (p *v1Puller) pullRepository(askedTag string) error {
 	return nil
 	return nil
 }
 }
 
 
-func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []string) (bool, error) {
-	history, err := p.session.GetRemoteHistory(imgID, endpoint)
+func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []string) (layersDownloaded bool, err error) {
+	var history []string
+	history, err = p.session.GetRemoteHistory(imgID, endpoint)
 	if err != nil {
 	if err != nil {
 		return false, err
 		return false, err
 	}
 	}
@@ -239,20 +239,28 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
 	p.graph.Retain(sessionID, history[1:]...)
 	p.graph.Retain(sessionID, history[1:]...)
 	defer p.graph.Release(sessionID, history[1:]...)
 	defer p.graph.Release(sessionID, history[1:]...)
 
 
-	layersDownloaded := false
+	layersDownloaded = false
 	for i := len(history) - 1; i >= 0; i-- {
 	for i := len(history) - 1; i >= 0; i-- {
 		id := history[i]
 		id := history[i]
 
 
 		// ensure no two downloads of the same layer happen at the same time
 		// ensure no two downloads of the same layer happen at the same time
-		broadcaster, found := p.poolAdd("pull", "layer:"+id)
+		poolKey := "layer:" + id
+		broadcaster, found := p.poolAdd("pull", poolKey)
+		broadcaster.Add(out)
 		if found {
 		if found {
 			logrus.Debugf("Image (id: %s) pull is already running, skipping", id)
 			logrus.Debugf("Image (id: %s) pull is already running, skipping", id)
-			broadcaster.Add(out)
-			broadcaster.Wait()
-		} else {
-			broadcaster.Add(out)
+			err = broadcaster.Wait()
+			if err != nil {
+				return layersDownloaded, err
+			}
+			continue
 		}
 		}
-		defer p.poolRemove("pull", "layer:"+id)
+
+		// This must use a closure so it captures the value of err when
+		// the function returns, not when the 'defer' is evaluated.
+		defer func() {
+			p.poolRemoveWithError("pull", poolKey, err)
+		}()
 
 
 		if !p.graph.Exists(id) {
 		if !p.graph.Exists(id) {
 			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Pulling metadata", nil))
 			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Pulling metadata", nil))
@@ -328,6 +336,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
 			}
 			}
 		}
 		}
 		broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Download complete", nil))
 		broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Download complete", nil))
+		broadcaster.Close()
 	}
 	}
 	return layersDownloaded, nil
 	return layersDownloaded, nil
 }
 }

+ 60 - 85
graph/pull_v2.go

@@ -1,7 +1,6 @@
 package graph
 package graph
 
 
 import (
 import (
-	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"io/ioutil"
 	"io/ioutil"
@@ -74,14 +73,17 @@ func (p *v2Puller) pullV2Repository(tag string) (err error) {
 	}
 	}
 
 
 	broadcaster, found := p.poolAdd("pull", taggedName)
 	broadcaster, found := p.poolAdd("pull", taggedName)
+	broadcaster.Add(p.config.OutStream)
 	if found {
 	if found {
 		// Another pull of the same repository is already taking place; just wait for it to finish
 		// Another pull of the same repository is already taking place; just wait for it to finish
-		broadcaster.Add(p.config.OutStream)
-		broadcaster.Wait()
-		return nil
+		return broadcaster.Wait()
 	}
 	}
-	defer p.poolRemove("pull", taggedName)
-	broadcaster.Add(p.config.OutStream)
+
+	// This must use a closure so it captures the value of err when the
+	// function returns, not when the 'defer' is evaluated.
+	defer func() {
+		p.poolRemoveWithError("pull", taggedName, err)
+	}()
 
 
 	var layersDownloaded bool
 	var layersDownloaded bool
 	for _, tag := range tags {
 	for _, tag := range tags {
@@ -101,13 +103,14 @@ func (p *v2Puller) pullV2Repository(tag string) (err error) {
 
 
 // downloadInfo is used to pass information from download to extractor
 // downloadInfo is used to pass information from download to extractor
 type downloadInfo struct {
 type downloadInfo struct {
-	img     *image.Image
-	tmpFile *os.File
-	digest  digest.Digest
-	layer   distribution.ReadSeekCloser
-	size    int64
-	err     chan error
-	out     io.Writer // Download progress is written here.
+	img         *image.Image
+	tmpFile     *os.File
+	digest      digest.Digest
+	layer       distribution.ReadSeekCloser
+	size        int64
+	err         chan error
+	poolKey     string
+	broadcaster *progressreader.Broadcaster
 }
 }
 
 
 type errVerification struct{}
 type errVerification struct{}
@@ -117,26 +120,6 @@ func (errVerification) Error() string { return "verification failed" }
 func (p *v2Puller) download(di *downloadInfo) {
 func (p *v2Puller) download(di *downloadInfo) {
 	logrus.Debugf("pulling blob %q to %s", di.digest, di.img.ID)
 	logrus.Debugf("pulling blob %q to %s", di.digest, di.img.ID)
 
 
-	out := di.out
-
-	broadcaster, found := p.poolAdd("pull", "img:"+di.img.ID)
-	if found {
-		broadcaster.Add(out)
-		broadcaster.Wait()
-		out.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil))
-		di.err <- nil
-		return
-	}
-
-	broadcaster.Add(out)
-	defer p.poolRemove("pull", "img:"+di.img.ID)
-	tmpFile, err := ioutil.TempFile("", "GetImageBlob")
-	if err != nil {
-		di.err <- err
-		return
-	}
-	di.tmpFile = tmpFile
-
 	blobs := p.repo.Blobs(context.Background())
 	blobs := p.repo.Blobs(context.Background())
 
 
 	desc, err := blobs.Stat(context.Background(), di.digest)
 	desc, err := blobs.Stat(context.Background(), di.digest)
@@ -163,16 +146,16 @@ func (p *v2Puller) download(di *downloadInfo) {
 
 
 	reader := progressreader.New(progressreader.Config{
 	reader := progressreader.New(progressreader.Config{
 		In:        ioutil.NopCloser(io.TeeReader(layerDownload, verifier)),
 		In:        ioutil.NopCloser(io.TeeReader(layerDownload, verifier)),
-		Out:       broadcaster,
+		Out:       di.broadcaster,
 		Formatter: p.sf,
 		Formatter: p.sf,
 		Size:      di.size,
 		Size:      di.size,
 		NewLines:  false,
 		NewLines:  false,
 		ID:        stringid.TruncateID(di.img.ID),
 		ID:        stringid.TruncateID(di.img.ID),
 		Action:    "Downloading",
 		Action:    "Downloading",
 	})
 	})
-	io.Copy(tmpFile, reader)
+	io.Copy(di.tmpFile, reader)
 
 
-	broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Verifying Checksum", nil))
+	di.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Verifying Checksum", nil))
 
 
 	if !verifier.Verified() {
 	if !verifier.Verified() {
 		err = fmt.Errorf("filesystem layer verification failed for digest %s", di.digest)
 		err = fmt.Errorf("filesystem layer verification failed for digest %s", di.digest)
@@ -181,9 +164,9 @@ func (p *v2Puller) download(di *downloadInfo) {
 		return
 		return
 	}
 	}
 
 
-	broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil))
+	di.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil))
 
 
-	logrus.Debugf("Downloaded %s to tempfile %s", di.img.ID, tmpFile.Name())
+	logrus.Debugf("Downloaded %s to tempfile %s", di.img.ID, di.tmpFile.Name())
 	di.layer = layerDownload
 	di.layer = layerDownload
 
 
 	di.err <- nil
 	di.err <- nil
@@ -209,33 +192,6 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo
 		logrus.Printf("Image manifest for %s has been verified", taggedName)
 		logrus.Printf("Image manifest for %s has been verified", taggedName)
 	}
 	}
 
 
-	// By using a pipeWriter for each of the downloads to write their progress
-	// to, we can avoid an issue where this function returns an error but
-	// leaves behind running download goroutines. By splitting the writer
-	// with a pipe, we can close the pipe if there is any error, consequently
-	// causing each download to cancel due to an error writing to this pipe.
-	pipeReader, pipeWriter := io.Pipe()
-	go func() {
-		if _, err := io.Copy(out, pipeReader); err != nil {
-			logrus.Errorf("error copying from layer download progress reader: %s", err)
-			if err := pipeReader.CloseWithError(err); err != nil {
-				logrus.Errorf("error closing the progress reader: %s", err)
-			}
-		}
-	}()
-	defer func() {
-		if err != nil {
-			// All operations on the pipe are synchronous. This call will wait
-			// until all current readers/writers are done using the pipe then
-			// set the error. All successive reads/writes will return with this
-			// error.
-			pipeWriter.CloseWithError(errors.New("download canceled"))
-		} else {
-			// If no error then just close the pipe.
-			pipeWriter.Close()
-		}
-	}()
-
 	out.Write(p.sf.FormatStatus(tag, "Pulling from %s", p.repo.Name()))
 	out.Write(p.sf.FormatStatus(tag, "Pulling from %s", p.repo.Name()))
 
 
 	var downloads []*downloadInfo
 	var downloads []*downloadInfo
@@ -243,6 +199,16 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo
 	var layerIDs []string
 	var layerIDs []string
 	defer func() {
 	defer func() {
 		p.graph.Release(p.sessionID, layerIDs...)
 		p.graph.Release(p.sessionID, layerIDs...)
+
+		for _, d := range downloads {
+			p.poolRemoveWithError("pull", d.poolKey, err)
+			if d.tmpFile != nil {
+				d.tmpFile.Close()
+				if err := os.RemoveAll(d.tmpFile.Name()); err != nil {
+					logrus.Errorf("Failed to remove temp file: %s", d.tmpFile.Name())
+				}
+			}
+		}
 	}()
 	}()
 
 
 	for i := len(manifest.FSLayers) - 1; i >= 0; i-- {
 	for i := len(manifest.FSLayers) - 1; i >= 0; i-- {
@@ -263,29 +229,31 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo
 		out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Pulling fs layer", nil))
 		out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Pulling fs layer", nil))
 
 
 		d := &downloadInfo{
 		d := &downloadInfo{
-			img:    img,
-			digest: manifest.FSLayers[i].BlobSum,
+			img:     img,
+			poolKey: "layer:" + img.ID,
+			digest:  manifest.FSLayers[i].BlobSum,
 			// TODO: seems like this chan buffer solved hanging problem in go1.5,
 			// TODO: seems like this chan buffer solved hanging problem in go1.5,
 			// this can indicate some deeper problem that somehow we never take
 			// this can indicate some deeper problem that somehow we never take
 			// error from channel in loop below
 			// error from channel in loop below
 			err: make(chan error, 1),
 			err: make(chan error, 1),
-			out: pipeWriter,
 		}
 		}
-		downloads = append(downloads, d)
 
 
-		go p.download(d)
-	}
+		tmpFile, err := ioutil.TempFile("", "GetImageBlob")
+		if err != nil {
+			return false, err
+		}
+		d.tmpFile = tmpFile
 
 
-	// run clean for all downloads to prevent leftovers
-	for _, d := range downloads {
-		defer func(d *downloadInfo) {
-			if d.tmpFile != nil {
-				d.tmpFile.Close()
-				if err := os.RemoveAll(d.tmpFile.Name()); err != nil {
-					logrus.Errorf("Failed to remove temp file: %s", d.tmpFile.Name())
-				}
-			}
-		}(d)
+		downloads = append(downloads, d)
+
+		broadcaster, found := p.poolAdd("pull", d.poolKey)
+		broadcaster.Add(out)
+		d.broadcaster = broadcaster
+		if found {
+			d.err <- nil
+		} else {
+			go p.download(d)
+		}
 	}
 	}
 
 
 	var tagUpdated bool
 	var tagUpdated bool
@@ -293,14 +261,21 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo
 		if err := <-d.err; err != nil {
 		if err := <-d.err; err != nil {
 			return false, err
 			return false, err
 		}
 		}
+
 		if d.layer == nil {
 		if d.layer == nil {
+			// Wait for a different pull to download and extract
+			// this layer.
+			err = d.broadcaster.Wait()
+			if err != nil {
+				return false, err
+			}
 			continue
 			continue
 		}
 		}
-		// if tmpFile is empty assume download and extracted elsewhere
+
 		d.tmpFile.Seek(0, 0)
 		d.tmpFile.Seek(0, 0)
 		reader := progressreader.New(progressreader.Config{
 		reader := progressreader.New(progressreader.Config{
 			In:        d.tmpFile,
 			In:        d.tmpFile,
-			Out:       out,
+			Out:       d.broadcaster,
 			Formatter: p.sf,
 			Formatter: p.sf,
 			Size:      d.size,
 			Size:      d.size,
 			NewLines:  false,
 			NewLines:  false,
@@ -317,8 +292,8 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo
 			return false, err
 			return false, err
 		}
 		}
 
 
-		// FIXME: Pool release here for parallel tag pull (ensures any downloads block until fully extracted)
-		out.Write(p.sf.FormatProgress(stringid.TruncateID(d.img.ID), "Pull complete", nil))
+		d.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(d.img.ID), "Pull complete", nil))
+		d.broadcaster.Close()
 		tagUpdated = true
 		tagUpdated = true
 	}
 	}
 
 

+ 9 - 5
graph/tags.go

@@ -462,18 +462,18 @@ func (store *TagStore) poolAdd(kind, key string) (*progressreader.Broadcaster, b
 	return broadcaster, false
 	return broadcaster, false
 }
 }
 
 
-func (store *TagStore) poolRemove(kind, key string) error {
+func (store *TagStore) poolRemoveWithError(kind, key string, broadcasterResult error) error {
 	store.Lock()
 	store.Lock()
 	defer store.Unlock()
 	defer store.Unlock()
 	switch kind {
 	switch kind {
 	case "pull":
 	case "pull":
-		if ps, exists := store.pullingPool[key]; exists {
-			ps.Close()
+		if broadcaster, exists := store.pullingPool[key]; exists {
+			broadcaster.CloseWithError(broadcasterResult)
 			delete(store.pullingPool, key)
 			delete(store.pullingPool, key)
 		}
 		}
 	case "push":
 	case "push":
-		if ps, exists := store.pushingPool[key]; exists {
-			ps.Close()
+		if broadcaster, exists := store.pushingPool[key]; exists {
+			broadcaster.CloseWithError(broadcasterResult)
 			delete(store.pushingPool, key)
 			delete(store.pushingPool, key)
 		}
 		}
 	default:
 	default:
@@ -481,3 +481,7 @@ func (store *TagStore) poolRemove(kind, key string) error {
 	}
 	}
 	return nil
 	return nil
 }
 }
+
+func (store *TagStore) poolRemove(kind, key string) error {
+	return store.poolRemoveWithError(kind, key, nil)
+}

+ 133 - 0
integration-cli/docker_cli_pull_local_test.go

@@ -2,6 +2,8 @@ package main
 
 
 import (
 import (
 	"fmt"
 	"fmt"
+	"os/exec"
+	"strings"
 
 
 	"github.com/go-check/check"
 	"github.com/go-check/check"
 )
 )
@@ -37,3 +39,134 @@ func (s *DockerRegistrySuite) TestPullImageWithAliases(c *check.C) {
 		}
 		}
 	}
 	}
 }
 }
+
+// TestConcurrentPullWholeRepo pulls the same repo concurrently.
+func (s *DockerRegistrySuite) TestConcurrentPullWholeRepo(c *check.C) {
+	repoName := fmt.Sprintf("%v/dockercli/busybox", privateRegistryURL)
+
+	repos := []string{}
+	for _, tag := range []string{"recent", "fresh", "todays"} {
+		repo := fmt.Sprintf("%v:%v", repoName, tag)
+		_, err := buildImage(repo, fmt.Sprintf(`
+		    FROM busybox
+		    ENTRYPOINT ["/bin/echo"]
+		    ENV FOO foo
+		    ENV BAR bar
+		    CMD echo %s
+		`, repo), true)
+		if err != nil {
+			c.Fatal(err)
+		}
+		dockerCmd(c, "push", repo)
+		repos = append(repos, repo)
+	}
+
+	// Clear local images store.
+	args := append([]string{"rmi"}, repos...)
+	dockerCmd(c, args...)
+
+	// Run multiple re-pulls concurrently
+	results := make(chan error)
+	numPulls := 3
+
+	for i := 0; i != numPulls; i++ {
+		go func() {
+			_, _, err := runCommandWithOutput(exec.Command(dockerBinary, "pull", "-a", repoName))
+			results <- err
+		}()
+	}
+
+	// These checks are separate from the loop above because the check
+	// package is not goroutine-safe.
+	for i := 0; i != numPulls; i++ {
+		err := <-results
+		c.Assert(err, check.IsNil, check.Commentf("concurrent pull failed with error: %v", err))
+	}
+
+	// Ensure all tags were pulled successfully
+	for _, repo := range repos {
+		dockerCmd(c, "inspect", repo)
+		out, _ := dockerCmd(c, "run", "--rm", repo)
+		if strings.TrimSpace(out) != "/bin/sh -c echo "+repo {
+			c.Fatalf("CMD did not contain /bin/sh -c echo %s: %s", repo, out)
+		}
+	}
+}
+
+// TestConcurrentFailingPull tries a concurrent pull that doesn't succeed.
+func (s *DockerRegistrySuite) TestConcurrentFailingPull(c *check.C) {
+	repoName := fmt.Sprintf("%v/dockercli/busybox", privateRegistryURL)
+
+	// Run multiple pulls concurrently
+	results := make(chan error)
+	numPulls := 3
+
+	for i := 0; i != numPulls; i++ {
+		go func() {
+			_, _, err := runCommandWithOutput(exec.Command(dockerBinary, "pull", repoName+":asdfasdf"))
+			results <- err
+		}()
+	}
+
+	// These checks are separate from the loop above because the check
+	// package is not goroutine-safe.
+	for i := 0; i != numPulls; i++ {
+		err := <-results
+		if err == nil {
+			c.Fatal("expected pull to fail")
+		}
+	}
+}
+
+// TestConcurrentPullMultipleTags pulls multiple tags from the same repo
+// concurrently.
+func (s *DockerRegistrySuite) TestConcurrentPullMultipleTags(c *check.C) {
+	repoName := fmt.Sprintf("%v/dockercli/busybox", privateRegistryURL)
+
+	repos := []string{}
+	for _, tag := range []string{"recent", "fresh", "todays"} {
+		repo := fmt.Sprintf("%v:%v", repoName, tag)
+		_, err := buildImage(repo, fmt.Sprintf(`
+		    FROM busybox
+		    ENTRYPOINT ["/bin/echo"]
+		    ENV FOO foo
+		    ENV BAR bar
+		    CMD echo %s
+		`, repo), true)
+		if err != nil {
+			c.Fatal(err)
+		}
+		dockerCmd(c, "push", repo)
+		repos = append(repos, repo)
+	}
+
+	// Clear local images store.
+	args := append([]string{"rmi"}, repos...)
+	dockerCmd(c, args...)
+
+	// Re-pull individual tags, in parallel
+	results := make(chan error)
+
+	for _, repo := range repos {
+		go func(repo string) {
+			_, _, err := runCommandWithOutput(exec.Command(dockerBinary, "pull", repo))
+			results <- err
+		}(repo)
+	}
+
+	// These checks are separate from the loop above because the check
+	// package is not goroutine-safe.
+	for range repos {
+		err := <-results
+		c.Assert(err, check.IsNil, check.Commentf("concurrent pull failed with error: %v", err))
+	}
+
+	// Ensure all tags were pulled successfully
+	for _, repo := range repos {
+		dockerCmd(c, "inspect", repo)
+		out, _ := dockerCmd(c, "run", "--rm", repo)
+		if strings.TrimSpace(out) != "/bin/sh -c echo "+repo {
+			c.Fatalf("CMD did not contain /bin/sh -c echo %s: %s", repo, out)
+		}
+	}
+}

+ 21 - 10
pkg/progressreader/broadcaster.go

@@ -24,9 +24,9 @@ type Broadcaster struct {
 	history [][]byte
 	history [][]byte
 	// wg is a WaitGroup used to wait for all writes to finish on Close
 	// wg is a WaitGroup used to wait for all writes to finish on Close
 	wg sync.WaitGroup
 	wg sync.WaitGroup
-	// isClosed is set to true when Close is called to avoid closing c
-	// multiple times.
-	isClosed bool
+	// result is the argument passed to the first call of Close, and
+	// returned to callers of Wait
+	result error
 }
 }
 
 
 // NewBroadcaster returns a Broadcaster structure
 // NewBroadcaster returns a Broadcaster structure
@@ -134,23 +134,34 @@ func (broadcaster *Broadcaster) Add(w io.Writer) error {
 	return nil
 	return nil
 }
 }
 
 
-// Close signals to all observers that the operation has finished.
-func (broadcaster *Broadcaster) Close() {
+// CloseWithError signals to all observers that the operation has finished. Its
+// argument is a result that should be returned to waiters blocking on Wait.
+func (broadcaster *Broadcaster) CloseWithError(result error) {
 	broadcaster.Lock()
 	broadcaster.Lock()
-	if broadcaster.isClosed {
+	if broadcaster.closed() {
 		broadcaster.Unlock()
 		broadcaster.Unlock()
 		return
 		return
 	}
 	}
-	broadcaster.isClosed = true
+	broadcaster.result = result
 	close(broadcaster.c)
 	close(broadcaster.c)
 	broadcaster.cond.Broadcast()
 	broadcaster.cond.Broadcast()
 	broadcaster.Unlock()
 	broadcaster.Unlock()
 
 
-	// Don't return from Close until all writers have caught up.
+	// Don't return until all writers have caught up.
 	broadcaster.wg.Wait()
 	broadcaster.wg.Wait()
 }
 }
 
 
-// Wait blocks until the operation is marked as completed by the Done method.
-func (broadcaster *Broadcaster) Wait() {
+// Close signals to all observers that the operation has finished. It causes
+// all calls to Wait to return nil.
+func (broadcaster *Broadcaster) Close() {
+	broadcaster.CloseWithError(nil)
+}
+
+// Wait blocks until the operation is marked as completed by the Close method,
+// and all writer goroutines have completed. It returns the argument that was
+// passed to Close.
+func (broadcaster *Broadcaster) Wait() error {
 	<-broadcaster.c
 	<-broadcaster.c
+	broadcaster.wg.Wait()
+	return broadcaster.result
 }
 }