Browse code

Set timeout on splunk batch send

Before this change, if the splunk endpoint is blocked it will cause a
deadlock on `Close()`.
This sets a reasonable timeout for the http request to send a log batch.

Signed-off-by: Brian Goff <cpuguy83@gmail.com>

Brian Goff authored on 2017/11/15 00:15:38
Showing 3 changed files
... ...
@@ -5,6 +5,7 @@ package splunk
5 5
 import (
6 6
 	"bytes"
7 7
 	"compress/gzip"
8
+	"context"
8 9
 	"crypto/tls"
9 10
 	"crypto/x509"
10 11
 	"encoding/json"
... ...
@@ -63,6 +64,8 @@ const (
63 63
 	envVarStreamChannelSize     = "SPLUNK_LOGGING_DRIVER_CHANNEL_SIZE"
64 64
 )
65 65
 
66
+var batchSendTimeout = 30 * time.Second
67
+
66 68
 type splunkLoggerInterface interface {
67 69
 	logger.Logger
68 70
 	worker()
... ...
@@ -416,13 +419,18 @@ func (l *splunkLogger) worker() {
416 416
 
417 417
 func (l *splunkLogger) postMessages(messages []*splunkMessage, lastChance bool) []*splunkMessage {
418 418
 	messagesLen := len(messages)
419
+
420
+	ctx, cancel := context.WithTimeout(context.Background(), batchSendTimeout)
421
+	defer cancel()
422
+
419 423
 	for i := 0; i < messagesLen; i += l.postMessagesBatchSize {
420 424
 		upperBound := i + l.postMessagesBatchSize
421 425
 		if upperBound > messagesLen {
422 426
 			upperBound = messagesLen
423 427
 		}
424
-		if err := l.tryPostMessages(messages[i:upperBound]); err != nil {
425
-			logrus.Error(err)
428
+
429
+		if err := l.tryPostMessages(ctx, messages[i:upperBound]); err != nil {
430
+			logrus.WithError(err).WithField("module", "logger/splunk").Warn("Error while sending logs")
426 431
 			if messagesLen-i >= l.bufferMaximum || lastChance {
427 432
 				// If this is last chance - print them all to the daemon log
428 433
 				if lastChance {
... ...
@@ -447,7 +455,7 @@ func (l *splunkLogger) postMessages(messages []*splunkMessage, lastChance bool)
447 447
 	return messages[:0]
448 448
 }
449 449
 
450
-func (l *splunkLogger) tryPostMessages(messages []*splunkMessage) error {
450
+func (l *splunkLogger) tryPostMessages(ctx context.Context, messages []*splunkMessage) error {
451 451
 	if len(messages) == 0 {
452 452
 		return nil
453 453
 	}
... ...
@@ -486,6 +494,7 @@ func (l *splunkLogger) tryPostMessages(messages []*splunkMessage) error {
486 486
 	if err != nil {
487 487
 		return err
488 488
 	}
489
+	req = req.WithContext(ctx)
489 490
 	req.Header.Set("Authorization", l.auth)
490 491
 	// Tell if we are sending gzip compressed body
491 492
 	if l.gzipCompression {
... ...
@@ -2,8 +2,10 @@ package splunk
2 2
 
3 3
 import (
4 4
 	"compress/gzip"
5
+	"context"
5 6
 	"fmt"
6 7
 	"os"
8
+	"runtime"
7 9
 	"testing"
8 10
 	"time"
9 11
 
... ...
@@ -1062,7 +1064,7 @@ func TestSkipVerify(t *testing.T) {
1062 1062
 		t.Fatal("No messages should be accepted at this point")
1063 1063
 	}
1064 1064
 
1065
-	hec.simulateServerError = false
1065
+	hec.simulateErr(false)
1066 1066
 
1067 1067
 	for i := defaultStreamChannelSize * 2; i < defaultStreamChannelSize*4; i++ {
1068 1068
 		if err := loggerDriver.Log(&logger.Message{Line: []byte(fmt.Sprintf("%d", i)), Source: "stdout", Timestamp: time.Now()}); err != nil {
... ...
@@ -1110,7 +1112,7 @@ func TestBufferMaximum(t *testing.T) {
1110 1110
 	}
1111 1111
 
1112 1112
 	hec := NewHTTPEventCollectorMock(t)
1113
-	hec.simulateServerError = true
1113
+	hec.simulateErr(true)
1114 1114
 	go hec.Serve()
1115 1115
 
1116 1116
 	info := logger.Info{
... ...
@@ -1308,3 +1310,48 @@ func TestCannotSendAfterClose(t *testing.T) {
1308 1308
 		t.Fatal(err)
1309 1309
 	}
1310 1310
 }
1311
+
1312
+func TestDeadlockOnBlockedEndpoint(t *testing.T) {
1313
+	hec := NewHTTPEventCollectorMock(t)
1314
+	go hec.Serve()
1315
+	info := logger.Info{
1316
+		Config: map[string]string{
1317
+			splunkURLKey:   hec.URL(),
1318
+			splunkTokenKey: hec.token,
1319
+		},
1320
+		ContainerID:        "containeriid",
1321
+		ContainerName:      "/container_name",
1322
+		ContainerImageID:   "contaimageid",
1323
+		ContainerImageName: "container_image_name",
1324
+	}
1325
+
1326
+	l, err := New(info)
1327
+	if err != nil {
1328
+		t.Fatal(err)
1329
+	}
1330
+
1331
+	ctx, unblock := context.WithCancel(context.Background())
1332
+	hec.withBlock(ctx)
1333
+	defer unblock()
1334
+
1335
+	batchSendTimeout = 1 * time.Second
1336
+
1337
+	if err := l.Log(&logger.Message{}); err != nil {
1338
+		t.Fatal(err)
1339
+	}
1340
+
1341
+	done := make(chan struct{})
1342
+	go func() {
1343
+		l.Close()
1344
+		close(done)
1345
+	}()
1346
+
1347
+	select {
1348
+	case <-time.After(60 * time.Second):
1349
+		buf := make([]byte, 1e6)
1350
+		buf = buf[:runtime.Stack(buf, true)]
1351
+		t.Logf("STACK DUMP: \n\n%s\n\n", string(buf))
1352
+		t.Fatal("timeout waiting for close to finish")
1353
+	case <-done:
1354
+	}
1355
+}
... ...
@@ -2,12 +2,14 @@ package splunk
2 2
 
3 3
 import (
4 4
 	"compress/gzip"
5
+	"context"
5 6
 	"encoding/json"
6 7
 	"fmt"
7 8
 	"io"
8 9
 	"io/ioutil"
9 10
 	"net"
10 11
 	"net/http"
12
+	"sync"
11 13
 	"testing"
12 14
 )
13 15
 
... ...
@@ -29,8 +31,10 @@ type HTTPEventCollectorMock struct {
29 29
 	tcpAddr     *net.TCPAddr
30 30
 	tcpListener *net.TCPListener
31 31
 
32
+	mu                  sync.Mutex
32 33
 	token               string
33 34
 	simulateServerError bool
35
+	blockingCtx         context.Context
34 36
 
35 37
 	test *testing.T
36 38
 
... ...
@@ -55,6 +59,18 @@ func NewHTTPEventCollectorMock(t *testing.T) *HTTPEventCollectorMock {
55 55
 		connectionVerified:  false}
56 56
 }
57 57
 
58
+func (hec *HTTPEventCollectorMock) simulateErr(b bool) {
59
+	hec.mu.Lock()
60
+	hec.simulateServerError = b
61
+	hec.mu.Unlock()
62
+}
63
+
64
+func (hec *HTTPEventCollectorMock) withBlock(ctx context.Context) {
65
+	hec.mu.Lock()
66
+	hec.blockingCtx = ctx
67
+	hec.mu.Unlock()
68
+}
69
+
58 70
 func (hec *HTTPEventCollectorMock) URL() string {
59 71
 	return "http://" + hec.tcpListener.Addr().String()
60 72
 }
... ...
@@ -72,7 +88,16 @@ func (hec *HTTPEventCollectorMock) ServeHTTP(writer http.ResponseWriter, request
72 72
 
73 73
 	hec.numOfRequests++
74 74
 
75
-	if hec.simulateServerError {
75
+	hec.mu.Lock()
76
+	simErr := hec.simulateServerError
77
+	ctx := hec.blockingCtx
78
+	hec.mu.Unlock()
79
+
80
+	if ctx != nil {
81
+		<-hec.blockingCtx.Done()
82
+	}
83
+
84
+	if simErr {
76 85
 		if request.Body != nil {
77 86
 			defer request.Body.Close()
78 87
 		}