Merge pull request #5526 from shykes/pr_out_beam_add_simple_framing_system_for_unixconn

This commit is contained in:
Solomon Hykes 2014-05-01 11:06:14 -07:00
commit 10a50fcd8f

View file

@ -21,6 +21,43 @@ func debugCheckpoint(msg string, args ...interface{}) {
type UnixConn struct {
*net.UnixConn
fds []*os.File
}
// Framing:
// In order to handle framing in Send/Recieve, as these give frame
// boundaries we use a very simple 4 bytes header. It is a big endiand
// uint32 where the high bit is set if the message includes a file
// descriptor. The rest of the uint32 is the length of the next frame.
// We need the bit in order to be able to assign recieved fds to
// the right message, as multiple messages may be coalesced into
// a single recieve operation.
func makeHeader(data []byte, fds []int) ([]byte, error) {
header := make([]byte, 4)
length := uint32(len(data))
if length > 0x7fffffff {
return nil, fmt.Errorf("Data to large")
}
if len(fds) != 0 {
length = length | 0x80000000
}
header[0] = byte((length >> 24) & 0xff)
header[1] = byte((length >> 16) & 0xff)
header[2] = byte((length >> 8) & 0xff)
header[3] = byte((length >> 0) & 0xff)
return header, nil
}
func parseHeader(header []byte) (uint32, bool) {
length := uint32(header[0])<<24 | uint32(header[1])<<16 | uint32(header[2])<<8 | uint32(header[3])
hasFd := length&0x80000000 != 0
length = length & ^uint32(0x80000000)
return length, hasFd
}
func FileConn(f *os.File) (*UnixConn, error) {
@ -33,7 +70,7 @@ func FileConn(f *os.File) (*UnixConn, error) {
conn.Close()
return nil, fmt.Errorf("%d: not a unix connection", f.Fd())
}
return &UnixConn{uconn}, nil
return &UnixConn{UnixConn: uconn}, nil
}
@ -52,7 +89,7 @@ func (conn *UnixConn) Send(data []byte, f *os.File) error {
if f != nil {
fds = append(fds, int(f.Fd()))
}
if err := sendUnix(conn.UnixConn, data, fds...); err != nil {
if err := conn.sendUnix(data, fds...); err != nil {
return err
}
@ -76,42 +113,104 @@ func (conn *UnixConn) Receive() (rdata []byte, rf *os.File, rerr error) {
}
debugCheckpoint("===DEBUG=== Receive() -> '%s'[%d]. Hit enter to continue.\n", rdata, fd)
}()
for {
data, fds, err := receiveUnix(conn.UnixConn)
// Read header
header := make([]byte, 4)
nRead := uint32(0)
for nRead < 4 {
n, err := conn.receiveUnix(header[nRead:])
if err != nil {
return nil, nil, err
}
var f *os.File
if len(fds) > 1 {
for _, fd := range fds[1:] {
syscall.Close(fd)
}
}
if len(fds) >= 1 {
f = os.NewFile(uintptr(fds[0]), "")
}
return data, f, nil
nRead = nRead + uint32(n)
}
panic("impossibru")
return nil, nil, nil
length, hasFd := parseHeader(header)
if hasFd {
if len(conn.fds) == 0 {
return nil, nil, fmt.Errorf("No expected file descriptor in message")
}
rf = conn.fds[0]
conn.fds = conn.fds[1:]
}
rdata = make([]byte, length)
nRead = 0
for nRead < length {
n, err := conn.receiveUnix(rdata[nRead:])
if err != nil {
return nil, nil, err
}
nRead = nRead + uint32(n)
}
return
}
func receiveUnix(conn *net.UnixConn) ([]byte, []int, error) {
buf := make([]byte, 4096)
oob := make([]byte, 4096)
func (conn *UnixConn) receiveUnix(buf []byte) (int, error) {
oob := make([]byte, syscall.CmsgSpace(4))
bufn, oobn, _, _, err := conn.ReadMsgUnix(buf, oob)
if err != nil {
return nil, nil, err
return 0, err
}
return buf[:bufn], extractFds(oob[:oobn]), nil
fd := extractFd(oob[:oobn])
if fd != -1 {
f := os.NewFile(uintptr(fd), "")
conn.fds = append(conn.fds, f)
}
return bufn, nil
}
func sendUnix(conn *net.UnixConn, data []byte, fds ...int) error {
_, _, err := conn.WriteMsgUnix(data, syscall.UnixRights(fds...), nil)
return err
func (conn *UnixConn) sendUnix(data []byte, fds ...int) error {
header, err := makeHeader(data, fds)
if err != nil {
return err
}
// There is a bug in conn.WriteMsgUnix where it doesn't correctly return
// the number of bytes writte (http://code.google.com/p/go/issues/detail?id=7645)
// So, we can't rely on the return value from it. However, we must use it to
// send the fds. In order to handle this we only write one byte using WriteMsgUnix
// (when we have to), as that can only ever block or fully suceed. We then write
// the rest with conn.Write()
// The reader side should not rely on this though, as hopefully this gets fixed
// in go later.
written := 0
if len(fds) != 0 {
oob := syscall.UnixRights(fds...)
wrote, _, err := conn.WriteMsgUnix(header[0:1], oob, nil)
if err != nil {
return err
}
written = written + wrote
}
for written < len(header) {
wrote, err := conn.Write(header[written:])
if err != nil {
return err
}
written = written + wrote
}
written = 0
for written < len(data) {
wrote, err := conn.Write(data[written:])
if err != nil {
return err
}
written = written + wrote
}
return nil
}
func extractFds(oob []byte) (fds []int) {
func extractFd(oob []byte) int {
// Grab forklock to make sure no forks accidentally inherit the new
// fds before they are made CLOEXEC
// There is a slight race condition between ReadMsgUnix returns and
@ -122,20 +221,27 @@ func extractFds(oob []byte) (fds []int) {
defer syscall.ForkLock.Unlock()
scms, err := syscall.ParseSocketControlMessage(oob)
if err != nil {
return
return -1
}
foundFd := -1
for _, scm := range scms {
gotFds, err := syscall.ParseUnixRights(&scm)
fds, err := syscall.ParseUnixRights(&scm)
if err != nil {
continue
}
fds = append(fds, gotFds...)
for _, fd := range fds {
syscall.CloseOnExec(fd)
if foundFd == -1 {
syscall.CloseOnExec(fd)
foundFd = fd
} else {
syscall.Close(fd)
}
}
}
return
return foundFd
}
func socketpair() ([2]int, error) {