Przeglądaj źródła

Safer file io for configuration files

Signed-off-by: Tonis Tiigi <tonistiigi@gmail.com>
Tonis Tiigi 9 lat temu
rodzic
commit
ea3cbd3274

+ 3 - 2
container/container.go

@@ -23,6 +23,7 @@ import (
 	"github.com/docker/docker/image"
 	"github.com/docker/docker/layer"
 	"github.com/docker/docker/pkg/idtools"
+	"github.com/docker/docker/pkg/ioutils"
 	"github.com/docker/docker/pkg/promise"
 	"github.com/docker/docker/pkg/signal"
 	"github.com/docker/docker/pkg/symlink"
@@ -131,7 +132,7 @@ func (container *Container) ToDisk() error {
 		return err
 	}
 
-	jsonSource, err := os.Create(pth)
+	jsonSource, err := ioutils.NewAtomicFileWriter(pth, 0666)
 	if err != nil {
 		return err
 	}
@@ -191,7 +192,7 @@ func (container *Container) WriteHostConfig() error {
 		return err
 	}
 
-	f, err := os.Create(pth)
+	f, err := ioutils.NewAtomicFileWriter(pth, 0666)
 	if err != nil {
 		return err
 	}

+ 3 - 5
distribution/metadata/metadata.go

@@ -5,6 +5,8 @@ import (
 	"os"
 	"path/filepath"
 	"sync"
+
+	"github.com/docker/docker/pkg/ioutils"
 )
 
 // Store implements a K/V store for mapping distribution-related IDs
@@ -56,14 +58,10 @@ func (store *FSMetadataStore) Set(namespace, key string, value []byte) error {
 	defer store.Unlock()
 
 	path := store.path(namespace, key)
-	tempFilePath := path + ".tmp"
 	if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
 		return err
 	}
-	if err := ioutil.WriteFile(tempFilePath, value, 0644); err != nil {
-		return err
-	}
-	return os.Rename(tempFilePath, path)
+	return ioutils.AtomicWriteFile(path, value, 0644)
 }
 
 // Delete removes data indexed by namespace and key. The data file named after

+ 3 - 12
image/fs.go

@@ -9,6 +9,7 @@ import (
 
 	"github.com/Sirupsen/logrus"
 	"github.com/docker/distribution/digest"
+	"github.com/docker/docker/pkg/ioutils"
 )
 
 // IDWalkFunc is function called by StoreBackend.Walk
@@ -118,12 +119,7 @@ func (s *fs) Set(data []byte) (ID, error) {
 	}
 
 	id := ID(digest.FromBytes(data))
-	filePath := s.contentFile(id)
-	tempFilePath := s.contentFile(id) + ".tmp"
-	if err := ioutil.WriteFile(tempFilePath, data, 0600); err != nil {
-		return "", err
-	}
-	if err := os.Rename(tempFilePath, filePath); err != nil {
+	if err := ioutils.AtomicWriteFile(s.contentFile(id), data, 0600); err != nil {
 		return "", err
 	}
 
@@ -156,12 +152,7 @@ func (s *fs) SetMetadata(id ID, key string, data []byte) error {
 	if err := os.MkdirAll(baseDir, 0700); err != nil {
 		return err
 	}
-	filePath := filepath.Join(s.metadataDir(id), key)
-	tempFilePath := filePath + ".tmp"
-	if err := ioutil.WriteFile(tempFilePath, data, 0600); err != nil {
-		return err
-	}
-	return os.Rename(tempFilePath, filePath)
+	return ioutils.AtomicWriteFile(filepath.Join(s.metadataDir(id), key), data, 0600)
 }
 
 // GetMetadata returns metadata for a given ID.

+ 2 - 6
migrate/v1/migratev1.go

@@ -19,6 +19,7 @@ import (
 	"github.com/docker/docker/image"
 	imagev1 "github.com/docker/docker/image/v1"
 	"github.com/docker/docker/layer"
+	"github.com/docker/docker/pkg/ioutils"
 	"github.com/docker/docker/reference"
 )
 
@@ -160,12 +161,7 @@ func calculateLayerChecksum(graphDir, id string, ls checksumCalculator) error {
 		return err
 	}
 
-	tmpFile := filepath.Join(graphDir, id, migrationDiffIDFileName+".tmp")
-	if err := ioutil.WriteFile(tmpFile, []byte(diffID), 0600); err != nil {
-		return err
-	}
-
-	if err := os.Rename(tmpFile, filepath.Join(graphDir, id, migrationDiffIDFileName)); err != nil {
+	if err := ioutils.AtomicWriteFile(filepath.Join(graphDir, id, migrationDiffIDFileName), []byte(diffID), 0600); err != nil {
 		return err
 	}
 

+ 75 - 0
pkg/ioutils/fswriters.go

@@ -0,0 +1,75 @@
+package ioutils
+
+import (
+	"io"
+	"io/ioutil"
+	"os"
+	"path/filepath"
+)
+
+// NewAtomicFileWriter returns WriteCloser so that writing to it writes to a
+// temporary file and closing it atomically changes the temporary file to
+// destination path. Writing and closing concurrently is not allowed.
+func NewAtomicFileWriter(filename string, perm os.FileMode) (io.WriteCloser, error) {
+	f, err := ioutil.TempFile(filepath.Dir(filename), ".tmp-"+filepath.Base(filename))
+	if err != nil {
+		return nil, err
+	}
+	abspath, err := filepath.Abs(filename)
+	if err != nil {
+		return nil, err
+	}
+	return &atomicFileWriter{
+		f:  f,
+		fn: abspath,
+	}, nil
+}
+
+// AtomicWriteFile atomically writes data to a file named by filename.
+func AtomicWriteFile(filename string, data []byte, perm os.FileMode) error {
+	f, err := NewAtomicFileWriter(filename, perm)
+	if err != nil {
+		return err
+	}
+	n, err := f.Write(data)
+	if err == nil && n < len(data) {
+		err = io.ErrShortWrite
+	}
+	if err1 := f.Close(); err == nil {
+		err = err1
+	}
+	return err
+}
+
+type atomicFileWriter struct {
+	f        *os.File
+	fn       string
+	writeErr error
+}
+
+func (w *atomicFileWriter) Write(dt []byte) (int, error) {
+	n, err := w.f.Write(dt)
+	if err != nil {
+		w.writeErr = err
+	}
+	return n, err
+}
+
+func (w *atomicFileWriter) Close() (retErr error) {
+	defer func() {
+		if retErr != nil {
+			os.Remove(w.f.Name())
+		}
+	}()
+	if err := w.f.Sync(); err != nil {
+		w.f.Close()
+		return err
+	}
+	if err := w.f.Close(); err != nil {
+		return err
+	}
+	if w.writeErr == nil {
+		return os.Rename(w.f.Name(), w.fn)
+	}
+	return nil
+}

+ 31 - 0
pkg/ioutils/fswriters_test.go

@@ -0,0 +1,31 @@
+package ioutils
+
+import (
+	"bytes"
+	"io/ioutil"
+	"os"
+	"path/filepath"
+	"testing"
+)
+
+func TestAtomicWriteToFile(t *testing.T) {
+	tmpDir, err := ioutil.TempDir("", "atomic-writers-test")
+	if err != nil {
+		t.Fatalf("Error when creating temporary directory: %s", err)
+	}
+	defer os.RemoveAll(tmpDir)
+
+	expected := []byte("barbaz")
+	if err := AtomicWriteFile(filepath.Join(tmpDir, "foo"), expected, 0600); err != nil {
+		t.Fatalf("Error writing to file: %v", err)
+	}
+
+	actual, err := ioutil.ReadFile(filepath.Join(tmpDir, "foo"))
+	if err != nil {
+		t.Fatalf("Error reading from file: %v", err)
+	}
+
+	if bytes.Compare(actual, expected) != 0 {
+		t.Fatalf("Data mismatch, expected %q, got %q", expected, actual)
+	}
+}

+ 2 - 13
reference/store.go

@@ -4,7 +4,6 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
-	"io/ioutil"
 	"os"
 	"path/filepath"
 	"sort"
@@ -12,6 +11,7 @@ import (
 
 	"github.com/docker/distribution/digest"
 	"github.com/docker/docker/image"
+	"github.com/docker/docker/pkg/ioutils"
 )
 
 var (
@@ -256,18 +256,7 @@ func (store *store) save() error {
 	if err != nil {
 		return err
 	}
-
-	tempFilePath := store.jsonPath + ".tmp"
-
-	if err := ioutil.WriteFile(tempFilePath, jsonData, 0600); err != nil {
-		return err
-	}
-
-	if err := os.Rename(tempFilePath, store.jsonPath); err != nil {
-		return err
-	}
-
-	return nil
+	return ioutils.AtomicWriteFile(store.jsonPath, jsonData, 0600)
 }
 
 func (store *store) reload() error {