Browse Source

Fix races accessing d.poolKey and d.tmpFile when pullV2Tag returns

Signed-off-by: Aaron Lehmann <aaron.lehmann@docker.com>
Aaron Lehmann 10 years ago
parent
commit
64faec8269
1 changed files with 33 additions and 39 deletions
  1. 33 39
      graph/pull_v2.go

+ 33 - 39
graph/pull_v2.go

@@ -110,7 +110,6 @@ type downloadInfo struct {
 	layer       distribution.ReadSeekCloser
 	size        int64
 	err         chan error
-	out         io.Writer // Download progress is written here.
 	poolKey     string
 	broadcaster *progressreader.Broadcaster
 }
@@ -122,22 +121,6 @@ func (errVerification) Error() string { return "verification failed" }
 func (p *v2Puller) download(di *downloadInfo) {
 	logrus.Debugf("pulling blob %q to %s", di.digest, di.img.ID)
 
-	di.poolKey = "layer:" + di.img.ID
-	broadcaster, found := p.poolAdd("pull", di.poolKey)
-	broadcaster.Add(di.out)
-	di.broadcaster = broadcaster
-	if found {
-		di.err <- nil
-		return
-	}
-
-	tmpFile, err := ioutil.TempFile("", "GetImageBlob")
-	if err != nil {
-		di.err <- err
-		return
-	}
-	di.tmpFile = tmpFile
-
 	blobs := p.repo.Blobs(context.Background())
 
 	desc, err := blobs.Stat(context.Background(), di.digest)
@@ -164,16 +147,16 @@ func (p *v2Puller) download(di *downloadInfo) {
 
 	reader := progressreader.New(progressreader.Config{
 		In:        ioutil.NopCloser(io.TeeReader(layerDownload, verifier)),
-		Out:       broadcaster,
+		Out:       di.broadcaster,
 		Formatter: p.sf,
 		Size:      di.size,
 		NewLines:  false,
 		ID:        stringid.TruncateID(di.img.ID),
 		Action:    "Downloading",
 	})
-	io.Copy(tmpFile, reader)
+	io.Copy(di.tmpFile, reader)
 
-	broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Verifying Checksum", nil))
+	di.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Verifying Checksum", nil))
 
 	if !verifier.Verified() {
 		err = fmt.Errorf("filesystem layer verification failed for digest %s", di.digest)
@@ -182,9 +165,9 @@ func (p *v2Puller) download(di *downloadInfo) {
 		return
 	}
 
-	broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil))
+	di.broadcaster.Write(p.sf.FormatProgress(stringid.TruncateID(di.img.ID), "Download complete", nil))
 
-	logrus.Debugf("Downloaded %s to tempfile %s", di.img.ID, tmpFile.Name())
+	logrus.Debugf("Downloaded %s to tempfile %s", di.img.ID, di.tmpFile.Name())
 	di.layer = layerDownload
 
 	di.err <- nil
@@ -244,6 +227,16 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo
 	var layerIDs []string
 	defer func() {
 		p.graph.Release(p.sessionID, layerIDs...)
+
+		for _, d := range downloads {
+			p.poolRemoveWithError("pull", d.poolKey, err)
+			if d.tmpFile != nil {
+				d.tmpFile.Close()
+				if err := os.RemoveAll(d.tmpFile.Name()); err != nil {
+					logrus.Errorf("Failed to remove temp file: %s", d.tmpFile.Name())
+				}
+			}
+		}
 	}()
 
 	for i := len(manifest.FSLayers) - 1; i >= 0; i-- {
@@ -264,30 +257,31 @@ func (p *v2Puller) pullV2Tag(out io.Writer, tag, taggedName string) (verified bo
 		out.Write(p.sf.FormatProgress(stringid.TruncateID(img.ID), "Pulling fs layer", nil))
 
 		d := &downloadInfo{
-			img:    img,
-			digest: manifest.FSLayers[i].BlobSum,
+			img:     img,
+			poolKey: "layer:" + img.ID,
+			digest:  manifest.FSLayers[i].BlobSum,
 			// TODO: seems like this chan buffer solved hanging problem in go1.5,
 			// this can indicate some deeper problem that somehow we never take
 			// error from channel in loop below
 			err: make(chan error, 1),
-			out: pipeWriter,
 		}
-		downloads = append(downloads, d)
 
-		go p.download(d)
-	}
+		tmpFile, err := ioutil.TempFile("", "GetImageBlob")
+		if err != nil {
+			return false, err
+		}
+		d.tmpFile = tmpFile
 
-	// run clean for all downloads to prevent leftovers
-	for _, d := range downloads {
-		defer func(d *downloadInfo) {
-			p.poolRemoveWithError("pull", d.poolKey, err)
-			if d.tmpFile != nil {
-				d.tmpFile.Close()
-				if err := os.RemoveAll(d.tmpFile.Name()); err != nil {
-					logrus.Errorf("Failed to remove temp file: %s", d.tmpFile.Name())
-				}
-			}
-		}(d)
+		downloads = append(downloads, d)
+
+		broadcaster, found := p.poolAdd("pull", d.poolKey)
+		broadcaster.Add(pipeWriter)
+		d.broadcaster = broadcaster
+		if found {
+			d.err <- nil
+		} else {
+			go p.download(d)
+		}
 	}
 
 	var tagUpdated bool