Parcourir la source

Merge pull request #20706 from calavera/remove_concurrent_access_to_stdtypes

Make stdcopy.StdWriter thread safe.
Brian Goff il y a 9 ans
Parent
commit
ec268be52e
2 fichiers modifiés avec 44 ajouts et 42 suppressions
  1. 40 37
      pkg/stdcopy/stdcopy.go
  2. 4 5
      pkg/stdcopy/stdcopy_test.go

+ 40 - 37
pkg/stdcopy/stdcopy.go

@@ -3,12 +3,24 @@ package stdcopy
 import (
 	"encoding/binary"
 	"errors"
+	"fmt"
 	"io"
 
 	"github.com/Sirupsen/logrus"
 )
 
+// StdType is the type of standard stream
+// a writer can multiplex to.
+type StdType byte
+
 const (
+	// Stdin represents standard input stream type.
+	Stdin StdType = iota
+	// Stdout represents standard output stream type.
+	Stdout
+	// Stderr represents standard error steam type.
+	Stderr
+
 	stdWriterPrefixLen = 8
 	stdWriterFdIndex   = 0
 	stdWriterSizeIndex = 4
@@ -16,38 +28,32 @@ const (
 	startingBufLen = 32*1024 + stdWriterPrefixLen + 1
 )
 
-// StdType prefixes type and length to standard stream.
-type StdType [stdWriterPrefixLen]byte
-
-var (
-	// Stdin represents standard input stream type.
-	Stdin = StdType{0: 0}
-	// Stdout represents standard output stream type.
-	Stdout = StdType{0: 1}
-	// Stderr represents standard error steam type.
-	Stderr = StdType{0: 2}
-)
-
-// StdWriter is wrapper of io.Writer with extra customized info.
-type StdWriter struct {
+// stdWriter is wrapper of io.Writer with extra customized info.
+type stdWriter struct {
 	io.Writer
-	prefix  StdType
-	sizeBuf []byte
+	prefix byte
 }
 
-func (w *StdWriter) Write(buf []byte) (n int, err error) {
-	var n1, n2 int
+// Write sends the buffer to the underneath writer.
+// It insert the prefix header before the buffer,
+// so stdcopy.StdCopy knows where to multiplex the output.
+// It makes stdWriter to implement io.Writer.
+func (w *stdWriter) Write(buf []byte) (n int, err error) {
 	if w == nil || w.Writer == nil {
 		return 0, errors.New("Writer not instantiated")
 	}
-	binary.BigEndian.PutUint32(w.prefix[4:], uint32(len(buf)))
-	n1, err = w.Writer.Write(w.prefix[:])
-	if err != nil {
-		n = n1 - stdWriterPrefixLen
-	} else {
-		n2, err = w.Writer.Write(buf)
-		n = n1 + n2 - stdWriterPrefixLen
+	if buf == nil {
+		return 0, nil
 	}
+
+	header := [stdWriterPrefixLen]byte{stdWriterFdIndex: w.prefix}
+	binary.BigEndian.PutUint32(header[stdWriterSizeIndex:], uint32(len(buf)))
+
+	line := append(header[:], buf...)
+
+	n, err = w.Writer.Write(line)
+	n -= stdWriterPrefixLen
+
 	if n < 0 {
 		n = 0
 	}
@@ -60,16 +66,13 @@ func (w *StdWriter) Write(buf []byte) (n int, err error) {
 // This allows multiple write streams (e.g. stdout and stderr) to be muxed into a single connection.
 // `t` indicates the id of the stream to encapsulate.
 // It can be stdcopy.Stdin, stdcopy.Stdout, stdcopy.Stderr.
-func NewStdWriter(w io.Writer, t StdType) *StdWriter {
-	return &StdWriter{
-		Writer:  w,
-		prefix:  t,
-		sizeBuf: make([]byte, 4),
+func NewStdWriter(w io.Writer, t StdType) io.Writer {
+	return &stdWriter{
+		Writer: w,
+		prefix: byte(t),
 	}
 }
 
-var errInvalidStdHeader = errors.New("Unrecognized input header")
-
 // StdCopy is a modified version of io.Copy.
 //
 // StdCopy will demultiplex `src`, assuming that it contains two streams,
@@ -110,18 +113,18 @@ func StdCopy(dstout, dsterr io.Writer, src io.Reader) (written int64, err error)
 		}
 
 		// Check the first byte to know where to write
-		switch buf[stdWriterFdIndex] {
-		case 0:
+		switch StdType(buf[stdWriterFdIndex]) {
+		case Stdin:
 			fallthrough
-		case 1:
+		case Stdout:
 			// Write on stdout
 			out = dstout
-		case 2:
+		case Stderr:
 			// Write on stderr
 			out = dsterr
 		default:
 			logrus.Debugf("Error selecting output fd: (%d)", buf[stdWriterFdIndex])
-			return 0, errInvalidStdHeader
+			return 0, fmt.Errorf("Unrecognized input header: %d", buf[stdWriterFdIndex])
 		}
 
 		// Retrieve the size of the frame

+ 4 - 5
pkg/stdcopy/stdcopy_test.go

@@ -17,10 +17,9 @@ func TestNewStdWriter(t *testing.T) {
 }
 
 func TestWriteWithUnitializedStdWriter(t *testing.T) {
-	writer := StdWriter{
-		Writer:  nil,
-		prefix:  Stdout,
-		sizeBuf: make([]byte, 4),
+	writer := stdWriter{
+		Writer: nil,
+		prefix: byte(Stdout),
 	}
 	n, err := writer.Write([]byte("Something here"))
 	if n != 0 || err == nil {
@@ -180,7 +179,7 @@ func TestStdCopyDetectsCorruptedFrame(t *testing.T) {
 		src:          buffer}
 	written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
 	if written != startingBufLen {
-		t.Fatalf("Expected 0 bytes read, got %d", written)
+		t.Fatalf("Expected %d bytes read, got %d", startingBufLen, written)
 	}
 	if err != nil {
 		t.Fatal("Didn't get nil error")