Pārlūkot izejas kodu

Merge pull request #16581 from fgimenez/11584-stdcopy-test-coverage

Added test coverage to StdCopy closes #11584
Jess Frazelle 9 gadi atpakaļ
vecāks
revīzija
5bde858db5
1 mainītis faili ar 122 papildinājumiem un 6 dzēšanām
  1. 122 6
      pkg/stdcopy/stdcopy_test.go

+ 122 - 6
pkg/stdcopy/stdcopy_test.go

@@ -3,6 +3,7 @@ package stdcopy
 import (
 import (
 	"bytes"
 	"bytes"
 	"errors"
 	"errors"
+	"io"
 	"io/ioutil"
 	"io/ioutil"
 	"strings"
 	"strings"
 	"testing"
 	"testing"
@@ -85,17 +86,22 @@ func TestWriteDoesNotReturnNegativeWrittenBytes(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func TestStdCopyWriteAndRead(t *testing.T) {
-	buffer := new(bytes.Buffer)
-	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
+func getSrcBuffer(stdOutBytes, stdErrBytes []byte) (buffer *bytes.Buffer, err error) {
+	buffer = new(bytes.Buffer)
 	dstOut := NewStdWriter(buffer, Stdout)
 	dstOut := NewStdWriter(buffer, Stdout)
-	_, err := dstOut.Write(stdOutBytes)
+	_, err = dstOut.Write(stdOutBytes)
 	if err != nil {
 	if err != nil {
-		t.Fatal(err)
+		return
 	}
 	}
-	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
 	dstErr := NewStdWriter(buffer, Stderr)
 	dstErr := NewStdWriter(buffer, Stderr)
 	_, err = dstErr.Write(stdErrBytes)
 	_, err = dstErr.Write(stdErrBytes)
+	return
+}
+
+func TestStdCopyWriteAndRead(t *testing.T) {
+	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
+	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
+	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -109,6 +115,78 @@ func TestStdCopyWriteAndRead(t *testing.T) {
 	}
 	}
 }
 }
 
 
+type customReader struct {
+	n            int
+	err          error
+	totalCalls   int
+	correctCalls int
+	src          *bytes.Buffer
+}
+
+func (f *customReader) Read(buf []byte) (int, error) {
+	f.totalCalls++
+	if f.totalCalls <= f.correctCalls {
+		return f.src.Read(buf)
+	}
+	return f.n, f.err
+}
+
+func TestStdCopyReturnsErrorReadingHeader(t *testing.T) {
+	expectedError := errors.New("error")
+	reader := &customReader{
+		err: expectedError}
+	written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
+	if written != 0 {
+		t.Fatalf("Expected 0 bytes read, got %d", written)
+	}
+	if err != expectedError {
+		t.Fatalf("Didn't get expected error")
+	}
+}
+
+func TestStdCopyReturnsErrorReadingFrame(t *testing.T) {
+	expectedError := errors.New("error")
+	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
+	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
+	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
+	if err != nil {
+		t.Fatal(err)
+	}
+	reader := &customReader{
+		correctCalls: 1,
+		n:            stdWriterPrefixLen + 1,
+		err:          expectedError,
+		src:          buffer}
+	written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
+	if written != 0 {
+		t.Fatalf("Expected 0 bytes read, got %d", written)
+	}
+	if err != expectedError {
+		t.Fatalf("Didn't get expected error")
+	}
+}
+
+func TestStdCopyDetectsCorruptedFrame(t *testing.T) {
+	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
+	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
+	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
+	if err != nil {
+		t.Fatal(err)
+	}
+	reader := &customReader{
+		correctCalls: 1,
+		n:            stdWriterPrefixLen + 1,
+		err:          io.EOF,
+		src:          buffer}
+	written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
+	if written != startingBufLen {
+		t.Fatalf("Expected 0 bytes read, got %d", written)
+	}
+	if err != nil {
+		t.Fatal("Didn't get nil error")
+	}
+}
+
 func TestStdCopyWithInvalidInputHeader(t *testing.T) {
 func TestStdCopyWithInvalidInputHeader(t *testing.T) {
 	dstOut := NewStdWriter(ioutil.Discard, Stdout)
 	dstOut := NewStdWriter(ioutil.Discard, Stdout)
 	dstErr := NewStdWriter(ioutil.Discard, Stderr)
 	dstErr := NewStdWriter(ioutil.Discard, Stderr)
@@ -131,6 +209,44 @@ func TestStdCopyWithCorruptedPrefix(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestStdCopyReturnsWriteErrors(t *testing.T) {
+	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
+	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
+	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
+	if err != nil {
+		t.Fatal(err)
+	}
+	expectedError := errors.New("expected")
+
+	dstOut := &errWriter{err: expectedError}
+
+	written, err := StdCopy(dstOut, ioutil.Discard, buffer)
+	if written != 0 {
+		t.Fatalf("StdCopy should have written 0, but has written %d", written)
+	}
+	if err != expectedError {
+		t.Fatalf("Didn't get expected error, got %v", err)
+	}
+}
+
+func TestStdCopyDetectsNotFullyWrittenFrames(t *testing.T) {
+	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
+	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
+	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
+	if err != nil {
+		t.Fatal(err)
+	}
+	dstOut := &errWriter{n: startingBufLen - 10}
+
+	written, err := StdCopy(dstOut, ioutil.Discard, buffer)
+	if written != 0 {
+		t.Fatalf("StdCopy should have return 0 written bytes, but returned %d", written)
+	}
+	if err != io.ErrShortWrite {
+		t.Fatalf("Didn't get expected io.ErrShortWrite error")
+	}
+}
+
 func BenchmarkWrite(b *testing.B) {
 func BenchmarkWrite(b *testing.B) {
 	w := NewStdWriter(ioutil.Discard, Stdout)
 	w := NewStdWriter(ioutil.Discard, Stdout)
 	data := []byte("Test line for testing stdwriter performance\n")
 	data := []byte("Test line for testing stdwriter performance\n")