Browse Source

Merge pull request #35496 from cpuguy83/add_timeouts_to_splunk_post

Set timeout on splunk batch send
Yong Tang 7 years ago
parent
commit
88e36dcc76

+ 12 - 3
daemon/logger/splunk/splunk.go

@@ -5,6 +5,7 @@ package splunk
 import (
 	"bytes"
 	"compress/gzip"
+	"context"
 	"crypto/tls"
 	"crypto/x509"
 	"encoding/json"
@@ -63,6 +64,8 @@ const (
 	envVarStreamChannelSize     = "SPLUNK_LOGGING_DRIVER_CHANNEL_SIZE"
 )
 
+var batchSendTimeout = 30 * time.Second
+
 type splunkLoggerInterface interface {
 	logger.Logger
 	worker()
@@ -416,13 +419,18 @@ func (l *splunkLogger) worker() {
 
 func (l *splunkLogger) postMessages(messages []*splunkMessage, lastChance bool) []*splunkMessage {
 	messagesLen := len(messages)
+
+	ctx, cancel := context.WithTimeout(context.Background(), batchSendTimeout)
+	defer cancel()
+
 	for i := 0; i < messagesLen; i += l.postMessagesBatchSize {
 		upperBound := i + l.postMessagesBatchSize
 		if upperBound > messagesLen {
 			upperBound = messagesLen
 		}
-		if err := l.tryPostMessages(messages[i:upperBound]); err != nil {
-			logrus.Error(err)
+
+		if err := l.tryPostMessages(ctx, messages[i:upperBound]); err != nil {
+			logrus.WithError(err).WithField("module", "logger/splunk").Warn("Error while sending logs")
 			if messagesLen-i >= l.bufferMaximum || lastChance {
 				// If this is last chance - print them all to the daemon log
 				if lastChance {
@@ -447,7 +455,7 @@ func (l *splunkLogger) postMessages(messages []*splunkMessage, lastChance bool)
 	return messages[:0]
 }
 
-func (l *splunkLogger) tryPostMessages(messages []*splunkMessage) error {
+func (l *splunkLogger) tryPostMessages(ctx context.Context, messages []*splunkMessage) error {
 	if len(messages) == 0 {
 		return nil
 	}
@@ -486,6 +494,7 @@ func (l *splunkLogger) tryPostMessages(messages []*splunkMessage) error {
 	if err != nil {
 		return err
 	}
+	req = req.WithContext(ctx)
 	req.Header.Set("Authorization", l.auth)
 	// Tell if we are sending gzip compressed body
 	if l.gzipCompression {

+ 49 - 2
daemon/logger/splunk/splunk_test.go

@@ -2,8 +2,10 @@ package splunk
 
 import (
 	"compress/gzip"
+	"context"
 	"fmt"
 	"os"
+	"runtime"
 	"testing"
 	"time"
 
@@ -1062,7 +1064,7 @@ func TestSkipVerify(t *testing.T) {
 		t.Fatal("No messages should be accepted at this point")
 	}
 
-	hec.simulateServerError = false
+	hec.simulateErr(false)
 
 	for i := defaultStreamChannelSize * 2; i < defaultStreamChannelSize*4; i++ {
 		if err := loggerDriver.Log(&logger.Message{Line: []byte(fmt.Sprintf("%d", i)), Source: "stdout", Timestamp: time.Now()}); err != nil {
@@ -1110,7 +1112,7 @@ func TestBufferMaximum(t *testing.T) {
 	}
 
 	hec := NewHTTPEventCollectorMock(t)
-	hec.simulateServerError = true
+	hec.simulateErr(true)
 	go hec.Serve()
 
 	info := logger.Info{
@@ -1308,3 +1310,48 @@ func TestCannotSendAfterClose(t *testing.T) {
 		t.Fatal(err)
 	}
 }
+
+func TestDeadlockOnBlockedEndpoint(t *testing.T) {
+	hec := NewHTTPEventCollectorMock(t)
+	go hec.Serve()
+	info := logger.Info{
+		Config: map[string]string{
+			splunkURLKey:   hec.URL(),
+			splunkTokenKey: hec.token,
+		},
+		ContainerID:        "containeriid",
+		ContainerName:      "/container_name",
+		ContainerImageID:   "contaimageid",
+		ContainerImageName: "container_image_name",
+	}
+
+	l, err := New(info)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	ctx, unblock := context.WithCancel(context.Background())
+	hec.withBlock(ctx)
+	defer unblock()
+
+	batchSendTimeout = 1 * time.Second
+
+	if err := l.Log(&logger.Message{}); err != nil {
+		t.Fatal(err)
+	}
+
+	done := make(chan struct{})
+	go func() {
+		l.Close()
+		close(done)
+	}()
+
+	select {
+	case <-time.After(60 * time.Second):
+		buf := make([]byte, 1e6)
+		buf = buf[:runtime.Stack(buf, true)]
+		t.Logf("STACK DUMP: \n\n%s\n\n", string(buf))
+		t.Fatal("timeout waiting for close to finish")
+	case <-done:
+	}
+}

+ 26 - 1
daemon/logger/splunk/splunkhecmock_test.go

@@ -2,12 +2,14 @@ package splunk
 
 import (
 	"compress/gzip"
+	"context"
 	"encoding/json"
 	"fmt"
 	"io"
 	"io/ioutil"
 	"net"
 	"net/http"
+	"sync"
 	"testing"
 )
 
@@ -29,8 +31,10 @@ type HTTPEventCollectorMock struct {
 	tcpAddr     *net.TCPAddr
 	tcpListener *net.TCPListener
 
+	mu                  sync.Mutex
 	token               string
 	simulateServerError bool
+	blockingCtx         context.Context
 
 	test *testing.T
 
@@ -55,6 +59,18 @@ func NewHTTPEventCollectorMock(t *testing.T) *HTTPEventCollectorMock {
 		connectionVerified:  false}
 }
 
+func (hec *HTTPEventCollectorMock) simulateErr(b bool) {
+	hec.mu.Lock()
+	hec.simulateServerError = b
+	hec.mu.Unlock()
+}
+
+func (hec *HTTPEventCollectorMock) withBlock(ctx context.Context) {
+	hec.mu.Lock()
+	hec.blockingCtx = ctx
+	hec.mu.Unlock()
+}
+
 func (hec *HTTPEventCollectorMock) URL() string {
 	return "http://" + hec.tcpListener.Addr().String()
 }
@@ -72,7 +88,16 @@ func (hec *HTTPEventCollectorMock) ServeHTTP(writer http.ResponseWriter, request
 
 	hec.numOfRequests++
 
-	if hec.simulateServerError {
+	hec.mu.Lock()
+	simErr := hec.simulateServerError
+	ctx := hec.blockingCtx
+	hec.mu.Unlock()
+
+	if ctx != nil {
+		<-hec.blockingCtx.Done()
+	}
+
+	if simErr {
 		if request.Body != nil {
 			defer request.Body.Close()
 		}