Browse code

api/pkg/stdcopy: move stdWriter to daemon/internal

Clients have no need for muxing streams using our StdCopy wire format.

Signed-off-by: Cory Snider <csnider@mirantis.com>

Cory Snider authored on 2025/10/10 08:45:01
Showing 10 changed files
... ...
@@ -1,12 +1,10 @@
1 1
 package stdcopy
2 2
 
3 3
 import (
4
-	"bytes"
5 4
 	"encoding/binary"
6 5
 	"errors"
7 6
 	"fmt"
8 7
 	"io"
9
-	"sync"
10 8
 )
11 9
 
12 10
 // StdType is the type of standard stream
... ...
@@ -28,71 +26,6 @@ const (
28 28
 	startingBufLen = 32*1024 + stdWriterPrefixLen + 1
29 29
 )
30 30
 
31
-var bufPool = &sync.Pool{New: func() any { return bytes.NewBuffer(nil) }}
32
-
33
-// stdWriter is wrapper of io.Writer with extra customized info.
34
-type stdWriter struct {
35
-	io.Writer
36
-	prefix byte
37
-}
38
-
39
-// Write sends the buffer to the underlying writer.
40
-// It inserts the prefix header before the buffer,
41
-// so [StdCopy] knows where to multiplex the output.
42
-//
43
-// It implements [io.Writer].
44
-func (w *stdWriter) Write(p []byte) (int, error) {
45
-	if w == nil || w.Writer == nil {
46
-		return 0, errors.New("writer not instantiated")
47
-	}
48
-	if p == nil {
49
-		return 0, nil
50
-	}
51
-
52
-	header := [stdWriterPrefixLen]byte{stdWriterFdIndex: w.prefix}
53
-	binary.BigEndian.PutUint32(header[stdWriterSizeIndex:], uint32(len(p)))
54
-	buf := bufPool.Get().(*bytes.Buffer)
55
-	buf.Write(header[:])
56
-	buf.Write(p)
57
-
58
-	n, err := w.Writer.Write(buf.Bytes())
59
-	n -= stdWriterPrefixLen
60
-	if n < 0 {
61
-		n = 0
62
-	}
63
-
64
-	buf.Reset()
65
-	bufPool.Put(buf)
66
-	return n, err
67
-}
68
-
69
-// NewStdWriter instantiates a new writer using a custom format to multiplex
70
-// multiple streams to a single writer. All messages written using this writer
71
-// are encapsulated using a custom format, and written to the underlying
72
-// stream "w".
73
-//
74
-// Writers created through NewStdWriter allow for multiple write streams
75
-// (e.g., stdout ([Stdout]) and stderr ([Stderr]) to be multiplexed into a
76
-// single connection. "streamType" indicates the type of stream to encapsulate,
77
-// commonly, [Stdout] or [Stderr]. The [Systemerr] stream can be used to
78
-// include server-side errors in the stream. Information on this stream
79
-// is returned as an error by [StdCopy] and terminates processing the
80
-// stream.
81
-//
82
-// The [Stdin] stream is present for completeness and should generally
83
-// NOT be used. It is output on [Stdout] when reading the stream with
84
-// [StdCopy].
85
-//
86
-// All streams must share the same underlying [io.Writer] to ensure proper
87
-// multiplexing. Each call to NewStdWriter wraps that shared writer with
88
-// a header indicating the target stream.
89
-func NewStdWriter(w io.Writer, streamType StdType) io.Writer {
90
-	return &stdWriter{
91
-		Writer: w,
92
-		prefix: byte(streamType),
93
-	}
94
-}
95
-
96 31
 // StdCopy is a modified version of [io.Copy] to de-multiplex messages
97 32
 // from "multiplexedSource" and copy them to destination streams
98 33
 // "destOut" and "destErr".
99 34
deleted file mode 100644
... ...
@@ -1,66 +0,0 @@
1
-package stdcopy_test
2
-
3
-import (
4
-	"errors"
5
-	"fmt"
6
-	"io"
7
-	"os"
8
-	"time"
9
-
10
-	"github.com/moby/moby/api/pkg/stdcopy"
11
-)
12
-
13
-func ExampleNewStdWriter() {
14
-	muxReader, muxStream := io.Pipe()
15
-	defer func() { _ = muxStream.Close() }()
16
-
17
-	// Start demuxing before the daemon starts writing.
18
-	done := make(chan error, 1)
19
-	go func() {
20
-		// using os.Stdout for both, otherwise output doesn't show up in the example.
21
-		osStdout := os.Stdout
22
-		osStderr := os.Stdout
23
-		_, err := stdcopy.StdCopy(osStdout, osStderr, muxReader)
24
-		done <- err
25
-	}()
26
-
27
-	// daemon writing to stdout, stderr, and systemErr.
28
-	stdout := stdcopy.NewStdWriter(muxStream, stdcopy.Stdout)
29
-	stderr := stdcopy.NewStdWriter(muxStream, stdcopy.Stderr)
30
-	systemErr := stdcopy.NewStdWriter(muxStream, stdcopy.Systemerr)
31
-
32
-	for range 10 {
33
-		_, _ = fmt.Fprintln(stdout, "hello from stdout")
34
-		_, _ = fmt.Fprintln(stderr, "hello from stderr")
35
-		time.Sleep(50 * time.Millisecond)
36
-	}
37
-	_, _ = fmt.Fprintln(systemErr, errors.New("something went wrong"))
38
-
39
-	// Wait for the demuxer to finish.
40
-	if err := <-done; err != nil {
41
-		fmt.Println(err)
42
-	}
43
-
44
-	// Output:
45
-	// hello from stdout
46
-	// hello from stderr
47
-	// hello from stdout
48
-	// hello from stderr
49
-	// hello from stdout
50
-	// hello from stderr
51
-	// hello from stdout
52
-	// hello from stderr
53
-	// hello from stdout
54
-	// hello from stderr
55
-	// hello from stdout
56
-	// hello from stderr
57
-	// hello from stdout
58
-	// hello from stderr
59
-	// hello from stdout
60
-	// hello from stderr
61
-	// hello from stdout
62
-	// hello from stderr
63
-	// hello from stdout
64
-	// hello from stderr
65
-	// error from daemon in stream: something went wrong
66
-}
67 1
deleted file mode 100644
... ...
@@ -1,292 +0,0 @@
1
-package stdcopy
2
-
3
-import (
4
-	"bytes"
5
-	"errors"
6
-	"io"
7
-	"strings"
8
-	"testing"
9
-)
10
-
11
-func TestNewStdWriter(t *testing.T) {
12
-	writer := NewStdWriter(io.Discard, Stdout)
13
-	if writer == nil {
14
-		t.Fatalf("NewStdWriter with an invalid StdType should not return nil.")
15
-	}
16
-}
17
-
18
-func TestWriteWithUninitializedStdWriter(t *testing.T) {
19
-	writer := stdWriter{
20
-		Writer: nil,
21
-		prefix: byte(Stdout),
22
-	}
23
-	n, err := writer.Write([]byte("Something here"))
24
-	if n != 0 || err == nil {
25
-		t.Fatalf("Should fail when given an incomplete or uninitialized StdWriter")
26
-	}
27
-}
28
-
29
-func TestWriteWithNilBytes(t *testing.T) {
30
-	writer := NewStdWriter(io.Discard, Stdout)
31
-	n, err := writer.Write(nil)
32
-	if err != nil {
33
-		t.Fatalf("Shouldn't have fail when given no data")
34
-	}
35
-	if n > 0 {
36
-		t.Fatalf("Write should have written 0 byte, but has written %d", n)
37
-	}
38
-}
39
-
40
-func TestWrite(t *testing.T) {
41
-	writer := NewStdWriter(io.Discard, Stdout)
42
-	data := []byte("Test StdWrite.Write")
43
-	n, err := writer.Write(data)
44
-	if err != nil {
45
-		t.Fatalf("Error while writing with StdWrite")
46
-	}
47
-	if n != len(data) {
48
-		t.Fatalf("Write should have written %d byte but wrote %d.", len(data), n)
49
-	}
50
-}
51
-
52
-type errWriter struct {
53
-	n   int
54
-	err error
55
-}
56
-
57
-func (f *errWriter) Write(buf []byte) (int, error) {
58
-	return f.n, f.err
59
-}
60
-
61
-func TestWriteWithWriterError(t *testing.T) {
62
-	expectedError := errors.New("expected")
63
-	expectedReturnedBytes := 10
64
-	writer := NewStdWriter(&errWriter{
65
-		n:   stdWriterPrefixLen + expectedReturnedBytes,
66
-		err: expectedError,
67
-	}, Stdout)
68
-	data := []byte("This won't get written, sigh")
69
-	n, err := writer.Write(data)
70
-	if !errors.Is(err, expectedError) {
71
-		t.Fatalf("Didn't get expected error.")
72
-	}
73
-	if n != expectedReturnedBytes {
74
-		t.Fatalf("Didn't get expected written bytes %d, got %d.",
75
-			expectedReturnedBytes, n)
76
-	}
77
-}
78
-
79
-func TestWriteDoesNotReturnNegativeWrittenBytes(t *testing.T) {
80
-	writer := NewStdWriter(&errWriter{n: -1}, Stdout)
81
-	data := []byte("This won't get written, sigh")
82
-	actual, _ := writer.Write(data)
83
-	if actual != 0 {
84
-		t.Fatalf("Expected returned written bytes equal to 0, got %d", actual)
85
-	}
86
-}
87
-
88
-func getSrcBuffer(stdOutBytes, stdErrBytes []byte) (*bytes.Buffer, error) {
89
-	buffer := new(bytes.Buffer)
90
-	dstOut := NewStdWriter(buffer, Stdout)
91
-	_, err := dstOut.Write(stdOutBytes)
92
-	if err != nil {
93
-		return buffer, err
94
-	}
95
-	dstErr := NewStdWriter(buffer, Stderr)
96
-	_, err = dstErr.Write(stdErrBytes)
97
-	return buffer, err
98
-}
99
-
100
-func TestStdCopyWriteAndRead(t *testing.T) {
101
-	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
102
-	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
103
-	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
104
-	if err != nil {
105
-		t.Fatal(err)
106
-	}
107
-	written, err := StdCopy(io.Discard, io.Discard, buffer)
108
-	if err != nil {
109
-		t.Fatal(err)
110
-	}
111
-	expectedTotalWritten := len(stdOutBytes) + len(stdErrBytes)
112
-	if written != int64(expectedTotalWritten) {
113
-		t.Fatalf("Expected to have total of %d bytes written, got %d", expectedTotalWritten, written)
114
-	}
115
-}
116
-
117
-type customReader struct {
118
-	n            int
119
-	err          error
120
-	totalCalls   int
121
-	correctCalls int
122
-	src          *bytes.Buffer
123
-}
124
-
125
-func (f *customReader) Read(buf []byte) (int, error) {
126
-	f.totalCalls++
127
-	if f.totalCalls <= f.correctCalls {
128
-		return f.src.Read(buf)
129
-	}
130
-	return f.n, f.err
131
-}
132
-
133
-func TestStdCopyReturnsErrorReadingHeader(t *testing.T) {
134
-	expectedError := errors.New("error")
135
-	reader := &customReader{
136
-		err: expectedError,
137
-	}
138
-	written, err := StdCopy(io.Discard, io.Discard, reader)
139
-	if written != 0 {
140
-		t.Fatalf("Expected 0 bytes read, got %d", written)
141
-	}
142
-	if !errors.Is(err, expectedError) {
143
-		t.Fatalf("Didn't get expected error")
144
-	}
145
-}
146
-
147
-func TestStdCopyReturnsErrorReadingFrame(t *testing.T) {
148
-	expectedError := errors.New("error")
149
-	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
150
-	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
151
-	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
152
-	if err != nil {
153
-		t.Fatal(err)
154
-	}
155
-	reader := &customReader{
156
-		correctCalls: 1,
157
-		n:            stdWriterPrefixLen + 1,
158
-		err:          expectedError,
159
-		src:          buffer,
160
-	}
161
-	written, err := StdCopy(io.Discard, io.Discard, reader)
162
-	if written != 0 {
163
-		t.Fatalf("Expected 0 bytes read, got %d", written)
164
-	}
165
-	if !errors.Is(err, expectedError) {
166
-		t.Fatalf("Didn't get expected error")
167
-	}
168
-}
169
-
170
-func TestStdCopyDetectsCorruptedFrame(t *testing.T) {
171
-	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
172
-	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
173
-	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
174
-	if err != nil {
175
-		t.Fatal(err)
176
-	}
177
-	reader := &customReader{
178
-		correctCalls: 1,
179
-		n:            stdWriterPrefixLen + 1,
180
-		err:          io.EOF,
181
-		src:          buffer,
182
-	}
183
-	written, err := StdCopy(io.Discard, io.Discard, reader)
184
-	if written != startingBufLen {
185
-		t.Fatalf("Expected %d bytes read, got %d", startingBufLen, written)
186
-	}
187
-	if err != nil {
188
-		t.Fatal("Didn't get nil error")
189
-	}
190
-}
191
-
192
-func TestStdCopyWithInvalidInputHeader(t *testing.T) {
193
-	dstOut := NewStdWriter(io.Discard, Stdout)
194
-	dstErr := NewStdWriter(io.Discard, Stderr)
195
-	src := strings.NewReader("Invalid input")
196
-	_, err := StdCopy(dstOut, dstErr, src)
197
-	if err == nil {
198
-		t.Fatal("StdCopy with invalid input header should fail.")
199
-	}
200
-}
201
-
202
-func TestStdCopyWithCorruptedPrefix(t *testing.T) {
203
-	data := []byte{0x01, 0x02, 0x03}
204
-	src := bytes.NewReader(data)
205
-	written, err := StdCopy(nil, nil, src)
206
-	if err != nil {
207
-		t.Fatalf("StdCopy should not return an error with corrupted prefix.")
208
-	}
209
-	if written != 0 {
210
-		t.Fatalf("StdCopy should have written 0, but has written %d", written)
211
-	}
212
-}
213
-
214
-func TestStdCopyReturnsWriteErrors(t *testing.T) {
215
-	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
216
-	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
217
-	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
218
-	if err != nil {
219
-		t.Fatal(err)
220
-	}
221
-	expectedError := errors.New("expected")
222
-
223
-	dstOut := &errWriter{err: expectedError}
224
-
225
-	written, err := StdCopy(dstOut, io.Discard, buffer)
226
-	if written != 0 {
227
-		t.Fatalf("StdCopy should have written 0, but has written %d", written)
228
-	}
229
-	if !errors.Is(err, expectedError) {
230
-		t.Fatalf("Didn't get expected error, got %v", err)
231
-	}
232
-}
233
-
234
-func TestStdCopyDetectsNotFullyWrittenFrames(t *testing.T) {
235
-	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
236
-	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
237
-	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
238
-	if err != nil {
239
-		t.Fatal(err)
240
-	}
241
-	dstOut := &errWriter{n: startingBufLen - 10}
242
-
243
-	written, err := StdCopy(dstOut, io.Discard, buffer)
244
-	if written != 0 {
245
-		t.Fatalf("StdCopy should have return 0 written bytes, but returned %d", written)
246
-	}
247
-	if !errors.Is(err, io.ErrShortWrite) {
248
-		t.Fatalf("Didn't get expected io.ErrShortWrite error")
249
-	}
250
-}
251
-
252
-// TestStdCopyReturnsErrorFromSystem tests that StdCopy correctly returns an
253
-// error, when that error is muxed into the Systemerr stream.
254
-func TestStdCopyReturnsErrorFromSystem(t *testing.T) {
255
-	// write in the basic messages, just so there's some fluff in there
256
-	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
257
-	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
258
-	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
259
-	if err != nil {
260
-		t.Fatal(err)
261
-	}
262
-	// add in an error message on the Systemerr stream
263
-	systemErrBytes := []byte(strings.Repeat("S", startingBufLen))
264
-	systemWriter := NewStdWriter(buffer, Systemerr)
265
-	_, err = systemWriter.Write(systemErrBytes)
266
-	if err != nil {
267
-		t.Fatal(err)
268
-	}
269
-
270
-	// now copy and demux. we should expect an error containing the string we
271
-	// wrote out
272
-	_, err = StdCopy(io.Discard, io.Discard, buffer)
273
-	if err == nil {
274
-		t.Fatal("expected error, got none")
275
-	}
276
-	if !strings.Contains(err.Error(), string(systemErrBytes)) {
277
-		t.Fatal("expected error to contain message")
278
-	}
279
-}
280
-
281
-func BenchmarkWrite(b *testing.B) {
282
-	w := NewStdWriter(io.Discard, Stdout)
283
-	data := []byte("Test line for testing stdwriter performance\n")
284
-	data = bytes.Repeat(data, 100)
285
-	b.SetBytes(int64(len(data)))
286
-	b.ResetTimer()
287
-	for i := 0; i < b.N; i++ {
288
-		if _, err := w.Write(data); err != nil {
289
-			b.Fatal(err)
290
-		}
291
-	}
292
-}
... ...
@@ -10,6 +10,7 @@ import (
10 10
 	containertypes "github.com/moby/moby/api/types/container"
11 11
 	"github.com/moby/moby/api/types/events"
12 12
 	"github.com/moby/moby/v2/daemon/container"
13
+	"github.com/moby/moby/v2/daemon/internal/stdcopymux"
13 14
 	"github.com/moby/moby/v2/daemon/internal/stream"
14 15
 	"github.com/moby/moby/v2/daemon/logger"
15 16
 	"github.com/moby/moby/v2/daemon/server/backend"
... ...
@@ -74,8 +75,8 @@ func (daemon *Daemon) ContainerAttach(prefixOrName string, req *backend.Containe
74 74
 	defer inStream.Close()
75 75
 
76 76
 	if multiplexed {
77
-		errStream = stdcopy.NewStdWriter(errStream, stdcopy.Stderr)
78
-		outStream = stdcopy.NewStdWriter(outStream, stdcopy.Stdout)
77
+		errStream = stdcopymux.NewStdWriter(errStream, stdcopy.Stderr)
78
+		outStream = stdcopymux.NewStdWriter(outStream, stdcopy.Stdout)
79 79
 	}
80 80
 
81 81
 	if cfg.UseStdin {
82 82
new file mode 100644
... ...
@@ -0,0 +1,67 @@
0
+package stdcopymux_test
1
+
2
+import (
3
+	"errors"
4
+	"fmt"
5
+	"io"
6
+	"os"
7
+	"time"
8
+
9
+	"github.com/moby/moby/api/pkg/stdcopy"
10
+	"github.com/moby/moby/v2/daemon/internal/stdcopymux"
11
+)
12
+
13
+func ExampleNewStdWriter() {
14
+	muxReader, muxStream := io.Pipe()
15
+	defer func() { _ = muxStream.Close() }()
16
+
17
+	// Start demuxing before the daemon starts writing.
18
+	done := make(chan error, 1)
19
+	go func() {
20
+		// using os.Stdout for both, otherwise output doesn't show up in the example.
21
+		osStdout := os.Stdout
22
+		osStderr := os.Stdout
23
+		_, err := stdcopy.StdCopy(osStdout, osStderr, muxReader)
24
+		done <- err
25
+	}()
26
+
27
+	// daemon writing to stdout, stderr, and systemErr.
28
+	stdout := stdcopymux.NewStdWriter(muxStream, stdcopy.Stdout)
29
+	stderr := stdcopymux.NewStdWriter(muxStream, stdcopy.Stderr)
30
+	systemErr := stdcopymux.NewStdWriter(muxStream, stdcopy.Systemerr)
31
+
32
+	for range 10 {
33
+		_, _ = fmt.Fprintln(stdout, "hello from stdout")
34
+		_, _ = fmt.Fprintln(stderr, "hello from stderr")
35
+		time.Sleep(50 * time.Millisecond)
36
+	}
37
+	_, _ = fmt.Fprintln(systemErr, errors.New("something went wrong"))
38
+
39
+	// Wait for the demuxer to finish.
40
+	if err := <-done; err != nil {
41
+		fmt.Println(err)
42
+	}
43
+
44
+	// Output:
45
+	// hello from stdout
46
+	// hello from stderr
47
+	// hello from stdout
48
+	// hello from stderr
49
+	// hello from stdout
50
+	// hello from stderr
51
+	// hello from stdout
52
+	// hello from stderr
53
+	// hello from stdout
54
+	// hello from stderr
55
+	// hello from stdout
56
+	// hello from stderr
57
+	// hello from stdout
58
+	// hello from stderr
59
+	// hello from stdout
60
+	// hello from stderr
61
+	// hello from stdout
62
+	// hello from stderr
63
+	// hello from stdout
64
+	// hello from stderr
65
+	// error from daemon in stream: something went wrong
66
+}
0 67
new file mode 100644
... ...
@@ -0,0 +1,296 @@
0
+package stdcopymux
1
+
2
+import (
3
+	"bytes"
4
+	"errors"
5
+	"io"
6
+	"strings"
7
+	"testing"
8
+
9
+	"github.com/moby/moby/api/pkg/stdcopy"
10
+)
11
+
12
+const startingBufLen = 32*1024 + 8 /* stdwriterPrefixLen */ + 1
13
+
14
+func TestNewStdWriter(t *testing.T) {
15
+	writer := NewStdWriter(io.Discard, stdcopy.Stdout)
16
+	if writer == nil {
17
+		t.Fatalf("NewStdWriter with an invalid StdType should not return nil.")
18
+	}
19
+}
20
+
21
+func TestWriteWithUninitializedStdWriter(t *testing.T) {
22
+	writer := stdWriter{
23
+		Writer: nil,
24
+		prefix: byte(stdcopy.Stdout),
25
+	}
26
+	n, err := writer.Write([]byte("Something here"))
27
+	if n != 0 || err == nil {
28
+		t.Fatalf("Should fail when given an incomplete or uninitialized StdWriter")
29
+	}
30
+}
31
+
32
+func TestWriteWithNilBytes(t *testing.T) {
33
+	writer := NewStdWriter(io.Discard, stdcopy.Stdout)
34
+	n, err := writer.Write(nil)
35
+	if err != nil {
36
+		t.Fatalf("Shouldn't have fail when given no data")
37
+	}
38
+	if n > 0 {
39
+		t.Fatalf("Write should have written 0 byte, but has written %d", n)
40
+	}
41
+}
42
+
43
+func TestWrite(t *testing.T) {
44
+	writer := NewStdWriter(io.Discard, stdcopy.Stdout)
45
+	data := []byte("Test StdWrite.Write")
46
+	n, err := writer.Write(data)
47
+	if err != nil {
48
+		t.Fatalf("Error while writing with StdWrite")
49
+	}
50
+	if n != len(data) {
51
+		t.Fatalf("Write should have written %d byte but wrote %d.", len(data), n)
52
+	}
53
+}
54
+
55
+type errWriter struct {
56
+	n   int
57
+	err error
58
+}
59
+
60
+func (f *errWriter) Write(buf []byte) (int, error) {
61
+	return f.n, f.err
62
+}
63
+
64
+func TestWriteWithWriterError(t *testing.T) {
65
+	expectedError := errors.New("expected")
66
+	expectedReturnedBytes := 10
67
+	writer := NewStdWriter(&errWriter{
68
+		n:   stdWriterPrefixLen + expectedReturnedBytes,
69
+		err: expectedError,
70
+	}, stdcopy.Stdout)
71
+	data := []byte("This won't get written, sigh")
72
+	n, err := writer.Write(data)
73
+	if !errors.Is(err, expectedError) {
74
+		t.Fatalf("Didn't get expected error.")
75
+	}
76
+	if n != expectedReturnedBytes {
77
+		t.Fatalf("Didn't get expected written bytes %d, got %d.",
78
+			expectedReturnedBytes, n)
79
+	}
80
+}
81
+
82
+func TestWriteDoesNotReturnNegativeWrittenBytes(t *testing.T) {
83
+	writer := NewStdWriter(&errWriter{n: -1}, stdcopy.Stdout)
84
+	data := []byte("This won't get written, sigh")
85
+	actual, _ := writer.Write(data)
86
+	if actual != 0 {
87
+		t.Fatalf("Expected returned written bytes equal to 0, got %d", actual)
88
+	}
89
+}
90
+
91
+func getSrcBuffer(stdOutBytes, stdErrBytes []byte) (*bytes.Buffer, error) {
92
+	buffer := new(bytes.Buffer)
93
+	dstOut := NewStdWriter(buffer, stdcopy.Stdout)
94
+	_, err := dstOut.Write(stdOutBytes)
95
+	if err != nil {
96
+		return buffer, err
97
+	}
98
+	dstErr := NewStdWriter(buffer, stdcopy.Stderr)
99
+	_, err = dstErr.Write(stdErrBytes)
100
+	return buffer, err
101
+}
102
+
103
+func TestStdCopyWriteAndRead(t *testing.T) {
104
+	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
105
+	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
106
+	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
107
+	if err != nil {
108
+		t.Fatal(err)
109
+	}
110
+	written, err := stdcopy.StdCopy(io.Discard, io.Discard, buffer)
111
+	if err != nil {
112
+		t.Fatal(err)
113
+	}
114
+	expectedTotalWritten := len(stdOutBytes) + len(stdErrBytes)
115
+	if written != int64(expectedTotalWritten) {
116
+		t.Fatalf("Expected to have total of %d bytes written, got %d", expectedTotalWritten, written)
117
+	}
118
+}
119
+
120
+type customReader struct {
121
+	n            int
122
+	err          error
123
+	totalCalls   int
124
+	correctCalls int
125
+	src          *bytes.Buffer
126
+}
127
+
128
+func (f *customReader) Read(buf []byte) (int, error) {
129
+	f.totalCalls++
130
+	if f.totalCalls <= f.correctCalls {
131
+		return f.src.Read(buf)
132
+	}
133
+	return f.n, f.err
134
+}
135
+
136
+func TestStdCopyReturnsErrorReadingHeader(t *testing.T) {
137
+	expectedError := errors.New("error")
138
+	reader := &customReader{
139
+		err: expectedError,
140
+	}
141
+	written, err := stdcopy.StdCopy(io.Discard, io.Discard, reader)
142
+	if written != 0 {
143
+		t.Fatalf("Expected 0 bytes read, got %d", written)
144
+	}
145
+	if !errors.Is(err, expectedError) {
146
+		t.Fatalf("Didn't get expected error")
147
+	}
148
+}
149
+
150
+func TestStdCopyReturnsErrorReadingFrame(t *testing.T) {
151
+	expectedError := errors.New("error")
152
+	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
153
+	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
154
+	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
155
+	if err != nil {
156
+		t.Fatal(err)
157
+	}
158
+	reader := &customReader{
159
+		correctCalls: 1,
160
+		n:            stdWriterPrefixLen + 1,
161
+		err:          expectedError,
162
+		src:          buffer,
163
+	}
164
+	written, err := stdcopy.StdCopy(io.Discard, io.Discard, reader)
165
+	if written != 0 {
166
+		t.Fatalf("Expected 0 bytes read, got %d", written)
167
+	}
168
+	if !errors.Is(err, expectedError) {
169
+		t.Fatalf("Didn't get expected error")
170
+	}
171
+}
172
+
173
+func TestStdCopyDetectsCorruptedFrame(t *testing.T) {
174
+	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
175
+	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
176
+	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
177
+	if err != nil {
178
+		t.Fatal(err)
179
+	}
180
+	reader := &customReader{
181
+		correctCalls: 1,
182
+		n:            stdWriterPrefixLen + 1,
183
+		err:          io.EOF,
184
+		src:          buffer,
185
+	}
186
+	written, err := stdcopy.StdCopy(io.Discard, io.Discard, reader)
187
+	if written != startingBufLen {
188
+		t.Fatalf("Expected %d bytes read, got %d", startingBufLen, written)
189
+	}
190
+	if err != nil {
191
+		t.Fatal("Didn't get nil error")
192
+	}
193
+}
194
+
195
+func TestStdCopyWithInvalidInputHeader(t *testing.T) {
196
+	dstOut := NewStdWriter(io.Discard, stdcopy.Stdout)
197
+	dstErr := NewStdWriter(io.Discard, stdcopy.Stderr)
198
+	src := strings.NewReader("Invalid input")
199
+	_, err := stdcopy.StdCopy(dstOut, dstErr, src)
200
+	if err == nil {
201
+		t.Fatal("StdCopy with invalid input header should fail.")
202
+	}
203
+}
204
+
205
+func TestStdCopyWithCorruptedPrefix(t *testing.T) {
206
+	data := []byte{0x01, 0x02, 0x03}
207
+	src := bytes.NewReader(data)
208
+	written, err := stdcopy.StdCopy(nil, nil, src)
209
+	if err != nil {
210
+		t.Fatalf("StdCopy should not return an error with corrupted prefix.")
211
+	}
212
+	if written != 0 {
213
+		t.Fatalf("StdCopy should have written 0, but has written %d", written)
214
+	}
215
+}
216
+
217
+func TestStdCopyReturnsWriteErrors(t *testing.T) {
218
+	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
219
+	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
220
+	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
221
+	if err != nil {
222
+		t.Fatal(err)
223
+	}
224
+	expectedError := errors.New("expected")
225
+
226
+	dstOut := &errWriter{err: expectedError}
227
+
228
+	written, err := stdcopy.StdCopy(dstOut, io.Discard, buffer)
229
+	if written != 0 {
230
+		t.Fatalf("StdCopy should have written 0, but has written %d", written)
231
+	}
232
+	if !errors.Is(err, expectedError) {
233
+		t.Fatalf("Didn't get expected error, got %v", err)
234
+	}
235
+}
236
+
237
+func TestStdCopyDetectsNotFullyWrittenFrames(t *testing.T) {
238
+	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
239
+	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
240
+	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
241
+	if err != nil {
242
+		t.Fatal(err)
243
+	}
244
+	dstOut := &errWriter{n: startingBufLen - 10}
245
+
246
+	written, err := stdcopy.StdCopy(dstOut, io.Discard, buffer)
247
+	if written != 0 {
248
+		t.Fatalf("StdCopy should have return 0 written bytes, but returned %d", written)
249
+	}
250
+	if !errors.Is(err, io.ErrShortWrite) {
251
+		t.Fatalf("Didn't get expected io.ErrShortWrite error")
252
+	}
253
+}
254
+
255
+// TestStdCopyReturnsErrorFromSystem tests that StdCopy correctly returns an
256
+// error, when that error is muxed into the Systemerr stream.
257
+func TestStdCopyReturnsErrorFromSystem(t *testing.T) {
258
+	// write in the basic messages, just so there's some fluff in there
259
+	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
260
+	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
261
+	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
262
+	if err != nil {
263
+		t.Fatal(err)
264
+	}
265
+	// add in an error message on the Systemerr stream
266
+	systemErrBytes := []byte(strings.Repeat("S", startingBufLen))
267
+	systemWriter := NewStdWriter(buffer, stdcopy.Systemerr)
268
+	_, err = systemWriter.Write(systemErrBytes)
269
+	if err != nil {
270
+		t.Fatal(err)
271
+	}
272
+
273
+	// now copy and demux. we should expect an error containing the string we
274
+	// wrote out
275
+	_, err = stdcopy.StdCopy(io.Discard, io.Discard, buffer)
276
+	if err == nil {
277
+		t.Fatal("expected error, got none")
278
+	}
279
+	if !strings.Contains(err.Error(), string(systemErrBytes)) {
280
+		t.Fatal("expected error to contain message")
281
+	}
282
+}
283
+
284
+func BenchmarkWrite(b *testing.B) {
285
+	w := NewStdWriter(io.Discard, stdcopy.Stdout)
286
+	data := []byte("Test line for testing stdwriter performance\n")
287
+	data = bytes.Repeat(data, 100)
288
+	b.SetBytes(int64(len(data)))
289
+	b.ResetTimer()
290
+	for i := 0; i < b.N; i++ {
291
+		if _, err := w.Write(data); err != nil {
292
+			b.Fatal(err)
293
+		}
294
+	}
295
+}
0 296
new file mode 100644
... ...
@@ -0,0 +1,82 @@
0
+package stdcopymux
1
+
2
+import (
3
+	"bytes"
4
+	"encoding/binary"
5
+	"errors"
6
+	"io"
7
+	"sync"
8
+
9
+	"github.com/moby/moby/api/pkg/stdcopy"
10
+)
11
+
12
+const (
13
+	stdWriterPrefixLen = 8
14
+	stdWriterFdIndex   = 0
15
+	stdWriterSizeIndex = 4
16
+)
17
+
18
+var bufPool = &sync.Pool{New: func() any { return bytes.NewBuffer(nil) }}
19
+
20
+// stdWriter is wrapper of io.Writer with extra customized info.
21
+type stdWriter struct {
22
+	io.Writer
23
+	prefix byte
24
+}
25
+
26
+// Write sends the buffer to the underlying writer.
27
+// It inserts the prefix header before the buffer,
28
+// so [StdCopy] knows where to multiplex the output.
29
+//
30
+// It implements [io.Writer].
31
+func (w *stdWriter) Write(p []byte) (int, error) {
32
+	if w == nil || w.Writer == nil {
33
+		return 0, errors.New("writer not instantiated")
34
+	}
35
+	if p == nil {
36
+		return 0, nil
37
+	}
38
+
39
+	header := [stdWriterPrefixLen]byte{stdWriterFdIndex: w.prefix}
40
+	binary.BigEndian.PutUint32(header[stdWriterSizeIndex:], uint32(len(p)))
41
+	buf := bufPool.Get().(*bytes.Buffer)
42
+	buf.Write(header[:])
43
+	buf.Write(p)
44
+
45
+	n, err := w.Writer.Write(buf.Bytes())
46
+	n -= stdWriterPrefixLen
47
+	if n < 0 {
48
+		n = 0
49
+	}
50
+
51
+	buf.Reset()
52
+	bufPool.Put(buf)
53
+	return n, err
54
+}
55
+
56
+// NewStdWriter instantiates a new writer using a custom format to multiplex
57
+// multiple streams to a single writer. All messages written using this writer
58
+// are encapsulated using a custom format, and written to the underlying
59
+// stream "w".
60
+//
61
+// Writers created through NewStdWriter allow for multiple write streams
62
+// (e.g., stdout ([Stdout]) and stderr ([Stderr]) to be multiplexed into a
63
+// single connection. "streamType" indicates the type of stream to encapsulate,
64
+// commonly, [Stdout] or [Stderr]. The [Systemerr] stream can be used to
65
+// include server-side errors in the stream. Information on this stream
66
+// is returned as an error by [StdCopy] and terminates processing the
67
+// stream.
68
+//
69
+// The [Stdin] stream is present for completeness and should generally
70
+// NOT be used. It is output on [Stdout] when reading the stream with
71
+// [StdCopy].
72
+//
73
+// All streams must share the same underlying [io.Writer] to ensure proper
74
+// multiplexing. Each call to NewStdWriter wraps that shared writer with
75
+// a header indicating the target stream.
76
+func NewStdWriter(w io.Writer, streamType stdcopy.StdType) io.Writer {
77
+	return &stdWriter{
78
+		Writer: w,
79
+		prefix: byte(streamType),
80
+	}
81
+}
... ...
@@ -9,6 +9,7 @@ import (
9 9
 	"sort"
10 10
 
11 11
 	"github.com/moby/moby/api/pkg/stdcopy"
12
+	"github.com/moby/moby/v2/daemon/internal/stdcopymux"
12 13
 	"github.com/moby/moby/v2/daemon/server/backend"
13 14
 	"github.com/moby/moby/v2/pkg/ioutils"
14 15
 )
... ...
@@ -33,9 +34,9 @@ func WriteLogStream(_ context.Context, w http.ResponseWriter, msgs <-chan *backe
33 33
 	errStream := outStream
34 34
 	sysErrStream := errStream
35 35
 	if mux {
36
-		sysErrStream = stdcopy.NewStdWriter(outStream, stdcopy.Systemerr)
37
-		errStream = stdcopy.NewStdWriter(outStream, stdcopy.Stderr)
38
-		outStream = stdcopy.NewStdWriter(outStream, stdcopy.Stdout)
36
+		sysErrStream = stdcopymux.NewStdWriter(outStream, stdcopy.Systemerr)
37
+		errStream = stdcopymux.NewStdWriter(outStream, stdcopy.Stderr)
38
+		outStream = stdcopymux.NewStdWriter(outStream, stdcopy.Stdout)
39 39
 	}
40 40
 
41 41
 	for {
... ...
@@ -11,6 +11,7 @@ import (
11 11
 	"github.com/moby/moby/api/types"
12 12
 	"github.com/moby/moby/api/types/container"
13 13
 	"github.com/moby/moby/api/types/versions"
14
+	"github.com/moby/moby/v2/daemon/internal/stdcopymux"
14 15
 	"github.com/moby/moby/v2/daemon/server/backend"
15 16
 	"github.com/moby/moby/v2/daemon/server/httputils"
16 17
 	"github.com/moby/moby/v2/errdefs"
... ...
@@ -130,8 +131,8 @@ func (c *containerRouter) postContainerExecStart(ctx context.Context, w http.Res
130 130
 		if options.Tty {
131 131
 			stdout = outStream
132 132
 		} else {
133
-			stderr = stdcopy.NewStdWriter(outStream, stdcopy.Stderr)
134
-			stdout = stdcopy.NewStdWriter(outStream, stdcopy.Stdout)
133
+			stderr = stdcopymux.NewStdWriter(outStream, stdcopy.Stderr)
134
+			stdout = stdcopymux.NewStdWriter(outStream, stdcopy.Stdout)
135 135
 		}
136 136
 	}
137 137
 
... ...
@@ -1,12 +1,10 @@
1 1
 package stdcopy
2 2
 
3 3
 import (
4
-	"bytes"
5 4
 	"encoding/binary"
6 5
 	"errors"
7 6
 	"fmt"
8 7
 	"io"
9
-	"sync"
10 8
 )
11 9
 
12 10
 // StdType is the type of standard stream
... ...
@@ -28,71 +26,6 @@ const (
28 28
 	startingBufLen = 32*1024 + stdWriterPrefixLen + 1
29 29
 )
30 30
 
31
-var bufPool = &sync.Pool{New: func() any { return bytes.NewBuffer(nil) }}
32
-
33
-// stdWriter is wrapper of io.Writer with extra customized info.
34
-type stdWriter struct {
35
-	io.Writer
36
-	prefix byte
37
-}
38
-
39
-// Write sends the buffer to the underlying writer.
40
-// It inserts the prefix header before the buffer,
41
-// so [StdCopy] knows where to multiplex the output.
42
-//
43
-// It implements [io.Writer].
44
-func (w *stdWriter) Write(p []byte) (int, error) {
45
-	if w == nil || w.Writer == nil {
46
-		return 0, errors.New("writer not instantiated")
47
-	}
48
-	if p == nil {
49
-		return 0, nil
50
-	}
51
-
52
-	header := [stdWriterPrefixLen]byte{stdWriterFdIndex: w.prefix}
53
-	binary.BigEndian.PutUint32(header[stdWriterSizeIndex:], uint32(len(p)))
54
-	buf := bufPool.Get().(*bytes.Buffer)
55
-	buf.Write(header[:])
56
-	buf.Write(p)
57
-
58
-	n, err := w.Writer.Write(buf.Bytes())
59
-	n -= stdWriterPrefixLen
60
-	if n < 0 {
61
-		n = 0
62
-	}
63
-
64
-	buf.Reset()
65
-	bufPool.Put(buf)
66
-	return n, err
67
-}
68
-
69
-// NewStdWriter instantiates a new writer using a custom format to multiplex
70
-// multiple streams to a single writer. All messages written using this writer
71
-// are encapsulated using a custom format, and written to the underlying
72
-// stream "w".
73
-//
74
-// Writers created through NewStdWriter allow for multiple write streams
75
-// (e.g., stdout ([Stdout]) and stderr ([Stderr]) to be multiplexed into a
76
-// single connection. "streamType" indicates the type of stream to encapsulate,
77
-// commonly, [Stdout] or [Stderr]. The [Systemerr] stream can be used to
78
-// include server-side errors in the stream. Information on this stream
79
-// is returned as an error by [StdCopy] and terminates processing the
80
-// stream.
81
-//
82
-// The [Stdin] stream is present for completeness and should generally
83
-// NOT be used. It is output on [Stdout] when reading the stream with
84
-// [StdCopy].
85
-//
86
-// All streams must share the same underlying [io.Writer] to ensure proper
87
-// multiplexing. Each call to NewStdWriter wraps that shared writer with
88
-// a header indicating the target stream.
89
-func NewStdWriter(w io.Writer, streamType StdType) io.Writer {
90
-	return &stdWriter{
91
-		Writer: w,
92
-		prefix: byte(streamType),
93
-	}
94
-}
95
-
96 31
 // StdCopy is a modified version of [io.Copy] to de-multiplex messages
97 32
 // from "multiplexedSource" and copy them to destination streams
98 33
 // "destOut" and "destErr".