Browse code

awslogs: Prevent close from being blocked on log

Before this change a call to `Close` could be blocked if the the channel
used to buffer logs is full.
When this happens the container state will end up wedged causing a
deadlock on anything that needs to lock the container state.

This removes the use of a channel which has semantics which are
difficult to manage to something more suitable for the situation.

Signed-off-by: Brian Goff <cpuguy83@gmail.com>

Brian Goff authored on 2024/04/24 03:26:08
Showing 4 changed files
... ...
@@ -8,7 +8,7 @@ import (
8 8
 	"regexp"
9 9
 	"sort"
10 10
 	"strconv"
11
-	"sync"
11
+	"sync/atomic"
12 12
 	"time"
13 13
 	"unicode/utf8"
14 14
 
... ...
@@ -76,10 +76,11 @@ type logStream struct {
76 76
 	forceFlushInterval time.Duration
77 77
 	multilinePattern   *regexp.Regexp
78 78
 	client             api
79
-	messages           chan *logger.Message
80
-	lock               sync.RWMutex
81
-	closed             bool
82
-	sequenceToken      *string
79
+
80
+	messages *loggerutils.MessageQueue
81
+	closed   atomic.Bool
82
+
83
+	sequenceToken *string
83 84
 }
84 85
 
85 86
 type logStreamConfig struct {
... ...
@@ -158,7 +159,7 @@ func New(info logger.Info) (logger.Logger, error) {
158 158
 		forceFlushInterval: containerStreamConfig.forceFlushInterval,
159 159
 		multilinePattern:   containerStreamConfig.multilinePattern,
160 160
 		client:             client,
161
-		messages:           make(chan *logger.Message, containerStreamConfig.maxBufferedEvents),
161
+		messages:           loggerutils.NewMessageQueue(containerStreamConfig.maxBufferedEvents),
162 162
 	}
163 163
 
164 164
 	creationDone := make(chan bool)
... ...
@@ -168,12 +169,10 @@ func New(info logger.Info) (logger.Logger, error) {
168 168
 			maxBackoff := 32
169 169
 			for {
170 170
 				// If logger is closed we are done
171
-				containerStream.lock.RLock()
172
-				if containerStream.closed {
173
-					containerStream.lock.RUnlock()
171
+				if containerStream.closed.Load() {
174 172
 					break
175 173
 				}
176
-				containerStream.lock.RUnlock()
174
+
177 175
 				err := containerStream.create()
178 176
 				if err == nil {
179 177
 					break
... ...
@@ -426,25 +425,26 @@ func (l *logStream) BufSize() int {
426 426
 	return maximumBytesPerEvent
427 427
 }
428 428
 
429
+var errClosed = errors.New("awslogs is closed")
430
+
429 431
 // Log submits messages for logging by an instance of the awslogs logging driver
430 432
 func (l *logStream) Log(msg *logger.Message) error {
431
-	l.lock.RLock()
432
-	defer l.lock.RUnlock()
433
-	if l.closed {
434
-		return errors.New("awslogs is closed")
433
+	// No need to check if we are closed here since the queue will be closed
434
+	// (i.e. returns false) in this case.
435
+	ctx := context.TODO()
436
+	if err := l.messages.Enqueue(ctx, msg); err != nil {
437
+		if err == loggerutils.ErrQueueClosed {
438
+			return errClosed
439
+		}
440
+		return err
435 441
 	}
436
-	l.messages <- msg
437 442
 	return nil
438 443
 }
439 444
 
440 445
 // Close closes the instance of the awslogs logging driver
441 446
 func (l *logStream) Close() error {
442
-	l.lock.Lock()
443
-	defer l.lock.Unlock()
444
-	if !l.closed {
445
-		close(l.messages)
446
-	}
447
-	l.closed = true
447
+	l.closed.Store(true)
448
+	l.messages.Close()
448 449
 	return nil
449 450
 }
450 451
 
... ...
@@ -561,6 +561,8 @@ func (l *logStream) collectBatch(created chan bool) {
561 561
 	var eventBuffer []byte
562 562
 	var eventBufferTimestamp int64
563 563
 	batch := newEventBatch()
564
+
565
+	chLogs := l.messages.Receiver()
564 566
 	for {
565 567
 		select {
566 568
 		case t := <-ticker.C:
... ...
@@ -576,7 +578,7 @@ func (l *logStream) collectBatch(created chan bool) {
576 576
 			}
577 577
 			l.publishBatch(batch)
578 578
 			batch.reset()
579
-		case msg, more := <-l.messages:
579
+		case msg, more := <-chLogs:
580 580
 			if !more {
581 581
 				// Flush event buffer and release resources
582 582
 				l.processEvent(batch, eventBuffer, eventBufferTimestamp)
... ...
@@ -356,9 +356,10 @@ func TestCreateAlreadyExists(t *testing.T) {
356 356
 func TestLogClosed(t *testing.T) {
357 357
 	mockClient := &mockClient{}
358 358
 	stream := &logStream{
359
-		client: mockClient,
360
-		closed: true,
359
+		client:   mockClient,
360
+		messages: loggerutils.NewMessageQueue(0),
361 361
 	}
362
+	stream.Close()
362 363
 	err := stream.Log(&logger.Message{})
363 364
 	assert.Check(t, err != nil)
364 365
 }
... ...
@@ -370,7 +371,7 @@ func TestLogBlocking(t *testing.T) {
370 370
 	mockClient := &mockClient{}
371 371
 	stream := &logStream{
372 372
 		client:   mockClient,
373
-		messages: make(chan *logger.Message),
373
+		messages: loggerutils.NewMessageQueue(0),
374 374
 	}
375 375
 
376 376
 	errorCh := make(chan error, 1)
... ...
@@ -387,14 +388,11 @@ func TestLogBlocking(t *testing.T) {
387 387
 		t.Fatal("Expected stream.Log to block: ", err)
388 388
 	default:
389 389
 	}
390
+
390 391
 	// assuming it is blocked, we can now try to drain the internal channel and
391 392
 	// unblock it
392
-	select {
393
-	case <-time.After(10 * time.Millisecond):
394
-		// if we're unable to drain the channel within 10ms, something seems broken
395
-		t.Fatal("Expected to be able to read from stream.messages but was unable to")
396
-	case <-stream.messages:
397
-	}
393
+	<-stream.messages.Receiver()
394
+
398 395
 	select {
399 396
 	case err := <-errorCh:
400 397
 		assert.NilError(t, err)
... ...
@@ -408,7 +406,7 @@ func TestLogBufferEmpty(t *testing.T) {
408 408
 	mockClient := &mockClient{}
409 409
 	stream := &logStream{
410 410
 		client:   mockClient,
411
-		messages: make(chan *logger.Message, 1),
411
+		messages: loggerutils.NewMessageQueue(1),
412 412
 	}
413 413
 	err := stream.Log(&logger.Message{})
414 414
 	assert.NilError(t, err)
... ...
@@ -556,7 +554,7 @@ func TestCollectBatchSimple(t *testing.T) {
556 556
 		logGroupName:  groupName,
557 557
 		logStreamName: streamName,
558 558
 		sequenceToken: aws.String(sequenceToken),
559
-		messages:      make(chan *logger.Message),
559
+		messages:      loggerutils.NewMessageQueue(0),
560 560
 	}
561 561
 	calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
562 562
 	mockClient.putLogEventsFunc = func(ctx context.Context, input *cloudwatchlogs.PutLogEventsInput, opts ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.PutLogEventsOutput, error) {
... ...
@@ -575,15 +573,20 @@ func TestCollectBatchSimple(t *testing.T) {
575 575
 	close(d)
576 576
 	go stream.collectBatch(d)
577 577
 
578
-	stream.Log(&logger.Message{
578
+	err := stream.Log(&logger.Message{
579 579
 		Line:      []byte(logline),
580 580
 		Timestamp: time.Time{},
581 581
 	})
582
+	assert.NilError(t, err)
582 583
 
583 584
 	ticks <- time.Time{}
584 585
 	ticks <- time.Time{}
585 586
 	stream.Close()
586 587
 
588
+	for len(calls) != 1 {
589
+		time.Sleep(10 * time.Millisecond)
590
+	}
591
+
587 592
 	assert.Assert(t, len(calls) == 1)
588 593
 	argument := calls[0]
589 594
 	assert.Assert(t, argument != nil)
... ...
@@ -598,7 +601,7 @@ func TestCollectBatchTicker(t *testing.T) {
598 598
 		logGroupName:  groupName,
599 599
 		logStreamName: streamName,
600 600
 		sequenceToken: aws.String(sequenceToken),
601
-		messages:      make(chan *logger.Message),
601
+		messages:      loggerutils.NewMessageQueue(0),
602 602
 	}
603 603
 	calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
604 604
 	called := make(chan struct{}, 50)
... ...
@@ -666,7 +669,7 @@ func TestCollectBatchMultilinePattern(t *testing.T) {
666 666
 		logStreamName:    streamName,
667 667
 		multilinePattern: multilinePattern,
668 668
 		sequenceToken:    aws.String(sequenceToken),
669
-		messages:         make(chan *logger.Message),
669
+		messages:         loggerutils.NewMessageQueue(0),
670 670
 	}
671 671
 	calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
672 672
 	called := make(chan struct{}, 50)
... ...
@@ -732,7 +735,7 @@ func BenchmarkCollectBatch(b *testing.B) {
732 732
 			logGroupName:  groupName,
733 733
 			logStreamName: streamName,
734 734
 			sequenceToken: aws.String(sequenceToken),
735
-			messages:      make(chan *logger.Message),
735
+			messages:      loggerutils.NewMessageQueue(0),
736 736
 		}
737 737
 		mockClient.putLogEventsFunc = func(ctx context.Context, input *cloudwatchlogs.PutLogEventsInput, opts ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.PutLogEventsOutput, error) {
738 738
 			return &cloudwatchlogs.PutLogEventsOutput{
... ...
@@ -765,7 +768,7 @@ func BenchmarkCollectBatchMultilinePattern(b *testing.B) {
765 765
 			logStreamName:    streamName,
766 766
 			multilinePattern: multilinePattern,
767 767
 			sequenceToken:    aws.String(sequenceToken),
768
-			messages:         make(chan *logger.Message),
768
+			messages:         loggerutils.NewMessageQueue(0),
769 769
 		}
770 770
 		mockClient.putLogEventsFunc = func(ctx context.Context, input *cloudwatchlogs.PutLogEventsInput, opts ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.PutLogEventsOutput, error) {
771 771
 			return &cloudwatchlogs.PutLogEventsOutput{
... ...
@@ -796,7 +799,7 @@ func TestCollectBatchMultilinePatternMaxEventAge(t *testing.T) {
796 796
 		logStreamName:    streamName,
797 797
 		multilinePattern: multilinePattern,
798 798
 		sequenceToken:    aws.String(sequenceToken),
799
-		messages:         make(chan *logger.Message),
799
+		messages:         loggerutils.NewMessageQueue(0),
800 800
 	}
801 801
 	calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
802 802
 	called := make(chan struct{}, 50)
... ...
@@ -870,7 +873,7 @@ func TestCollectBatchMultilinePatternNegativeEventAge(t *testing.T) {
870 870
 		logStreamName:    streamName,
871 871
 		multilinePattern: multilinePattern,
872 872
 		sequenceToken:    aws.String(sequenceToken),
873
-		messages:         make(chan *logger.Message),
873
+		messages:         loggerutils.NewMessageQueue(0),
874 874
 	}
875 875
 	calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
876 876
 	called := make(chan struct{}, 50)
... ...
@@ -927,7 +930,7 @@ func TestCollectBatchMultilinePatternMaxEventSize(t *testing.T) {
927 927
 		logStreamName:    streamName,
928 928
 		multilinePattern: multilinePattern,
929 929
 		sequenceToken:    aws.String(sequenceToken),
930
-		messages:         make(chan *logger.Message),
930
+		messages:         loggerutils.NewMessageQueue(0),
931 931
 	}
932 932
 	calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
933 933
 	called := make(chan struct{}, 50)
... ...
@@ -987,7 +990,7 @@ func TestCollectBatchClose(t *testing.T) {
987 987
 		logGroupName:  groupName,
988 988
 		logStreamName: streamName,
989 989
 		sequenceToken: aws.String(sequenceToken),
990
-		messages:      make(chan *logger.Message),
990
+		messages:      loggerutils.NewMessageQueue(0),
991 991
 	}
992 992
 	calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
993 993
 	called := make(chan struct{}, 50)
... ...
@@ -1089,7 +1092,7 @@ func TestCollectBatchLineSplit(t *testing.T) {
1089 1089
 		logGroupName:  groupName,
1090 1090
 		logStreamName: streamName,
1091 1091
 		sequenceToken: aws.String(sequenceToken),
1092
-		messages:      make(chan *logger.Message),
1092
+		messages:      loggerutils.NewMessageQueue(0),
1093 1093
 	}
1094 1094
 	calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
1095 1095
 	called := make(chan struct{}, 50)
... ...
@@ -1137,7 +1140,7 @@ func TestCollectBatchLineSplitWithBinary(t *testing.T) {
1137 1137
 		logGroupName:  groupName,
1138 1138
 		logStreamName: streamName,
1139 1139
 		sequenceToken: aws.String(sequenceToken),
1140
-		messages:      make(chan *logger.Message),
1140
+		messages:      loggerutils.NewMessageQueue(0),
1141 1141
 	}
1142 1142
 	calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
1143 1143
 	called := make(chan struct{}, 50)
... ...
@@ -1185,7 +1188,7 @@ func TestCollectBatchMaxEvents(t *testing.T) {
1185 1185
 		logGroupName:  groupName,
1186 1186
 		logStreamName: streamName,
1187 1187
 		sequenceToken: aws.String(sequenceToken),
1188
-		messages:      make(chan *logger.Message),
1188
+		messages:      loggerutils.NewMessageQueue(0),
1189 1189
 	}
1190 1190
 	calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
1191 1191
 	called := make(chan struct{}, 50)
... ...
@@ -1239,7 +1242,7 @@ func TestCollectBatchMaxTotalBytes(t *testing.T) {
1239 1239
 		logGroupName:  groupName,
1240 1240
 		logStreamName: streamName,
1241 1241
 		sequenceToken: aws.String(sequenceToken),
1242
-		messages:      make(chan *logger.Message),
1242
+		messages:      loggerutils.NewMessageQueue(0),
1243 1243
 	}
1244 1244
 	calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
1245 1245
 	called := make(chan struct{}, 50)
... ...
@@ -1320,7 +1323,7 @@ func TestCollectBatchMaxTotalBytesWithBinary(t *testing.T) {
1320 1320
 		logGroupName:  groupName,
1321 1321
 		logStreamName: streamName,
1322 1322
 		sequenceToken: aws.String(sequenceToken),
1323
-		messages:      make(chan *logger.Message),
1323
+		messages:      loggerutils.NewMessageQueue(0),
1324 1324
 	}
1325 1325
 	calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
1326 1326
 	called := make(chan struct{}, 50)
... ...
@@ -1394,7 +1397,7 @@ func TestCollectBatchWithDuplicateTimestamps(t *testing.T) {
1394 1394
 		logGroupName:  groupName,
1395 1395
 		logStreamName: streamName,
1396 1396
 		sequenceToken: aws.String(sequenceToken),
1397
-		messages:      make(chan *logger.Message),
1397
+		messages:      loggerutils.NewMessageQueue(0),
1398 1398
 	}
1399 1399
 	calls := make([]*cloudwatchlogs.PutLogEventsInput, 0)
1400 1400
 	called := make(chan struct{}, 50)
1401 1401
new file mode 100644
... ...
@@ -0,0 +1,156 @@
0
+package loggerutils
1
+
2
+import (
3
+	"context"
4
+	"sync"
5
+
6
+	"github.com/docker/docker/daemon/logger"
7
+	"github.com/pkg/errors"
8
+)
9
+
10
+// MessageQueue is a queue for log messages.
11
+//
12
+// [MessageQueue.Enqueue] will block when the queue is full.
13
+// To dequeue messages call [MessageQueue.Reciever] and pull messsages off the
14
+// returned channel.
15
+//
16
+// Closing only prevents new messages from being added to the queue.
17
+// The queue can still be drained after close.
18
+//
19
+// The zero value of MessageQueue is safe to use, but does not do any internal
20
+// buffering (queue size is 0).
21
+type MessageQueue struct {
22
+	maxSize int
23
+
24
+	mu      sync.Mutex
25
+	closing bool
26
+	closed  chan struct{}
27
+
28
+	// Blocks multiple calls to [MessageQueue.Close] until the queue is actually closed
29
+	closeWait chan struct{}
30
+
31
+	// We need to be able to safely close the send channel so that [MessageQueue.Dequeue]
32
+	// can drain the queue without blocking.
33
+	// This cond var helps deal with that.
34
+	cond        *sync.Cond
35
+	sendWaiters int
36
+
37
+	ch chan *logger.Message
38
+}
39
+
40
+// NewMessageQueue creates a new queue with the specified size.
41
+func NewMessageQueue(maxSize int) *MessageQueue {
42
+	var q MessageQueue
43
+	q.maxSize = maxSize
44
+	q.init()
45
+	return &q
46
+}
47
+
48
+func (q *MessageQueue) init() {
49
+	if q.cond == nil {
50
+		q.cond = sync.NewCond(&q.mu)
51
+	}
52
+
53
+	if q.ch == nil {
54
+		q.ch = make(chan *logger.Message, q.maxSize)
55
+	}
56
+
57
+	if q.closed == nil {
58
+		q.closed = make(chan struct{})
59
+	}
60
+
61
+	if q.closeWait == nil {
62
+		q.closeWait = make(chan struct{})
63
+	}
64
+}
65
+
66
+var ErrQueueClosed = errors.New("queue is closed")
67
+
68
+// Enqueue adds the provided message to the queue.
69
+// Enqueue blocks if the queue is full.
70
+//
71
+// The two possible error cases are:
72
+// 1. The provided context is cancelled
73
+// 2. [ErrQueueClosed] when the queue has been closed.
74
+func (q *MessageQueue) Enqueue(ctx context.Context, m *logger.Message) error {
75
+	q.mu.Lock()
76
+	q.init()
77
+
78
+	// Increment the waiter count
79
+	// This prevents the send channel from being closed while we are trying to send.
80
+	q.sendWaiters++
81
+	q.mu.Unlock()
82
+
83
+	defer func() {
84
+		q.mu.Lock()
85
+		// Decrement the waiter count and signal to any potential closer to check
86
+		// the wait count again.
87
+		// Only bother signaling if this is the last waiter.
88
+		q.sendWaiters--
89
+		if q.sendWaiters == 0 {
90
+			q.cond.Signal()
91
+		}
92
+		q.mu.Unlock()
93
+	}()
94
+
95
+	// Before trying to send on the channel, check if we care closed.
96
+	select {
97
+	case <-ctx.Done():
98
+		return ctx.Err()
99
+	case <-q.closed:
100
+		return ErrQueueClosed
101
+	default:
102
+	}
103
+
104
+	select {
105
+	case <-ctx.Done():
106
+		return ctx.Err()
107
+	case <-q.closed:
108
+		return ErrQueueClosed
109
+	case q.ch <- m:
110
+		return nil
111
+	}
112
+}
113
+
114
+// Close prevents any new messages from being added to the queue.
115
+func (q *MessageQueue) Close() {
116
+	q.mu.Lock()
117
+
118
+	q.init()
119
+
120
+	if q.closing {
121
+		// unlock the mutex here so that the goroutine waiting on the cond var can
122
+		// take the lock when signaled.
123
+		q.mu.Unlock()
124
+		<-q.closeWait
125
+		return
126
+	}
127
+
128
+	defer q.mu.Unlock()
129
+
130
+	// Prevent multiple Close calls from trying to close things.
131
+	q.closing = true
132
+
133
+	close(q.closed)
134
+
135
+	// Wait for any senders to finish
136
+	// Because we closed the channel above, this shouldn't block for a long period.
137
+	for q.sendWaiters > 0 {
138
+		q.cond.Wait()
139
+	}
140
+
141
+	close(q.ch)
142
+	close(q.closeWait)
143
+}
144
+
145
+// Receiver returns a channel that can be used to dequeue messages
146
+// The channel will be closed when the message queue is closed but may have
147
+// messages buffered.
148
+func (q *MessageQueue) Receiver() <-chan *logger.Message {
149
+	q.mu.Lock()
150
+	defer q.mu.Unlock()
151
+
152
+	q.init()
153
+
154
+	return q.ch
155
+}
0 156
new file mode 100644
... ...
@@ -0,0 +1,87 @@
0
+package loggerutils
1
+
2
+import (
3
+	"context"
4
+	"testing"
5
+	"time"
6
+
7
+	"github.com/docker/docker/daemon/logger"
8
+	"gotest.tools/v3/assert"
9
+)
10
+
11
+func TestQueue(t *testing.T) {
12
+	q := NewMessageQueue(2)
13
+	msg := &logger.Message{Line: []byte("hello")}
14
+
15
+	ctx := context.Background()
16
+	err := q.Enqueue(ctx, msg)
17
+	assert.Check(t, err)
18
+
19
+	recv := q.Receiver()
20
+	// These pointer values should be the same
21
+	assert.Equal(t, msg, <-recv)
22
+
23
+	err = q.Enqueue(ctx, msg)
24
+	assert.Check(t, err)
25
+
26
+	err = q.Enqueue(ctx, msg)
27
+	assert.Check(t, err)
28
+
29
+	q.Close()
30
+
31
+	// We have 2 messages in the queue
32
+	// Even though this is closed, we should get a true value from dequeue twice.
33
+	assert.Equal(t, msg, <-recv)
34
+	assert.Equal(t, msg, <-recv)
35
+
36
+	// This should not block and should return false
37
+	_, more := <-recv
38
+	assert.Check(t, !more, "expected no more messages in the queue")
39
+
40
+	// Test with unbuffered
41
+	q = &MessageQueue{}
42
+	recv = q.Receiver()
43
+
44
+	chAdd := make(chan error, 1)
45
+	go func() {
46
+		chAdd <- q.Enqueue(ctx, msg)
47
+	}()
48
+
49
+	assert.Equal(t, msg, <-recv)
50
+	assert.Assert(t, <-chAdd)
51
+
52
+	ctxC, cancel := context.WithCancel(ctx)
53
+	cancel()
54
+
55
+	err = q.Enqueue(ctxC, msg)
56
+	assert.ErrorIs(t, err, context.Canceled)
57
+
58
+	// Test that blocked senders do not cause a panic on close.
59
+	// This check is useful because the underlying implementation uses channels
60
+	// with the send channel eventually getting closed when q.Close is called.
61
+	go func() {
62
+		chAdd <- q.Enqueue(ctx, msg)
63
+	}()
64
+
65
+	// Wait for enqueue to be ready (or as close to ready as it can be)
66
+	for {
67
+		q.mu.Lock()
68
+		if q.sendWaiters == 1 {
69
+			q.mu.Unlock()
70
+			break
71
+		}
72
+		q.mu.Unlock()
73
+		time.Sleep(time.Millisecond)
74
+	}
75
+
76
+	q.Close()
77
+
78
+	select {
79
+	case <-time.After(5 * time.Second):
80
+	case err := <-chAdd:
81
+		assert.ErrorIs(t, err, ErrQueueClosed)
82
+	}
83
+
84
+	// Double-close should not cause any issues
85
+	q.Close()
86
+}