Przeglądaj źródła

Merge pull request #5526 from shykes/pr_out_beam_add_simple_framing_system_for_unixconn

Solomon Hykes 11 lat temu
rodzic
commit
10a50fcd8f
1 zmienionych plików z 134 dodań i 28 usunięć
  1. 134 28
      pkg/beam/unix.go

+ 134 - 28
pkg/beam/unix.go

@@ -21,6 +21,43 @@ func debugCheckpoint(msg string, args ...interface{}) {
 
 type UnixConn struct {
 	*net.UnixConn
+	fds []*os.File
+}
+
+// Framing:
+// In order to handle framing in Send/Recieve, as these give frame
+// boundaries we use a very simple 4 bytes header. It is a big endiand
+// uint32 where the high bit is set if the message includes a file
+// descriptor. The rest of the uint32 is the length of the next frame.
+// We need the bit in order to be able to assign recieved fds to
+// the right message, as multiple messages may be coalesced into
+// a single recieve operation.
+func makeHeader(data []byte, fds []int) ([]byte, error) {
+	header := make([]byte, 4)
+
+	length := uint32(len(data))
+
+	if length > 0x7fffffff {
+		return nil, fmt.Errorf("Data to large")
+	}
+
+	if len(fds) != 0 {
+		length = length | 0x80000000
+	}
+	header[0] = byte((length >> 24) & 0xff)
+	header[1] = byte((length >> 16) & 0xff)
+	header[2] = byte((length >> 8) & 0xff)
+	header[3] = byte((length >> 0) & 0xff)
+
+	return header, nil
+}
+
+func parseHeader(header []byte) (uint32, bool) {
+	length := uint32(header[0])<<24 | uint32(header[1])<<16 | uint32(header[2])<<8 | uint32(header[3])
+	hasFd := length&0x80000000 != 0
+	length = length & ^uint32(0x80000000)
+
+	return length, hasFd
 }
 
 func FileConn(f *os.File) (*UnixConn, error) {
@@ -33,7 +70,7 @@ func FileConn(f *os.File) (*UnixConn, error) {
 		conn.Close()
 		return nil, fmt.Errorf("%d: not a unix connection", f.Fd())
 	}
-	return &UnixConn{uconn}, nil
+	return &UnixConn{UnixConn: uconn}, nil
 
 }
 
@@ -52,7 +89,7 @@ func (conn *UnixConn) Send(data []byte, f *os.File) error {
 	if f != nil {
 		fds = append(fds, int(f.Fd()))
 	}
-	if err := sendUnix(conn.UnixConn, data, fds...); err != nil {
+	if err := conn.sendUnix(data, fds...); err != nil {
 		return err
 	}
 
@@ -76,42 +113,104 @@ func (conn *UnixConn) Receive() (rdata []byte, rf *os.File, rerr error) {
 		}
 		debugCheckpoint("===DEBUG=== Receive() -> '%s'[%d]. Hit enter to continue.\n", rdata, fd)
 	}()
-	for {
-		data, fds, err := receiveUnix(conn.UnixConn)
+
+	// Read header
+	header := make([]byte, 4)
+	nRead := uint32(0)
+
+	for nRead < 4 {
+		n, err := conn.receiveUnix(header[nRead:])
 		if err != nil {
 			return nil, nil, err
 		}
-		var f *os.File
-		if len(fds) > 1 {
-			for _, fd := range fds[1:] {
-				syscall.Close(fd)
-			}
+		nRead = nRead + uint32(n)
+	}
+
+	length, hasFd := parseHeader(header)
+
+	if hasFd {
+		if len(conn.fds) == 0 {
+			return nil, nil, fmt.Errorf("No expected file descriptor in message")
 		}
-		if len(fds) >= 1 {
-			f = os.NewFile(uintptr(fds[0]), "")
+
+		rf = conn.fds[0]
+		conn.fds = conn.fds[1:]
+	}
+
+	rdata = make([]byte, length)
+
+	nRead = 0
+	for nRead < length {
+		n, err := conn.receiveUnix(rdata[nRead:])
+		if err != nil {
+			return nil, nil, err
 		}
-		return data, f, nil
+		nRead = nRead + uint32(n)
 	}
-	panic("impossibru")
-	return nil, nil, nil
+
+	return
 }
 
-func receiveUnix(conn *net.UnixConn) ([]byte, []int, error) {
-	buf := make([]byte, 4096)
-	oob := make([]byte, 4096)
+func (conn *UnixConn) receiveUnix(buf []byte) (int, error) {
+	oob := make([]byte, syscall.CmsgSpace(4))
 	bufn, oobn, _, _, err := conn.ReadMsgUnix(buf, oob)
 	if err != nil {
-		return nil, nil, err
+		return 0, err
+	}
+	fd := extractFd(oob[:oobn])
+	if fd != -1 {
+		f := os.NewFile(uintptr(fd), "")
+		conn.fds = append(conn.fds, f)
 	}
-	return buf[:bufn], extractFds(oob[:oobn]), nil
+
+	return bufn, nil
 }
 
-func sendUnix(conn *net.UnixConn, data []byte, fds ...int) error {
-	_, _, err := conn.WriteMsgUnix(data, syscall.UnixRights(fds...), nil)
-	return err
+func (conn *UnixConn) sendUnix(data []byte, fds ...int) error {
+	header, err := makeHeader(data, fds)
+	if err != nil {
+		return err
+	}
+
+	// There is a bug in conn.WriteMsgUnix where it doesn't correctly return
+	// the number of bytes writte (http://code.google.com/p/go/issues/detail?id=7645)
+	// So, we can't rely on the return value from it. However, we must use it to
+	// send the fds. In order to handle this we only write one byte using WriteMsgUnix
+	// (when we have to), as that can only ever block or fully suceed. We then write
+	// the rest with conn.Write()
+	// The reader side should not rely on this though, as hopefully this gets fixed
+	// in go later.
+	written := 0
+	if len(fds) != 0 {
+		oob := syscall.UnixRights(fds...)
+		wrote, _, err := conn.WriteMsgUnix(header[0:1], oob, nil)
+		if err != nil {
+			return err
+		}
+		written = written + wrote
+	}
+
+	for written < len(header) {
+		wrote, err := conn.Write(header[written:])
+		if err != nil {
+			return err
+		}
+		written = written + wrote
+	}
+
+	written = 0
+	for written < len(data) {
+		wrote, err := conn.Write(data[written:])
+		if err != nil {
+			return err
+		}
+		written = written + wrote
+	}
+
+	return nil
 }
 
-func extractFds(oob []byte) (fds []int) {
+func extractFd(oob []byte) int {
 	// Grab forklock to make sure no forks accidentally inherit the new
 	// fds before they are made CLOEXEC
 	// There is a slight race condition between ReadMsgUnix returns and
@@ -122,20 +221,27 @@ func extractFds(oob []byte) (fds []int) {
 	defer syscall.ForkLock.Unlock()
 	scms, err := syscall.ParseSocketControlMessage(oob)
 	if err != nil {
-		return
+		return -1
 	}
+
+	foundFd := -1
 	for _, scm := range scms {
-		gotFds, err := syscall.ParseUnixRights(&scm)
+		fds, err := syscall.ParseUnixRights(&scm)
 		if err != nil {
 			continue
 		}
-		fds = append(fds, gotFds...)
 
 		for _, fd := range fds {
-			syscall.CloseOnExec(fd)
+			if foundFd == -1 {
+				syscall.CloseOnExec(fd)
+				foundFd = fd
+			} else {
+				syscall.Close(fd)
+			}
 		}
 	}
-	return
+
+	return foundFd
 }
 
 func socketpair() ([2]int, error) {