소스 검색

Set socket timeout on netlink sockets

In case the file descriptor of the netlink socket is closed
the recvfrom is not returning. This may create deadlock conditions.
The current solution is to make sure that all the netlink socket used
have a proper timeout set on them to have the possibility to return

Added test to emulate the watchMiss condition

Signed-off-by: Flavio Crisciani <flavio.crisciani@docker.com>
Flavio Crisciani 7 년 전
부모
커밋
6736b223ec
4개의 변경된 파일70개의 추가작업 그리고 4개의 파일을 삭제
  1. 11 0
      libnetwork/drivers/overlay/ov_network.go
  2. 36 0
      libnetwork/drivers/overlay/overlay_test.go
  3. 16 0
      libnetwork/ipvs/ipvs.go
  4. 7 4
      libnetwork/ipvs/netlink.go

+ 11 - 0
libnetwork/drivers/overlay/ov_network.go

@@ -696,6 +696,12 @@ func (n *network) initSandbox(restore bool) error {
 	var nlSock *nl.NetlinkSocket
 	sbox.InvokeFunc(func() {
 		nlSock, err = nl.Subscribe(syscall.NETLINK_ROUTE, syscall.RTNLGRP_NEIGH)
+		if err != nil {
+			return
+		}
+		// set the receive timeout to not remain stuck on the RecvFrom if the fd gets closed
+		tv := syscall.NsecToTimeval(soTimeout.Nanoseconds())
+		err = nlSock.SetReceiveTimeout(&tv)
 	})
 	n.setNetlinkSocket(nlSock)
 
@@ -721,6 +727,11 @@ func (n *network) watchMiss(nlSock *nl.NetlinkSocket) {
 				// The netlink socket got closed, simply exit to not leak this goroutine
 				return
 			}
+			// When the receive timeout expires the receive will return EAGAIN
+			if err == syscall.EAGAIN {
+				// we continue here to avoid spam for timeouts
+				continue
+			}
 			logrus.Errorf("Failed to receive from netlink: %v ", err)
 			continue
 		}

+ 36 - 0
libnetwork/drivers/overlay/overlay_test.go

@@ -1,7 +1,9 @@
 package overlay
 
 import (
+	"context"
 	"net"
+	"syscall"
 	"testing"
 	"time"
 
@@ -12,6 +14,7 @@ import (
 	"github.com/docker/libnetwork/driverapi"
 	"github.com/docker/libnetwork/netlabel"
 	_ "github.com/docker/libnetwork/testutils"
+	"github.com/vishvananda/netlink/nl"
 )
 
 func init() {
@@ -135,3 +138,36 @@ func TestOverlayType(t *testing.T) {
 			dt.d.Type())
 	}
 }
+
+// Test that the netlink socket close unblock the watchMiss to avoid deadlock
+func TestNetlinkSocket(t *testing.T) {
+	// This is the same code used by the overlay driver to create the netlink interface
+	// for the watch miss
+	nlSock, err := nl.Subscribe(syscall.NETLINK_ROUTE, syscall.RTNLGRP_NEIGH)
+	if err != nil {
+		t.Fatal()
+	}
+	// set the receive timeout to not remain stuck on the RecvFrom if the fd gets closed
+	tv := syscall.NsecToTimeval(soTimeout.Nanoseconds())
+	err = nlSock.SetReceiveTimeout(&tv)
+	if err != nil {
+		t.Fatal()
+	}
+	n := &network{id: "testnetid"}
+	ch := make(chan error)
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+	defer cancel()
+	go func() {
+		n.watchMiss(nlSock)
+		ch <- nil
+	}()
+	time.Sleep(5 * time.Second)
+	nlSock.Close()
+	select {
+	case <-ch:
+	case <-ctx.Done():
+		{
+			t.Fatalf("Timeout expired")
+		}
+	}
+}

+ 16 - 0
libnetwork/ipvs/ipvs.go

@@ -5,12 +5,19 @@ package ipvs
 import (
 	"net"
 	"syscall"
+	"time"
 
 	"fmt"
+
 	"github.com/vishvananda/netlink/nl"
 	"github.com/vishvananda/netns"
 )
 
+const (
+	netlinkRecvSocketsTimeout = 3 * time.Second
+	netlinkSendSocketTimeout  = 30 * time.Second
+)
+
 // Service defines an IPVS service in its entirety.
 type Service struct {
 	// Virtual service address.
@@ -82,6 +89,15 @@ func New(path string) (*Handle, error) {
 	if err != nil {
 		return nil, err
 	}
+	// Add operation timeout to avoid deadlocks
+	tv := syscall.NsecToTimeval(netlinkSendSocketTimeout.Nanoseconds())
+	if err := sock.SetSendTimeout(&tv); err != nil {
+		return nil, err
+	}
+	tv = syscall.NsecToTimeval(netlinkRecvSocketsTimeout.Nanoseconds())
+	if err := sock.SetReceiveTimeout(&tv); err != nil {
+		return nil, err
+	}
 
 	return &Handle{sock: sock}, nil
 }

+ 7 - 4
libnetwork/ipvs/netlink.go

@@ -203,10 +203,6 @@ func newGenlRequest(familyID int, cmd uint8) *nl.NetlinkRequest {
 }
 
 func execute(s *nl.NetlinkSocket, req *nl.NetlinkRequest, resType uint16) ([][]byte, error) {
-	var (
-		err error
-	)
-
 	if err := s.Send(req); err != nil {
 		return nil, err
 	}
@@ -222,6 +218,13 @@ done:
 	for {
 		msgs, err := s.Receive()
 		if err != nil {
+			if s.GetFd() == -1 {
+				return nil, fmt.Errorf("Socket got closed on receive")
+			}
+			if err == syscall.EAGAIN {
+				// timeout fired
+				continue
+			}
 			return nil, err
 		}
 		for _, m := range msgs {