Merge pull request #5526 from shykes/pr_out_beam_add_simple_framing_system_for_unixconn
This commit is contained in:
commit
10a50fcd8f
1 changed files with 136 additions and 30 deletions
166
pkg/beam/unix.go
166
pkg/beam/unix.go
|
@ -21,6 +21,43 @@ func debugCheckpoint(msg string, args ...interface{}) {
|
|||
|
||||
type UnixConn struct {
|
||||
*net.UnixConn
|
||||
fds []*os.File
|
||||
}
|
||||
|
||||
// Framing:
|
||||
// In order to handle framing in Send/Recieve, as these give frame
|
||||
// boundaries we use a very simple 4 bytes header. It is a big endiand
|
||||
// uint32 where the high bit is set if the message includes a file
|
||||
// descriptor. The rest of the uint32 is the length of the next frame.
|
||||
// We need the bit in order to be able to assign recieved fds to
|
||||
// the right message, as multiple messages may be coalesced into
|
||||
// a single recieve operation.
|
||||
func makeHeader(data []byte, fds []int) ([]byte, error) {
|
||||
header := make([]byte, 4)
|
||||
|
||||
length := uint32(len(data))
|
||||
|
||||
if length > 0x7fffffff {
|
||||
return nil, fmt.Errorf("Data to large")
|
||||
}
|
||||
|
||||
if len(fds) != 0 {
|
||||
length = length | 0x80000000
|
||||
}
|
||||
header[0] = byte((length >> 24) & 0xff)
|
||||
header[1] = byte((length >> 16) & 0xff)
|
||||
header[2] = byte((length >> 8) & 0xff)
|
||||
header[3] = byte((length >> 0) & 0xff)
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
func parseHeader(header []byte) (uint32, bool) {
|
||||
length := uint32(header[0])<<24 | uint32(header[1])<<16 | uint32(header[2])<<8 | uint32(header[3])
|
||||
hasFd := length&0x80000000 != 0
|
||||
length = length & ^uint32(0x80000000)
|
||||
|
||||
return length, hasFd
|
||||
}
|
||||
|
||||
func FileConn(f *os.File) (*UnixConn, error) {
|
||||
|
@ -33,7 +70,7 @@ func FileConn(f *os.File) (*UnixConn, error) {
|
|||
conn.Close()
|
||||
return nil, fmt.Errorf("%d: not a unix connection", f.Fd())
|
||||
}
|
||||
return &UnixConn{uconn}, nil
|
||||
return &UnixConn{UnixConn: uconn}, nil
|
||||
|
||||
}
|
||||
|
||||
|
@ -52,7 +89,7 @@ func (conn *UnixConn) Send(data []byte, f *os.File) error {
|
|||
if f != nil {
|
||||
fds = append(fds, int(f.Fd()))
|
||||
}
|
||||
if err := sendUnix(conn.UnixConn, data, fds...); err != nil {
|
||||
if err := conn.sendUnix(data, fds...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -76,42 +113,104 @@ func (conn *UnixConn) Receive() (rdata []byte, rf *os.File, rerr error) {
|
|||
}
|
||||
debugCheckpoint("===DEBUG=== Receive() -> '%s'[%d]. Hit enter to continue.\n", rdata, fd)
|
||||
}()
|
||||
for {
|
||||
data, fds, err := receiveUnix(conn.UnixConn)
|
||||
|
||||
// Read header
|
||||
header := make([]byte, 4)
|
||||
nRead := uint32(0)
|
||||
|
||||
for nRead < 4 {
|
||||
n, err := conn.receiveUnix(header[nRead:])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
var f *os.File
|
||||
if len(fds) > 1 {
|
||||
for _, fd := range fds[1:] {
|
||||
syscall.Close(fd)
|
||||
}
|
||||
}
|
||||
if len(fds) >= 1 {
|
||||
f = os.NewFile(uintptr(fds[0]), "")
|
||||
}
|
||||
return data, f, nil
|
||||
nRead = nRead + uint32(n)
|
||||
}
|
||||
panic("impossibru")
|
||||
return nil, nil, nil
|
||||
|
||||
length, hasFd := parseHeader(header)
|
||||
|
||||
if hasFd {
|
||||
if len(conn.fds) == 0 {
|
||||
return nil, nil, fmt.Errorf("No expected file descriptor in message")
|
||||
}
|
||||
|
||||
rf = conn.fds[0]
|
||||
conn.fds = conn.fds[1:]
|
||||
}
|
||||
|
||||
rdata = make([]byte, length)
|
||||
|
||||
nRead = 0
|
||||
for nRead < length {
|
||||
n, err := conn.receiveUnix(rdata[nRead:])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
nRead = nRead + uint32(n)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func receiveUnix(conn *net.UnixConn) ([]byte, []int, error) {
|
||||
buf := make([]byte, 4096)
|
||||
oob := make([]byte, 4096)
|
||||
func (conn *UnixConn) receiveUnix(buf []byte) (int, error) {
|
||||
oob := make([]byte, syscall.CmsgSpace(4))
|
||||
bufn, oobn, _, _, err := conn.ReadMsgUnix(buf, oob)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return 0, err
|
||||
}
|
||||
return buf[:bufn], extractFds(oob[:oobn]), nil
|
||||
fd := extractFd(oob[:oobn])
|
||||
if fd != -1 {
|
||||
f := os.NewFile(uintptr(fd), "")
|
||||
conn.fds = append(conn.fds, f)
|
||||
}
|
||||
|
||||
return bufn, nil
|
||||
}
|
||||
|
||||
func sendUnix(conn *net.UnixConn, data []byte, fds ...int) error {
|
||||
_, _, err := conn.WriteMsgUnix(data, syscall.UnixRights(fds...), nil)
|
||||
return err
|
||||
func (conn *UnixConn) sendUnix(data []byte, fds ...int) error {
|
||||
header, err := makeHeader(data, fds)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// There is a bug in conn.WriteMsgUnix where it doesn't correctly return
|
||||
// the number of bytes writte (http://code.google.com/p/go/issues/detail?id=7645)
|
||||
// So, we can't rely on the return value from it. However, we must use it to
|
||||
// send the fds. In order to handle this we only write one byte using WriteMsgUnix
|
||||
// (when we have to), as that can only ever block or fully suceed. We then write
|
||||
// the rest with conn.Write()
|
||||
// The reader side should not rely on this though, as hopefully this gets fixed
|
||||
// in go later.
|
||||
written := 0
|
||||
if len(fds) != 0 {
|
||||
oob := syscall.UnixRights(fds...)
|
||||
wrote, _, err := conn.WriteMsgUnix(header[0:1], oob, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
written = written + wrote
|
||||
}
|
||||
|
||||
for written < len(header) {
|
||||
wrote, err := conn.Write(header[written:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
written = written + wrote
|
||||
}
|
||||
|
||||
written = 0
|
||||
for written < len(data) {
|
||||
wrote, err := conn.Write(data[written:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
written = written + wrote
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractFds(oob []byte) (fds []int) {
|
||||
func extractFd(oob []byte) int {
|
||||
// Grab forklock to make sure no forks accidentally inherit the new
|
||||
// fds before they are made CLOEXEC
|
||||
// There is a slight race condition between ReadMsgUnix returns and
|
||||
|
@ -122,20 +221,27 @@ func extractFds(oob []byte) (fds []int) {
|
|||
defer syscall.ForkLock.Unlock()
|
||||
scms, err := syscall.ParseSocketControlMessage(oob)
|
||||
if err != nil {
|
||||
return
|
||||
return -1
|
||||
}
|
||||
|
||||
foundFd := -1
|
||||
for _, scm := range scms {
|
||||
gotFds, err := syscall.ParseUnixRights(&scm)
|
||||
fds, err := syscall.ParseUnixRights(&scm)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fds = append(fds, gotFds...)
|
||||
|
||||
for _, fd := range fds {
|
||||
syscall.CloseOnExec(fd)
|
||||
if foundFd == -1 {
|
||||
syscall.CloseOnExec(fd)
|
||||
foundFd = fd
|
||||
} else {
|
||||
syscall.Close(fd)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
return foundFd
|
||||
}
|
||||
|
||||
func socketpair() ([2]int, error) {
|
||||
|
|
Loading…
Add table
Reference in a new issue