diff --git a/graph/load.go b/graph/load.go index a58c5a3cf9..a3e3551252 100644 --- a/graph/load.go +++ b/graph/load.go @@ -106,13 +106,14 @@ func (s *TagStore) recursiveLoad(address, tmpImageDir string) error { } // ensure no two downloads of the same layer happen at the same time - if ps, found := s.poolAdd("pull", "layer:"+img.ID); found { + poolKey := "layer:" + img.ID + broadcaster, found := s.poolAdd("pull", poolKey) + if found { logrus.Debugf("Image (id: %s) load is already running, waiting", img.ID) - ps.Wait() - return nil + return broadcaster.Wait() } - defer s.poolRemove("pull", "layer:"+img.ID) + defer s.poolRemove("pull", poolKey) if img.Parent != "" { if !s.graph.Exists(img.Parent) { diff --git a/graph/pull_v1.go b/graph/pull_v1.go index 13741292fd..79fb1709fd 100644 --- a/graph/pull_v1.go +++ b/graph/pull_v1.go @@ -138,16 +138,14 @@ func (p *v1Puller) pullRepository(askedTag string) error { } // ensure no two downloads of the same image happen at the same time - broadcaster, found := p.poolAdd("pull", "img:"+img.ID) + poolKey := "img:" + img.ID + broadcaster, found := p.poolAdd("pull", poolKey) + broadcaster.Add(out) if found { - broadcaster.Add(out) - broadcaster.Wait() - out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil)) - errors <- nil + errors <- broadcaster.Wait() return } - broadcaster.Add(out) - defer p.poolRemove("pull", "img:"+img.ID) + defer p.poolRemove("pull", poolKey) // we need to retain it until tagging p.graph.Retain(sessionID, img.ID) @@ -188,6 +186,7 @@ func (p *v1Puller) pullRepository(askedTag string) error { err := fmt.Errorf("Error pulling image (%s) from %s, %v", img.Tag, p.repoInfo.CanonicalName, lastErr) broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), err.Error(), nil)) errors <- err + broadcaster.CloseWithError(err) return } broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Download complete", nil)) @@ -225,8 +224,9 @@ func (p *v1Puller) pullRepository(askedTag string) error { return nil } -func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []string) (bool, error) { - history, err := p.session.GetRemoteHistory(imgID, endpoint) +func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []string) (layersDownloaded bool, err error) { + var history []string + history, err = p.session.GetRemoteHistory(imgID, endpoint) if err != nil { return false, err } @@ -239,20 +239,28 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri p.graph.Retain(sessionID, history[1:]...) defer p.graph.Release(sessionID, history[1:]...) - layersDownloaded := false + layersDownloaded = false for i := len(history) - 1; i >= 0; i-- { id := history[i] // ensure no two downloads of the same layer happen at the same time - broadcaster, found := p.poolAdd("pull", "layer:"+id) + poolKey := "layer:" + id + broadcaster, found := p.poolAdd("pull", poolKey) + broadcaster.Add(out) if found { logrus.Debugf("Image (id: %s) pull is already running, skipping", id) - broadcaster.Add(out) - broadcaster.Wait() - } else { - broadcaster.Add(out) + err = broadcaster.Wait() + if err != nil { + return layersDownloaded, err + } + continue } - defer p.poolRemove("pull", "layer:"+id) + + // This must use a closure so it captures the value of err when + // the function returns, not when the 'defer' is evaluated. + defer func() { + p.poolRemoveWithError("pull", poolKey, err) + }() if !p.graph.Exists(id) { broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Pulling metadata", nil)) @@ -328,6 +336,7 @@ func (p *v1Puller) pullImage(out io.Writer, imgID, endpoint string, token []stri } } broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(id), "Download complete", nil)) + broadcaster.Close() } return layersDownloaded, nil } diff --git a/graph/pull_v2.go b/graph/pull_v2.go index 116fee77c5..0b6ac91703 100644 --- a/graph/pull_v2.go +++ b/graph/pull_v2.go @@ -1,7 +1,6 @@ package graph import ( - "errors" "fmt" "io" "io/ioutil" @@ -74,14 +73,17 @@ func (p *v2Puller) pullV2Repository(tag string) (err error) { } broadcaster, found := p.poolAdd("pull", taggedName) + broadcaster.Add(p.config.OutStream) if found { // Another pull of the same repository is already taking place; just wait for it to finish - broadcaster.Add(p.config.OutStream) - broadcaster.Wait() - return nil + return broadcaster.Wait() } - defer p.poolRemove("pull", taggedName) - broadcaster.Add(p.config.OutStream) + + // This must use a closure so it captures the value of err when the + // function returns, not when the 'defer' is evaluated. + defer func() { + p.poolRemoveWithError("pull", taggedName, err) + }() var layersDownloaded bool for _, tag := range tags { @@ -101,13 +103,14 @@ func (p *v2Puller) pullV2Repository(tag string) (err error) { // downloadInfo is used to pass information from download to extractor type downloadInfo struct { - img *image.Image - tmpFile *os.File - digest digest.Digest - layer distribution.ReadSeekCloser - size int64 - err chan error - out io.Writer // Download progress is written here. + img *image.Image + tmpFile *os.File + digest digest.Digest + layer distribution.ReadSeekCloser + size int64 + err chan error + poolKey string + broadcaster *progressreader.Broadcaster } type errVerification struct{} @@ -117,26 +120,6 @@ func (errVerification) Error() string { return "verification failed" } func (p *v2Puller) download(di *downloadInfo) { logrus.Debugf("pulling blob %q to %s", di.digest, di.img.ID) - out := di.out - - broadcaster, found := p.poolAdd("pull", "img:"+di.img.ID) - if found { - broadcaster.Add(out) - broadcaster.Wait() - out.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil)) - di.err <- nil - return - } - - broadcaster.Add(out) - defer p.poolRemove("pull", "img:"+di.img.ID) - tmpFile, err := ioutil.TempFile("", "GetImageBlob") - if err != nil { - di.err <- err - return - } - di.tmpFile = tmpFile - blobs := p.repo.Blobs(context.Background()) desc, err := blobs.Stat(context.Background(), di.digest) @@ -163,16 +146,16 @@ func (p *v2Puller) download(di *downloadInfo) { reader := progressreader.New(progressreader.Config{ In: ioutil.NopCloser(io.TeeReader(layerDownload, verifier)), - Out: broadcaster, + Out: di.broadcaster, Formatter: p.sf, Size: di.size, NewLines: false, ID: stringid.TruncateID(di.img.ID), Action: "Downloading", }) - io.Copy(tmpFile, reader) + io.Copy(di.tmpFile, reader) - broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Verifying Checksum", nil)) + di.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Verifying Checksum", nil)) if !verifier.Verified() { err = fmt.Errorf("filesystem layer verification failed for digest %s", di.digest) @@ -181,9 +164,9 @@ func (p *v2Puller) download(di *downloadInfo) { return } - broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil)) + di.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil)) - logrus.Debugf("Downloaded %s to tempfile %s", di.img.ID, tmpFile.Name()) + logrus.Debugf("Downloaded %s to tempfile %s", di.img.ID, di.tmpFile.Name()) di.layer = layerDownload di.err <- nil @@ -209,33 +192,6 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo logrus.Printf("Image manifest for %s has been verified", taggedName) } - // By using a pipeWriter for each of the downloads to write their progress - // to, we can avoid an issue where this function returns an error but - // leaves behind running download goroutines. By splitting the writer - // with a pipe, we can close the pipe if there is any error, consequently - // causing each download to cancel due to an error writing to this pipe. - pipeReader, pipeWriter := io.Pipe() - go func() { - if _, err := io.Copy(out, pipeReader); err != nil { - logrus.Errorf("error copying from layer download progress reader: %s", err) - if err := pipeReader.CloseWithError(err); err != nil { - logrus.Errorf("error closing the progress reader: %s", err) - } - } - }() - defer func() { - if err != nil { - // All operations on the pipe are synchronous. This call will wait - // until all current readers/writers are done using the pipe then - // set the error. All successive reads/writes will return with this - // error. - pipeWriter.CloseWithError(errors.New("download canceled")) - } else { - // If no error then just close the pipe. - pipeWriter.Close() - } - }() - out.Write(p.sf.FormatStatus(tag, "Pulling from %s", p.repo.Name())) var downloads []*downloadInfo @@ -243,6 +199,16 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo var layerIDs []string defer func() { p.graph.Release(p.sessionID, layerIDs...) + + for _, d := range downloads { + p.poolRemoveWithError("pull", d.poolKey, err) + if d.tmpFile != nil { + d.tmpFile.Close() + if err := os.RemoveAll(d.tmpFile.Name()); err != nil { + logrus.Errorf("Failed to remove temp file: %s", d.tmpFile.Name()) + } + } + } }() for i := len(manifest.FSLayers) - 1; i >= 0; i-- { @@ -263,29 +229,31 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Pulling fs layer", nil)) d := &downloadInfo{ - img: img, - digest: manifest.FSLayers[i].BlobSum, + img: img, + poolKey: "layer:" + img.ID, + digest: manifest.FSLayers[i].BlobSum, // TODO: seems like this chan buffer solved hanging problem in go1.5, // this can indicate some deeper problem that somehow we never take // error from channel in loop below err: make(chan error, 1), - out: pipeWriter, } + + tmpFile, err := ioutil.TempFile("", "GetImageBlob") + if err != nil { + return false, err + } + d.tmpFile = tmpFile + downloads = append(downloads, d) - go p.download(d) - } - - // run clean for all downloads to prevent leftovers - for _, d := range downloads { - defer func(d *downloadInfo) { - if d.tmpFile != nil { - d.tmpFile.Close() - if err := os.RemoveAll(d.tmpFile.Name()); err != nil { - logrus.Errorf("Failed to remove temp file: %s", d.tmpFile.Name()) - } - } - }(d) + broadcaster, found := p.poolAdd("pull", d.poolKey) + broadcaster.Add(out) + d.broadcaster = broadcaster + if found { + d.err <- nil + } else { + go p.download(d) + } } var tagUpdated bool @@ -293,14 +261,21 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo if err := <-d.err; err != nil { return false, err } + if d.layer == nil { + // Wait for a different pull to download and extract + // this layer. + err = d.broadcaster.Wait() + if err != nil { + return false, err + } continue } - // if tmpFile is empty assume download and extracted elsewhere + d.tmpFile.Seek(0, 0) reader := progressreader.New(progressreader.Config{ In: d.tmpFile, - Out: out, + Out: d.broadcaster, Formatter: p.sf, Size: d.size, NewLines: false, @@ -317,8 +292,8 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo return false, err } - // FIXME: Pool release here for parallel tag pull (ensures any downloads block until fully extracted) - out.Write(p.sf.FormatProgress(stringid.TruncateID(d.img.ID), "Pull complete", nil)) + d.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(d.img.ID), "Pull complete", nil)) + d.broadcaster.Close() tagUpdated = true } diff --git a/graph/tags.go b/graph/tags.go index 09a9bc8530..804898c1d3 100644 --- a/graph/tags.go +++ b/graph/tags.go @@ -462,18 +462,18 @@ func (store *TagStore) poolAdd(kind, key string) (*progressreader.Broadcaster, b return broadcaster, false } -func (store *TagStore) poolRemove(kind, key string) error { +func (store *TagStore) poolRemoveWithError(kind, key string, broadcasterResult error) error { store.Lock() defer store.Unlock() switch kind { case "pull": - if ps, exists := store.pullingPool[key]; exists { - ps.Close() + if broadcaster, exists := store.pullingPool[key]; exists { + broadcaster.CloseWithError(broadcasterResult) delete(store.pullingPool, key) } case "push": - if ps, exists := store.pushingPool[key]; exists { - ps.Close() + if broadcaster, exists := store.pushingPool[key]; exists { + broadcaster.CloseWithError(broadcasterResult) delete(store.pushingPool, key) } default: @@ -481,3 +481,7 @@ func (store *TagStore) poolRemove(kind, key string) error { } return nil } + +func (store *TagStore) poolRemove(kind, key string) error { + return store.poolRemoveWithError(kind, key, nil) +} diff --git a/integration-cli/docker_cli_pull_local_test.go b/integration-cli/docker_cli_pull_local_test.go index 350e871cf7..a875f38456 100644 --- a/integration-cli/docker_cli_pull_local_test.go +++ b/integration-cli/docker_cli_pull_local_test.go @@ -2,6 +2,8 @@ package main import ( "fmt" + "os/exec" + "strings" "github.com/go-check/check" ) @@ -37,3 +39,134 @@ func (s *DockerRegistrySuite) TestPullImageWithAliases(c *check.C) { } } } + +// TestConcurrentPullWholeRepo pulls the same repo concurrently. +func (s *DockerRegistrySuite) TestConcurrentPullWholeRepo(c *check.C) { + repoName := fmt.Sprintf("%v/dockercli/busybox", privateRegistryURL) + + repos := []string{} + for _, tag := range []string{"recent", "fresh", "todays"} { + repo := fmt.Sprintf("%v:%v", repoName, tag) + _, err := buildImage(repo, fmt.Sprintf(` + FROM busybox + ENTRYPOINT ["/bin/echo"] + ENV FOO foo + ENV BAR bar + CMD echo %s + `, repo), true) + if err != nil { + c.Fatal(err) + } + dockerCmd(c, "push", repo) + repos = append(repos, repo) + } + + // Clear local images store. + args := append([]string{"rmi"}, repos...) + dockerCmd(c, args...) + + // Run multiple re-pulls concurrently + results := make(chan error) + numPulls := 3 + + for i := 0; i != numPulls; i++ { + go func() { + _, _, err := runCommandWithOutput(exec.Command(dockerBinary, "pull", "-a", repoName)) + results <- err + }() + } + + // These checks are separate from the loop above because the check + // package is not goroutine-safe. + for i := 0; i != numPulls; i++ { + err := <-results + c.Assert(err, check.IsNil, check.Commentf("concurrent pull failed with error: %v", err)) + } + + // Ensure all tags were pulled successfully + for _, repo := range repos { + dockerCmd(c, "inspect", repo) + out, _ := dockerCmd(c, "run", "--rm", repo) + if strings.TrimSpace(out) != "/bin/sh -c echo "+repo { + c.Fatalf("CMD did not contain /bin/sh -c echo %s: %s", repo, out) + } + } +} + +// TestConcurrentFailingPull tries a concurrent pull that doesn't succeed. +func (s *DockerRegistrySuite) TestConcurrentFailingPull(c *check.C) { + repoName := fmt.Sprintf("%v/dockercli/busybox", privateRegistryURL) + + // Run multiple pulls concurrently + results := make(chan error) + numPulls := 3 + + for i := 0; i != numPulls; i++ { + go func() { + _, _, err := runCommandWithOutput(exec.Command(dockerBinary, "pull", repoName+":asdfasdf")) + results <- err + }() + } + + // These checks are separate from the loop above because the check + // package is not goroutine-safe. + for i := 0; i != numPulls; i++ { + err := <-results + if err == nil { + c.Fatal("expected pull to fail") + } + } +} + +// TestConcurrentPullMultipleTags pulls multiple tags from the same repo +// concurrently. +func (s *DockerRegistrySuite) TestConcurrentPullMultipleTags(c *check.C) { + repoName := fmt.Sprintf("%v/dockercli/busybox", privateRegistryURL) + + repos := []string{} + for _, tag := range []string{"recent", "fresh", "todays"} { + repo := fmt.Sprintf("%v:%v", repoName, tag) + _, err := buildImage(repo, fmt.Sprintf(` + FROM busybox + ENTRYPOINT ["/bin/echo"] + ENV FOO foo + ENV BAR bar + CMD echo %s + `, repo), true) + if err != nil { + c.Fatal(err) + } + dockerCmd(c, "push", repo) + repos = append(repos, repo) + } + + // Clear local images store. + args := append([]string{"rmi"}, repos...) + dockerCmd(c, args...) + + // Re-pull individual tags, in parallel + results := make(chan error) + + for _, repo := range repos { + go func(repo string) { + _, _, err := runCommandWithOutput(exec.Command(dockerBinary, "pull", repo)) + results <- err + }(repo) + } + + // These checks are separate from the loop above because the check + // package is not goroutine-safe. + for range repos { + err := <-results + c.Assert(err, check.IsNil, check.Commentf("concurrent pull failed with error: %v", err)) + } + + // Ensure all tags were pulled successfully + for _, repo := range repos { + dockerCmd(c, "inspect", repo) + out, _ := dockerCmd(c, "run", "--rm", repo) + if strings.TrimSpace(out) != "/bin/sh -c echo "+repo { + c.Fatalf("CMD did not contain /bin/sh -c echo %s: %s", repo, out) + } + } +} diff --git a/pkg/progressreader/broadcaster.go b/pkg/progressreader/broadcaster.go index 5118e9e2f8..a48ff226db 100644 --- a/pkg/progressreader/broadcaster.go +++ b/pkg/progressreader/broadcaster.go @@ -24,9 +24,9 @@ type Broadcaster struct { history [][]byte // wg is a WaitGroup used to wait for all writes to finish on Close wg sync.WaitGroup - // isClosed is set to true when Close is called to avoid closing c - // multiple times. - isClosed bool + // result is the argument passed to the first call of Close, and + // returned to callers of Wait + result error } // NewBroadcaster returns a Broadcaster structure @@ -134,23 +134,34 @@ func (broadcaster *Broadcaster) Add(w io.Writer) error { return nil } -// Close signals to all observers that the operation has finished. -func (broadcaster *Broadcaster) Close() { +// CloseWithError signals to all observers that the operation has finished. Its +// argument is a result that should be returned to waiters blocking on Wait. +func (broadcaster *Broadcaster) CloseWithError(result error) { broadcaster.Lock() - if broadcaster.isClosed { + if broadcaster.closed() { broadcaster.Unlock() return } - broadcaster.isClosed = true + broadcaster.result = result close(broadcaster.c) broadcaster.cond.Broadcast() broadcaster.Unlock() - // Don't return from Close until all writers have caught up. + // Don't return until all writers have caught up. broadcaster.wg.Wait() } -// Wait blocks until the operation is marked as completed by the Done method. -func (broadcaster *Broadcaster) Wait() { - <-broadcaster.c +// Close signals to all observers that the operation has finished. It causes +// all calls to Wait to return nil. +func (broadcaster *Broadcaster) Close() { + broadcaster.CloseWithError(nil) +} + +// Wait blocks until the operation is marked as completed by the Close method, +// and all writer goroutines have completed. It returns the argument that was +// passed to Close. +func (broadcaster *Broadcaster) Wait() error { + <-broadcaster.c + broadcaster.wg.Wait() + return broadcaster.result }