Browse Source

Merge pull request #4146 from alexlarsson/clean-up-archive-closing

Clean up archive closing
Michael Crosby 11 years ago
parent
commit
d3a2c15a5d
14 changed files with 83 additions and 65 deletions
  1. 24 15
      archive/archive.go
  2. 3 2
      archive/archive_test.go
  3. 1 1
      archive/diff.go
  4. 1 0
      buildfile.go
  5. 2 2
      commands.go
  6. 15 3
      container.go
  7. 5 3
      graph.go
  8. 1 1
      graphdriver/aufs/aufs.go
  9. 1 1
      graphdriver/driver.go
  10. 11 3
      image.go
  11. 2 2
      integration/utils_test.go
  12. 7 1
      runtime.go
  13. 10 4
      server.go
  14. 0 27
      utils.go

+ 24 - 15
archive/archive.go

@@ -19,9 +19,10 @@ import (
 )
 
 type (
-	Archive     io.Reader
-	Compression int
-	TarOptions  struct {
+	Archive       io.ReadCloser
+	ArchiveReader io.Reader
+	Compression   int
+	TarOptions    struct {
 		Includes    []string
 		Compression Compression
 	}
@@ -65,13 +66,13 @@ func DetectCompression(source []byte) Compression {
 	return Uncompressed
 }
 
-func xzDecompress(archive io.Reader) (io.Reader, error) {
+func xzDecompress(archive io.Reader) (io.ReadCloser, error) {
 	args := []string{"xz", "-d", "-c", "-q"}
 
 	return CmdStream(exec.Command(args[0], args[1:]...), archive)
 }
 
-func DecompressStream(archive io.Reader) (io.Reader, error) {
+func DecompressStream(archive io.Reader) (io.ReadCloser, error) {
 	buf := make([]byte, 10)
 	totalN := 0
 	for totalN < 10 {
@@ -90,11 +91,11 @@ func DecompressStream(archive io.Reader) (io.Reader, error) {
 
 	switch compression {
 	case Uncompressed:
-		return wrap, nil
+		return ioutil.NopCloser(wrap), nil
 	case Gzip:
 		return gzip.NewReader(wrap)
 	case Bzip2:
-		return bzip2.NewReader(wrap), nil
+		return ioutil.NopCloser(bzip2.NewReader(wrap)), nil
 	case Xz:
 		return xzDecompress(wrap)
 	default:
@@ -106,7 +107,7 @@ func CompressStream(dest io.WriteCloser, compression Compression) (io.WriteClose
 
 	switch compression {
 	case Uncompressed:
-		return dest, nil
+		return utils.NopWriteCloser(dest), nil
 	case Gzip:
 		return gzip.NewWriter(dest), nil
 	case Bzip2, Xz:
@@ -269,7 +270,7 @@ func createTarFile(path, extractDir string, hdr *tar.Header, reader *tar.Reader)
 
 // Tar creates an archive from the directory at `path`, and returns it as a
 // stream of bytes.
-func Tar(path string, compression Compression) (io.Reader, error) {
+func Tar(path string, compression Compression) (io.ReadCloser, error) {
 	return TarFilter(path, &TarOptions{Compression: compression})
 }
 
@@ -291,7 +292,7 @@ func escapeName(name string) string {
 
 // Tar creates an archive from the directory at `path`, only including files whose relative
 // paths are included in `filter`. If `filter` is nil, then all files are included.
-func TarFilter(srcPath string, options *TarOptions) (io.Reader, error) {
+func TarFilter(srcPath string, options *TarOptions) (io.ReadCloser, error) {
 	pipeReader, pipeWriter := io.Pipe()
 
 	compressWriter, err := CompressStream(pipeWriter, options.Compression)
@@ -337,6 +338,9 @@ func TarFilter(srcPath string, options *TarOptions) (io.Reader, error) {
 		if err := compressWriter.Close(); err != nil {
 			utils.Debugf("Can't close compress writer: %s\n", err)
 		}
+		if err := pipeWriter.Close(); err != nil {
+			utils.Debugf("Can't close pipe writer: %s\n", err)
+		}
 	}()
 
 	return pipeReader, nil
@@ -352,12 +356,13 @@ func Untar(archive io.Reader, dest string, options *TarOptions) error {
 		return fmt.Errorf("Empty archive")
 	}
 
-	archive, err := DecompressStream(archive)
+	decompressedArchive, err := DecompressStream(archive)
 	if err != nil {
 		return err
 	}
+	defer decompressedArchive.Close()
 
-	tr := tar.NewReader(archive)
+	tr := tar.NewReader(decompressedArchive)
 
 	var dirs []*tar.Header
 
@@ -432,15 +437,19 @@ func TarUntar(src string, dst string) error {
 	if err != nil {
 		return err
 	}
+	defer archive.Close()
 	return Untar(archive, dst, nil)
 }
 
 // UntarPath is a convenience function which looks for an archive
 // at filesystem path `src`, and unpacks it at `dst`.
 func UntarPath(src, dst string) error {
-	if archive, err := os.Open(src); err != nil {
+	archive, err := os.Open(src)
+	if err != nil {
 		return err
-	} else if err := Untar(archive, dst, nil); err != nil {
+	}
+	defer archive.Close()
+	if err := Untar(archive, dst, nil); err != nil {
 		return err
 	}
 	return nil
@@ -528,7 +537,7 @@ func CopyFileWithTar(src, dst string) (err error) {
 // CmdStream executes a command, and returns its stdout as a stream.
 // If the command fails to run or doesn't complete successfully, an error
 // will be returned, including anything written on stderr.
-func CmdStream(cmd *exec.Cmd, input io.Reader) (io.Reader, error) {
+func CmdStream(cmd *exec.Cmd, input io.Reader) (io.ReadCloser, error) {
 	if input != nil {
 		stdin, err := cmd.StdinPipe()
 		if err != nil {

+ 3 - 2
archive/archive_test.go

@@ -67,12 +67,13 @@ func tarUntar(t *testing.T, origin string, compression Compression) error {
 	if err != nil {
 		t.Fatal(err)
 	}
+	defer archive.Close()
 
 	buf := make([]byte, 10)
 	if _, err := archive.Read(buf); err != nil {
 		return err
 	}
-	archive = io.MultiReader(bytes.NewReader(buf), archive)
+	wrap := io.MultiReader(bytes.NewReader(buf), archive)
 
 	detectedCompression := DetectCompression(buf)
 	if detectedCompression.Extension() != compression.Extension() {
@@ -84,7 +85,7 @@ func tarUntar(t *testing.T, origin string, compression Compression) error {
 		return err
 	}
 	defer os.RemoveAll(tmp)
-	if err := Untar(archive, tmp, nil); err != nil {
+	if err := Untar(wrap, tmp, nil); err != nil {
 		return err
 	}
 	if _, err := os.Stat(tmp); err != nil {

+ 1 - 1
archive/diff.go

@@ -28,7 +28,7 @@ func timeToTimespec(time time.Time) (ts syscall.Timespec) {
 
 // ApplyLayer parses a diff in the standard layer format from `layer`, and
 // applies it to the directory `dest`.
-func ApplyLayer(dest string, layer Archive) error {
+func ApplyLayer(dest string, layer ArchiveReader) error {
 	// We need to be able to set any perms
 	oldmask := syscall.Umask(0)
 	defer syscall.Umask(oldmask)

+ 1 - 0
buildfile.go

@@ -464,6 +464,7 @@ func (b *buildFile) CmdAdd(args string) error {
 		}
 		tarSum := utils.TarSum{Reader: r, DisableCompression: true}
 		remoteHash = tarSum.Sum(nil)
+		r.Close()
 
 		// If the destination is a directory, figure out the filename.
 		if strings.HasSuffix(dest, "/") {

+ 2 - 2
commands.go

@@ -158,7 +158,7 @@ func MkBuildContext(dockerfile string, files [][2]string) (archive.Archive, erro
 	if err := tw.Close(); err != nil {
 		return nil, err
 	}
-	return buf, nil
+	return ioutil.NopCloser(buf), nil
 }
 
 func (cli *DockerCli) CmdBuild(args ...string) error {
@@ -206,7 +206,7 @@ func (cli *DockerCli) CmdBuild(args ...string) error {
 	// FIXME: ProgressReader shouldn't be this annoying to use
 	if context != nil {
 		sf := utils.NewStreamFormatter(false)
-		body = utils.ProgressReader(ioutil.NopCloser(context), 0, cli.err, sf, true, "", "Uploading context")
+		body = utils.ProgressReader(context, 0, cli.err, sf, true, "", "Uploading context")
 	}
 	// Upload the build context
 	v := &url.Values{}

+ 15 - 3
container.go

@@ -1288,7 +1288,11 @@ func (container *Container) ExportRw() (archive.Archive, error) {
 		container.Unmount()
 		return nil, err
 	}
-	return EofReader(archive, func() { container.Unmount() }), nil
+	return utils.NewReadCloserWrapper(archive, func() error {
+		err := archive.Close()
+		container.Unmount()
+		return err
+	}), nil
 }
 
 func (container *Container) Export() (archive.Archive, error) {
@@ -1301,7 +1305,11 @@ func (container *Container) Export() (archive.Archive, error) {
 		container.Unmount()
 		return nil, err
 	}
-	return EofReader(archive, func() { container.Unmount() }), nil
+	return utils.NewReadCloserWrapper(archive, func() error {
+		err := archive.Close()
+		container.Unmount()
+		return err
+	}), nil
 }
 
 func (container *Container) WaitTimeout(timeout time.Duration) error {
@@ -1455,7 +1463,11 @@ func (container *Container) Copy(resource string) (io.ReadCloser, error) {
 	if err != nil {
 		return nil, err
 	}
-	return utils.NewReadCloserWrapper(archive, container.Unmount), nil
+	return utils.NewReadCloserWrapper(archive, func() error {
+		err := archive.Close()
+		container.Unmount()
+		return err
+	}), nil
 }
 
 // Returns true if the container exposes a certain port

+ 5 - 3
graph.go

@@ -127,7 +127,7 @@ func (graph *Graph) Get(name string) (*Image, error) {
 }
 
 // Create creates a new image and registers it in the graph.
-func (graph *Graph) Create(layerData archive.Archive, container *Container, comment, author string, config *runconfig.Config) (*Image, error) {
+func (graph *Graph) Create(layerData archive.ArchiveReader, container *Container, comment, author string, config *runconfig.Config) (*Image, error) {
 	img := &Image{
 		ID:            GenerateID(),
 		Comment:       comment,
@@ -151,7 +151,7 @@ func (graph *Graph) Create(layerData archive.Archive, container *Container, comm
 
 // Register imports a pre-existing image into the graph.
 // FIXME: pass img as first argument
-func (graph *Graph) Register(jsonData []byte, layerData archive.Archive, img *Image) (err error) {
+func (graph *Graph) Register(jsonData []byte, layerData archive.ArchiveReader, img *Image) (err error) {
 	defer func() {
 		// If any error occurs, remove the new dir from the driver.
 		// Don't check for errors since the dir might not have been created.
@@ -226,7 +226,9 @@ func (graph *Graph) TempLayerArchive(id string, compression archive.Compression,
 	if err != nil {
 		return nil, err
 	}
-	return archive.NewTempArchive(utils.ProgressReader(ioutil.NopCloser(a), 0, output, sf, false, utils.TruncateID(id), "Buffering to disk"), tmp)
+	progress := utils.ProgressReader(a, 0, output, sf, false, utils.TruncateID(id), "Buffering to disk")
+	defer progress.Close()
+	return archive.NewTempArchive(progress, tmp)
 }
 
 // Mktemp creates a temporary sub-directory inside the graph's filesystem.

+ 1 - 1
graphdriver/aufs/aufs.go

@@ -271,7 +271,7 @@ func (a *Driver) Diff(id string) (archive.Archive, error) {
 	})
 }
 
-func (a *Driver) ApplyDiff(id string, diff archive.Archive) error {
+func (a *Driver) ApplyDiff(id string, diff archive.ArchiveReader) error {
 	return archive.Untar(diff, path.Join(a.rootPath(), "diff", id), nil)
 }
 

+ 1 - 1
graphdriver/driver.go

@@ -28,7 +28,7 @@ type Driver interface {
 type Differ interface {
 	Diff(id string) (archive.Archive, error)
 	Changes(id string) ([]archive.Change, error)
-	ApplyDiff(id string, diff archive.Archive) error
+	ApplyDiff(id string, diff archive.ArchiveReader) error
 	DiffSize(id string) (bytes int64, err error)
 }
 

+ 11 - 3
image.go

@@ -67,7 +67,7 @@ func LoadImage(root string) (*Image, error) {
 	return img, nil
 }
 
-func StoreImage(img *Image, jsonData []byte, layerData archive.Archive, root, layer string) error {
+func StoreImage(img *Image, jsonData []byte, layerData archive.ArchiveReader, root, layer string) error {
 	// Store the layer
 	var (
 		size   int64
@@ -174,7 +174,11 @@ func (img *Image) TarLayer() (arch archive.Archive, err error) {
 		if err != nil {
 			return nil, err
 		}
-		return EofReader(archive, func() { driver.Put(img.ID) }), nil
+		return utils.NewReadCloserWrapper(archive, func() error {
+			err := archive.Close()
+			driver.Put(img.ID)
+			return err
+		}), nil
 	}
 
 	parentFs, err := driver.Get(img.Parent)
@@ -190,7 +194,11 @@ func (img *Image) TarLayer() (arch archive.Archive, err error) {
 	if err != nil {
 		return nil, err
 	}
-	return EofReader(archive, func() { driver.Put(img.ID) }), nil
+	return utils.NewReadCloserWrapper(archive, func() error {
+		err := archive.Close()
+		driver.Put(img.ID)
+		return err
+	}), nil
 }
 
 func ValidateID(id string) error {

+ 2 - 2
integration/utils_test.go

@@ -319,7 +319,7 @@ func runContainer(eng *engine.Engine, r *docker.Runtime, args []string, t *testi
 }
 
 // FIXME: this is duplicated from graph_test.go in the docker package.
-func fakeTar() (io.Reader, error) {
+func fakeTar() (io.ReadCloser, error) {
 	content := []byte("Hello world!\n")
 	buf := new(bytes.Buffer)
 	tw := tar.NewWriter(buf)
@@ -333,7 +333,7 @@ func fakeTar() (io.Reader, error) {
 		tw.Write([]byte(content))
 	}
 	tw.Close()
-	return buf, nil
+	return ioutil.NopCloser(buf), nil
 }
 
 func getAllImages(eng *engine.Engine, t *testing.T) *engine.Table {

+ 7 - 1
runtime.go

@@ -531,6 +531,8 @@ func (runtime *Runtime) Commit(container *Container, repository, tag, comment, a
 	if err != nil {
 		return nil, err
 	}
+	defer rwTar.Close()
+
 	// Create a new image from the container's base layers + a new layer from container changes
 	img, err := runtime.graph.Create(rwTar, container, comment, author, config)
 	if err != nil {
@@ -817,7 +819,11 @@ func (runtime *Runtime) Diff(container *Container) (archive.Archive, error) {
 	if err != nil {
 		return nil, err
 	}
-	return EofReader(archive, func() { runtime.driver.Put(container.ID) }), nil
+	return utils.NewReadCloserWrapper(archive, func() error {
+		err := archive.Close()
+		runtime.driver.Put(container.ID)
+		return err
+	}), nil
 }
 
 func (runtime *Runtime) Run(c *Container, startCallback execdriver.StartCallback) (int, error) {

+ 10 - 4
server.go

@@ -292,6 +292,7 @@ func (srv *Server) ContainerExport(job *engine.Job) engine.Status {
 		if err != nil {
 			return job.Errorf("%s: %s", name, err)
 		}
+		defer data.Close()
 
 		// Stream the entire contents of the container (basically a volatile snapshot)
 		if _, err := io.Copy(job.Stdout, data); err != nil {
@@ -361,6 +362,7 @@ func (srv *Server) ImageExport(job *engine.Job) engine.Status {
 	if err != nil {
 		return job.Error(err)
 	}
+	defer fs.Close()
 
 	if _, err := io.Copy(job.Stdout, fs); err != nil {
 		return job.Error(err)
@@ -400,6 +402,7 @@ func (srv *Server) exportImage(image *Image, tempdir string) error {
 		if err != nil {
 			return err
 		}
+		defer fs.Close()
 
 		fsTar, err := os.Create(path.Join(tmpImageDir, "layer.tar"))
 		if err != nil {
@@ -436,14 +439,14 @@ func (srv *Server) Build(job *engine.Job) engine.Status {
 		authConfig     = &auth.AuthConfig{}
 		configFile     = &auth.ConfigFile{}
 		tag            string
-		context        io.Reader
+		context        io.ReadCloser
 	)
 	job.GetenvJson("authConfig", authConfig)
 	job.GetenvJson("configFile", configFile)
 	repoName, tag = utils.ParseRepositoryTag(repoName)
 
 	if remoteURL == "" {
-		context = job.Stdin
+		context = ioutil.NopCloser(job.Stdin)
 	} else if utils.IsGIT(remoteURL) {
 		if !strings.HasPrefix(remoteURL, "git://") {
 			remoteURL = "https://" + remoteURL
@@ -479,6 +482,7 @@ func (srv *Server) Build(job *engine.Job) engine.Status {
 		}
 		context = c
 	}
+	defer context.Close()
 
 	sf := utils.NewStreamFormatter(job.GetenvBool("json"))
 	b := NewBuildFile(srv,
@@ -1575,7 +1579,7 @@ func (srv *Server) ImageImport(job *engine.Job) engine.Status {
 		repo    = job.Args[1]
 		tag     string
 		sf      = utils.NewStreamFormatter(job.GetenvBool("json"))
-		archive io.Reader
+		archive archive.ArchiveReader
 		resp    *http.Response
 	)
 	if len(job.Args) > 2 {
@@ -1601,7 +1605,9 @@ func (srv *Server) ImageImport(job *engine.Job) engine.Status {
 		if err != nil {
 			return job.Error(err)
 		}
-		archive = utils.ProgressReader(resp.Body, int(resp.ContentLength), job.Stdout, sf, true, "", "Importing")
+		progressReader := utils.ProgressReader(resp.Body, int(resp.ContentLength), job.Stdout, sf, true, "", "Importing")
+		defer progressReader.Close()
+		archive = progressReader
 	}
 	img, err := srv.runtime.graph.Create(archive, nil, "Imported from "+src, "", nil)
 	if err != nil {

+ 0 - 27
utils.go

@@ -6,8 +6,6 @@ import (
 	"github.com/dotcloud/docker/pkg/namesgenerator"
 	"github.com/dotcloud/docker/runconfig"
 	"github.com/dotcloud/docker/utils"
-	"io"
-	"sync/atomic"
 )
 
 type Change struct {
@@ -56,28 +54,3 @@ func (c *checker) Exists(name string) bool {
 func generateRandomName(runtime *Runtime) (string, error) {
 	return namesgenerator.GenerateRandomName(&checker{runtime})
 }
-
-// Read an io.Reader and call a function when it returns EOF
-func EofReader(r io.Reader, callback func()) *eofReader {
-	return &eofReader{
-		Reader:   r,
-		callback: callback,
-	}
-}
-
-type eofReader struct {
-	io.Reader
-	gotEOF   int32
-	callback func()
-}
-
-func (r *eofReader) Read(p []byte) (n int, err error) {
-	n, err = r.Reader.Read(p)
-	if err == io.EOF {
-		// Use atomics to make the gotEOF check threadsafe
-		if atomic.CompareAndSwapInt32(&r.gotEOF, 0, 1) {
-			r.callback()
-		}
-	}
-	return
-}