ソースを参照

Make StdCopy works with huge amount of data

Guillaume J. Charmes 12 年 前
コミット
e854b7b2e6
6 ファイル変更212 行追加116 行削除
  1. 21 10
      api.go
  2. 2 2
      api_test.go
  3. 1 1
      commands.go
  4. 9 16
      server.go
  5. 179 0
      utils/stdcopy.go
  6. 0 87
      utils/utils.go

+ 21 - 10
api.go

@@ -766,32 +766,43 @@ func postContainersAttach(srv *Server, version float64, w http.ResponseWriter, r
 	}
 	name := vars["name"]
 
-	if _, err := srv.ContainerInspect(name); err != nil {
+	c, err := srv.ContainerInspect(name)
+	if err != nil {
 		return err
 	}
 
-	in, out, err := hijackServer(w)
+	inStream, outStream, err := hijackServer(w)
 	if err != nil {
 		return err
 	}
 	defer func() {
-		if tcpc, ok := in.(*net.TCPConn); ok {
+		if tcpc, ok := inStream.(*net.TCPConn); ok {
 			tcpc.CloseWrite()
 		} else {
-			in.Close()
+			inStream.Close()
 		}
 	}()
 	defer func() {
-		if tcpc, ok := out.(*net.TCPConn); ok {
+		if tcpc, ok := outStream.(*net.TCPConn); ok {
 			tcpc.CloseWrite()
-		} else if closer, ok := out.(io.Closer); ok {
+		} else if closer, ok := outStream.(io.Closer); ok {
 			closer.Close()
 		}
 	}()
 
-	fmt.Fprintf(out, "HTTP/1.1 200 OK\r\nContent-Type: application/vnd.docker.raw-stream\r\n\r\n")
-	if err := srv.ContainerAttach(name, logs, stream, stdin, stdout, stderr, in, out); err != nil {
-		fmt.Fprintf(out, "Error: %s\n", err)
+	var errStream io.Writer
+
+	fmt.Fprintf(outStream, "HTTP/1.1 200 OK\r\nContent-Type: application/vnd.docker.raw-stream\r\n\r\n")
+
+	if !c.Config.Tty && version >= 1.4 {
+		errStream = utils.NewStdWriter(outStream, utils.Stderr)
+		outStream = utils.NewStdWriter(outStream, utils.Stdout)
+	} else {
+		errStream = outStream
+	}
+
+	if err := srv.ContainerAttach(name, logs, stream, stdin, stdout, stderr, inStream, outStream, errStream); err != nil {
+		fmt.Fprintf(outStream, "Error: %s\n", err)
 	}
 	return nil
 }
@@ -834,7 +845,7 @@ func wsContainersAttach(srv *Server, version float64, w http.ResponseWriter, r *
 	h := websocket.Handler(func(ws *websocket.Conn) {
 		defer ws.Close()
 
-		if err := srv.ContainerAttach(name, logs, stream, stdin, stdout, stderr, ws, ws); err != nil {
+		if err := srv.ContainerAttach(name, logs, stream, stdin, stdout, stderr, ws, ws, ws); err != nil {
 			utils.Debugf("Error: %s", err)
 		}
 	})

+ 2 - 2
api_test.go

@@ -951,7 +951,7 @@ func TestPostContainersAttach(t *testing.T) {
 	})
 
 	setTimeout(t, "read/write assertion timed out", 2*time.Second, func() {
-		if err := assertPipe("hello\n", string(utils.Stdout)+"hello", stdout, stdinPipe, 15); err != nil {
+		if err := assertPipe("hello\n", string([]byte{1, 0, 0, 0, 6, 0, 0, 0})+"hello", stdout, stdinPipe, 15); err != nil {
 			t.Fatal(err)
 		}
 	})
@@ -1040,7 +1040,7 @@ func TestPostContainersAttachStderr(t *testing.T) {
 	})
 
 	setTimeout(t, "read/write assertion timed out", 2*time.Second, func() {
-		if err := assertPipe("hello\n", string(utils.Stderr)+"hello", stdout, stdinPipe, 15); err != nil {
+		if err := assertPipe("hello\n", string([]byte{2, 0, 0, 0, 6, 0, 0, 0})+"hello", stdout, stdinPipe, 15); err != nil {
 			t.Fatal(err)
 		}
 	})

+ 1 - 1
commands.go

@@ -1570,7 +1570,7 @@ func (cli *DockerCli) CmdRun(args ...string) error {
 			return err
 		}
 		if status != 0 {
-			return &utils.StatusError{status}
+			return &utils.StatusError{Status: status}
 		}
 	}
 

+ 9 - 16
server.go

@@ -1175,11 +1175,12 @@ func (srv *Server) ContainerResize(name string, h, w int) error {
 	return fmt.Errorf("No such container: %s", name)
 }
 
-func (srv *Server) ContainerAttach(name string, logs, stream, stdin, stdout, stderr bool, in io.ReadCloser, out io.Writer) error {
+func (srv *Server) ContainerAttach(name string, logs, stream, stdin, stdout, stderr bool, inStream io.ReadCloser, outStream, errStream io.Writer) error {
 	container := srv.runtime.Get(name)
 	if container == nil {
 		return fmt.Errorf("No such container: %s", name)
 	}
+
 	//logs
 	if logs {
 		cLog, err := container.ReadLog("json")
@@ -1190,7 +1191,7 @@ func (srv *Server) ContainerAttach(name string, logs, stream, stdin, stdout, std
 				cLog, err := container.ReadLog("stdout")
 				if err != nil {
 					utils.Debugf("Error reading logs (stdout): %s", err)
-				} else if _, err := io.Copy(out, cLog); err != nil {
+				} else if _, err := io.Copy(outStream, cLog); err != nil {
 					utils.Debugf("Error streaming logs (stdout): %s", err)
 				}
 			}
@@ -1198,7 +1199,7 @@ func (srv *Server) ContainerAttach(name string, logs, stream, stdin, stdout, std
 				cLog, err := container.ReadLog("stderr")
 				if err != nil {
 					utils.Debugf("Error reading logs (stderr): %s", err)
-				} else if _, err := io.Copy(out, cLog); err != nil {
+				} else if _, err := io.Copy(errStream, cLog); err != nil {
 					utils.Debugf("Error streaming logs (stderr): %s", err)
 				}
 			}
@@ -1215,7 +1216,7 @@ func (srv *Server) ContainerAttach(name string, logs, stream, stdin, stdout, std
 					break
 				}
 				if (l.Stream == "stdout" && stdout) || (l.Stream == "stderr" && stderr) {
-					fmt.Fprintf(out, "%s", l.Log)
+					fmt.Fprintf(outStream, "%s", l.Log)
 				}
 			}
 		}
@@ -1238,24 +1239,16 @@ func (srv *Server) ContainerAttach(name string, logs, stream, stdin, stdout, std
 			go func() {
 				defer w.Close()
 				defer utils.Debugf("Closing buffered stdin pipe")
-				io.Copy(w, in)
+				io.Copy(w, inStream)
 			}()
 			cStdin = r
-			cStdinCloser = in
+			cStdinCloser = inStream
 		}
 		if stdout {
-			if container.Config.Tty {
-				cStdout = out
-			} else {
-				cStdout = utils.NewStdWriter(out, utils.Stdout)
-			}
+			cStdout = outStream
 		}
 		if stderr {
-			if container.Config.Tty {
-				cStderr = out
-			} else {
-				cStderr = utils.NewStdWriter(out, utils.Stderr)
-			}
+			cStderr = errStream
 		}
 
 		<-container.Attach(cStdin, cStdinCloser, cStdout, cStderr)

+ 179 - 0
utils/stdcopy.go

@@ -0,0 +1,179 @@
+package utils
+
+import (
+	"encoding/binary"
+	"errors"
+	"io"
+	"unsafe"
+)
+
+func CheckBigEndian() bool {
+	var x uint32 = 0x01020304
+
+	if *(*byte)(unsafe.Pointer(&x)) == 0x01 {
+		return true
+	}
+	return false
+}
+
+const (
+	StdWriterPrefixLen = 8
+	StdWriterFdIndex   = 0
+	StdWriterSizeIndex = 4
+)
+
+type StdType [StdWriterPrefixLen]byte
+
+var (
+	Stdin  StdType = StdType{0: 0}
+	Stdout StdType = StdType{0: 1}
+	Stderr StdType = StdType{0: 2}
+)
+
+type StdWriter struct {
+	io.Writer
+	prefix    StdType
+	sizeBuf   []byte
+	byteOrder binary.ByteOrder
+}
+
+func (w *StdWriter) Write(buf []byte) (n int, err error) {
+	if w == nil || w.Writer == nil {
+		return 0, errors.New("Writer not instanciated")
+	}
+	w.byteOrder.PutUint32(w.prefix[4:], uint32(len(buf)))
+	buf = append(w.prefix[:], buf...)
+
+	n, err = w.Writer.Write(buf)
+	return n - StdWriterPrefixLen, err
+}
+
+// NewStdWriter instanciate a new Writer based on the given type `t`.
+// the utils package contains the valid parametres for `t`:
+func NewStdWriter(w io.Writer, t StdType) *StdWriter {
+	if len(t) != StdWriterPrefixLen {
+		return nil
+	}
+
+	var bo binary.ByteOrder
+
+	if CheckBigEndian() {
+		bo = binary.BigEndian
+	} else {
+		bo = binary.LittleEndian
+	}
+	return &StdWriter{
+		Writer:    w,
+		prefix:    t,
+		sizeBuf:   make([]byte, 4),
+		byteOrder: bo,
+	}
+}
+
+var ErrInvalidStdHeader = errors.New("Unrecognized input header")
+
+// StdCopy is a modified version of io.Copy.
+//
+// StdCopy copies from src to dstout or dsterr until either EOF is reached
+// on src or an error occurs.  It returns the number of bytes
+// copied and the first error encountered while copying, if any.
+//
+// A successful Copy returns err == nil, not err == EOF.
+// Because Copy is defined to read from src until EOF, it does
+// not treat an EOF from Read as an error to be reported.
+//
+// The source needs to be writter via StdWriter, dstout or dsterr is selected
+// based on the prefix added by StdWriter
+func StdCopy(dstout, dsterr io.Writer, src io.Reader) (written int64, err error) {
+	var (
+		buf       = make([]byte, 32*1024+StdWriterPrefixLen+1)
+		bufLen    = len(buf)
+		nr, nw    int
+		er, ew    error
+		out       io.Writer
+		byteOrder binary.ByteOrder
+		frameSize int
+	)
+
+	// Check the machine's endianness
+	if CheckBigEndian() {
+		byteOrder = binary.BigEndian
+	} else {
+		byteOrder = binary.LittleEndian
+	}
+
+	for {
+		// Make sure we have at least a full header
+		for nr < StdWriterPrefixLen {
+			var nr2 int
+			nr2, er = src.Read(buf[nr:])
+			if er == io.EOF {
+				return written, nil
+			}
+			if er != nil {
+				return 0, er
+			}
+			nr += nr2
+		}
+
+		// Check the first byte to know where to write
+		switch buf[StdWriterFdIndex] {
+		case 0:
+			fallthrough
+		case 1:
+			// Write on stdout
+			out = dstout
+		case 2:
+			// Write on stderr
+			out = dsterr
+		default:
+			Debugf("Error selecting output fd: (%d)", buf[StdWriterFdIndex])
+			return 0, ErrInvalidStdHeader
+		}
+
+		// Retrieve the size of the frame
+		frameSize = int(byteOrder.Uint32(buf[StdWriterSizeIndex : StdWriterSizeIndex+4]))
+
+		// Check if the buffer is big enough to read the frame.
+		// Extend it if necessary.
+		if frameSize+StdWriterPrefixLen > bufLen {
+			Debugf("Extending buffer cap.")
+			buf = append(buf, make([]byte, frameSize-len(buf)+1)...)
+			bufLen = len(buf)
+		}
+
+		// While the amount of bytes read is less than the size of the frame + header, we keep reading
+		for nr < frameSize+StdWriterPrefixLen {
+			var nr2 int
+			nr2, er = src.Read(buf[nr:])
+			if er == io.EOF {
+				return written, nil
+			}
+			if er != nil {
+				Debugf("Error reading frame: %s", er)
+				return 0, er
+			}
+			nr += nr2
+		}
+
+		// Write the retrieved frame (without header)
+		nw, ew = out.Write(buf[StdWriterPrefixLen : frameSize+StdWriterPrefixLen])
+		if nw > 0 {
+			written += int64(nw)
+		}
+		if ew != nil {
+			Debugf("Error writing frame: %s", ew)
+			return 0, ew
+		}
+		// If the frame has not been fully written: error
+		if nw != frameSize {
+			Debugf("Error Short Write: (%d on %d)", nw, frameSize)
+			return 0, io.ErrShortWrite
+		}
+
+		// Move the rest of the buffer to the beginning
+		copy(buf, buf[frameSize+StdWriterPrefixLen:])
+		// Move the index
+		nr -= frameSize + StdWriterPrefixLen
+	}
+}

+ 0 - 87
utils/utils.go

@@ -1021,90 +1021,3 @@ type StatusError struct {
 func (e *StatusError) Error() string {
 	return fmt.Sprintf("Status: %d", e.Status)
 }
-
-type StdType []byte
-
-const StdWriterPrefixLen = 8
-
-var (
-	Stdin  StdType = StdType("\001 stdin\002")
-	Stdout StdType = StdType("\001stdout\002")
-	Stderr StdType = StdType("\001stderr\002")
-)
-
-type StdWriter struct {
-	io.Writer
-	prefix []byte
-}
-
-func (w *StdWriter) Write(buf []byte) (n int, err error) {
-	n, err = w.Writer.Write(append(w.prefix, buf...))
-	if n >= len(buf)+StdWriterPrefixLen {
-		n -= StdWriterPrefixLen
-	}
-	return n, err
-}
-
-// NewStdWriter instanciate a new Writer based on the given type `t`.
-// the utils package contains the valid parametres for `t`:
-func NewStdWriter(w io.Writer, t StdType) *StdWriter {
-	if len(t) != StdWriterPrefixLen {
-		return nil
-	}
-	return &StdWriter{
-		Writer: w,
-		prefix: []byte(t),
-	}
-}
-
-// StdCopy is a modified version of io.Copy.
-//
-// StdCopy copies from src to dstout or dsterr until either EOF is reached
-// on src or an error occurs.  It returns the number of bytes
-// copied and the first error encountered while copying, if any.
-//
-// A successful Copy returns err == nil, not err == EOF.
-// Because Copy is defined to read from src until EOF, it does
-// not treat an EOF from Read as an error to be reported.
-//
-// The source needs to be writter via StdWriter, dstout or dsterr is selected
-// based on the prefix added by StdWriter
-func StdCopy(dstout, dsterr io.Writer, src io.Reader) (written int64, err error) {
-	var (
-		buf = make([]byte, 32*1024)
-		nw  int
-		ew  error
-	)
-
-	for {
-		nr, er := src.Read(buf)
-		if nr > 0 {
-			if bytes.Compare(buf[:StdWriterPrefixLen], Stdout) == 0 {
-				nw, ew = dstout.Write(buf[StdWriterPrefixLen:nr])
-			} else if bytes.Compare(buf[:StdWriterPrefixLen], Stderr) == 0 {
-				nw, ew = dsterr.Write(buf[StdWriterPrefixLen:nr])
-			} else if bytes.Compare(buf[:StdWriterPrefixLen], Stdin) == 0 {
-				nw, ew = dstout.Write(buf[StdWriterPrefixLen:nr])
-			}
-			if nw > 0 {
-				written += int64(nw)
-			}
-			if ew != nil {
-				err = ew
-				break
-			}
-			if nr-StdWriterPrefixLen != nw {
-				err = io.ErrShortWrite
-				break
-			}
-		}
-		if er == io.EOF {
-			break
-		}
-		if er != nil {
-			err = er
-			break
-		}
-	}
-	return written, err
-}