Browse code

Make stdcopy.stdWriter goroutine safe.

Stop using global variables as prefixes to inject the writer header.
That can cause issues when two writers set the length of the buffer in
the same header concurrently.

Stop Writing to the internal buffer twice for each write. This could
mess up with the ordering information is written.

Signed-off-by: David Calavera <david.calavera@gmail.com>

David Calavera authored on 2016/02/26 08:22:19
Showing 2 changed files
... ...
@@ -3,12 +3,24 @@ package stdcopy
3 3
 import (
4 4
 	"encoding/binary"
5 5
 	"errors"
6
+	"fmt"
6 7
 	"io"
7 8
 
8 9
 	"github.com/Sirupsen/logrus"
9 10
 )
10 11
 
12
+// StdType is the type of standard stream
13
+// a writer can multiplex to.
14
+type StdType byte
15
+
11 16
 const (
17
+	// Stdin represents standard input stream type.
18
+	Stdin StdType = iota
19
+	// Stdout represents standard output stream type.
20
+	Stdout
21
+	// Stderr represents standard error steam type.
22
+	Stderr
23
+
12 24
 	stdWriterPrefixLen = 8
13 25
 	stdWriterFdIndex   = 0
14 26
 	stdWriterSizeIndex = 4
... ...
@@ -16,38 +28,32 @@ const (
16 16
 	startingBufLen = 32*1024 + stdWriterPrefixLen + 1
17 17
 )
18 18
 
19
-// StdType prefixes type and length to standard stream.
20
-type StdType [stdWriterPrefixLen]byte
21
-
22
-var (
23
-	// Stdin represents standard input stream type.
24
-	Stdin = StdType{0: 0}
25
-	// Stdout represents standard output stream type.
26
-	Stdout = StdType{0: 1}
27
-	// Stderr represents standard error steam type.
28
-	Stderr = StdType{0: 2}
29
-)
30
-
31
-// StdWriter is wrapper of io.Writer with extra customized info.
32
-type StdWriter struct {
19
+// stdWriter is wrapper of io.Writer with extra customized info.
20
+type stdWriter struct {
33 21
 	io.Writer
34
-	prefix  StdType
35
-	sizeBuf []byte
22
+	prefix byte
36 23
 }
37 24
 
38
-func (w *StdWriter) Write(buf []byte) (n int, err error) {
39
-	var n1, n2 int
25
+// Write sends the buffer to the underneath writer.
26
+// It insert the prefix header before the buffer,
27
+// so stdcopy.StdCopy knows where to multiplex the output.
28
+// It makes stdWriter to implement io.Writer.
29
+func (w *stdWriter) Write(buf []byte) (n int, err error) {
40 30
 	if w == nil || w.Writer == nil {
41 31
 		return 0, errors.New("Writer not instantiated")
42 32
 	}
43
-	binary.BigEndian.PutUint32(w.prefix[4:], uint32(len(buf)))
44
-	n1, err = w.Writer.Write(w.prefix[:])
45
-	if err != nil {
46
-		n = n1 - stdWriterPrefixLen
47
-	} else {
48
-		n2, err = w.Writer.Write(buf)
49
-		n = n1 + n2 - stdWriterPrefixLen
33
+	if buf == nil {
34
+		return 0, nil
50 35
 	}
36
+
37
+	header := [stdWriterPrefixLen]byte{stdWriterFdIndex: w.prefix}
38
+	binary.BigEndian.PutUint32(header[stdWriterSizeIndex:], uint32(len(buf)))
39
+
40
+	line := append(header[:], buf...)
41
+
42
+	n, err = w.Writer.Write(line)
43
+	n -= stdWriterPrefixLen
44
+
51 45
 	if n < 0 {
52 46
 		n = 0
53 47
 	}
... ...
@@ -60,16 +66,13 @@ func (w *StdWriter) Write(buf []byte) (n int, err error) {
60 60
 // This allows multiple write streams (e.g. stdout and stderr) to be muxed into a single connection.
61 61
 // `t` indicates the id of the stream to encapsulate.
62 62
 // It can be stdcopy.Stdin, stdcopy.Stdout, stdcopy.Stderr.
63
-func NewStdWriter(w io.Writer, t StdType) *StdWriter {
64
-	return &StdWriter{
65
-		Writer:  w,
66
-		prefix:  t,
67
-		sizeBuf: make([]byte, 4),
63
+func NewStdWriter(w io.Writer, t StdType) io.Writer {
64
+	return &stdWriter{
65
+		Writer: w,
66
+		prefix: byte(t),
68 67
 	}
69 68
 }
70 69
 
71
-var errInvalidStdHeader = errors.New("Unrecognized input header")
72
-
73 70
 // StdCopy is a modified version of io.Copy.
74 71
 //
75 72
 // StdCopy will demultiplex `src`, assuming that it contains two streams,
... ...
@@ -110,18 +113,18 @@ func StdCopy(dstout, dsterr io.Writer, src io.Reader) (written int64, err error)
110 110
 		}
111 111
 
112 112
 		// Check the first byte to know where to write
113
-		switch buf[stdWriterFdIndex] {
114
-		case 0:
113
+		switch StdType(buf[stdWriterFdIndex]) {
114
+		case Stdin:
115 115
 			fallthrough
116
-		case 1:
116
+		case Stdout:
117 117
 			// Write on stdout
118 118
 			out = dstout
119
-		case 2:
119
+		case Stderr:
120 120
 			// Write on stderr
121 121
 			out = dsterr
122 122
 		default:
123 123
 			logrus.Debugf("Error selecting output fd: (%d)", buf[stdWriterFdIndex])
124
-			return 0, errInvalidStdHeader
124
+			return 0, fmt.Errorf("Unrecognized input header: %d", buf[stdWriterFdIndex])
125 125
 		}
126 126
 
127 127
 		// Retrieve the size of the frame
... ...
@@ -17,10 +17,9 @@ func TestNewStdWriter(t *testing.T) {
17 17
 }
18 18
 
19 19
 func TestWriteWithUnitializedStdWriter(t *testing.T) {
20
-	writer := StdWriter{
21
-		Writer:  nil,
22
-		prefix:  Stdout,
23
-		sizeBuf: make([]byte, 4),
20
+	writer := stdWriter{
21
+		Writer: nil,
22
+		prefix: byte(Stdout),
24 23
 	}
25 24
 	n, err := writer.Write([]byte("Something here"))
26 25
 	if n != 0 || err == nil {
... ...
@@ -180,7 +179,7 @@ func TestStdCopyDetectsCorruptedFrame(t *testing.T) {
180 180
 		src:          buffer}
181 181
 	written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
182 182
 	if written != startingBufLen {
183
-		t.Fatalf("Expected 0 bytes read, got %d", written)
183
+		t.Fatalf("Expected %d bytes read, got %d", startingBufLen, written)
184 184
 	}
185 185
 	if err != nil {
186 186
 		t.Fatal("Didn't get nil error")