|
@@ -3,6 +3,7 @@ package stdcopy
|
|
import (
|
|
import (
|
|
"bytes"
|
|
"bytes"
|
|
"errors"
|
|
"errors"
|
|
|
|
+ "io"
|
|
"io/ioutil"
|
|
"io/ioutil"
|
|
"strings"
|
|
"strings"
|
|
"testing"
|
|
"testing"
|
|
@@ -85,17 +86,22 @@ func TestWriteDoesNotReturnNegativeWrittenBytes(t *testing.T) {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-func TestStdCopyWriteAndRead(t *testing.T) {
|
|
|
|
- buffer := new(bytes.Buffer)
|
|
|
|
- stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
|
|
|
|
|
|
+func getSrcBuffer(stdOutBytes, stdErrBytes []byte) (buffer *bytes.Buffer, err error) {
|
|
|
|
+ buffer = new(bytes.Buffer)
|
|
dstOut := NewStdWriter(buffer, Stdout)
|
|
dstOut := NewStdWriter(buffer, Stdout)
|
|
- _, err := dstOut.Write(stdOutBytes)
|
|
|
|
|
|
+ _, err = dstOut.Write(stdOutBytes)
|
|
if err != nil {
|
|
if err != nil {
|
|
- t.Fatal(err)
|
|
|
|
|
|
+ return
|
|
}
|
|
}
|
|
- stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
|
|
|
|
dstErr := NewStdWriter(buffer, Stderr)
|
|
dstErr := NewStdWriter(buffer, Stderr)
|
|
_, err = dstErr.Write(stdErrBytes)
|
|
_, err = dstErr.Write(stdErrBytes)
|
|
|
|
+ return
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func TestStdCopyWriteAndRead(t *testing.T) {
|
|
|
|
+ stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
|
|
|
|
+ stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
|
|
|
|
+ buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
|
|
if err != nil {
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
@@ -109,6 +115,78 @@ func TestStdCopyWriteAndRead(t *testing.T) {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+type customReader struct {
|
|
|
|
+ n int
|
|
|
|
+ err error
|
|
|
|
+ totalCalls int
|
|
|
|
+ correctCalls int
|
|
|
|
+ src *bytes.Buffer
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (f *customReader) Read(buf []byte) (int, error) {
|
|
|
|
+ f.totalCalls++
|
|
|
|
+ if f.totalCalls <= f.correctCalls {
|
|
|
|
+ return f.src.Read(buf)
|
|
|
|
+ }
|
|
|
|
+ return f.n, f.err
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func TestStdCopyReturnsErrorReadingHeader(t *testing.T) {
|
|
|
|
+ expectedError := errors.New("error")
|
|
|
|
+ reader := &customReader{
|
|
|
|
+ err: expectedError}
|
|
|
|
+ written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
|
|
|
|
+ if written != 0 {
|
|
|
|
+ t.Fatalf("Expected 0 bytes read, got %d", written)
|
|
|
|
+ }
|
|
|
|
+ if err != expectedError {
|
|
|
|
+ t.Fatalf("Didn't get expected error")
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func TestStdCopyReturnsErrorReadingFrame(t *testing.T) {
|
|
|
|
+ expectedError := errors.New("error")
|
|
|
|
+ stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
|
|
|
|
+ stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
|
|
|
|
+ buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
|
|
|
|
+ if err != nil {
|
|
|
|
+ t.Fatal(err)
|
|
|
|
+ }
|
|
|
|
+ reader := &customReader{
|
|
|
|
+ correctCalls: 1,
|
|
|
|
+ n: stdWriterPrefixLen + 1,
|
|
|
|
+ err: expectedError,
|
|
|
|
+ src: buffer}
|
|
|
|
+ written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
|
|
|
|
+ if written != 0 {
|
|
|
|
+ t.Fatalf("Expected 0 bytes read, got %d", written)
|
|
|
|
+ }
|
|
|
|
+ if err != expectedError {
|
|
|
|
+ t.Fatalf("Didn't get expected error")
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func TestStdCopyDetectsCorruptedFrame(t *testing.T) {
|
|
|
|
+ stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
|
|
|
|
+ stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
|
|
|
|
+ buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
|
|
|
|
+ if err != nil {
|
|
|
|
+ t.Fatal(err)
|
|
|
|
+ }
|
|
|
|
+ reader := &customReader{
|
|
|
|
+ correctCalls: 1,
|
|
|
|
+ n: stdWriterPrefixLen + 1,
|
|
|
|
+ err: io.EOF,
|
|
|
|
+ src: buffer}
|
|
|
|
+ written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
|
|
|
|
+ if written != startingBufLen {
|
|
|
|
+ t.Fatalf("Expected 0 bytes read, got %d", written)
|
|
|
|
+ }
|
|
|
|
+ if err != nil {
|
|
|
|
+ t.Fatal("Didn't get nil error")
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
func TestStdCopyWithInvalidInputHeader(t *testing.T) {
|
|
func TestStdCopyWithInvalidInputHeader(t *testing.T) {
|
|
dstOut := NewStdWriter(ioutil.Discard, Stdout)
|
|
dstOut := NewStdWriter(ioutil.Discard, Stdout)
|
|
dstErr := NewStdWriter(ioutil.Discard, Stderr)
|
|
dstErr := NewStdWriter(ioutil.Discard, Stderr)
|
|
@@ -131,6 +209,44 @@ func TestStdCopyWithCorruptedPrefix(t *testing.T) {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+func TestStdCopyReturnsWriteErrors(t *testing.T) {
|
|
|
|
+ stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
|
|
|
|
+ stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
|
|
|
|
+ buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
|
|
|
|
+ if err != nil {
|
|
|
|
+ t.Fatal(err)
|
|
|
|
+ }
|
|
|
|
+ expectedError := errors.New("expected")
|
|
|
|
+
|
|
|
|
+ dstOut := &errWriter{err: expectedError}
|
|
|
|
+
|
|
|
|
+ written, err := StdCopy(dstOut, ioutil.Discard, buffer)
|
|
|
|
+ if written != 0 {
|
|
|
|
+ t.Fatalf("StdCopy should have written 0, but has written %d", written)
|
|
|
|
+ }
|
|
|
|
+ if err != expectedError {
|
|
|
|
+ t.Fatalf("Didn't get expected error, got %v", err)
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func TestStdCopyDetectsNotFullyWrittenFrames(t *testing.T) {
|
|
|
|
+ stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
|
|
|
|
+ stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
|
|
|
|
+ buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
|
|
|
|
+ if err != nil {
|
|
|
|
+ t.Fatal(err)
|
|
|
|
+ }
|
|
|
|
+ dstOut := &errWriter{n: startingBufLen - 10}
|
|
|
|
+
|
|
|
|
+ written, err := StdCopy(dstOut, ioutil.Discard, buffer)
|
|
|
|
+ if written != 0 {
|
|
|
|
+ t.Fatalf("StdCopy should have return 0 written bytes, but returned %d", written)
|
|
|
|
+ }
|
|
|
|
+ if err != io.ErrShortWrite {
|
|
|
|
+ t.Fatalf("Didn't get expected io.ErrShortWrite error")
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
func BenchmarkWrite(b *testing.B) {
|
|
func BenchmarkWrite(b *testing.B) {
|
|
w := NewStdWriter(ioutil.Discard, Stdout)
|
|
w := NewStdWriter(ioutil.Discard, Stdout)
|
|
data := []byte("Test line for testing stdwriter performance\n")
|
|
data := []byte("Test line for testing stdwriter performance\n")
|