Before this change a call to `Close` could be blocked if the the channel
used to buffer logs is full.
When this happens the container state will end up wedged causing a
deadlock on anything that needs to lock the container state.
This removes the use of a channel which has semantics which are
difficult to manage to something more suitable for the situation.
Signed-off-by: Brian Goff <cpuguy83@gmail.com>
| ... | ... |
@@ -8,7 +8,7 @@ import ( |
| 8 | 8 |
"regexp" |
| 9 | 9 |
"sort" |
| 10 | 10 |
"strconv" |
| 11 |
- "sync" |
|
| 11 |
+ "sync/atomic" |
|
| 12 | 12 |
"time" |
| 13 | 13 |
"unicode/utf8" |
| 14 | 14 |
|
| ... | ... |
@@ -76,10 +76,11 @@ type logStream struct {
|
| 76 | 76 |
forceFlushInterval time.Duration |
| 77 | 77 |
multilinePattern *regexp.Regexp |
| 78 | 78 |
client api |
| 79 |
- messages chan *logger.Message |
|
| 80 |
- lock sync.RWMutex |
|
| 81 |
- closed bool |
|
| 82 |
- sequenceToken *string |
|
| 79 |
+ |
|
| 80 |
+ messages *loggerutils.MessageQueue |
|
| 81 |
+ closed atomic.Bool |
|
| 82 |
+ |
|
| 83 |
+ sequenceToken *string |
|
| 83 | 84 |
} |
| 84 | 85 |
|
| 85 | 86 |
type logStreamConfig struct {
|
| ... | ... |
@@ -158,7 +159,7 @@ func New(info logger.Info) (logger.Logger, error) {
|
| 158 | 158 |
forceFlushInterval: containerStreamConfig.forceFlushInterval, |
| 159 | 159 |
multilinePattern: containerStreamConfig.multilinePattern, |
| 160 | 160 |
client: client, |
| 161 |
- messages: make(chan *logger.Message, containerStreamConfig.maxBufferedEvents), |
|
| 161 |
+ messages: loggerutils.NewMessageQueue(containerStreamConfig.maxBufferedEvents), |
|
| 162 | 162 |
} |
| 163 | 163 |
|
| 164 | 164 |
creationDone := make(chan bool) |
| ... | ... |
@@ -168,12 +169,10 @@ func New(info logger.Info) (logger.Logger, error) {
|
| 168 | 168 |
maxBackoff := 32 |
| 169 | 169 |
for {
|
| 170 | 170 |
// If logger is closed we are done |
| 171 |
- containerStream.lock.RLock() |
|
| 172 |
- if containerStream.closed {
|
|
| 173 |
- containerStream.lock.RUnlock() |
|
| 171 |
+ if containerStream.closed.Load() {
|
|
| 174 | 172 |
break |
| 175 | 173 |
} |
| 176 |
- containerStream.lock.RUnlock() |
|
| 174 |
+ |
|
| 177 | 175 |
err := containerStream.create() |
| 178 | 176 |
if err == nil {
|
| 179 | 177 |
break |
| ... | ... |
@@ -426,25 +425,26 @@ func (l *logStream) BufSize() int {
|
| 426 | 426 |
return maximumBytesPerEvent |
| 427 | 427 |
} |
| 428 | 428 |
|
| 429 |
+var errClosed = errors.New("awslogs is closed")
|
|
| 430 |
+ |
|
| 429 | 431 |
// Log submits messages for logging by an instance of the awslogs logging driver |
| 430 | 432 |
func (l *logStream) Log(msg *logger.Message) error {
|
| 431 |
- l.lock.RLock() |
|
| 432 |
- defer l.lock.RUnlock() |
|
| 433 |
- if l.closed {
|
|
| 434 |
- return errors.New("awslogs is closed")
|
|
| 433 |
+ // No need to check if we are closed here since the queue will be closed |
|
| 434 |
+ // (i.e. returns false) in this case. |
|
| 435 |
+ ctx := context.TODO() |
|
| 436 |
+ if err := l.messages.Enqueue(ctx, msg); err != nil {
|
|
| 437 |
+ if err == loggerutils.ErrQueueClosed {
|
|
| 438 |
+ return errClosed |
|
| 439 |
+ } |
|
| 440 |
+ return err |
|
| 435 | 441 |
} |
| 436 |
- l.messages <- msg |
|
| 437 | 442 |
return nil |
| 438 | 443 |
} |
| 439 | 444 |
|
| 440 | 445 |
// Close closes the instance of the awslogs logging driver |
| 441 | 446 |
func (l *logStream) Close() error {
|
| 442 |
- l.lock.Lock() |
|
| 443 |
- defer l.lock.Unlock() |
|
| 444 |
- if !l.closed {
|
|
| 445 |
- close(l.messages) |
|
| 446 |
- } |
|
| 447 |
- l.closed = true |
|
| 447 |
+ l.closed.Store(true) |
|
| 448 |
+ l.messages.Close() |
|
| 448 | 449 |
return nil |
| 449 | 450 |
} |
| 450 | 451 |
|
| ... | ... |
@@ -561,6 +561,8 @@ func (l *logStream) collectBatch(created chan bool) {
|
| 561 | 561 |
var eventBuffer []byte |
| 562 | 562 |
var eventBufferTimestamp int64 |
| 563 | 563 |
batch := newEventBatch() |
| 564 |
+ |
|
| 565 |
+ chLogs := l.messages.Receiver() |
|
| 564 | 566 |
for {
|
| 565 | 567 |
select {
|
| 566 | 568 |
case t := <-ticker.C: |
| ... | ... |
@@ -576,7 +578,7 @@ func (l *logStream) collectBatch(created chan bool) {
|
| 576 | 576 |
} |
| 577 | 577 |
l.publishBatch(batch) |
| 578 | 578 |
batch.reset() |
| 579 |
- case msg, more := <-l.messages: |
|
| 579 |
+ case msg, more := <-chLogs: |
|
| 580 | 580 |
if !more {
|
| 581 | 581 |
// Flush event buffer and release resources |
| 582 | 582 |
l.processEvent(batch, eventBuffer, eventBufferTimestamp) |
| ... | ... |
@@ -356,9 +356,10 @@ func TestCreateAlreadyExists(t *testing.T) {
|
| 356 | 356 |
func TestLogClosed(t *testing.T) {
|
| 357 | 357 |
mockClient := &mockClient{}
|
| 358 | 358 |
stream := &logStream{
|
| 359 |
- client: mockClient, |
|
| 360 |
- closed: true, |
|
| 359 |
+ client: mockClient, |
|
| 360 |
+ messages: loggerutils.NewMessageQueue(0), |
|
| 361 | 361 |
} |
| 362 |
+ stream.Close() |
|
| 362 | 363 |
err := stream.Log(&logger.Message{})
|
| 363 | 364 |
assert.Check(t, err != nil) |
| 364 | 365 |
} |
| ... | ... |
@@ -370,7 +371,7 @@ func TestLogBlocking(t *testing.T) {
|
| 370 | 370 |
mockClient := &mockClient{}
|
| 371 | 371 |
stream := &logStream{
|
| 372 | 372 |
client: mockClient, |
| 373 |
- messages: make(chan *logger.Message), |
|
| 373 |
+ messages: loggerutils.NewMessageQueue(0), |
|
| 374 | 374 |
} |
| 375 | 375 |
|
| 376 | 376 |
errorCh := make(chan error, 1) |
| ... | ... |
@@ -387,14 +388,11 @@ func TestLogBlocking(t *testing.T) {
|
| 387 | 387 |
t.Fatal("Expected stream.Log to block: ", err)
|
| 388 | 388 |
default: |
| 389 | 389 |
} |
| 390 |
+ |
|
| 390 | 391 |
// assuming it is blocked, we can now try to drain the internal channel and |
| 391 | 392 |
// unblock it |
| 392 |
- select {
|
|
| 393 |
- case <-time.After(10 * time.Millisecond): |
|
| 394 |
- // if we're unable to drain the channel within 10ms, something seems broken |
|
| 395 |
- t.Fatal("Expected to be able to read from stream.messages but was unable to")
|
|
| 396 |
- case <-stream.messages: |
|
| 397 |
- } |
|
| 393 |
+ <-stream.messages.Receiver() |
|
| 394 |
+ |
|
| 398 | 395 |
select {
|
| 399 | 396 |
case err := <-errorCh: |
| 400 | 397 |
assert.NilError(t, err) |
| ... | ... |
@@ -408,7 +406,7 @@ func TestLogBufferEmpty(t *testing.T) {
|
| 408 | 408 |
mockClient := &mockClient{}
|
| 409 | 409 |
stream := &logStream{
|
| 410 | 410 |
client: mockClient, |
| 411 |
- messages: make(chan *logger.Message, 1), |
|
| 411 |
+ messages: loggerutils.NewMessageQueue(1), |
|
| 412 | 412 |
} |
| 413 | 413 |
err := stream.Log(&logger.Message{})
|
| 414 | 414 |
assert.NilError(t, err) |
| ... | ... |
@@ -556,7 +554,7 @@ func TestCollectBatchSimple(t *testing.T) {
|
| 556 | 556 |
logGroupName: groupName, |
| 557 | 557 |
logStreamName: streamName, |
| 558 | 558 |
sequenceToken: aws.String(sequenceToken), |
| 559 |
- messages: make(chan *logger.Message), |
|
| 559 |
+ messages: loggerutils.NewMessageQueue(0), |
|
| 560 | 560 |
} |
| 561 | 561 |
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0) |
| 562 | 562 |
mockClient.putLogEventsFunc = func(ctx context.Context, input *cloudwatchlogs.PutLogEventsInput, opts ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.PutLogEventsOutput, error) {
|
| ... | ... |
@@ -575,15 +573,20 @@ func TestCollectBatchSimple(t *testing.T) {
|
| 575 | 575 |
close(d) |
| 576 | 576 |
go stream.collectBatch(d) |
| 577 | 577 |
|
| 578 |
- stream.Log(&logger.Message{
|
|
| 578 |
+ err := stream.Log(&logger.Message{
|
|
| 579 | 579 |
Line: []byte(logline), |
| 580 | 580 |
Timestamp: time.Time{},
|
| 581 | 581 |
}) |
| 582 |
+ assert.NilError(t, err) |
|
| 582 | 583 |
|
| 583 | 584 |
ticks <- time.Time{}
|
| 584 | 585 |
ticks <- time.Time{}
|
| 585 | 586 |
stream.Close() |
| 586 | 587 |
|
| 588 |
+ for len(calls) != 1 {
|
|
| 589 |
+ time.Sleep(10 * time.Millisecond) |
|
| 590 |
+ } |
|
| 591 |
+ |
|
| 587 | 592 |
assert.Assert(t, len(calls) == 1) |
| 588 | 593 |
argument := calls[0] |
| 589 | 594 |
assert.Assert(t, argument != nil) |
| ... | ... |
@@ -598,7 +601,7 @@ func TestCollectBatchTicker(t *testing.T) {
|
| 598 | 598 |
logGroupName: groupName, |
| 599 | 599 |
logStreamName: streamName, |
| 600 | 600 |
sequenceToken: aws.String(sequenceToken), |
| 601 |
- messages: make(chan *logger.Message), |
|
| 601 |
+ messages: loggerutils.NewMessageQueue(0), |
|
| 602 | 602 |
} |
| 603 | 603 |
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0) |
| 604 | 604 |
called := make(chan struct{}, 50)
|
| ... | ... |
@@ -666,7 +669,7 @@ func TestCollectBatchMultilinePattern(t *testing.T) {
|
| 666 | 666 |
logStreamName: streamName, |
| 667 | 667 |
multilinePattern: multilinePattern, |
| 668 | 668 |
sequenceToken: aws.String(sequenceToken), |
| 669 |
- messages: make(chan *logger.Message), |
|
| 669 |
+ messages: loggerutils.NewMessageQueue(0), |
|
| 670 | 670 |
} |
| 671 | 671 |
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0) |
| 672 | 672 |
called := make(chan struct{}, 50)
|
| ... | ... |
@@ -732,7 +735,7 @@ func BenchmarkCollectBatch(b *testing.B) {
|
| 732 | 732 |
logGroupName: groupName, |
| 733 | 733 |
logStreamName: streamName, |
| 734 | 734 |
sequenceToken: aws.String(sequenceToken), |
| 735 |
- messages: make(chan *logger.Message), |
|
| 735 |
+ messages: loggerutils.NewMessageQueue(0), |
|
| 736 | 736 |
} |
| 737 | 737 |
mockClient.putLogEventsFunc = func(ctx context.Context, input *cloudwatchlogs.PutLogEventsInput, opts ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.PutLogEventsOutput, error) {
|
| 738 | 738 |
return &cloudwatchlogs.PutLogEventsOutput{
|
| ... | ... |
@@ -765,7 +768,7 @@ func BenchmarkCollectBatchMultilinePattern(b *testing.B) {
|
| 765 | 765 |
logStreamName: streamName, |
| 766 | 766 |
multilinePattern: multilinePattern, |
| 767 | 767 |
sequenceToken: aws.String(sequenceToken), |
| 768 |
- messages: make(chan *logger.Message), |
|
| 768 |
+ messages: loggerutils.NewMessageQueue(0), |
|
| 769 | 769 |
} |
| 770 | 770 |
mockClient.putLogEventsFunc = func(ctx context.Context, input *cloudwatchlogs.PutLogEventsInput, opts ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.PutLogEventsOutput, error) {
|
| 771 | 771 |
return &cloudwatchlogs.PutLogEventsOutput{
|
| ... | ... |
@@ -796,7 +799,7 @@ func TestCollectBatchMultilinePatternMaxEventAge(t *testing.T) {
|
| 796 | 796 |
logStreamName: streamName, |
| 797 | 797 |
multilinePattern: multilinePattern, |
| 798 | 798 |
sequenceToken: aws.String(sequenceToken), |
| 799 |
- messages: make(chan *logger.Message), |
|
| 799 |
+ messages: loggerutils.NewMessageQueue(0), |
|
| 800 | 800 |
} |
| 801 | 801 |
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0) |
| 802 | 802 |
called := make(chan struct{}, 50)
|
| ... | ... |
@@ -870,7 +873,7 @@ func TestCollectBatchMultilinePatternNegativeEventAge(t *testing.T) {
|
| 870 | 870 |
logStreamName: streamName, |
| 871 | 871 |
multilinePattern: multilinePattern, |
| 872 | 872 |
sequenceToken: aws.String(sequenceToken), |
| 873 |
- messages: make(chan *logger.Message), |
|
| 873 |
+ messages: loggerutils.NewMessageQueue(0), |
|
| 874 | 874 |
} |
| 875 | 875 |
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0) |
| 876 | 876 |
called := make(chan struct{}, 50)
|
| ... | ... |
@@ -927,7 +930,7 @@ func TestCollectBatchMultilinePatternMaxEventSize(t *testing.T) {
|
| 927 | 927 |
logStreamName: streamName, |
| 928 | 928 |
multilinePattern: multilinePattern, |
| 929 | 929 |
sequenceToken: aws.String(sequenceToken), |
| 930 |
- messages: make(chan *logger.Message), |
|
| 930 |
+ messages: loggerutils.NewMessageQueue(0), |
|
| 931 | 931 |
} |
| 932 | 932 |
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0) |
| 933 | 933 |
called := make(chan struct{}, 50)
|
| ... | ... |
@@ -987,7 +990,7 @@ func TestCollectBatchClose(t *testing.T) {
|
| 987 | 987 |
logGroupName: groupName, |
| 988 | 988 |
logStreamName: streamName, |
| 989 | 989 |
sequenceToken: aws.String(sequenceToken), |
| 990 |
- messages: make(chan *logger.Message), |
|
| 990 |
+ messages: loggerutils.NewMessageQueue(0), |
|
| 991 | 991 |
} |
| 992 | 992 |
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0) |
| 993 | 993 |
called := make(chan struct{}, 50)
|
| ... | ... |
@@ -1089,7 +1092,7 @@ func TestCollectBatchLineSplit(t *testing.T) {
|
| 1089 | 1089 |
logGroupName: groupName, |
| 1090 | 1090 |
logStreamName: streamName, |
| 1091 | 1091 |
sequenceToken: aws.String(sequenceToken), |
| 1092 |
- messages: make(chan *logger.Message), |
|
| 1092 |
+ messages: loggerutils.NewMessageQueue(0), |
|
| 1093 | 1093 |
} |
| 1094 | 1094 |
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0) |
| 1095 | 1095 |
called := make(chan struct{}, 50)
|
| ... | ... |
@@ -1137,7 +1140,7 @@ func TestCollectBatchLineSplitWithBinary(t *testing.T) {
|
| 1137 | 1137 |
logGroupName: groupName, |
| 1138 | 1138 |
logStreamName: streamName, |
| 1139 | 1139 |
sequenceToken: aws.String(sequenceToken), |
| 1140 |
- messages: make(chan *logger.Message), |
|
| 1140 |
+ messages: loggerutils.NewMessageQueue(0), |
|
| 1141 | 1141 |
} |
| 1142 | 1142 |
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0) |
| 1143 | 1143 |
called := make(chan struct{}, 50)
|
| ... | ... |
@@ -1185,7 +1188,7 @@ func TestCollectBatchMaxEvents(t *testing.T) {
|
| 1185 | 1185 |
logGroupName: groupName, |
| 1186 | 1186 |
logStreamName: streamName, |
| 1187 | 1187 |
sequenceToken: aws.String(sequenceToken), |
| 1188 |
- messages: make(chan *logger.Message), |
|
| 1188 |
+ messages: loggerutils.NewMessageQueue(0), |
|
| 1189 | 1189 |
} |
| 1190 | 1190 |
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0) |
| 1191 | 1191 |
called := make(chan struct{}, 50)
|
| ... | ... |
@@ -1239,7 +1242,7 @@ func TestCollectBatchMaxTotalBytes(t *testing.T) {
|
| 1239 | 1239 |
logGroupName: groupName, |
| 1240 | 1240 |
logStreamName: streamName, |
| 1241 | 1241 |
sequenceToken: aws.String(sequenceToken), |
| 1242 |
- messages: make(chan *logger.Message), |
|
| 1242 |
+ messages: loggerutils.NewMessageQueue(0), |
|
| 1243 | 1243 |
} |
| 1244 | 1244 |
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0) |
| 1245 | 1245 |
called := make(chan struct{}, 50)
|
| ... | ... |
@@ -1320,7 +1323,7 @@ func TestCollectBatchMaxTotalBytesWithBinary(t *testing.T) {
|
| 1320 | 1320 |
logGroupName: groupName, |
| 1321 | 1321 |
logStreamName: streamName, |
| 1322 | 1322 |
sequenceToken: aws.String(sequenceToken), |
| 1323 |
- messages: make(chan *logger.Message), |
|
| 1323 |
+ messages: loggerutils.NewMessageQueue(0), |
|
| 1324 | 1324 |
} |
| 1325 | 1325 |
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0) |
| 1326 | 1326 |
called := make(chan struct{}, 50)
|
| ... | ... |
@@ -1394,7 +1397,7 @@ func TestCollectBatchWithDuplicateTimestamps(t *testing.T) {
|
| 1394 | 1394 |
logGroupName: groupName, |
| 1395 | 1395 |
logStreamName: streamName, |
| 1396 | 1396 |
sequenceToken: aws.String(sequenceToken), |
| 1397 |
- messages: make(chan *logger.Message), |
|
| 1397 |
+ messages: loggerutils.NewMessageQueue(0), |
|
| 1398 | 1398 |
} |
| 1399 | 1399 |
calls := make([]*cloudwatchlogs.PutLogEventsInput, 0) |
| 1400 | 1400 |
called := make(chan struct{}, 50)
|
| 1401 | 1401 |
new file mode 100644 |
| ... | ... |
@@ -0,0 +1,156 @@ |
| 0 |
+package loggerutils |
|
| 1 |
+ |
|
| 2 |
+import ( |
|
| 3 |
+ "context" |
|
| 4 |
+ "sync" |
|
| 5 |
+ |
|
| 6 |
+ "github.com/docker/docker/daemon/logger" |
|
| 7 |
+ "github.com/pkg/errors" |
|
| 8 |
+) |
|
| 9 |
+ |
|
| 10 |
+// MessageQueue is a queue for log messages. |
|
| 11 |
+// |
|
| 12 |
+// [MessageQueue.Enqueue] will block when the queue is full. |
|
| 13 |
+// To dequeue messages call [MessageQueue.Reciever] and pull messsages off the |
|
| 14 |
+// returned channel. |
|
| 15 |
+// |
|
| 16 |
+// Closing only prevents new messages from being added to the queue. |
|
| 17 |
+// The queue can still be drained after close. |
|
| 18 |
+// |
|
| 19 |
+// The zero value of MessageQueue is safe to use, but does not do any internal |
|
| 20 |
+// buffering (queue size is 0). |
|
| 21 |
+type MessageQueue struct {
|
|
| 22 |
+ maxSize int |
|
| 23 |
+ |
|
| 24 |
+ mu sync.Mutex |
|
| 25 |
+ closing bool |
|
| 26 |
+ closed chan struct{}
|
|
| 27 |
+ |
|
| 28 |
+ // Blocks multiple calls to [MessageQueue.Close] until the queue is actually closed |
|
| 29 |
+ closeWait chan struct{}
|
|
| 30 |
+ |
|
| 31 |
+ // We need to be able to safely close the send channel so that [MessageQueue.Dequeue] |
|
| 32 |
+ // can drain the queue without blocking. |
|
| 33 |
+ // This cond var helps deal with that. |
|
| 34 |
+ cond *sync.Cond |
|
| 35 |
+ sendWaiters int |
|
| 36 |
+ |
|
| 37 |
+ ch chan *logger.Message |
|
| 38 |
+} |
|
| 39 |
+ |
|
| 40 |
+// NewMessageQueue creates a new queue with the specified size. |
|
| 41 |
+func NewMessageQueue(maxSize int) *MessageQueue {
|
|
| 42 |
+ var q MessageQueue |
|
| 43 |
+ q.maxSize = maxSize |
|
| 44 |
+ q.init() |
|
| 45 |
+ return &q |
|
| 46 |
+} |
|
| 47 |
+ |
|
| 48 |
+func (q *MessageQueue) init() {
|
|
| 49 |
+ if q.cond == nil {
|
|
| 50 |
+ q.cond = sync.NewCond(&q.mu) |
|
| 51 |
+ } |
|
| 52 |
+ |
|
| 53 |
+ if q.ch == nil {
|
|
| 54 |
+ q.ch = make(chan *logger.Message, q.maxSize) |
|
| 55 |
+ } |
|
| 56 |
+ |
|
| 57 |
+ if q.closed == nil {
|
|
| 58 |
+ q.closed = make(chan struct{})
|
|
| 59 |
+ } |
|
| 60 |
+ |
|
| 61 |
+ if q.closeWait == nil {
|
|
| 62 |
+ q.closeWait = make(chan struct{})
|
|
| 63 |
+ } |
|
| 64 |
+} |
|
| 65 |
+ |
|
| 66 |
+var ErrQueueClosed = errors.New("queue is closed")
|
|
| 67 |
+ |
|
| 68 |
+// Enqueue adds the provided message to the queue. |
|
| 69 |
+// Enqueue blocks if the queue is full. |
|
| 70 |
+// |
|
| 71 |
+// The two possible error cases are: |
|
| 72 |
+// 1. The provided context is cancelled |
|
| 73 |
+// 2. [ErrQueueClosed] when the queue has been closed. |
|
| 74 |
+func (q *MessageQueue) Enqueue(ctx context.Context, m *logger.Message) error {
|
|
| 75 |
+ q.mu.Lock() |
|
| 76 |
+ q.init() |
|
| 77 |
+ |
|
| 78 |
+ // Increment the waiter count |
|
| 79 |
+ // This prevents the send channel from being closed while we are trying to send. |
|
| 80 |
+ q.sendWaiters++ |
|
| 81 |
+ q.mu.Unlock() |
|
| 82 |
+ |
|
| 83 |
+ defer func() {
|
|
| 84 |
+ q.mu.Lock() |
|
| 85 |
+ // Decrement the waiter count and signal to any potential closer to check |
|
| 86 |
+ // the wait count again. |
|
| 87 |
+ // Only bother signaling if this is the last waiter. |
|
| 88 |
+ q.sendWaiters-- |
|
| 89 |
+ if q.sendWaiters == 0 {
|
|
| 90 |
+ q.cond.Signal() |
|
| 91 |
+ } |
|
| 92 |
+ q.mu.Unlock() |
|
| 93 |
+ }() |
|
| 94 |
+ |
|
| 95 |
+ // Before trying to send on the channel, check if we care closed. |
|
| 96 |
+ select {
|
|
| 97 |
+ case <-ctx.Done(): |
|
| 98 |
+ return ctx.Err() |
|
| 99 |
+ case <-q.closed: |
|
| 100 |
+ return ErrQueueClosed |
|
| 101 |
+ default: |
|
| 102 |
+ } |
|
| 103 |
+ |
|
| 104 |
+ select {
|
|
| 105 |
+ case <-ctx.Done(): |
|
| 106 |
+ return ctx.Err() |
|
| 107 |
+ case <-q.closed: |
|
| 108 |
+ return ErrQueueClosed |
|
| 109 |
+ case q.ch <- m: |
|
| 110 |
+ return nil |
|
| 111 |
+ } |
|
| 112 |
+} |
|
| 113 |
+ |
|
| 114 |
+// Close prevents any new messages from being added to the queue. |
|
| 115 |
+func (q *MessageQueue) Close() {
|
|
| 116 |
+ q.mu.Lock() |
|
| 117 |
+ |
|
| 118 |
+ q.init() |
|
| 119 |
+ |
|
| 120 |
+ if q.closing {
|
|
| 121 |
+ // unlock the mutex here so that the goroutine waiting on the cond var can |
|
| 122 |
+ // take the lock when signaled. |
|
| 123 |
+ q.mu.Unlock() |
|
| 124 |
+ <-q.closeWait |
|
| 125 |
+ return |
|
| 126 |
+ } |
|
| 127 |
+ |
|
| 128 |
+ defer q.mu.Unlock() |
|
| 129 |
+ |
|
| 130 |
+ // Prevent multiple Close calls from trying to close things. |
|
| 131 |
+ q.closing = true |
|
| 132 |
+ |
|
| 133 |
+ close(q.closed) |
|
| 134 |
+ |
|
| 135 |
+ // Wait for any senders to finish |
|
| 136 |
+ // Because we closed the channel above, this shouldn't block for a long period. |
|
| 137 |
+ for q.sendWaiters > 0 {
|
|
| 138 |
+ q.cond.Wait() |
|
| 139 |
+ } |
|
| 140 |
+ |
|
| 141 |
+ close(q.ch) |
|
| 142 |
+ close(q.closeWait) |
|
| 143 |
+} |
|
| 144 |
+ |
|
| 145 |
+// Receiver returns a channel that can be used to dequeue messages |
|
| 146 |
+// The channel will be closed when the message queue is closed but may have |
|
| 147 |
+// messages buffered. |
|
| 148 |
+func (q *MessageQueue) Receiver() <-chan *logger.Message {
|
|
| 149 |
+ q.mu.Lock() |
|
| 150 |
+ defer q.mu.Unlock() |
|
| 151 |
+ |
|
| 152 |
+ q.init() |
|
| 153 |
+ |
|
| 154 |
+ return q.ch |
|
| 155 |
+} |
| 0 | 156 |
new file mode 100644 |
| ... | ... |
@@ -0,0 +1,87 @@ |
| 0 |
+package loggerutils |
|
| 1 |
+ |
|
| 2 |
+import ( |
|
| 3 |
+ "context" |
|
| 4 |
+ "testing" |
|
| 5 |
+ "time" |
|
| 6 |
+ |
|
| 7 |
+ "github.com/docker/docker/daemon/logger" |
|
| 8 |
+ "gotest.tools/v3/assert" |
|
| 9 |
+) |
|
| 10 |
+ |
|
| 11 |
+func TestQueue(t *testing.T) {
|
|
| 12 |
+ q := NewMessageQueue(2) |
|
| 13 |
+ msg := &logger.Message{Line: []byte("hello")}
|
|
| 14 |
+ |
|
| 15 |
+ ctx := context.Background() |
|
| 16 |
+ err := q.Enqueue(ctx, msg) |
|
| 17 |
+ assert.Check(t, err) |
|
| 18 |
+ |
|
| 19 |
+ recv := q.Receiver() |
|
| 20 |
+ // These pointer values should be the same |
|
| 21 |
+ assert.Equal(t, msg, <-recv) |
|
| 22 |
+ |
|
| 23 |
+ err = q.Enqueue(ctx, msg) |
|
| 24 |
+ assert.Check(t, err) |
|
| 25 |
+ |
|
| 26 |
+ err = q.Enqueue(ctx, msg) |
|
| 27 |
+ assert.Check(t, err) |
|
| 28 |
+ |
|
| 29 |
+ q.Close() |
|
| 30 |
+ |
|
| 31 |
+ // We have 2 messages in the queue |
|
| 32 |
+ // Even though this is closed, we should get a true value from dequeue twice. |
|
| 33 |
+ assert.Equal(t, msg, <-recv) |
|
| 34 |
+ assert.Equal(t, msg, <-recv) |
|
| 35 |
+ |
|
| 36 |
+ // This should not block and should return false |
|
| 37 |
+ _, more := <-recv |
|
| 38 |
+ assert.Check(t, !more, "expected no more messages in the queue") |
|
| 39 |
+ |
|
| 40 |
+ // Test with unbuffered |
|
| 41 |
+ q = &MessageQueue{}
|
|
| 42 |
+ recv = q.Receiver() |
|
| 43 |
+ |
|
| 44 |
+ chAdd := make(chan error, 1) |
|
| 45 |
+ go func() {
|
|
| 46 |
+ chAdd <- q.Enqueue(ctx, msg) |
|
| 47 |
+ }() |
|
| 48 |
+ |
|
| 49 |
+ assert.Equal(t, msg, <-recv) |
|
| 50 |
+ assert.Assert(t, <-chAdd) |
|
| 51 |
+ |
|
| 52 |
+ ctxC, cancel := context.WithCancel(ctx) |
|
| 53 |
+ cancel() |
|
| 54 |
+ |
|
| 55 |
+ err = q.Enqueue(ctxC, msg) |
|
| 56 |
+ assert.ErrorIs(t, err, context.Canceled) |
|
| 57 |
+ |
|
| 58 |
+ // Test that blocked senders do not cause a panic on close. |
|
| 59 |
+ // This check is useful because the underlying implementation uses channels |
|
| 60 |
+ // with the send channel eventually getting closed when q.Close is called. |
|
| 61 |
+ go func() {
|
|
| 62 |
+ chAdd <- q.Enqueue(ctx, msg) |
|
| 63 |
+ }() |
|
| 64 |
+ |
|
| 65 |
+ // Wait for enqueue to be ready (or as close to ready as it can be) |
|
| 66 |
+ for {
|
|
| 67 |
+ q.mu.Lock() |
|
| 68 |
+ if q.sendWaiters == 1 {
|
|
| 69 |
+ q.mu.Unlock() |
|
| 70 |
+ break |
|
| 71 |
+ } |
|
| 72 |
+ q.mu.Unlock() |
|
| 73 |
+ time.Sleep(time.Millisecond) |
|
| 74 |
+ } |
|
| 75 |
+ |
|
| 76 |
+ q.Close() |
|
| 77 |
+ |
|
| 78 |
+ select {
|
|
| 79 |
+ case <-time.After(5 * time.Second): |
|
| 80 |
+ case err := <-chAdd: |
|
| 81 |
+ assert.ErrorIs(t, err, ErrQueueClosed) |
|
| 82 |
+ } |
|
| 83 |
+ |
|
| 84 |
+ // Double-close should not cause any issues |
|
| 85 |
+ q.Close() |
|
| 86 |
+} |