瀏覽代碼

Move temporary download file to download descriptor scope

This will allow it to be reused between download attempts in a
subsequent commit.

Signed-off-by: Aaron Lehmann <aaron.lehmann@docker.com>
Aaron Lehmann 9 年之前
父節點
當前提交
f425529e7e
共有 5 個文件被更改,包括 57 次插入25 次删除
  1. 0 14
      distribution/pull.go
  2. 18 5
      distribution/pull_v1.go
  3. 30 6
      distribution/pull_v2.go
  4. 6 0
      distribution/xfer/download.go
  5. 3 0
      distribution/xfer/download_test.go

+ 0 - 14
distribution/pull.go

@@ -2,7 +2,6 @@ package distribution
 
 
 import (
 import (
 	"fmt"
 	"fmt"
-	"os"
 
 
 	"github.com/Sirupsen/logrus"
 	"github.com/Sirupsen/logrus"
 	"github.com/docker/docker/api"
 	"github.com/docker/docker/api"
@@ -187,16 +186,3 @@ func validateRepoName(name string) error {
 	}
 	}
 	return nil
 	return nil
 }
 }
-
-// tmpFileClose creates a closer function for a temporary file that closes the file
-// and also deletes it.
-func tmpFileCloser(tmpFile *os.File) func() error {
-	return func() error {
-		tmpFile.Close()
-		if err := os.RemoveAll(tmpFile.Name()); err != nil {
-			logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name())
-		}
-
-		return nil
-	}
-}

+ 18 - 5
distribution/pull_v1.go

@@ -7,6 +7,7 @@ import (
 	"io/ioutil"
 	"io/ioutil"
 	"net"
 	"net"
 	"net/url"
 	"net/url"
+	"os"
 	"strings"
 	"strings"
 	"time"
 	"time"
 
 
@@ -279,6 +280,7 @@ type v1LayerDescriptor struct {
 	layersDownloaded *bool
 	layersDownloaded *bool
 	layerSize        int64
 	layerSize        int64
 	session          *registry.Session
 	session          *registry.Session
+	tmpFile          *os.File
 }
 }
 
 
 func (ld *v1LayerDescriptor) Key() string {
 func (ld *v1LayerDescriptor) Key() string {
@@ -308,7 +310,7 @@ func (ld *v1LayerDescriptor) Download(ctx context.Context, progressOutput progre
 	}
 	}
 	*ld.layersDownloaded = true
 	*ld.layersDownloaded = true
 
 
-	tmpFile, err := ioutil.TempFile("", "GetImageBlob")
+	ld.tmpFile, err = ioutil.TempFile("", "GetImageBlob")
 	if err != nil {
 	if err != nil {
 		layerReader.Close()
 		layerReader.Close()
 		return nil, 0, err
 		return nil, 0, err
@@ -317,17 +319,28 @@ func (ld *v1LayerDescriptor) Download(ctx context.Context, progressOutput progre
 	reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, layerReader), progressOutput, ld.layerSize, ld.ID(), "Downloading")
 	reader := progress.NewProgressReader(ioutils.NewCancelReadCloser(ctx, layerReader), progressOutput, ld.layerSize, ld.ID(), "Downloading")
 	defer reader.Close()
 	defer reader.Close()
 
 
-	_, err = io.Copy(tmpFile, reader)
+	_, err = io.Copy(ld.tmpFile, reader)
 	if err != nil {
 	if err != nil {
+		ld.Close()
 		return nil, 0, err
 		return nil, 0, err
 	}
 	}
 
 
 	progress.Update(progressOutput, ld.ID(), "Download complete")
 	progress.Update(progressOutput, ld.ID(), "Download complete")
 
 
-	logrus.Debugf("Downloaded %s to tempfile %s", ld.ID(), tmpFile.Name())
+	logrus.Debugf("Downloaded %s to tempfile %s", ld.ID(), ld.tmpFile.Name())
 
 
-	tmpFile.Seek(0, 0)
-	return ioutils.NewReadCloserWrapper(tmpFile, tmpFileCloser(tmpFile)), ld.layerSize, nil
+	ld.tmpFile.Seek(0, 0)
+	return ld.tmpFile, ld.layerSize, nil
+}
+
+func (ld *v1LayerDescriptor) Close() {
+	if ld.tmpFile != nil {
+		ld.tmpFile.Close()
+		if err := os.RemoveAll(ld.tmpFile.Name()); err != nil {
+			logrus.Errorf("Failed to remove temp file: %s", ld.tmpFile.Name())
+		}
+		ld.tmpFile = nil
+	}
 }
 }
 
 
 func (ld *v1LayerDescriptor) Registered(diffID layer.DiffID) {
 func (ld *v1LayerDescriptor) Registered(diffID layer.DiffID) {

+ 30 - 6
distribution/pull_v2.go

@@ -114,6 +114,7 @@ type v2LayerDescriptor struct {
 	repoInfo          *registry.RepositoryInfo
 	repoInfo          *registry.RepositoryInfo
 	repo              distribution.Repository
 	repo              distribution.Repository
 	V2MetadataService *metadata.V2MetadataService
 	V2MetadataService *metadata.V2MetadataService
+	tmpFile           *os.File
 }
 }
 
 
 func (ld *v2LayerDescriptor) Key() string {
 func (ld *v2LayerDescriptor) Key() string {
@@ -131,6 +132,18 @@ func (ld *v2LayerDescriptor) DiffID() (layer.DiffID, error) {
 func (ld *v2LayerDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) {
 func (ld *v2LayerDescriptor) Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error) {
 	logrus.Debugf("pulling blob %q", ld.digest)
 	logrus.Debugf("pulling blob %q", ld.digest)
 
 
+	var err error
+
+	if ld.tmpFile == nil {
+		ld.tmpFile, err = createDownloadFile()
+	} else {
+		_, err = ld.tmpFile.Seek(0, os.SEEK_SET)
+	}
+	if err != nil {
+		return nil, 0, xfer.DoNotRetry{Err: err}
+	}
+
+	tmpFile := ld.tmpFile
 	blobs := ld.repo.Blobs(ctx)
 	blobs := ld.repo.Blobs(ctx)
 
 
 	layerDownload, err := blobs.Open(ctx, ld.digest)
 	layerDownload, err := blobs.Open(ctx, ld.digest)
@@ -164,17 +177,13 @@ func (ld *v2LayerDescriptor) Download(ctx context.Context, progressOutput progre
 		return nil, 0, xfer.DoNotRetry{Err: err}
 		return nil, 0, xfer.DoNotRetry{Err: err}
 	}
 	}
 
 
-	tmpFile, err := ioutil.TempFile("", "GetImageBlob")
-	if err != nil {
-		return nil, 0, xfer.DoNotRetry{Err: err}
-	}
-
 	_, err = io.Copy(tmpFile, io.TeeReader(reader, verifier))
 	_, err = io.Copy(tmpFile, io.TeeReader(reader, verifier))
 	if err != nil {
 	if err != nil {
 		tmpFile.Close()
 		tmpFile.Close()
 		if err := os.Remove(tmpFile.Name()); err != nil {
 		if err := os.Remove(tmpFile.Name()); err != nil {
 			logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name())
 			logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name())
 		}
 		}
+		ld.tmpFile = nil
 		return nil, 0, retryOnError(err)
 		return nil, 0, retryOnError(err)
 	}
 	}
 
 
@@ -188,6 +197,7 @@ func (ld *v2LayerDescriptor) Download(ctx context.Context, progressOutput progre
 		if err := os.Remove(tmpFile.Name()); err != nil {
 		if err := os.Remove(tmpFile.Name()); err != nil {
 			logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name())
 			logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name())
 		}
 		}
+		ld.tmpFile = nil
 
 
 		return nil, 0, xfer.DoNotRetry{Err: err}
 		return nil, 0, xfer.DoNotRetry{Err: err}
 	}
 	}
@@ -202,9 +212,19 @@ func (ld *v2LayerDescriptor) Download(ctx context.Context, progressOutput progre
 		if err := os.Remove(tmpFile.Name()); err != nil {
 		if err := os.Remove(tmpFile.Name()); err != nil {
 			logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name())
 			logrus.Errorf("Failed to remove temp file: %s", tmpFile.Name())
 		}
 		}
+		ld.tmpFile = nil
 		return nil, 0, xfer.DoNotRetry{Err: err}
 		return nil, 0, xfer.DoNotRetry{Err: err}
 	}
 	}
-	return ioutils.NewReadCloserWrapper(tmpFile, tmpFileCloser(tmpFile)), size, nil
+	return tmpFile, size, nil
+}
+
+func (ld *v2LayerDescriptor) Close() {
+	if ld.tmpFile != nil {
+		ld.tmpFile.Close()
+		if err := os.RemoveAll(ld.tmpFile.Name()); err != nil {
+			logrus.Errorf("Failed to remove temp file: %s", ld.tmpFile.Name())
+		}
+	}
 }
 }
 
 
 func (ld *v2LayerDescriptor) Registered(diffID layer.DiffID) {
 func (ld *v2LayerDescriptor) Registered(diffID layer.DiffID) {
@@ -711,3 +731,7 @@ func fixManifestLayers(m *schema1.Manifest) error {
 
 
 	return nil
 	return nil
 }
 }
+
+func createDownloadFile() (*os.File, error) {
+	return ioutil.TempFile("", "GetImageBlob")
+}

+ 6 - 0
distribution/xfer/download.go

@@ -59,6 +59,10 @@ type DownloadDescriptor interface {
 	DiffID() (layer.DiffID, error)
 	DiffID() (layer.DiffID, error)
 	// Download is called to perform the download.
 	// Download is called to perform the download.
 	Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error)
 	Download(ctx context.Context, progressOutput progress.Output) (io.ReadCloser, int64, error)
+	// Close is called when the download manager is finished with this
+	// descriptor and will not call Download again or read from the reader
+	// that Download returned.
+	Close()
 }
 }
 
 
 // DownloadDescriptorWithRegistered is a DownloadDescriptor that has an
 // DownloadDescriptorWithRegistered is a DownloadDescriptor that has an
@@ -229,6 +233,8 @@ func (ldm *LayerDownloadManager) makeDownloadFunc(descriptor DownloadDescriptor,
 				retries        int
 				retries        int
 			)
 			)
 
 
+			defer descriptor.Close()
+
 			for {
 			for {
 				downloadReader, size, err = descriptor.Download(d.Transfer.Context(), progressOutput)
 				downloadReader, size, err = descriptor.Download(d.Transfer.Context(), progressOutput)
 				if err == nil {
 				if err == nil {

+ 3 - 0
distribution/xfer/download_test.go

@@ -199,6 +199,9 @@ func (d *mockDownloadDescriptor) Download(ctx context.Context, progressOutput pr
 	return d.mockTarStream(), 0, nil
 	return d.mockTarStream(), 0, nil
 }
 }
 
 
+func (d *mockDownloadDescriptor) Close() {
+}
+
 func downloadDescriptors(currentDownloads *int32) []DownloadDescriptor {
 func downloadDescriptors(currentDownloads *int32) []DownloadDescriptor {
 	return []DownloadDescriptor{
 	return []DownloadDescriptor{
 		&mockDownloadDescriptor{
 		&mockDownloadDescriptor{