Merge pull request #20706 from calavera/remove_concurrent_access_to_stdtypes

Make stdcopy.StdWriter thread safe.
This commit is contained in:
Brian Goff 2016-02-27 21:14:09 -05:00
commit ec268be52e
2 changed files with 44 additions and 42 deletions

View file

@ -3,12 +3,24 @@ package stdcopy
import (
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/Sirupsen/logrus"
)
// StdType is the type of standard stream
// a writer can multiplex to.
type StdType byte
const (
// Stdin represents standard input stream type.
Stdin StdType = iota
// Stdout represents standard output stream type.
Stdout
// Stderr represents standard error steam type.
Stderr
stdWriterPrefixLen = 8
stdWriterFdIndex = 0
stdWriterSizeIndex = 4
@ -16,38 +28,32 @@ const (
startingBufLen = 32*1024 + stdWriterPrefixLen + 1
)
// StdType prefixes type and length to standard stream.
type StdType [stdWriterPrefixLen]byte
var (
// Stdin represents standard input stream type.
Stdin = StdType{0: 0}
// Stdout represents standard output stream type.
Stdout = StdType{0: 1}
// Stderr represents standard error steam type.
Stderr = StdType{0: 2}
)
// StdWriter is wrapper of io.Writer with extra customized info.
type StdWriter struct {
// stdWriter is wrapper of io.Writer with extra customized info.
type stdWriter struct {
io.Writer
prefix StdType
sizeBuf []byte
prefix byte
}
func (w *StdWriter) Write(buf []byte) (n int, err error) {
var n1, n2 int
// Write sends the buffer to the underneath writer.
// It insert the prefix header before the buffer,
// so stdcopy.StdCopy knows where to multiplex the output.
// It makes stdWriter to implement io.Writer.
func (w *stdWriter) Write(buf []byte) (n int, err error) {
if w == nil || w.Writer == nil {
return 0, errors.New("Writer not instantiated")
}
binary.BigEndian.PutUint32(w.prefix[4:], uint32(len(buf)))
n1, err = w.Writer.Write(w.prefix[:])
if err != nil {
n = n1 - stdWriterPrefixLen
} else {
n2, err = w.Writer.Write(buf)
n = n1 + n2 - stdWriterPrefixLen
if buf == nil {
return 0, nil
}
header := [stdWriterPrefixLen]byte{stdWriterFdIndex: w.prefix}
binary.BigEndian.PutUint32(header[stdWriterSizeIndex:], uint32(len(buf)))
line := append(header[:], buf...)
n, err = w.Writer.Write(line)
n -= stdWriterPrefixLen
if n < 0 {
n = 0
}
@ -60,16 +66,13 @@ func (w *StdWriter) Write(buf []byte) (n int, err error) {
// This allows multiple write streams (e.g. stdout and stderr) to be muxed into a single connection.
// `t` indicates the id of the stream to encapsulate.
// It can be stdcopy.Stdin, stdcopy.Stdout, stdcopy.Stderr.
func NewStdWriter(w io.Writer, t StdType) *StdWriter {
return &StdWriter{
Writer: w,
prefix: t,
sizeBuf: make([]byte, 4),
func NewStdWriter(w io.Writer, t StdType) io.Writer {
return &stdWriter{
Writer: w,
prefix: byte(t),
}
}
var errInvalidStdHeader = errors.New("Unrecognized input header")
// StdCopy is a modified version of io.Copy.
//
// StdCopy will demultiplex `src`, assuming that it contains two streams,
@ -110,18 +113,18 @@ func StdCopy(dstout, dsterr io.Writer, src io.Reader) (written int64, err error)
}
// Check the first byte to know where to write
switch buf[stdWriterFdIndex] {
case 0:
switch StdType(buf[stdWriterFdIndex]) {
case Stdin:
fallthrough
case 1:
case Stdout:
// Write on stdout
out = dstout
case 2:
case Stderr:
// Write on stderr
out = dsterr
default:
logrus.Debugf("Error selecting output fd: (%d)", buf[stdWriterFdIndex])
return 0, errInvalidStdHeader
return 0, fmt.Errorf("Unrecognized input header: %d", buf[stdWriterFdIndex])
}
// Retrieve the size of the frame

View file

@ -17,10 +17,9 @@ func TestNewStdWriter(t *testing.T) {
}
func TestWriteWithUnitializedStdWriter(t *testing.T) {
writer := StdWriter{
Writer: nil,
prefix: Stdout,
sizeBuf: make([]byte, 4),
writer := stdWriter{
Writer: nil,
prefix: byte(Stdout),
}
n, err := writer.Write([]byte("Something here"))
if n != 0 || err == nil {
@ -180,7 +179,7 @@ func TestStdCopyDetectsCorruptedFrame(t *testing.T) {
src: buffer}
written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
if written != startingBufLen {
t.Fatalf("Expected 0 bytes read, got %d", written)
t.Fatalf("Expected %d bytes read, got %d", startingBufLen, written)
}
if err != nil {
t.Fatal("Didn't get nil error")