Browse Source

Properly close archives

All archive that are created from somewhere generally have to be closed, because
at some point there is a file or a pipe or something that backs them. So, we
make archive.Archive a ReadCloser. However, code consuming archives does not
typically close them so we add an archive.ArchiveReader and use that when we're
only reading.

We then change all the Tar/Archive places to create ReadClosers, and to properly
close them everywhere.

As an added bonus we can use ReadCloserWrapper rather than EofReader in several places,
which is good as EofReader doesn't always work right. For instance, many compression
schemes like gzip knows it is EOF before having read the EOF from the stream, so the
EofCloser never sees an EOF.

Docker-DCO-1.1-Signed-off-by: Alexander Larsson <alexl@redhat.com> (github: alexlarsson)
Alexander Larsson 11 years ago
parent
commit
f198ee525a
13 changed files with 71 additions and 30 deletions
  1. 12 7
      archive/archive.go
  2. 3 2
      archive/archive_test.go
  3. 1 1
      archive/diff.go
  4. 1 0
      buildfile.go
  5. 2 2
      commands.go
  6. 15 3
      container.go
  7. 5 3
      graph.go
  8. 1 1
      graphdriver/aufs/aufs.go
  9. 1 1
      graphdriver/driver.go
  10. 11 3
      image.go
  11. 2 2
      integration/utils_test.go
  12. 7 1
      runtime.go
  13. 10 4
      server.go

+ 12 - 7
archive/archive.go

@@ -19,9 +19,10 @@ import (
 )
 )
 
 
 type (
 type (
-	Archive     io.Reader
-	Compression int
-	TarOptions  struct {
+	Archive       io.ReadCloser
+	ArchiveReader io.Reader
+	Compression   int
+	TarOptions    struct {
 		Includes    []string
 		Includes    []string
 		Compression Compression
 		Compression Compression
 	}
 	}
@@ -269,7 +270,7 @@ func createTarFile(path, extractDir string, hdr *tar.Header, reader *tar.Reader)
 
 
 // Tar creates an archive from the directory at `path`, and returns it as a
 // Tar creates an archive from the directory at `path`, and returns it as a
 // stream of bytes.
 // stream of bytes.
-func Tar(path string, compression Compression) (io.Reader, error) {
+func Tar(path string, compression Compression) (io.ReadCloser, error) {
 	return TarFilter(path, &TarOptions{Compression: compression})
 	return TarFilter(path, &TarOptions{Compression: compression})
 }
 }
 
 
@@ -291,7 +292,7 @@ func escapeName(name string) string {
 
 
 // Tar creates an archive from the directory at `path`, only including files whose relative
 // Tar creates an archive from the directory at `path`, only including files whose relative
 // paths are included in `filter`. If `filter` is nil, then all files are included.
 // paths are included in `filter`. If `filter` is nil, then all files are included.
-func TarFilter(srcPath string, options *TarOptions) (io.Reader, error) {
+func TarFilter(srcPath string, options *TarOptions) (io.ReadCloser, error) {
 	pipeReader, pipeWriter := io.Pipe()
 	pipeReader, pipeWriter := io.Pipe()
 
 
 	compressWriter, err := CompressStream(pipeWriter, options.Compression)
 	compressWriter, err := CompressStream(pipeWriter, options.Compression)
@@ -436,15 +437,19 @@ func TarUntar(src string, dst string) error {
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
+	defer archive.Close()
 	return Untar(archive, dst, nil)
 	return Untar(archive, dst, nil)
 }
 }
 
 
 // UntarPath is a convenience function which looks for an archive
 // UntarPath is a convenience function which looks for an archive
 // at filesystem path `src`, and unpacks it at `dst`.
 // at filesystem path `src`, and unpacks it at `dst`.
 func UntarPath(src, dst string) error {
 func UntarPath(src, dst string) error {
-	if archive, err := os.Open(src); err != nil {
+	archive, err := os.Open(src)
+	if err != nil {
 		return err
 		return err
-	} else if err := Untar(archive, dst, nil); err != nil {
+	}
+	defer archive.Close()
+	if err := Untar(archive, dst, nil); err != nil {
 		return err
 		return err
 	}
 	}
 	return nil
 	return nil

+ 3 - 2
archive/archive_test.go

@@ -67,12 +67,13 @@ func tarUntar(t *testing.T, origin string, compression Compression) error {
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
+	defer archive.Close()
 
 
 	buf := make([]byte, 10)
 	buf := make([]byte, 10)
 	if _, err := archive.Read(buf); err != nil {
 	if _, err := archive.Read(buf); err != nil {
 		return err
 		return err
 	}
 	}
-	archive = io.MultiReader(bytes.NewReader(buf), archive)
+	wrap := io.MultiReader(bytes.NewReader(buf), archive)
 
 
 	detectedCompression := DetectCompression(buf)
 	detectedCompression := DetectCompression(buf)
 	if detectedCompression.Extension() != compression.Extension() {
 	if detectedCompression.Extension() != compression.Extension() {
@@ -84,7 +85,7 @@ func tarUntar(t *testing.T, origin string, compression Compression) error {
 		return err
 		return err
 	}
 	}
 	defer os.RemoveAll(tmp)
 	defer os.RemoveAll(tmp)
-	if err := Untar(archive, tmp, nil); err != nil {
+	if err := Untar(wrap, tmp, nil); err != nil {
 		return err
 		return err
 	}
 	}
 	if _, err := os.Stat(tmp); err != nil {
 	if _, err := os.Stat(tmp); err != nil {

+ 1 - 1
archive/diff.go

@@ -28,7 +28,7 @@ func timeToTimespec(time time.Time) (ts syscall.Timespec) {
 
 
 // ApplyLayer parses a diff in the standard layer format from `layer`, and
 // ApplyLayer parses a diff in the standard layer format from `layer`, and
 // applies it to the directory `dest`.
 // applies it to the directory `dest`.
-func ApplyLayer(dest string, layer Archive) error {
+func ApplyLayer(dest string, layer ArchiveReader) error {
 	// We need to be able to set any perms
 	// We need to be able to set any perms
 	oldmask := syscall.Umask(0)
 	oldmask := syscall.Umask(0)
 	defer syscall.Umask(oldmask)
 	defer syscall.Umask(oldmask)

+ 1 - 0
buildfile.go

@@ -464,6 +464,7 @@ func (b *buildFile) CmdAdd(args string) error {
 		}
 		}
 		tarSum := utils.TarSum{Reader: r, DisableCompression: true}
 		tarSum := utils.TarSum{Reader: r, DisableCompression: true}
 		remoteHash = tarSum.Sum(nil)
 		remoteHash = tarSum.Sum(nil)
+		r.Close()
 
 
 		// If the destination is a directory, figure out the filename.
 		// If the destination is a directory, figure out the filename.
 		if strings.HasSuffix(dest, "/") {
 		if strings.HasSuffix(dest, "/") {

+ 2 - 2
commands.go

@@ -158,7 +158,7 @@ func MkBuildContext(dockerfile string, files [][2]string) (archive.Archive, erro
 	if err := tw.Close(); err != nil {
 	if err := tw.Close(); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	return buf, nil
+	return ioutil.NopCloser(buf), nil
 }
 }
 
 
 func (cli *DockerCli) CmdBuild(args ...string) error {
 func (cli *DockerCli) CmdBuild(args ...string) error {
@@ -206,7 +206,7 @@ func (cli *DockerCli) CmdBuild(args ...string) error {
 	// FIXME: ProgressReader shouldn't be this annoying to use
 	// FIXME: ProgressReader shouldn't be this annoying to use
 	if context != nil {
 	if context != nil {
 		sf := utils.NewStreamFormatter(false)
 		sf := utils.NewStreamFormatter(false)
-		body = utils.ProgressReader(ioutil.NopCloser(context), 0, cli.err, sf, true, "", "Uploading context")
+		body = utils.ProgressReader(context, 0, cli.err, sf, true, "", "Uploading context")
 	}
 	}
 	// Upload the build context
 	// Upload the build context
 	v := &url.Values{}
 	v := &url.Values{}

+ 15 - 3
container.go

@@ -1288,7 +1288,11 @@ func (container *Container) ExportRw() (archive.Archive, error) {
 		container.Unmount()
 		container.Unmount()
 		return nil, err
 		return nil, err
 	}
 	}
-	return EofReader(archive, func() { container.Unmount() }), nil
+	return utils.NewReadCloserWrapper(archive, func() error {
+		err := archive.Close()
+		container.Unmount()
+		return err
+	}), nil
 }
 }
 
 
 func (container *Container) Export() (archive.Archive, error) {
 func (container *Container) Export() (archive.Archive, error) {
@@ -1301,7 +1305,11 @@ func (container *Container) Export() (archive.Archive, error) {
 		container.Unmount()
 		container.Unmount()
 		return nil, err
 		return nil, err
 	}
 	}
-	return EofReader(archive, func() { container.Unmount() }), nil
+	return utils.NewReadCloserWrapper(archive, func() error {
+		err := archive.Close()
+		container.Unmount()
+		return err
+	}), nil
 }
 }
 
 
 func (container *Container) WaitTimeout(timeout time.Duration) error {
 func (container *Container) WaitTimeout(timeout time.Duration) error {
@@ -1455,7 +1463,11 @@ func (container *Container) Copy(resource string) (io.ReadCloser, error) {
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	return utils.NewReadCloserWrapper(archive, container.Unmount), nil
+	return utils.NewReadCloserWrapper(archive, func() error {
+		err := archive.Close()
+		container.Unmount()
+		return err
+	}), nil
 }
 }
 
 
 // Returns true if the container exposes a certain port
 // Returns true if the container exposes a certain port

+ 5 - 3
graph.go

@@ -127,7 +127,7 @@ func (graph *Graph) Get(name string) (*Image, error) {
 }
 }
 
 
 // Create creates a new image and registers it in the graph.
 // Create creates a new image and registers it in the graph.
-func (graph *Graph) Create(layerData archive.Archive, container *Container, comment, author string, config *runconfig.Config) (*Image, error) {
+func (graph *Graph) Create(layerData archive.ArchiveReader, container *Container, comment, author string, config *runconfig.Config) (*Image, error) {
 	img := &Image{
 	img := &Image{
 		ID:            GenerateID(),
 		ID:            GenerateID(),
 		Comment:       comment,
 		Comment:       comment,
@@ -151,7 +151,7 @@ func (graph *Graph) Create(layerData archive.Archive, container *Container, comm
 
 
 // Register imports a pre-existing image into the graph.
 // Register imports a pre-existing image into the graph.
 // FIXME: pass img as first argument
 // FIXME: pass img as first argument
-func (graph *Graph) Register(jsonData []byte, layerData archive.Archive, img *Image) (err error) {
+func (graph *Graph) Register(jsonData []byte, layerData archive.ArchiveReader, img *Image) (err error) {
 	defer func() {
 	defer func() {
 		// If any error occurs, remove the new dir from the driver.
 		// If any error occurs, remove the new dir from the driver.
 		// Don't check for errors since the dir might not have been created.
 		// Don't check for errors since the dir might not have been created.
@@ -226,7 +226,9 @@ func (graph *Graph) TempLayerArchive(id string, compression archive.Compression,
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	return archive.NewTempArchive(utils.ProgressReader(ioutil.NopCloser(a), 0, output, sf, false, utils.TruncateID(id), "Buffering to disk"), tmp)
+	progress := utils.ProgressReader(a, 0, output, sf, false, utils.TruncateID(id), "Buffering to disk")
+	defer progress.Close()
+	return archive.NewTempArchive(progress, tmp)
 }
 }
 
 
 // Mktemp creates a temporary sub-directory inside the graph's filesystem.
 // Mktemp creates a temporary sub-directory inside the graph's filesystem.

+ 1 - 1
graphdriver/aufs/aufs.go

@@ -271,7 +271,7 @@ func (a *Driver) Diff(id string) (archive.Archive, error) {
 	})
 	})
 }
 }
 
 
-func (a *Driver) ApplyDiff(id string, diff archive.Archive) error {
+func (a *Driver) ApplyDiff(id string, diff archive.ArchiveReader) error {
 	return archive.Untar(diff, path.Join(a.rootPath(), "diff", id), nil)
 	return archive.Untar(diff, path.Join(a.rootPath(), "diff", id), nil)
 }
 }
 
 

+ 1 - 1
graphdriver/driver.go

@@ -28,7 +28,7 @@ type Driver interface {
 type Differ interface {
 type Differ interface {
 	Diff(id string) (archive.Archive, error)
 	Diff(id string) (archive.Archive, error)
 	Changes(id string) ([]archive.Change, error)
 	Changes(id string) ([]archive.Change, error)
-	ApplyDiff(id string, diff archive.Archive) error
+	ApplyDiff(id string, diff archive.ArchiveReader) error
 	DiffSize(id string) (bytes int64, err error)
 	DiffSize(id string) (bytes int64, err error)
 }
 }
 
 

+ 11 - 3
image.go

@@ -67,7 +67,7 @@ func LoadImage(root string) (*Image, error) {
 	return img, nil
 	return img, nil
 }
 }
 
 
-func StoreImage(img *Image, jsonData []byte, layerData archive.Archive, root, layer string) error {
+func StoreImage(img *Image, jsonData []byte, layerData archive.ArchiveReader, root, layer string) error {
 	// Store the layer
 	// Store the layer
 	var (
 	var (
 		size   int64
 		size   int64
@@ -174,7 +174,11 @@ func (img *Image) TarLayer() (arch archive.Archive, err error) {
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
-		return EofReader(archive, func() { driver.Put(img.ID) }), nil
+		return utils.NewReadCloserWrapper(archive, func() error {
+			err := archive.Close()
+			driver.Put(img.ID)
+			return err
+		}), nil
 	}
 	}
 
 
 	parentFs, err := driver.Get(img.Parent)
 	parentFs, err := driver.Get(img.Parent)
@@ -190,7 +194,11 @@ func (img *Image) TarLayer() (arch archive.Archive, err error) {
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	return EofReader(archive, func() { driver.Put(img.ID) }), nil
+	return utils.NewReadCloserWrapper(archive, func() error {
+		err := archive.Close()
+		driver.Put(img.ID)
+		return err
+	}), nil
 }
 }
 
 
 func ValidateID(id string) error {
 func ValidateID(id string) error {

+ 2 - 2
integration/utils_test.go

@@ -319,7 +319,7 @@ func runContainer(eng *engine.Engine, r *docker.Runtime, args []string, t *testi
 }
 }
 
 
 // FIXME: this is duplicated from graph_test.go in the docker package.
 // FIXME: this is duplicated from graph_test.go in the docker package.
-func fakeTar() (io.Reader, error) {
+func fakeTar() (io.ReadCloser, error) {
 	content := []byte("Hello world!\n")
 	content := []byte("Hello world!\n")
 	buf := new(bytes.Buffer)
 	buf := new(bytes.Buffer)
 	tw := tar.NewWriter(buf)
 	tw := tar.NewWriter(buf)
@@ -333,7 +333,7 @@ func fakeTar() (io.Reader, error) {
 		tw.Write([]byte(content))
 		tw.Write([]byte(content))
 	}
 	}
 	tw.Close()
 	tw.Close()
-	return buf, nil
+	return ioutil.NopCloser(buf), nil
 }
 }
 
 
 func getAllImages(eng *engine.Engine, t *testing.T) *engine.Table {
 func getAllImages(eng *engine.Engine, t *testing.T) *engine.Table {

+ 7 - 1
runtime.go

@@ -531,6 +531,8 @@ func (runtime *Runtime) Commit(container *Container, repository, tag, comment, a
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
+	defer rwTar.Close()
+
 	// Create a new image from the container's base layers + a new layer from container changes
 	// Create a new image from the container's base layers + a new layer from container changes
 	img, err := runtime.graph.Create(rwTar, container, comment, author, config)
 	img, err := runtime.graph.Create(rwTar, container, comment, author, config)
 	if err != nil {
 	if err != nil {
@@ -817,7 +819,11 @@ func (runtime *Runtime) Diff(container *Container) (archive.Archive, error) {
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	return EofReader(archive, func() { runtime.driver.Put(container.ID) }), nil
+	return utils.NewReadCloserWrapper(archive, func() error {
+		err := archive.Close()
+		runtime.driver.Put(container.ID)
+		return err
+	}), nil
 }
 }
 
 
 func (runtime *Runtime) Run(c *Container, startCallback execdriver.StartCallback) (int, error) {
 func (runtime *Runtime) Run(c *Container, startCallback execdriver.StartCallback) (int, error) {

+ 10 - 4
server.go

@@ -292,6 +292,7 @@ func (srv *Server) ContainerExport(job *engine.Job) engine.Status {
 		if err != nil {
 		if err != nil {
 			return job.Errorf("%s: %s", name, err)
 			return job.Errorf("%s: %s", name, err)
 		}
 		}
+		defer data.Close()
 
 
 		// Stream the entire contents of the container (basically a volatile snapshot)
 		// Stream the entire contents of the container (basically a volatile snapshot)
 		if _, err := io.Copy(job.Stdout, data); err != nil {
 		if _, err := io.Copy(job.Stdout, data); err != nil {
@@ -361,6 +362,7 @@ func (srv *Server) ImageExport(job *engine.Job) engine.Status {
 	if err != nil {
 	if err != nil {
 		return job.Error(err)
 		return job.Error(err)
 	}
 	}
+	defer fs.Close()
 
 
 	if _, err := io.Copy(job.Stdout, fs); err != nil {
 	if _, err := io.Copy(job.Stdout, fs); err != nil {
 		return job.Error(err)
 		return job.Error(err)
@@ -400,6 +402,7 @@ func (srv *Server) exportImage(image *Image, tempdir string) error {
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
+		defer fs.Close()
 
 
 		fsTar, err := os.Create(path.Join(tmpImageDir, "layer.tar"))
 		fsTar, err := os.Create(path.Join(tmpImageDir, "layer.tar"))
 		if err != nil {
 		if err != nil {
@@ -436,14 +439,14 @@ func (srv *Server) Build(job *engine.Job) engine.Status {
 		authConfig     = &auth.AuthConfig{}
 		authConfig     = &auth.AuthConfig{}
 		configFile     = &auth.ConfigFile{}
 		configFile     = &auth.ConfigFile{}
 		tag            string
 		tag            string
-		context        io.Reader
+		context        io.ReadCloser
 	)
 	)
 	job.GetenvJson("authConfig", authConfig)
 	job.GetenvJson("authConfig", authConfig)
 	job.GetenvJson("configFile", configFile)
 	job.GetenvJson("configFile", configFile)
 	repoName, tag = utils.ParseRepositoryTag(repoName)
 	repoName, tag = utils.ParseRepositoryTag(repoName)
 
 
 	if remoteURL == "" {
 	if remoteURL == "" {
-		context = job.Stdin
+		context = ioutil.NopCloser(job.Stdin)
 	} else if utils.IsGIT(remoteURL) {
 	} else if utils.IsGIT(remoteURL) {
 		if !strings.HasPrefix(remoteURL, "git://") {
 		if !strings.HasPrefix(remoteURL, "git://") {
 			remoteURL = "https://" + remoteURL
 			remoteURL = "https://" + remoteURL
@@ -479,6 +482,7 @@ func (srv *Server) Build(job *engine.Job) engine.Status {
 		}
 		}
 		context = c
 		context = c
 	}
 	}
+	defer context.Close()
 
 
 	sf := utils.NewStreamFormatter(job.GetenvBool("json"))
 	sf := utils.NewStreamFormatter(job.GetenvBool("json"))
 	b := NewBuildFile(srv,
 	b := NewBuildFile(srv,
@@ -1575,7 +1579,7 @@ func (srv *Server) ImageImport(job *engine.Job) engine.Status {
 		repo    = job.Args[1]
 		repo    = job.Args[1]
 		tag     string
 		tag     string
 		sf      = utils.NewStreamFormatter(job.GetenvBool("json"))
 		sf      = utils.NewStreamFormatter(job.GetenvBool("json"))
-		archive io.Reader
+		archive archive.ArchiveReader
 		resp    *http.Response
 		resp    *http.Response
 	)
 	)
 	if len(job.Args) > 2 {
 	if len(job.Args) > 2 {
@@ -1601,7 +1605,9 @@ func (srv *Server) ImageImport(job *engine.Job) engine.Status {
 		if err != nil {
 		if err != nil {
 			return job.Error(err)
 			return job.Error(err)
 		}
 		}
-		archive = utils.ProgressReader(resp.Body, int(resp.ContentLength), job.Stdout, sf, true, "", "Importing")
+		progressReader := utils.ProgressReader(resp.Body, int(resp.ContentLength), job.Stdout, sf, true, "", "Importing")
+		defer progressReader.Close()
+		archive = progressReader
 	}
 	}
 	img, err := srv.runtime.graph.Create(archive, nil, "Imported from "+src, "", nil)
 	img, err := srv.runtime.graph.Create(archive, nil, "Imported from "+src, "", nil)
 	if err != nil {
 	if err != nil {