浏览代码

Set permission on atomic file write

Perform chmod before rename with the atomic file writer.
Ensure writeErr is set on short write and file is removed on write error.

Signed-off-by: Derek McGowan <derek@mcgstyle.net> (github: dmcgowan)
Derek McGowan 9 年之前
父节点
当前提交
1cd7490281
共有 2 个文件被更改,包括 19 次插入4 次删除
  1. 10 3
      pkg/ioutils/fswriters.go
  2. 9 1
      pkg/ioutils/fswriters_test.go

+ 10 - 3
pkg/ioutils/fswriters.go

@@ -15,13 +15,15 @@ func NewAtomicFileWriter(filename string, perm os.FileMode) (io.WriteCloser, err
 	if err != nil {
 		return nil, err
 	}
+
 	abspath, err := filepath.Abs(filename)
 	if err != nil {
 		return nil, err
 	}
 	return &atomicFileWriter{
-		f:  f,
-		fn: abspath,
+		f:    f,
+		fn:   abspath,
+		perm: perm,
 	}, nil
 }
 
@@ -34,6 +36,7 @@ func AtomicWriteFile(filename string, data []byte, perm os.FileMode) error {
 	n, err := f.Write(data)
 	if err == nil && n < len(data) {
 		err = io.ErrShortWrite
+		f.(*atomicFileWriter).writeErr = err
 	}
 	if err1 := f.Close(); err == nil {
 		err = err1
@@ -45,6 +48,7 @@ type atomicFileWriter struct {
 	f        *os.File
 	fn       string
 	writeErr error
+	perm     os.FileMode
 }
 
 func (w *atomicFileWriter) Write(dt []byte) (int, error) {
@@ -57,7 +61,7 @@ func (w *atomicFileWriter) Write(dt []byte) (int, error) {
 
 func (w *atomicFileWriter) Close() (retErr error) {
 	defer func() {
-		if retErr != nil {
+		if retErr != nil || w.writeErr != nil {
 			os.Remove(w.f.Name())
 		}
 	}()
@@ -68,6 +72,9 @@ func (w *atomicFileWriter) Close() (retErr error) {
 	if err := w.f.Close(); err != nil {
 		return err
 	}
+	if err := os.Chmod(w.f.Name(), w.perm); err != nil {
+		return err
+	}
 	if w.writeErr == nil {
 		return os.Rename(w.f.Name(), w.fn)
 	}

+ 9 - 1
pkg/ioutils/fswriters_test.go

@@ -16,7 +16,7 @@ func TestAtomicWriteToFile(t *testing.T) {
 	defer os.RemoveAll(tmpDir)
 
 	expected := []byte("barbaz")
-	if err := AtomicWriteFile(filepath.Join(tmpDir, "foo"), expected, 0600); err != nil {
+	if err := AtomicWriteFile(filepath.Join(tmpDir, "foo"), expected, 0666); err != nil {
 		t.Fatalf("Error writing to file: %v", err)
 	}
 
@@ -28,4 +28,12 @@ func TestAtomicWriteToFile(t *testing.T) {
 	if bytes.Compare(actual, expected) != 0 {
 		t.Fatalf("Data mismatch, expected %q, got %q", expected, actual)
 	}
+
+	st, err := os.Stat(filepath.Join(tmpDir, "foo"))
+	if err != nil {
+		t.Fatalf("Error statting file: %v", err)
+	}
+	if expected := os.FileMode(0666); st.Mode() != expected {
+		t.Fatalf("Mode mismatched, expected %o, got %o", expected, st.Mode())
+	}
 }