Browse code

daemon/logger: fix refcounting decompressed files

The refCounter used for sharing temporary decompressed log files and
tracking when the files can be deleted is keyed off the source file's
path. But the path of a log file is not stable: it is renamed on each
rotation. Consequently, when logging is configured with both rotation
and compression, multiple concurrent readers of a container's logs could
read logs out of order, see duplicates or decompress a log file which
has already been decompressed.

Replace refCounter with a new implementation, sharedTempFileConverter,
which is agnostic to the file path, keying off the source file's
identity instead. Additionally, sharedTempFileConverter handles the full
lifecycle of the temporary file, from creation to deletion. This is all
abstracted from the consumer: all the bookkeeping and cleanup is handled
behind the scenes when Close() is called on the returned reader value.
Only one file descriptor is used per temporary file, which is shared by
all readers.

A channel is used for concurrency control so that the lock can be
acquired inside a select statement. While not currently utilized, this
makes it possible to add support for cancellation to
sharedTempFileConverter in the future.

Signed-off-by: Cory Snider <csnider@mirantis.com>

Cory Snider authored on 2022/02/24 03:44:15
Showing 6 changed files
... ...
@@ -1,11 +1,16 @@
1 1
 package jsonfilelog // import "github.com/docker/docker/daemon/logger/jsonfilelog"
2 2
 
3 3
 import (
4
+	"bufio"
4 5
 	"bytes"
5 6
 	"encoding/json"
7
+	"fmt"
6 8
 	"io"
9
+	"os"
7 10
 	"path/filepath"
11
+	"strconv"
8 12
 	"testing"
13
+	"text/tabwriter"
9 14
 	"time"
10 15
 
11 16
 	"github.com/docker/docker/daemon/logger"
... ...
@@ -142,7 +147,8 @@ func TestUnexpectedEOF(t *testing.T) {
142 142
 }
143 143
 
144 144
 func TestReadLogs(t *testing.T) {
145
-	loggertest.Reader{
145
+	t.Parallel()
146
+	r := loggertest.Reader{
146 147
 		Factory: func(t *testing.T, info logger.Info) func(*testing.T) logger.Logger {
147 148
 			dir := t.TempDir()
148 149
 			info.LogPath = filepath.Join(dir, info.ContainerID+".log")
... ...
@@ -152,7 +158,67 @@ func TestReadLogs(t *testing.T) {
152 152
 				return l
153 153
 			}
154 154
 		},
155
-	}.Do(t)
155
+	}
156
+	t.Run("Tail", r.TestTail)
157
+	t.Run("Follow", r.TestFollow)
158
+}
159
+
160
+func TestTailLogsWithRotation(t *testing.T) {
161
+	t.Parallel()
162
+	compress := func(cmprs bool) {
163
+		t.Run(fmt.Sprintf("compress=%v", cmprs), func(t *testing.T) {
164
+			t.Parallel()
165
+			(&loggertest.Reader{
166
+				Factory: func(t *testing.T, info logger.Info) func(*testing.T) logger.Logger {
167
+					info.Config = map[string]string{
168
+						"compress": strconv.FormatBool(cmprs),
169
+						"max-size": "1b",
170
+						"max-file": "10",
171
+					}
172
+					dir := t.TempDir()
173
+					t.Cleanup(func() {
174
+						t.Logf("%s:\n%s", t.Name(), dirStringer{dir})
175
+					})
176
+					info.LogPath = filepath.Join(dir, info.ContainerID+".log")
177
+					return func(t *testing.T) logger.Logger {
178
+						l, err := New(info)
179
+						assert.NilError(t, err)
180
+						return l
181
+					}
182
+				},
183
+			}).TestTail(t)
184
+		})
185
+	}
186
+	compress(true)
187
+	compress(false)
188
+}
189
+
190
+type dirStringer struct {
191
+	d string
192
+}
193
+
194
+func (d dirStringer) String() string {
195
+	ls, err := os.ReadDir(d.d)
196
+	if err != nil {
197
+		return ""
198
+	}
199
+	buf := bytes.NewBuffer(nil)
200
+	tw := tabwriter.NewWriter(buf, 1, 8, 1, '\t', 0)
201
+	buf.WriteString("\n")
202
+
203
+	btw := bufio.NewWriter(tw)
204
+
205
+	for _, entry := range ls {
206
+		fi, err := entry.Info()
207
+		if err != nil {
208
+			return ""
209
+		}
210
+
211
+		btw.WriteString(fmt.Sprintf("%s\t%s\t%dB\t%s\n", fi.Name(), fi.Mode(), fi.Size(), fi.ModTime()))
212
+	}
213
+	btw.Flush()
214
+	tw.Flush()
215
+	return buf.String()
156 216
 }
157 217
 
158 218
 type readerWithErr struct {
... ...
@@ -79,7 +79,7 @@ func TestWriteLog(t *testing.T) {
79 79
 }
80 80
 
81 81
 func TestReadLog(t *testing.T) {
82
-	loggertest.Reader{
82
+	r := loggertest.Reader{
83 83
 		Factory: func(t *testing.T, info logger.Info) func(*testing.T) logger.Logger {
84 84
 			dir := t.TempDir()
85 85
 			info.LogPath = filepath.Join(dir, info.ContainerID+".log")
... ...
@@ -89,7 +89,9 @@ func TestReadLog(t *testing.T) {
89 89
 				return l
90 90
 			}
91 91
 		},
92
-	}.Do(t)
92
+	}
93
+	t.Run("Tail", r.TestTail)
94
+	t.Run("Follow", r.TestFollow)
93 95
 }
94 96
 
95 97
 func BenchmarkLogWrite(b *testing.B) {
... ...
@@ -1,6 +1,7 @@
1 1
 package loggertest // import "github.com/docker/docker/daemon/logger/loggertest"
2 2
 
3 3
 import (
4
+	"runtime"
4 5
 	"strings"
5 6
 	"testing"
6 7
 	"time"
... ...
@@ -29,13 +30,12 @@ var compareLog cmp.Options = []cmp.Option{
29 29
 	cmp.Transformer("string", func(b []byte) string { return string(b) }),
30 30
 }
31 31
 
32
-// Do tests the behavior of the LogReader implementation.
33
-func (tr Reader) Do(t *testing.T) {
34
-	t.Run("Live/Tail", func(t *testing.T) { tr.testTail(t, true) })
35
-	t.Run("Live/TailEmpty", func(t *testing.T) { tr.testTailEmptyLogs(t, true) })
36
-	t.Run("Live/Follow", tr.testFollow)
37
-	t.Run("Stopped/Tail", func(t *testing.T) { tr.testTail(t, false) })
38
-	t.Run("Stopped/TailEmpty", func(t *testing.T) { tr.testTailEmptyLogs(t, false) })
32
+// TestTail tests the behavior of the LogReader's tail implementation.
33
+func (tr Reader) TestTail(t *testing.T) {
34
+	t.Run("Live", func(t *testing.T) { tr.testTail(t, true) })
35
+	t.Run("LiveEmpty", func(t *testing.T) { tr.testTailEmptyLogs(t, true) })
36
+	t.Run("Stopped", func(t *testing.T) { tr.testTail(t, false) })
37
+	t.Run("StoppedEmpty", func(t *testing.T) { tr.testTailEmptyLogs(t, false) })
39 38
 }
40 39
 
41 40
 func makeTestMessages() []*logger.Message {
... ...
@@ -170,8 +170,11 @@ func (tr Reader) testTailEmptyLogs(t *testing.T, live bool) {
170 170
 	}
171 171
 }
172 172
 
173
-func (tr Reader) testFollow(t *testing.T) {
174
-	t.Parallel()
173
+// TestFollow tests the LogReader's follow implementation.
174
+//
175
+// The LogReader is expected to be able to follow an arbitrary number of
176
+// messages at a high rate with no dropped messages.
177
+func (tr Reader) TestFollow(t *testing.T) {
175 178
 	// Reader sends all logs and closes after logger is closed
176 179
 	// - Starting from empty log (like run)
177 180
 	t.Run("FromEmptyLog", func(t *testing.T) {
... ...
@@ -390,6 +393,7 @@ func logMessages(t *testing.T, l logger.Logger, messages []*logger.Message) []*l
390 390
 		// Copy the log message because the underlying log writer resets
391 391
 		// the log message and returns it to a buffer pool.
392 392
 		assert.NilError(t, l.Log(copyLogMessage(m)))
393
+		runtime.Gosched()
393 394
 
394 395
 		// Copy the log message again so as not to mutate the input.
395 396
 		expect := copyLogMessage(m)
... ...
@@ -437,6 +441,9 @@ func readMessage(t *testing.T, lw *logger.LogWatcher) *logger.Message {
437 437
 			default:
438 438
 			}
439 439
 		}
440
+		if msg != nil {
441
+			t.Logf("loggertest: ReadMessage [%v %v] %s", msg.Source, msg.Timestamp, msg.Line)
442
+		}
440 443
 		return msg
441 444
 	}
442 445
 }
... ...
@@ -10,7 +10,6 @@ import (
10 10
 	"os"
11 11
 	"runtime"
12 12
 	"strconv"
13
-	"strings"
14 13
 	"sync"
15 14
 	"time"
16 15
 
... ...
@@ -22,77 +21,29 @@ import (
22 22
 	"github.com/sirupsen/logrus"
23 23
 )
24 24
 
25
-const tmpLogfileSuffix = ".tmp"
26
-
27 25
 // rotateFileMetadata is a metadata of the gzip header of the compressed log file
28 26
 type rotateFileMetadata struct {
29 27
 	LastTime time.Time `json:"lastTime,omitempty"`
30 28
 }
31 29
 
32
-// refCounter is a counter of logfile being referenced
33
-type refCounter struct {
34
-	mu      sync.Mutex
35
-	counter map[string]int
36
-}
37
-
38
-// Reference increase the reference counter for specified logfile
39
-func (rc *refCounter) GetReference(fileName string, openRefFile func(fileName string, exists bool) (*os.File, error)) (*os.File, error) {
40
-	rc.mu.Lock()
41
-	defer rc.mu.Unlock()
42
-
43
-	var (
44
-		file *os.File
45
-		err  error
46
-	)
47
-	_, ok := rc.counter[fileName]
48
-	file, err = openRefFile(fileName, ok)
49
-	if err != nil {
50
-		return nil, err
51
-	}
52
-
53
-	if ok {
54
-		rc.counter[fileName]++
55
-	} else if file != nil {
56
-		rc.counter[file.Name()] = 1
57
-	}
58
-
59
-	return file, nil
60
-}
61
-
62
-// Dereference reduce the reference counter for specified logfile
63
-func (rc *refCounter) Dereference(fileName string) error {
64
-	rc.mu.Lock()
65
-	defer rc.mu.Unlock()
66
-
67
-	rc.counter[fileName]--
68
-	if rc.counter[fileName] <= 0 {
69
-		delete(rc.counter, fileName)
70
-		err := unlink(fileName)
71
-		if err != nil && !errors.Is(err, fs.ErrNotExist) {
72
-			return err
73
-		}
74
-	}
75
-	return nil
76
-}
77
-
78 30
 // LogFile is Logger implementation for default Docker logging.
79 31
 type LogFile struct {
80
-	mu              sync.RWMutex // protects the logfile access
81
-	f               *os.File     // store for closing
82
-	closed          bool
83
-	closedCh        chan struct{}
84
-	rotateMu        sync.Mutex // blocks the next rotation until the current rotation is completed
85
-	capacity        int64      // maximum size of each file
86
-	currentSize     int64      // current size of the latest file
87
-	maxFiles        int        // maximum number of files
88
-	compress        bool       // whether old versions of log files are compressed
89
-	lastTimestamp   time.Time  // timestamp of the last log
90
-	filesRefCounter refCounter // keep reference-counted of decompressed files
91
-	notifyReaders   *pubsub.Publisher
92
-	marshal         logger.MarshalFunc
93
-	createDecoder   MakeDecoderFn
94
-	getTailReader   GetTailReaderFunc
95
-	perms           os.FileMode
32
+	mu            sync.RWMutex // protects the logfile access
33
+	f             *os.File     // store for closing
34
+	closed        bool
35
+	closedCh      chan struct{}
36
+	rotateMu      sync.Mutex               // blocks the next rotation until the current rotation is completed
37
+	capacity      int64                    // maximum size of each file
38
+	currentSize   int64                    // current size of the latest file
39
+	maxFiles      int                      // maximum number of files
40
+	compress      bool                     // whether old versions of log files are compressed
41
+	lastTimestamp time.Time                // timestamp of the last log
42
+	decompress    *sharedTempFileConverter // keep reference-counted decompressed files
43
+	notifyReaders *pubsub.Publisher
44
+	marshal       logger.MarshalFunc
45
+	createDecoder MakeDecoderFn
46
+	getTailReader GetTailReaderFunc
47
+	perms         os.FileMode
96 48
 }
97 49
 
98 50
 // MakeDecoderFn creates a decoder
... ...
@@ -113,10 +64,16 @@ type Decoder interface {
113 113
 // SizeReaderAt defines a ReaderAt that also reports its size.
114 114
 // This is used for tailing log files.
115 115
 type SizeReaderAt interface {
116
+	io.Reader
116 117
 	io.ReaderAt
117 118
 	Size() int64
118 119
 }
119 120
 
121
+type readAtCloser interface {
122
+	io.ReaderAt
123
+	io.Closer
124
+}
125
+
120 126
 // GetTailReaderFunc is used to truncate a reader to only read as much as is required
121 127
 // in order to get the passed in number of log lines.
122 128
 // It returns the sectioned reader, the number of lines that the section reader
... ...
@@ -136,18 +93,18 @@ func NewLogFile(logPath string, capacity int64, maxFiles int, compress bool, mar
136 136
 	}
137 137
 
138 138
 	return &LogFile{
139
-		f:               log,
140
-		closedCh:        make(chan struct{}),
141
-		capacity:        capacity,
142
-		currentSize:     size,
143
-		maxFiles:        maxFiles,
144
-		compress:        compress,
145
-		filesRefCounter: refCounter{counter: make(map[string]int)},
146
-		notifyReaders:   pubsub.NewPublisher(0, 1),
147
-		marshal:         marshaller,
148
-		createDecoder:   decodeFunc,
149
-		perms:           perms,
150
-		getTailReader:   getTailReader,
139
+		f:             log,
140
+		closedCh:      make(chan struct{}),
141
+		capacity:      capacity,
142
+		currentSize:   size,
143
+		maxFiles:      maxFiles,
144
+		compress:      compress,
145
+		decompress:    newSharedTempFileConverter(decompress),
146
+		notifyReaders: pubsub.NewPublisher(0, 1),
147
+		marshal:       marshaller,
148
+		createDecoder: decodeFunc,
149
+		perms:         perms,
150
+		getTailReader: getTailReader,
151 151
 	}, nil
152 152
 }
153 153
 
... ...
@@ -411,25 +368,25 @@ func (w *LogFile) readLogsLocked(config logger.ReadConfig, watcher *logger.LogWa
411 411
 		closeFiles := func() {
412 412
 			for _, f := range files {
413 413
 				f.Close()
414
-				fileName := f.Name()
415
-				if strings.HasSuffix(fileName, tmpLogfileSuffix) {
416
-					err := w.filesRefCounter.Dereference(fileName)
417
-					if err != nil {
418
-						logrus.WithError(err).WithField("file", fileName).Error("Failed to dereference the log file")
419
-					}
420
-				}
421 414
 			}
422 415
 		}
423 416
 
424 417
 		readers := make([]SizeReaderAt, 0, len(files)+1)
425 418
 		for _, f := range files {
426
-			stat, err := f.Stat()
427
-			if err != nil {
428
-				watcher.Err <- errors.Wrap(err, "error reading size of rotated file")
429
-				closeFiles()
430
-				return
419
+			switch ff := f.(type) {
420
+			case SizeReaderAt:
421
+				readers = append(readers, ff)
422
+			case interface{ Stat() (fs.FileInfo, error) }:
423
+				stat, err := ff.Stat()
424
+				if err != nil {
425
+					watcher.Err <- errors.Wrap(err, "error reading size of rotated file")
426
+					closeFiles()
427
+					return
428
+				}
429
+				readers = append(readers, io.NewSectionReader(f, 0, stat.Size()))
430
+			default:
431
+				panic(fmt.Errorf("rotated file value %#v (%[1]T) has neither Size() nor Stat() methods", f))
431 432
 			}
432
-			readers = append(readers, io.NewSectionReader(f, 0, stat.Size()))
433 433
 		}
434 434
 		if currentChunk.Size() > 0 {
435 435
 			readers = append(readers, currentChunk)
... ...
@@ -457,7 +414,8 @@ func (w *LogFile) readLogsLocked(config logger.ReadConfig, watcher *logger.LogWa
457 457
 	followLogs(currentFile, watcher, w.closedCh, notifyRotate, notifyEvict, dec, config.Since, config.Until)
458 458
 }
459 459
 
460
-func (w *LogFile) openRotatedFiles(config logger.ReadConfig) (files []*os.File, err error) {
460
+// openRotatedFiles returns a slice of files open for reading, in order from oldest to newest.
461
+func (w *LogFile) openRotatedFiles(config logger.ReadConfig) (files []readAtCloser, err error) {
461 462
 	w.rotateMu.Lock()
462 463
 	defer w.rotateMu.Unlock()
463 464
 
... ...
@@ -467,44 +425,27 @@ func (w *LogFile) openRotatedFiles(config logger.ReadConfig) (files []*os.File,
467 467
 		}
468 468
 		for _, f := range files {
469 469
 			f.Close()
470
-			if strings.HasSuffix(f.Name(), tmpLogfileSuffix) {
471
-				err := unlink(f.Name())
472
-				if err != nil && !errors.Is(err, fs.ErrNotExist) {
473
-					logrus.Warnf("Failed to remove logfile: %v", err)
474
-				}
475
-			}
476 470
 		}
477 471
 	}()
478 472
 
479 473
 	for i := w.maxFiles; i > 1; i-- {
480
-		f, err := open(fmt.Sprintf("%s.%d", w.f.Name(), i-1))
474
+		var f readAtCloser
475
+		f, err = open(fmt.Sprintf("%s.%d", w.f.Name(), i-1))
481 476
 		if err != nil {
482 477
 			if !errors.Is(err, fs.ErrNotExist) {
483 478
 				return nil, errors.Wrap(err, "error opening rotated log file")
484 479
 			}
485 480
 
486
-			fileName := fmt.Sprintf("%s.%d.gz", w.f.Name(), i-1)
487
-			decompressedFileName := fileName + tmpLogfileSuffix
488
-			tmpFile, err := w.filesRefCounter.GetReference(decompressedFileName, func(refFileName string, exists bool) (*os.File, error) {
489
-				if exists {
490
-					return open(refFileName)
491
-				}
492
-				return decompressfile(fileName, refFileName, config.Since)
493
-			})
494
-
481
+			f, err = w.maybeDecompressFile(fmt.Sprintf("%s.%d.gz", w.f.Name(), i-1), config)
495 482
 			if err != nil {
496 483
 				if !errors.Is(err, fs.ErrNotExist) {
497
-					return nil, errors.Wrap(err, "error getting reference to decompressed log file")
484
+					return nil, err
498 485
 				}
499 486
 				continue
500
-			}
501
-			if tmpFile == nil {
487
+			} else if f == nil {
502 488
 				// The log before `config.Since` does not need to read
503
-				break
489
+				continue
504 490
 			}
505
-
506
-			files = append(files, tmpFile)
507
-			continue
508 491
 		}
509 492
 		files = append(files, f)
510 493
 	}
... ...
@@ -512,7 +453,7 @@ func (w *LogFile) openRotatedFiles(config logger.ReadConfig) (files []*os.File,
512 512
 	return files, nil
513 513
 }
514 514
 
515
-func decompressfile(fileName, destFileName string, since time.Time) (*os.File, error) {
515
+func (w *LogFile) maybeDecompressFile(fileName string, config logger.ReadConfig) (readAtCloser, error) {
516 516
 	cf, err := open(fileName)
517 517
 	if err != nil {
518 518
 		return nil, errors.Wrap(err, "error opening file for decompression")
... ...
@@ -528,26 +469,26 @@ func decompressfile(fileName, destFileName string, since time.Time) (*os.File, e
528 528
 	// Extract the last log entry timestramp from the gzip header
529 529
 	extra := &rotateFileMetadata{}
530 530
 	err = json.Unmarshal(rc.Header.Extra, extra)
531
-	if err == nil && extra.LastTime.Before(since) {
531
+	if err == nil && !extra.LastTime.IsZero() && extra.LastTime.Before(config.Since) {
532 532
 		return nil, nil
533 533
 	}
534
+	tmpf, err := w.decompress.Do(cf)
535
+	return tmpf, errors.Wrap(err, "error decompressing log file")
536
+}
534 537
 
535
-	rs, err := openFile(destFileName, os.O_CREATE|os.O_RDWR, 0640)
538
+func decompress(dst io.WriteSeeker, src io.ReadSeeker) error {
539
+	if _, err := src.Seek(0, io.SeekStart); err != nil {
540
+		return err
541
+	}
542
+	rc, err := gzip.NewReader(src)
536 543
 	if err != nil {
537
-		return nil, errors.Wrap(err, "error creating file for copying decompressed log stream")
544
+		return err
538 545
 	}
539
-
540
-	_, err = pools.Copy(rs, rc)
546
+	_, err = pools.Copy(dst, rc)
541 547
 	if err != nil {
542
-		rs.Close()
543
-		rErr := unlink(rs.Name())
544
-		if rErr != nil && !errors.Is(rErr, fs.ErrNotExist) {
545
-			logrus.Errorf("Failed to remove logfile: %v", rErr)
546
-		}
547
-		return nil, errors.Wrap(err, "error while copying decompressed log stream to file")
548
+		return err
548 549
 	}
549
-
550
-	return rs, nil
550
+	return rc.Close()
551 551
 }
552 552
 
553 553
 func newSectionReader(f *os.File) (*io.SectionReader, error) {
... ...
@@ -597,7 +538,7 @@ func tailFiles(files []SizeReaderAt, watcher *logger.LogWatcher, dec Decoder, ge
597 597
 		}
598 598
 	} else {
599 599
 		for _, r := range files {
600
-			readers = append(readers, &wrappedReaderAt{ReaderAt: r})
600
+			readers = append(readers, r)
601 601
 		}
602 602
 	}
603 603
 
... ...
@@ -663,14 +604,3 @@ func watchFile(name string) (filenotify.FileWatcher, error) {
663 663
 
664 664
 	return fileWatcher, nil
665 665
 }
666
-
667
-type wrappedReaderAt struct {
668
-	io.ReaderAt
669
-	pos int64
670
-}
671
-
672
-func (r *wrappedReaderAt) Read(p []byte) (int, error) {
673
-	n, err := r.ReaderAt.ReadAt(p, r.pos)
674
-	r.pos += int64(n)
675
-	return n, err
676
-}
677 666
new file mode 100644
... ...
@@ -0,0 +1,227 @@
0
+package loggerutils // import "github.com/docker/docker/daemon/logger/loggerutils"
1
+
2
+import (
3
+	"io"
4
+	"io/fs"
5
+	"os"
6
+	"runtime"
7
+)
8
+
9
+type fileConvertFn func(dst io.WriteSeeker, src io.ReadSeeker) error
10
+
11
+type stfID uint64
12
+
13
+// sharedTempFileConverter converts files using a user-supplied function and
14
+// writes the results to temporary files which are automatically cleaned up on
15
+// close. If another request is made to convert the same file, the conversion
16
+// result and temporary file are reused if they have not yet been cleaned up.
17
+//
18
+// A file is considered the same as another file using the os.SameFile function,
19
+// which compares file identity (e.g. device and inode numbers on Linux) and is
20
+// robust to file renames. Input files are assumed to be immutable; no attempt
21
+// is made to ascertain whether the file contents have changed between requests.
22
+//
23
+// One file descriptor is used per source file, irrespective of the number of
24
+// concurrent readers of the converted contents.
25
+type sharedTempFileConverter struct {
26
+	// The directory where temporary converted files are to be written to.
27
+	// If set to the empty string, the default directory for temporary files
28
+	// is used.
29
+	TempDir string
30
+
31
+	conv fileConvertFn
32
+	st   chan stfcState
33
+}
34
+
35
+type stfcState struct {
36
+	fl     map[stfID]sharedTempFile
37
+	nextID stfID
38
+}
39
+
40
+type sharedTempFile struct {
41
+	src  os.FileInfo // Info about the source file for path-independent identification with os.SameFile.
42
+	fd   *os.File
43
+	size int64
44
+	ref  int                       // Reference count of open readers on the temporary file.
45
+	wait []chan<- stfConvertResult // Wait list for the conversion to complete.
46
+}
47
+
48
+type stfConvertResult struct {
49
+	fr  *sharedFileReader
50
+	err error
51
+}
52
+
53
+func newSharedTempFileConverter(conv fileConvertFn) *sharedTempFileConverter {
54
+	st := make(chan stfcState, 1)
55
+	st <- stfcState{fl: make(map[stfID]sharedTempFile)}
56
+	return &sharedTempFileConverter{conv: conv, st: st}
57
+}
58
+
59
+// Do returns a reader for the contents of f as converted by the c.C function.
60
+// It is the caller's responsibility to close the returned reader.
61
+//
62
+// This function is safe for concurrent use by multiple goroutines.
63
+func (c *sharedTempFileConverter) Do(f *os.File) (*sharedFileReader, error) {
64
+	stat, err := f.Stat()
65
+	if err != nil {
66
+		return nil, err
67
+	}
68
+
69
+	st := <-c.st
70
+	for id, tf := range st.fl {
71
+		// os.SameFile can have false positives if one of the files was
72
+		// deleted before the other file was created -- such as during
73
+		// log rotations... https://github.com/golang/go/issues/36895
74
+		// Weed out those false positives by also comparing the files'
75
+		// ModTime, which conveniently also handles the case of true
76
+		// positives where the file has also been modified since it was
77
+		// first converted.
78
+		if os.SameFile(tf.src, stat) && tf.src.ModTime() == stat.ModTime() {
79
+			return c.openExisting(st, id, tf)
80
+		}
81
+	}
82
+	return c.openNew(st, f, stat)
83
+}
84
+
85
+func (c *sharedTempFileConverter) openNew(st stfcState, f *os.File, stat os.FileInfo) (*sharedFileReader, error) {
86
+	// Record that we are starting to convert this file so that any other
87
+	// requests for the same source file while the conversion is in progress
88
+	// can join.
89
+	id := st.nextID
90
+	st.nextID++
91
+	st.fl[id] = sharedTempFile{src: stat}
92
+	c.st <- st
93
+
94
+	dst, size, convErr := c.convert(f)
95
+
96
+	st = <-c.st
97
+	flid := st.fl[id]
98
+
99
+	if convErr != nil {
100
+		// Conversion failed. Delete it from the state so that future
101
+		// requests to convert the same file can try again fresh.
102
+		delete(st.fl, id)
103
+		c.st <- st
104
+		for _, w := range flid.wait {
105
+			w <- stfConvertResult{err: convErr}
106
+		}
107
+		return nil, convErr
108
+	}
109
+
110
+	flid.fd = dst
111
+	flid.size = size
112
+	flid.ref = len(flid.wait) + 1
113
+	for _, w := range flid.wait {
114
+		// Each waiter needs its own reader with an independent read pointer.
115
+		w <- stfConvertResult{fr: flid.Reader(c, id)}
116
+	}
117
+	flid.wait = nil
118
+	st.fl[id] = flid
119
+	c.st <- st
120
+	return flid.Reader(c, id), nil
121
+}
122
+
123
+func (c *sharedTempFileConverter) openExisting(st stfcState, id stfID, v sharedTempFile) (*sharedFileReader, error) {
124
+	if v.fd != nil {
125
+		// Already converted.
126
+		v.ref++
127
+		st.fl[id] = v
128
+		c.st <- st
129
+		return v.Reader(c, id), nil
130
+	}
131
+	// The file has not finished being converted.
132
+	// Add ourselves to the wait list. "Don't call us; we'll call you."
133
+	wait := make(chan stfConvertResult, 1)
134
+	v.wait = append(v.wait, wait)
135
+	st.fl[id] = v
136
+	c.st <- st
137
+
138
+	res := <-wait
139
+	return res.fr, res.err
140
+
141
+}
142
+
143
+func (c *sharedTempFileConverter) convert(f *os.File) (converted *os.File, size int64, err error) {
144
+	dst, err := os.CreateTemp(c.TempDir, "dockerdtemp.*")
145
+	if err != nil {
146
+		return nil, 0, err
147
+	}
148
+	defer func() {
149
+		_ = dst.Close()
150
+		// Delete the temporary file immediately so that final cleanup
151
+		// of the file on disk is deferred to the OS once we close all
152
+		// our file descriptors (or the process dies). Assuming no early
153
+		// returns due to errors, the file will be open by this process
154
+		// with a read-only descriptor at this point. As we don't care
155
+		// about being able to reuse the file name -- it's randomly
156
+		// generated and unique -- we can safely use os.Remove on
157
+		// Windows.
158
+		_ = os.Remove(dst.Name())
159
+	}()
160
+	err = c.conv(dst, f)
161
+	if err != nil {
162
+		return nil, 0, err
163
+	}
164
+	// Close the exclusive read-write file descriptor, catching any delayed
165
+	// write errors (and on Windows, releasing the share-locks on the file)
166
+	if err := dst.Close(); err != nil {
167
+		_ = os.Remove(dst.Name())
168
+		return nil, 0, err
169
+	}
170
+	// Open the file again read-only (without locking the file against
171
+	// deletion on Windows).
172
+	converted, err = open(dst.Name())
173
+	if err != nil {
174
+		return nil, 0, err
175
+	}
176
+
177
+	// The position of the file's read pointer doesn't matter as all readers
178
+	// will be accessing the file through its io.ReaderAt interface.
179
+	size, err = converted.Seek(0, io.SeekEnd)
180
+	if err != nil {
181
+		_ = converted.Close()
182
+		return nil, 0, err
183
+	}
184
+	return converted, size, nil
185
+}
186
+
187
+type sharedFileReader struct {
188
+	*io.SectionReader
189
+
190
+	c      *sharedTempFileConverter
191
+	id     stfID
192
+	closed bool
193
+}
194
+
195
+func (stf sharedTempFile) Reader(c *sharedTempFileConverter, id stfID) *sharedFileReader {
196
+	rdr := &sharedFileReader{SectionReader: io.NewSectionReader(stf.fd, 0, stf.size), c: c, id: id}
197
+	runtime.SetFinalizer(rdr, (*sharedFileReader).Close)
198
+	return rdr
199
+}
200
+
201
+func (r *sharedFileReader) Close() error {
202
+	if r.closed {
203
+		return fs.ErrClosed
204
+	}
205
+
206
+	st := <-r.c.st
207
+	flid, ok := st.fl[r.id]
208
+	if !ok {
209
+		panic("invariant violation: temp file state missing from map")
210
+	}
211
+	flid.ref--
212
+	lastRef := flid.ref <= 0
213
+	if lastRef {
214
+		delete(st.fl, r.id)
215
+	} else {
216
+		st.fl[r.id] = flid
217
+	}
218
+	r.closed = true
219
+	r.c.st <- st
220
+
221
+	if lastRef {
222
+		return flid.fd.Close()
223
+	}
224
+	runtime.SetFinalizer(r, nil)
225
+	return nil
226
+}
0 227
new file mode 100644
... ...
@@ -0,0 +1,256 @@
0
+package loggerutils // import "github.com/docker/docker/daemon/logger/loggerutils"
1
+
2
+import (
3
+	"io"
4
+	"io/fs"
5
+	"os"
6
+	"path/filepath"
7
+	"runtime"
8
+	"strings"
9
+	"sync"
10
+	"sync/atomic"
11
+	"testing"
12
+	"time"
13
+
14
+	"github.com/pkg/errors"
15
+	"gotest.tools/v3/assert"
16
+	"gotest.tools/v3/assert/cmp"
17
+)
18
+
19
+func TestSharedTempFileConverter(t *testing.T) {
20
+	t.Parallel()
21
+
22
+	t.Run("OneReaderAtATime", func(t *testing.T) {
23
+		t.Parallel()
24
+		dir := t.TempDir()
25
+		name := filepath.Join(dir, "test.txt")
26
+		createFile(t, name, "hello, world!")
27
+
28
+		uut := newSharedTempFileConverter(copyTransform(strings.ToUpper))
29
+		uut.TempDir = dir
30
+
31
+		for i := 0; i < 3; i++ {
32
+			t.Logf("Iteration %v", i)
33
+
34
+			rdr := convertPath(t, uut, name)
35
+			assert.Check(t, cmp.Equal("HELLO, WORLD!", readAll(t, rdr)))
36
+			assert.Check(t, rdr.Close())
37
+			assert.Check(t, cmp.Equal(fs.ErrClosed, rdr.Close()), "closing an already-closed reader should return an error")
38
+		}
39
+
40
+		assert.NilError(t, os.Remove(name))
41
+		checkDirEmpty(t, dir)
42
+	})
43
+
44
+	t.Run("RobustToRenames", func(t *testing.T) {
45
+		t.Parallel()
46
+		dir := t.TempDir()
47
+		apath := filepath.Join(dir, "test.txt")
48
+		createFile(t, apath, "file a")
49
+
50
+		var conversions int
51
+		uut := newSharedTempFileConverter(
52
+			func(dst io.WriteSeeker, src io.ReadSeeker) error {
53
+				conversions++
54
+				return copyTransform(strings.ToUpper)(dst, src)
55
+			},
56
+		)
57
+		uut.TempDir = dir
58
+
59
+		ra1 := convertPath(t, uut, apath)
60
+
61
+		// Rotate the file to a new name and write a new file in its place.
62
+		bpath := apath
63
+		apath = filepath.Join(dir, "test2.txt")
64
+		assert.NilError(t, os.Rename(bpath, apath))
65
+		createFile(t, bpath, "file b")
66
+
67
+		rb1 := convertPath(t, uut, bpath) // Same path, different file.
68
+		ra2 := convertPath(t, uut, apath) // New path, old file.
69
+		assert.Check(t, cmp.Equal(2, conversions), "expected only one conversion per unique file")
70
+
71
+		// Interleave reading and closing to shake out ref-counting bugs:
72
+		// closing one reader shouldn't affect any other open readers.
73
+		assert.Check(t, cmp.Equal("FILE A", readAll(t, ra1)))
74
+		assert.NilError(t, ra1.Close())
75
+		assert.Check(t, cmp.Equal("FILE A", readAll(t, ra2)))
76
+		assert.NilError(t, ra2.Close())
77
+		assert.Check(t, cmp.Equal("FILE B", readAll(t, rb1)))
78
+		assert.NilError(t, rb1.Close())
79
+
80
+		assert.NilError(t, os.Remove(apath))
81
+		assert.NilError(t, os.Remove(bpath))
82
+		checkDirEmpty(t, dir)
83
+	})
84
+
85
+	t.Run("ConcurrentRequests", func(t *testing.T) {
86
+		t.Parallel()
87
+		dir := t.TempDir()
88
+		name := filepath.Join(dir, "test.txt")
89
+		createFile(t, name, "hi there")
90
+
91
+		var conversions int32
92
+		notify := make(chan chan struct{}, 1)
93
+		firstConversionStarted := make(chan struct{})
94
+		notify <- firstConversionStarted
95
+		unblock := make(chan struct{})
96
+		uut := newSharedTempFileConverter(
97
+			func(dst io.WriteSeeker, src io.ReadSeeker) error {
98
+				t.Log("Convert: enter")
99
+				defer t.Log("Convert: exit")
100
+				select {
101
+				case c := <-notify:
102
+					close(c)
103
+				default:
104
+				}
105
+				<-unblock
106
+				atomic.AddInt32(&conversions, 1)
107
+				return copyTransform(strings.ToUpper)(dst, src)
108
+			},
109
+		)
110
+		uut.TempDir = dir
111
+
112
+		closers := make(chan io.Closer, 4)
113
+		var wg sync.WaitGroup
114
+		wg.Add(3)
115
+		for i := 0; i < 3; i++ {
116
+			i := i
117
+			go func() {
118
+				defer wg.Done()
119
+				t.Logf("goroutine %v: enter", i)
120
+				defer t.Logf("goroutine %v: exit", i)
121
+				f := convertPath(t, uut, name)
122
+				assert.Check(t, cmp.Equal("HI THERE", readAll(t, f)), "in goroutine %v", i)
123
+				closers <- f
124
+			}()
125
+		}
126
+
127
+		select {
128
+		case <-firstConversionStarted:
129
+		case <-time.After(2 * time.Second):
130
+			t.Fatal("the first conversion should have started by now")
131
+		}
132
+		close(unblock)
133
+		t.Log("starting wait")
134
+		wg.Wait()
135
+		t.Log("wait done")
136
+
137
+		f := convertPath(t, uut, name)
138
+		closers <- f
139
+		close(closers)
140
+		assert.Check(t, cmp.Equal("HI THERE", readAll(t, f)), "after all goroutines returned")
141
+		for c := range closers {
142
+			assert.Check(t, c.Close())
143
+		}
144
+
145
+		assert.Check(t, cmp.Equal(int32(1), conversions))
146
+
147
+		assert.NilError(t, os.Remove(name))
148
+		checkDirEmpty(t, dir)
149
+	})
150
+
151
+	t.Run("ConvertError", func(t *testing.T) {
152
+		t.Parallel()
153
+		dir := t.TempDir()
154
+		name := filepath.Join(dir, "test.txt")
155
+		createFile(t, name, "hi there")
156
+		src, err := open(name)
157
+		assert.NilError(t, err)
158
+		defer src.Close()
159
+
160
+		fakeErr := errors.New("fake error")
161
+		var start sync.WaitGroup
162
+		start.Add(3)
163
+		uut := newSharedTempFileConverter(
164
+			func(dst io.WriteSeeker, src io.ReadSeeker) error {
165
+				start.Wait()
166
+				runtime.Gosched()
167
+				if fakeErr != nil {
168
+					return fakeErr
169
+				}
170
+				return copyTransform(strings.ToUpper)(dst, src)
171
+			},
172
+		)
173
+		uut.TempDir = dir
174
+
175
+		var done sync.WaitGroup
176
+		done.Add(3)
177
+		for i := 0; i < 3; i++ {
178
+			i := i
179
+			go func() {
180
+				defer done.Done()
181
+				t.Logf("goroutine %v: enter", i)
182
+				defer t.Logf("goroutine %v: exit", i)
183
+				start.Done()
184
+				_, err := uut.Do(src)
185
+				assert.Check(t, errors.Is(err, fakeErr), "in goroutine %v", i)
186
+			}()
187
+		}
188
+		done.Wait()
189
+
190
+		// Conversion errors should not be "sticky". A subsequent
191
+		// request should retry from scratch.
192
+		fakeErr = errors.New("another fake error")
193
+		_, err = uut.Do(src)
194
+		assert.Check(t, errors.Is(err, fakeErr))
195
+
196
+		fakeErr = nil
197
+		f, err := uut.Do(src)
198
+		assert.Check(t, err)
199
+		assert.Check(t, cmp.Equal("HI THERE", readAll(t, f)))
200
+		assert.Check(t, f.Close())
201
+
202
+		// Files pending delete continue to show up in directory
203
+		// listings on Windows RS5. Close the remaining handle before
204
+		// deleting the file to prevent spurious failures with
205
+		// checkDirEmpty.
206
+		assert.Check(t, src.Close())
207
+		assert.NilError(t, os.Remove(name))
208
+		checkDirEmpty(t, dir)
209
+
210
+	})
211
+}
212
+
213
+func createFile(t *testing.T, path string, content string) {
214
+	t.Helper()
215
+	f, err := openFile(path, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0644)
216
+	assert.NilError(t, err)
217
+	_, err = io.WriteString(f, content)
218
+	assert.NilError(t, err)
219
+	assert.NilError(t, f.Close())
220
+}
221
+
222
+func convertPath(t *testing.T, uut *sharedTempFileConverter, path string) *sharedFileReader {
223
+	t.Helper()
224
+	f, err := open(path)
225
+	assert.NilError(t, err)
226
+	defer func() { assert.NilError(t, f.Close()) }()
227
+	r, err := uut.Do(f)
228
+	assert.NilError(t, err)
229
+	return r
230
+}
231
+
232
+func readAll(t *testing.T, r io.Reader) string {
233
+	t.Helper()
234
+	v, err := io.ReadAll(r)
235
+	assert.NilError(t, err)
236
+	return string(v)
237
+}
238
+
239
+func checkDirEmpty(t *testing.T, path string) {
240
+	t.Helper()
241
+	ls, err := os.ReadDir(path)
242
+	assert.NilError(t, err)
243
+	assert.Check(t, cmp.Len(ls, 0), "directory should be free of temp files")
244
+}
245
+
246
+func copyTransform(f func(string) string) func(dst io.WriteSeeker, src io.ReadSeeker) error {
247
+	return func(dst io.WriteSeeker, src io.ReadSeeker) error {
248
+		s, err := io.ReadAll(src)
249
+		if err != nil {
250
+			return err
251
+		}
252
+		_, err = io.WriteString(dst, f(string(s)))
253
+		return err
254
+	}
255
+}