123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289 |
- package stdcopy
- import (
- "bytes"
- "errors"
- "io"
- "io/ioutil"
- "strings"
- "testing"
- )
- func TestNewStdWriter(t *testing.T) {
- writer := NewStdWriter(ioutil.Discard, Stdout)
- if writer == nil {
- t.Fatalf("NewStdWriter with an invalid StdType should not return nil.")
- }
- }
- func TestWriteWithUnitializedStdWriter(t *testing.T) {
- writer := stdWriter{
- Writer: nil,
- prefix: byte(Stdout),
- }
- n, err := writer.Write([]byte("Something here"))
- if n != 0 || err == nil {
- t.Fatalf("Should fail when given an uncomplete or uninitialized StdWriter")
- }
- }
- func TestWriteWithNilBytes(t *testing.T) {
- writer := NewStdWriter(ioutil.Discard, Stdout)
- n, err := writer.Write(nil)
- if err != nil {
- t.Fatalf("Shouldn't have fail when given no data")
- }
- if n > 0 {
- t.Fatalf("Write should have written 0 byte, but has written %d", n)
- }
- }
- func TestWrite(t *testing.T) {
- writer := NewStdWriter(ioutil.Discard, Stdout)
- data := []byte("Test StdWrite.Write")
- n, err := writer.Write(data)
- if err != nil {
- t.Fatalf("Error while writing with StdWrite")
- }
- if n != len(data) {
- t.Fatalf("Write should have written %d byte but wrote %d.", len(data), n)
- }
- }
- type errWriter struct {
- n int
- err error
- }
- func (f *errWriter) Write(buf []byte) (int, error) {
- return f.n, f.err
- }
- func TestWriteWithWriterError(t *testing.T) {
- expectedError := errors.New("expected")
- expectedReturnedBytes := 10
- writer := NewStdWriter(&errWriter{
- n: stdWriterPrefixLen + expectedReturnedBytes,
- err: expectedError}, Stdout)
- data := []byte("This won't get written, sigh")
- n, err := writer.Write(data)
- if err != expectedError {
- t.Fatalf("Didn't get expected error.")
- }
- if n != expectedReturnedBytes {
- t.Fatalf("Didn't get expected written bytes %d, got %d.",
- expectedReturnedBytes, n)
- }
- }
- func TestWriteDoesNotReturnNegativeWrittenBytes(t *testing.T) {
- writer := NewStdWriter(&errWriter{n: -1}, Stdout)
- data := []byte("This won't get written, sigh")
- actual, _ := writer.Write(data)
- if actual != 0 {
- t.Fatalf("Expected returned written bytes equal to 0, got %d", actual)
- }
- }
- func getSrcBuffer(stdOutBytes, stdErrBytes []byte) (buffer *bytes.Buffer, err error) {
- buffer = new(bytes.Buffer)
- dstOut := NewStdWriter(buffer, Stdout)
- _, err = dstOut.Write(stdOutBytes)
- if err != nil {
- return
- }
- dstErr := NewStdWriter(buffer, Stderr)
- _, err = dstErr.Write(stdErrBytes)
- return
- }
- func TestStdCopyWriteAndRead(t *testing.T) {
- stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
- stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
- buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
- if err != nil {
- t.Fatal(err)
- }
- written, err := StdCopy(ioutil.Discard, ioutil.Discard, buffer)
- if err != nil {
- t.Fatal(err)
- }
- expectedTotalWritten := len(stdOutBytes) + len(stdErrBytes)
- if written != int64(expectedTotalWritten) {
- t.Fatalf("Expected to have total of %d bytes written, got %d", expectedTotalWritten, written)
- }
- }
- type customReader struct {
- n int
- err error
- totalCalls int
- correctCalls int
- src *bytes.Buffer
- }
- func (f *customReader) Read(buf []byte) (int, error) {
- f.totalCalls++
- if f.totalCalls <= f.correctCalls {
- return f.src.Read(buf)
- }
- return f.n, f.err
- }
- func TestStdCopyReturnsErrorReadingHeader(t *testing.T) {
- expectedError := errors.New("error")
- reader := &customReader{
- err: expectedError}
- written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
- if written != 0 {
- t.Fatalf("Expected 0 bytes read, got %d", written)
- }
- if err != expectedError {
- t.Fatalf("Didn't get expected error")
- }
- }
- func TestStdCopyReturnsErrorReadingFrame(t *testing.T) {
- expectedError := errors.New("error")
- stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
- stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
- buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
- if err != nil {
- t.Fatal(err)
- }
- reader := &customReader{
- correctCalls: 1,
- n: stdWriterPrefixLen + 1,
- err: expectedError,
- src: buffer}
- written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
- if written != 0 {
- t.Fatalf("Expected 0 bytes read, got %d", written)
- }
- if err != expectedError {
- t.Fatalf("Didn't get expected error")
- }
- }
- func TestStdCopyDetectsCorruptedFrame(t *testing.T) {
- stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
- stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
- buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
- if err != nil {
- t.Fatal(err)
- }
- reader := &customReader{
- correctCalls: 1,
- n: stdWriterPrefixLen + 1,
- err: io.EOF,
- src: buffer}
- written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
- if written != startingBufLen {
- t.Fatalf("Expected %d bytes read, got %d", startingBufLen, written)
- }
- if err != nil {
- t.Fatal("Didn't get nil error")
- }
- }
- func TestStdCopyWithInvalidInputHeader(t *testing.T) {
- dstOut := NewStdWriter(ioutil.Discard, Stdout)
- dstErr := NewStdWriter(ioutil.Discard, Stderr)
- src := strings.NewReader("Invalid input")
- _, err := StdCopy(dstOut, dstErr, src)
- if err == nil {
- t.Fatal("StdCopy with invalid input header should fail.")
- }
- }
- func TestStdCopyWithCorruptedPrefix(t *testing.T) {
- data := []byte{0x01, 0x02, 0x03}
- src := bytes.NewReader(data)
- written, err := StdCopy(nil, nil, src)
- if err != nil {
- t.Fatalf("StdCopy should not return an error with corrupted prefix.")
- }
- if written != 0 {
- t.Fatalf("StdCopy should have written 0, but has written %d", written)
- }
- }
- func TestStdCopyReturnsWriteErrors(t *testing.T) {
- stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
- stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
- buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
- if err != nil {
- t.Fatal(err)
- }
- expectedError := errors.New("expected")
- dstOut := &errWriter{err: expectedError}
- written, err := StdCopy(dstOut, ioutil.Discard, buffer)
- if written != 0 {
- t.Fatalf("StdCopy should have written 0, but has written %d", written)
- }
- if err != expectedError {
- t.Fatalf("Didn't get expected error, got %v", err)
- }
- }
- func TestStdCopyDetectsNotFullyWrittenFrames(t *testing.T) {
- stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
- stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
- buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
- if err != nil {
- t.Fatal(err)
- }
- dstOut := &errWriter{n: startingBufLen - 10}
- written, err := StdCopy(dstOut, ioutil.Discard, buffer)
- if written != 0 {
- t.Fatalf("StdCopy should have return 0 written bytes, but returned %d", written)
- }
- if err != io.ErrShortWrite {
- t.Fatalf("Didn't get expected io.ErrShortWrite error")
- }
- }
- // TestStdCopyReturnsErrorFromSystem tests that StdCopy correctly returns an
- // error, when that error is muxed into the Systemerr stream.
- func TestStdCopyReturnsErrorFromSystem(t *testing.T) {
- // write in the basic messages, just so there's some fluff in there
- stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
- stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
- buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
- if err != nil {
- t.Fatal(err)
- }
- // add in an error message on the Systemerr stream
- systemErrBytes := []byte(strings.Repeat("S", startingBufLen))
- systemWriter := NewStdWriter(buffer, Systemerr)
- _, err = systemWriter.Write(systemErrBytes)
- if err != nil {
- t.Fatal(err)
- }
- // now copy and demux. we should expect an error containing the string we
- // wrote out
- _, err = StdCopy(ioutil.Discard, ioutil.Discard, buffer)
- if err == nil {
- t.Fatal("expected error, got none")
- }
- if !strings.Contains(err.Error(), string(systemErrBytes)) {
- t.Fatal("expected error to contain message")
- }
- }
- func BenchmarkWrite(b *testing.B) {
- w := NewStdWriter(ioutil.Discard, Stdout)
- data := []byte("Test line for testing stdwriter performance\n")
- data = bytes.Repeat(data, 100)
- b.SetBytes(int64(len(data)))
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- if _, err := w.Write(data); err != nil {
- b.Fatal(err)
- }
- }
- }
|