Prechádzať zdrojové kódy

Fix race condition when waiting for a concurrent layer pull

Before, this only waited for the download to complete. There was no
guarantee that the layer had been registered in the graph and was ready
use. This is especially problematic with v2 pulls, which wait for all
downloads before extracting layers.

Change Broadcaster to allow an error value to be propagated from Close
to the waiters.

Make the wait stop when the extraction is finished, rather than just the
download.

This also fixes v2 layer downloads to prefix the pool key with "layer:"
instead of "img:". "img:" is the wrong prefix, because this is what v1
uses for entire images. A v1 pull waiting for one of these operations to
finish would only wait for that particular layer, not all its
dependencies.

Signed-off-by: Aaron Lehmann <aaron.lehmann@docker.com>
Aaron Lehmann 9 rokov pred
rodič
commit
23e68679f0

+ 5 - 4
graph/load.go

@@ -106,13 +106,14 @@ func (s *TagStore) recursiveLoad(address, tmpImageDir string) error {
 		}
 		}
 
 
 		// ensure no two downloads of the same layer happen at the same time
 		// ensure no two downloads of the same layer happen at the same time
-		if ps, found := s.poolAdd("pull", "layer:"+img.ID); found {
+		poolKey := "layer:" + img.ID
+		broadcaster, found := s.poolAdd("pull", poolKey)
+		if found {
 			logrus.Debugf("Image (id: %s) load is already running, waiting", img.ID)
 			logrus.Debugf("Image (id: %s) load is already running, waiting", img.ID)
-			ps.Wait()
-			return nil
+			return broadcaster.Wait()
 		}
 		}
 
 
-		defer s.poolRemove("pull", "layer:"+img.ID)
+		defer s.poolRemove("pull", poolKey)
 
 
 		if img.Parent != "" {
 		if img.Parent != "" {
 			if !s.graph.Exists(img.Parent) {
 			if !s.graph.Exists(img.Parent) {

+ 25 - 16
graph/pull_v1.go

@@ -138,16 +138,14 @@ func (p *v1Puller) pullRepository(askedTag string) error {
 			}
 			}
 
 
 			// ensure no two downloads of the same image happen at the same time
 			// ensure no two downloads of the same image happen at the same time
-			broadcaster, found := p.poolAdd("pull", "img:"+img.ID)
+			poolKey := "img:" + img.ID
+			broadcaster, found := p.poolAdd("pull", poolKey)
+			broadcaster.Add(out)
 			if found {
 			if found {
-				broadcaster.Add(out)
-				broadcaster.Wait()
-				out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil))
-				errors <- nil
+				errors <- broadcaster.Wait()
 				return
 				return
 			}
 			}
-			broadcaster.Add(out)
-			defer p.poolRemove("pull", "img:"+img.ID)
+			defer p.poolRemove("pull", poolKey)
 
 
 			// we need to retain it until tagging
 			// we need to retain it until tagging
 			p.graph.Retain(sessionID, img.ID)
 			p.graph.Retain(sessionID, img.ID)
@@ -188,6 +186,7 @@ func (p *v1Puller) pullRepository(askedTag string) error {
 				err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName, lastErr)
 				err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName, lastErr)
 				broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil))
 				broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil))
 				errors <- err
 				errors <- err
+				broadcaster.CloseWithError(err)
 				return
 				return
 			}
 			}
 			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil))
 			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil))
@@ -225,8 +224,9 @@ func (p *v1Puller) pullRepository(askedTag string) error {
 	return nil
 	return nil
 }
 }
 
 
-func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []string) (bool, error) {
-	history, err := p.session.GetRemoteHistory(imgID, endpoint)
+func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []string) (layersDownloaded bool, err error) {
+	var history []string
+	history, err = p.session.GetRemoteHistory(imgID, endpoint)
 	if err != nil {
 	if err != nil {
 		return false, err
 		return false, err
 	}
 	}
@@ -239,20 +239,28 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
 	p.graph.Retain(sessionID, history[1:]...)
 	p.graph.Retain(sessionID, history[1:]...)
 	defer p.graph.Release(sessionID, history[1:]...)
 	defer p.graph.Release(sessionID, history[1:]...)
 
 
-	layersDownloaded := false
+	layersDownloaded = false
 	for i := len(history) - 1; i >= 0; i-- {
 	for i := len(history) - 1; i >= 0; i-- {
 		id := history[i]
 		id := history[i]
 
 
 		// ensure no two downloads of the same layer happen at the same time
 		// ensure no two downloads of the same layer happen at the same time
-		broadcaster, found := p.poolAdd("pull", "layer:"+id)
+		poolKey := "layer:" + id
+		broadcaster, found := p.poolAdd("pull", poolKey)
+		broadcaster.Add(out)
 		if found {
 		if found {
 			logrus.Debugf("Image (id: %s) pull is already running, skipping", id)
 			logrus.Debugf("Image (id: %s) pull is already running, skipping", id)
-			broadcaster.Add(out)
-			broadcaster.Wait()
-		} else {
-			broadcaster.Add(out)
+			err = broadcaster.Wait()
+			if err != nil {
+				return layersDownloaded, err
+			}
+			continue
 		}
 		}
-		defer p.poolRemove("pull", "layer:"+id)
+
+		// This must use a closure so it captures the value of err when
+		// the function returns, not when the 'defer' is evaluated.
+		defer func() {
+			p.poolRemoveWithError("pull", poolKey, err)
+		}()
 
 
 		if !p.graph.Exists(id) {
 		if !p.graph.Exists(id) {
 			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Pulling metadata", nil))
 			broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Pulling metadata", nil))
@@ -328,6 +336,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri
 			}
 			}
 		}
 		}
 		broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Download complete", nil))
 		broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Download complete", nil))
+		broadcaster.Close()
 	}
 	}
 	return layersDownloaded, nil
 	return layersDownloaded, nil
 }
 }

+ 33 - 24
graph/pull_v2.go

@@ -74,14 +74,17 @@ func (p *v2Puller) pullV2Repository(tag string) (err error) {
 	}
 	}
 
 
 	broadcaster, found := p.poolAdd("pull", taggedName)
 	broadcaster, found := p.poolAdd("pull", taggedName)
+	broadcaster.Add(p.config.OutStream)
 	if found {
 	if found {
 		// Another pull of the same repository is already taking place; just wait for it to finish
 		// Another pull of the same repository is already taking place; just wait for it to finish
-		broadcaster.Add(p.config.OutStream)
-		broadcaster.Wait()
-		return nil
+		return broadcaster.Wait()
 	}
 	}
-	defer p.poolRemove("pull", taggedName)
-	broadcaster.Add(p.config.OutStream)
+
+	// This must use a closure so it captures the value of err when the
+	// function returns, not when the 'defer' is evaluated.
+	defer func() {
+		p.poolRemoveWithError("pull", taggedName, err)
+	}()
 
 
 	var layersDownloaded bool
 	var layersDownloaded bool
 	for _, tag := range tags {
 	for _, tag := range tags {
@@ -101,13 +104,15 @@ func (p *v2Puller) pullV2Repository(tag string) (err error) {
 
 
 // downloadInfo is used to pass information from download to extractor
 // downloadInfo is used to pass information from download to extractor
 type downloadInfo struct {
 type downloadInfo struct {
-	img     *image.Image
-	tmpFile *os.File
-	digest  digest.Digest
-	layer   distribution.ReadSeekCloser
-	size    int64
-	err     chan error
-	out     io.Writer // Download progress is written here.
+	img         *image.Image
+	tmpFile     *os.File
+	digest      digest.Digest
+	layer       distribution.ReadSeekCloser
+	size        int64
+	err         chan error
+	out         io.Writer // Download progress is written here.
+	poolKey     string
+	broadcaster *progressreader.Broadcaster
 }
 }
 
 
 type errVerification struct{}
 type errVerification struct{}
@@ -117,19 +122,15 @@ func (errVerification) Error() string { return "verification failed" }
 func (p *v2Puller) download(di *downloadInfo) {
 func (p *v2Puller) download(di *downloadInfo) {
 	logrus.Debugf("pulling blob %q to %s", di.digest, di.img.ID)
 	logrus.Debugf("pulling blob %q to %s", di.digest, di.img.ID)
 
 
-	out := di.out
-
-	broadcaster, found := p.poolAdd("pull", "img:"+di.img.ID)
+	di.poolKey = "layer:" + di.img.ID
+	broadcaster, found := p.poolAdd("pull", di.poolKey)
+	broadcaster.Add(di.out)
+	di.broadcaster = broadcaster
 	if found {
 	if found {
-		broadcaster.Add(out)
-		broadcaster.Wait()
-		out.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil))
 		di.err <- nil
 		di.err <- nil
 		return
 		return
 	}
 	}
 
 
-	broadcaster.Add(out)
-	defer p.poolRemove("pull", "img:"+di.img.ID)
 	tmpFile, err := ioutil.TempFile("", "GetImageBlob")
 	tmpFile, err := ioutil.TempFile("", "GetImageBlob")
 	if err != nil {
 	if err != nil {
 		di.err <- err
 		di.err <- err
@@ -279,6 +280,7 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo
 	// run clean for all downloads to prevent leftovers
 	// run clean for all downloads to prevent leftovers
 	for _, d := range downloads {
 	for _, d := range downloads {
 		defer func(d *downloadInfo) {
 		defer func(d *downloadInfo) {
+			p.poolRemoveWithError("pull", d.poolKey, err)
 			if d.tmpFile != nil {
 			if d.tmpFile != nil {
 				d.tmpFile.Close()
 				d.tmpFile.Close()
 				if err := os.RemoveAll(d.tmpFile.Name()); err != nil {
 				if err := os.RemoveAll(d.tmpFile.Name()); err != nil {
@@ -293,14 +295,21 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo
 		if err := <-d.err; err != nil {
 		if err := <-d.err; err != nil {
 			return false, err
 			return false, err
 		}
 		}
+
 		if d.layer == nil {
 		if d.layer == nil {
+			// Wait for a different pull to download and extract
+			// this layer.
+			err = d.broadcaster.Wait()
+			if err != nil {
+				return false, err
+			}
 			continue
 			continue
 		}
 		}
-		// if tmpFile is empty assume download and extracted elsewhere
+
 		d.tmpFile.Seek(0, 0)
 		d.tmpFile.Seek(0, 0)
 		reader := progressreader.New(progressreader.Config{
 		reader := progressreader.New(progressreader.Config{
 			In:        d.tmpFile,
 			In:        d.tmpFile,
-			Out:       out,
+			Out:       d.broadcaster,
 			Formatter: p.sf,
 			Formatter: p.sf,
 			Size:      d.size,
 			Size:      d.size,
 			NewLines:  false,
 			NewLines:  false,
@@ -317,8 +326,8 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo
 			return false, err
 			return false, err
 		}
 		}
 
 
-		// FIXME: Pool release here for parallel tag pull (ensures any downloads block until fully extracted)
-		out.Write(p.sf.FormatProgress(stringid.TruncateID(d.img.ID), "Pull complete", nil))
+		d.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(d.img.ID), "Pull complete", nil))
+		d.broadcaster.Close()
 		tagUpdated = true
 		tagUpdated = true
 	}
 	}
 
 

+ 9 - 5
graph/tags.go

@@ -462,18 +462,18 @@ func (store *TagStore) poolAdd(kind, key string) (*progressreader.Broadcaster, b
 	return broadcaster, false
 	return broadcaster, false
 }
 }
 
 
-func (store *TagStore) poolRemove(kind, key string) error {
+func (store *TagStore) poolRemoveWithError(kind, key string, broadcasterResult error) error {
 	store.Lock()
 	store.Lock()
 	defer store.Unlock()
 	defer store.Unlock()
 	switch kind {
 	switch kind {
 	case "pull":
 	case "pull":
-		if ps, exists := store.pullingPool[key]; exists {
-			ps.Close()
+		if broadcaster, exists := store.pullingPool[key]; exists {
+			broadcaster.CloseWithError(broadcasterResult)
 			delete(store.pullingPool, key)
 			delete(store.pullingPool, key)
 		}
 		}
 	case "push":
 	case "push":
-		if ps, exists := store.pushingPool[key]; exists {
-			ps.Close()
+		if broadcaster, exists := store.pushingPool[key]; exists {
+			broadcaster.CloseWithError(broadcasterResult)
 			delete(store.pushingPool, key)
 			delete(store.pushingPool, key)
 		}
 		}
 	default:
 	default:
@@ -481,3 +481,7 @@ func (store *TagStore) poolRemove(kind, key string) error {
 	}
 	}
 	return nil
 	return nil
 }
 }
+
+func (store *TagStore) poolRemove(kind, key string) error {
+	return store.poolRemoveWithError(kind, key, nil)
+}

+ 133 - 0
integration-cli/docker_cli_pull_local_test.go

@@ -2,6 +2,8 @@ package main
 
 
 import (
 import (
 	"fmt"
 	"fmt"
+	"os/exec"
+	"strings"
 
 
 	"github.com/go-check/check"
 	"github.com/go-check/check"
 )
 )
@@ -37,3 +39,134 @@ func (s *DockerRegistrySuite) TestPullImageWithAliases(c *check.C) {
 		}
 		}
 	}
 	}
 }
 }
+
+// TestConcurrentPullWholeRepo pulls the same repo concurrently.
+func (s *DockerRegistrySuite) TestConcurrentPullWholeRepo(c *check.C) {
+	repoName := fmt.Sprintf("%v/dockercli/busybox", privateRegistryURL)
+
+	repos := []string{}
+	for _, tag := range []string{"recent", "fresh", "todays"} {
+		repo := fmt.Sprintf("%v:%v", repoName, tag)
+		_, err := buildImage(repo, fmt.Sprintf(`
+		    FROM busybox
+		    ENTRYPOINT ["/bin/echo"]
+		    ENV FOO foo
+		    ENV BAR bar
+		    CMD echo %s
+		`, repo), true)
+		if err != nil {
+			c.Fatal(err)
+		}
+		dockerCmd(c, "push", repo)
+		repos = append(repos, repo)
+	}
+
+	// Clear local images store.
+	args := append([]string{"rmi"}, repos...)
+	dockerCmd(c, args...)
+
+	// Run multiple re-pulls concurrently
+	results := make(chan error)
+	numPulls := 3
+
+	for i := 0; i != numPulls; i++ {
+		go func() {
+			_, _, err := runCommandWithOutput(exec.Command(dockerBinary, "pull", "-a", repoName))
+			results <- err
+		}()
+	}
+
+	// These checks are separate from the loop above because the check
+	// package is not goroutine-safe.
+	for i := 0; i != numPulls; i++ {
+		err := <-results
+		c.Assert(err, check.IsNil, check.Commentf("concurrent pull failed with error: %v", err))
+	}
+
+	// Ensure all tags were pulled successfully
+	for _, repo := range repos {
+		dockerCmd(c, "inspect", repo)
+		out, _ := dockerCmd(c, "run", "--rm", repo)
+		if strings.TrimSpace(out) != "/bin/sh -c echo "+repo {
+			c.Fatalf("CMD did not contain /bin/sh -c echo %s: %s", repo, out)
+		}
+	}
+}
+
+// TestConcurrentFailingPull tries a concurrent pull that doesn't succeed.
+func (s *DockerRegistrySuite) TestConcurrentFailingPull(c *check.C) {
+	repoName := fmt.Sprintf("%v/dockercli/busybox", privateRegistryURL)
+
+	// Run multiple pulls concurrently
+	results := make(chan error)
+	numPulls := 3
+
+	for i := 0; i != numPulls; i++ {
+		go func() {
+			_, _, err := runCommandWithOutput(exec.Command(dockerBinary, "pull", repoName+":asdfasdf"))
+			results <- err
+		}()
+	}
+
+	// These checks are separate from the loop above because the check
+	// package is not goroutine-safe.
+	for i := 0; i != numPulls; i++ {
+		err := <-results
+		if err == nil {
+			c.Fatal("expected pull to fail")
+		}
+	}
+}
+
+// TestConcurrentPullMultipleTags pulls multiple tags from the same repo
+// concurrently.
+func (s *DockerRegistrySuite) TestConcurrentPullMultipleTags(c *check.C) {
+	repoName := fmt.Sprintf("%v/dockercli/busybox", privateRegistryURL)
+
+	repos := []string{}
+	for _, tag := range []string{"recent", "fresh", "todays"} {
+		repo := fmt.Sprintf("%v:%v", repoName, tag)
+		_, err := buildImage(repo, fmt.Sprintf(`
+		    FROM busybox
+		    ENTRYPOINT ["/bin/echo"]
+		    ENV FOO foo
+		    ENV BAR bar
+		    CMD echo %s
+		`, repo), true)
+		if err != nil {
+			c.Fatal(err)
+		}
+		dockerCmd(c, "push", repo)
+		repos = append(repos, repo)
+	}
+
+	// Clear local images store.
+	args := append([]string{"rmi"}, repos...)
+	dockerCmd(c, args...)
+
+	// Re-pull individual tags, in parallel
+	results := make(chan error)
+
+	for _, repo := range repos {
+		go func(repo string) {
+			_, _, err := runCommandWithOutput(exec.Command(dockerBinary, "pull", repo))
+			results <- err
+		}(repo)
+	}
+
+	// These checks are separate from the loop above because the check
+	// package is not goroutine-safe.
+	for range repos {
+		err := <-results
+		c.Assert(err, check.IsNil, check.Commentf("concurrent pull failed with error: %v", err))
+	}
+
+	// Ensure all tags were pulled successfully
+	for _, repo := range repos {
+		dockerCmd(c, "inspect", repo)
+		out, _ := dockerCmd(c, "run", "--rm", repo)
+		if strings.TrimSpace(out) != "/bin/sh -c echo "+repo {
+			c.Fatalf("CMD did not contain /bin/sh -c echo %s: %s", repo, out)
+		}
+	}
+}

+ 17 - 4
pkg/progressreader/broadcaster.go

@@ -27,6 +27,9 @@ type Broadcaster struct {
 	// isClosed is set to true when Close is called to avoid closing c
 	// isClosed is set to true when Close is called to avoid closing c
 	// multiple times.
 	// multiple times.
 	isClosed bool
 	isClosed bool
+	// result is the argument passed to the first call of Close, and
+	// returned to callers of Wait
+	result error
 }
 }
 
 
 // NewBroadcaster returns a Broadcaster structure
 // NewBroadcaster returns a Broadcaster structure
@@ -134,23 +137,33 @@ func (broadcaster *Broadcaster) Add(w io.Writer) error {
 	return nil
 	return nil
 }
 }
 
 
-// Close signals to all observers that the operation has finished.
-func (broadcaster *Broadcaster) Close() {
+// CloseWithError signals to all observers that the operation has finished. Its
+// argument is a result that should be returned to waiters blocking on Wait.
+func (broadcaster *Broadcaster) CloseWithError(result error) {
 	broadcaster.Lock()
 	broadcaster.Lock()
 	if broadcaster.isClosed {
 	if broadcaster.isClosed {
 		broadcaster.Unlock()
 		broadcaster.Unlock()
 		return
 		return
 	}
 	}
 	broadcaster.isClosed = true
 	broadcaster.isClosed = true
+	broadcaster.result = result
 	close(broadcaster.c)
 	close(broadcaster.c)
 	broadcaster.cond.Broadcast()
 	broadcaster.cond.Broadcast()
 	broadcaster.Unlock()
 	broadcaster.Unlock()
 
 
-	// Don't return from Close until all writers have caught up.
+	// Don't return until all writers have caught up.
 	broadcaster.wg.Wait()
 	broadcaster.wg.Wait()
 }
 }
 
 
+// Close signals to all observers that the operation has finished. It causes
+// all calls to Wait to return nil.
+func (broadcaster *Broadcaster) Close() {
+	broadcaster.CloseWithError(nil)
+}
+
 // Wait blocks until the operation is marked as completed by the Done method.
 // Wait blocks until the operation is marked as completed by the Done method.
-func (broadcaster *Broadcaster) Wait() {
+// It returns the argument that was passed to Close.
+func (broadcaster *Broadcaster) Wait() error {
 	<-broadcaster.c
 	<-broadcaster.c
+	return broadcaster.result
 }
 }