|
@@ -21,6 +21,43 @@ func debugCheckpoint(msg string, args ...interface{}) {
|
|
|
|
|
|
type UnixConn struct {
|
|
|
*net.UnixConn
|
|
|
+ fds []*os.File
|
|
|
+}
|
|
|
+
|
|
|
+// Framing:
|
|
|
+// In order to handle framing in Send/Recieve, as these give frame
|
|
|
+// boundaries we use a very simple 4 bytes header. It is a big endiand
|
|
|
+// uint32 where the high bit is set if the message includes a file
|
|
|
+// descriptor. The rest of the uint32 is the length of the next frame.
|
|
|
+// We need the bit in order to be able to assign recieved fds to
|
|
|
+// the right message, as multiple messages may be coalesced into
|
|
|
+// a single recieve operation.
|
|
|
+func makeHeader(data []byte, fds []int) ([]byte, error) {
|
|
|
+ header := make([]byte, 4)
|
|
|
+
|
|
|
+ length := uint32(len(data))
|
|
|
+
|
|
|
+ if length > 0x7fffffff {
|
|
|
+ return nil, fmt.Errorf("Data to large")
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(fds) != 0 {
|
|
|
+ length = length | 0x80000000
|
|
|
+ }
|
|
|
+ header[0] = byte((length >> 24) & 0xff)
|
|
|
+ header[1] = byte((length >> 16) & 0xff)
|
|
|
+ header[2] = byte((length >> 8) & 0xff)
|
|
|
+ header[3] = byte((length >> 0) & 0xff)
|
|
|
+
|
|
|
+ return header, nil
|
|
|
+}
|
|
|
+
|
|
|
+func parseHeader(header []byte) (uint32, bool) {
|
|
|
+ length := uint32(header[0])<<24 | uint32(header[1])<<16 | uint32(header[2])<<8 | uint32(header[3])
|
|
|
+ hasFd := length&0x80000000 != 0
|
|
|
+ length = length & ^uint32(0x80000000)
|
|
|
+
|
|
|
+ return length, hasFd
|
|
|
}
|
|
|
|
|
|
func FileConn(f *os.File) (*UnixConn, error) {
|
|
@@ -33,7 +70,7 @@ func FileConn(f *os.File) (*UnixConn, error) {
|
|
|
conn.Close()
|
|
|
return nil, fmt.Errorf("%d: not a unix connection", f.Fd())
|
|
|
}
|
|
|
- return &UnixConn{uconn}, nil
|
|
|
+ return &UnixConn{UnixConn: uconn}, nil
|
|
|
|
|
|
}
|
|
|
|
|
@@ -52,7 +89,7 @@ func (conn *UnixConn) Send(data []byte, f *os.File) error {
|
|
|
if f != nil {
|
|
|
fds = append(fds, int(f.Fd()))
|
|
|
}
|
|
|
- if err := sendUnix(conn.UnixConn, data, fds...); err != nil {
|
|
|
+ if err := conn.sendUnix(data, fds...); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
@@ -76,42 +113,104 @@ func (conn *UnixConn) Receive() (rdata []byte, rf *os.File, rerr error) {
|
|
|
}
|
|
|
debugCheckpoint("===DEBUG=== Receive() -> '%s'[%d]. Hit enter to continue.\n", rdata, fd)
|
|
|
}()
|
|
|
- for {
|
|
|
- data, fds, err := receiveUnix(conn.UnixConn)
|
|
|
+
|
|
|
+ // Read header
|
|
|
+ header := make([]byte, 4)
|
|
|
+ nRead := uint32(0)
|
|
|
+
|
|
|
+ for nRead < 4 {
|
|
|
+ n, err := conn.receiveUnix(header[nRead:])
|
|
|
if err != nil {
|
|
|
return nil, nil, err
|
|
|
}
|
|
|
- var f *os.File
|
|
|
- if len(fds) > 1 {
|
|
|
- for _, fd := range fds[1:] {
|
|
|
- syscall.Close(fd)
|
|
|
- }
|
|
|
+ nRead = nRead + uint32(n)
|
|
|
+ }
|
|
|
+
|
|
|
+ length, hasFd := parseHeader(header)
|
|
|
+
|
|
|
+ if hasFd {
|
|
|
+ if len(conn.fds) == 0 {
|
|
|
+ return nil, nil, fmt.Errorf("No expected file descriptor in message")
|
|
|
}
|
|
|
- if len(fds) >= 1 {
|
|
|
- f = os.NewFile(uintptr(fds[0]), "")
|
|
|
+
|
|
|
+ rf = conn.fds[0]
|
|
|
+ conn.fds = conn.fds[1:]
|
|
|
+ }
|
|
|
+
|
|
|
+ rdata = make([]byte, length)
|
|
|
+
|
|
|
+ nRead = 0
|
|
|
+ for nRead < length {
|
|
|
+ n, err := conn.receiveUnix(rdata[nRead:])
|
|
|
+ if err != nil {
|
|
|
+ return nil, nil, err
|
|
|
}
|
|
|
- return data, f, nil
|
|
|
+ nRead = nRead + uint32(n)
|
|
|
}
|
|
|
- panic("impossibru")
|
|
|
- return nil, nil, nil
|
|
|
+
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
-func receiveUnix(conn *net.UnixConn) ([]byte, []int, error) {
|
|
|
- buf := make([]byte, 4096)
|
|
|
- oob := make([]byte, 4096)
|
|
|
+func (conn *UnixConn) receiveUnix(buf []byte) (int, error) {
|
|
|
+ oob := make([]byte, syscall.CmsgSpace(4))
|
|
|
bufn, oobn, _, _, err := conn.ReadMsgUnix(buf, oob)
|
|
|
if err != nil {
|
|
|
- return nil, nil, err
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ fd := extractFd(oob[:oobn])
|
|
|
+ if fd != -1 {
|
|
|
+ f := os.NewFile(uintptr(fd), "")
|
|
|
+ conn.fds = append(conn.fds, f)
|
|
|
}
|
|
|
- return buf[:bufn], extractFds(oob[:oobn]), nil
|
|
|
+
|
|
|
+ return bufn, nil
|
|
|
}
|
|
|
|
|
|
-func sendUnix(conn *net.UnixConn, data []byte, fds ...int) error {
|
|
|
- _, _, err := conn.WriteMsgUnix(data, syscall.UnixRights(fds...), nil)
|
|
|
- return err
|
|
|
+func (conn *UnixConn) sendUnix(data []byte, fds ...int) error {
|
|
|
+ header, err := makeHeader(data, fds)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ // There is a bug in conn.WriteMsgUnix where it doesn't correctly return
|
|
|
+ // the number of bytes writte (http://code.google.com/p/go/issues/detail?id=7645)
|
|
|
+ // So, we can't rely on the return value from it. However, we must use it to
|
|
|
+ // send the fds. In order to handle this we only write one byte using WriteMsgUnix
|
|
|
+ // (when we have to), as that can only ever block or fully suceed. We then write
|
|
|
+ // the rest with conn.Write()
|
|
|
+ // The reader side should not rely on this though, as hopefully this gets fixed
|
|
|
+ // in go later.
|
|
|
+ written := 0
|
|
|
+ if len(fds) != 0 {
|
|
|
+ oob := syscall.UnixRights(fds...)
|
|
|
+ wrote, _, err := conn.WriteMsgUnix(header[0:1], oob, nil)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ written = written + wrote
|
|
|
+ }
|
|
|
+
|
|
|
+ for written < len(header) {
|
|
|
+ wrote, err := conn.Write(header[written:])
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ written = written + wrote
|
|
|
+ }
|
|
|
+
|
|
|
+ written = 0
|
|
|
+ for written < len(data) {
|
|
|
+ wrote, err := conn.Write(data[written:])
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ written = written + wrote
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
}
|
|
|
|
|
|
-func extractFds(oob []byte) (fds []int) {
|
|
|
+func extractFd(oob []byte) int {
|
|
|
// Grab forklock to make sure no forks accidentally inherit the new
|
|
|
// fds before they are made CLOEXEC
|
|
|
// There is a slight race condition between ReadMsgUnix returns and
|
|
@@ -122,20 +221,27 @@ func extractFds(oob []byte) (fds []int) {
|
|
|
defer syscall.ForkLock.Unlock()
|
|
|
scms, err := syscall.ParseSocketControlMessage(oob)
|
|
|
if err != nil {
|
|
|
- return
|
|
|
+ return -1
|
|
|
}
|
|
|
+
|
|
|
+ foundFd := -1
|
|
|
for _, scm := range scms {
|
|
|
- gotFds, err := syscall.ParseUnixRights(&scm)
|
|
|
+ fds, err := syscall.ParseUnixRights(&scm)
|
|
|
if err != nil {
|
|
|
continue
|
|
|
}
|
|
|
- fds = append(fds, gotFds...)
|
|
|
|
|
|
for _, fd := range fds {
|
|
|
- syscall.CloseOnExec(fd)
|
|
|
+ if foundFd == -1 {
|
|
|
+ syscall.CloseOnExec(fd)
|
|
|
+ foundFd = fd
|
|
|
+ } else {
|
|
|
+ syscall.Close(fd)
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
- return
|
|
|
+
|
|
|
+ return foundFd
|
|
|
}
|
|
|
|
|
|
func socketpair() ([2]int, error) {
|