Some corner cases and error conditions are covered while reading
and writing
Signed-off-by: Federico Gimenez <fgimenez@coit.es>
... | ... |
@@ -3,6 +3,7 @@ package stdcopy |
3 | 3 |
import ( |
4 | 4 |
"bytes" |
5 | 5 |
"errors" |
6 |
+ "io" |
|
6 | 7 |
"io/ioutil" |
7 | 8 |
"strings" |
8 | 9 |
"testing" |
... | ... |
@@ -85,17 +86,22 @@ func TestWriteDoesNotReturnNegativeWrittenBytes(t *testing.T) { |
85 | 85 |
} |
86 | 86 |
} |
87 | 87 |
|
88 |
-func TestStdCopyWriteAndRead(t *testing.T) { |
|
89 |
- buffer := new(bytes.Buffer) |
|
90 |
- stdOutBytes := []byte(strings.Repeat("o", startingBufLen)) |
|
88 |
+func getSrcBuffer(stdOutBytes, stdErrBytes []byte) (buffer *bytes.Buffer, err error) { |
|
89 |
+ buffer = new(bytes.Buffer) |
|
91 | 90 |
dstOut := NewStdWriter(buffer, Stdout) |
92 |
- _, err := dstOut.Write(stdOutBytes) |
|
91 |
+ _, err = dstOut.Write(stdOutBytes) |
|
93 | 92 |
if err != nil { |
94 |
- t.Fatal(err) |
|
93 |
+ return |
|
95 | 94 |
} |
96 |
- stdErrBytes := []byte(strings.Repeat("e", startingBufLen)) |
|
97 | 95 |
dstErr := NewStdWriter(buffer, Stderr) |
98 | 96 |
_, err = dstErr.Write(stdErrBytes) |
97 |
+ return |
|
98 |
+} |
|
99 |
+ |
|
100 |
+func TestStdCopyWriteAndRead(t *testing.T) { |
|
101 |
+ stdOutBytes := []byte(strings.Repeat("o", startingBufLen)) |
|
102 |
+ stdErrBytes := []byte(strings.Repeat("e", startingBufLen)) |
|
103 |
+ buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes) |
|
99 | 104 |
if err != nil { |
100 | 105 |
t.Fatal(err) |
101 | 106 |
} |
... | ... |
@@ -109,6 +115,78 @@ func TestStdCopyWriteAndRead(t *testing.T) { |
109 | 109 |
} |
110 | 110 |
} |
111 | 111 |
|
112 |
+type customReader struct { |
|
113 |
+ n int |
|
114 |
+ err error |
|
115 |
+ totalCalls int |
|
116 |
+ correctCalls int |
|
117 |
+ src *bytes.Buffer |
|
118 |
+} |
|
119 |
+ |
|
120 |
+func (f *customReader) Read(buf []byte) (int, error) { |
|
121 |
+ f.totalCalls++ |
|
122 |
+ if f.totalCalls <= f.correctCalls { |
|
123 |
+ return f.src.Read(buf) |
|
124 |
+ } |
|
125 |
+ return f.n, f.err |
|
126 |
+} |
|
127 |
+ |
|
128 |
+func TestStdCopyReturnsErrorReadingHeader(t *testing.T) { |
|
129 |
+ expectedError := errors.New("error") |
|
130 |
+ reader := &customReader{ |
|
131 |
+ err: expectedError} |
|
132 |
+ written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader) |
|
133 |
+ if written != 0 { |
|
134 |
+ t.Fatalf("Expected 0 bytes read, got %d", written) |
|
135 |
+ } |
|
136 |
+ if err != expectedError { |
|
137 |
+ t.Fatalf("Didn't get expected error") |
|
138 |
+ } |
|
139 |
+} |
|
140 |
+ |
|
141 |
+func TestStdCopyReturnsErrorReadingFrame(t *testing.T) { |
|
142 |
+ expectedError := errors.New("error") |
|
143 |
+ stdOutBytes := []byte(strings.Repeat("o", startingBufLen)) |
|
144 |
+ stdErrBytes := []byte(strings.Repeat("e", startingBufLen)) |
|
145 |
+ buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes) |
|
146 |
+ if err != nil { |
|
147 |
+ t.Fatal(err) |
|
148 |
+ } |
|
149 |
+ reader := &customReader{ |
|
150 |
+ correctCalls: 1, |
|
151 |
+ n: stdWriterPrefixLen + 1, |
|
152 |
+ err: expectedError, |
|
153 |
+ src: buffer} |
|
154 |
+ written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader) |
|
155 |
+ if written != 0 { |
|
156 |
+ t.Fatalf("Expected 0 bytes read, got %d", written) |
|
157 |
+ } |
|
158 |
+ if err != expectedError { |
|
159 |
+ t.Fatalf("Didn't get expected error") |
|
160 |
+ } |
|
161 |
+} |
|
162 |
+ |
|
163 |
+func TestStdCopyDetectsCorruptedFrame(t *testing.T) { |
|
164 |
+ stdOutBytes := []byte(strings.Repeat("o", startingBufLen)) |
|
165 |
+ stdErrBytes := []byte(strings.Repeat("e", startingBufLen)) |
|
166 |
+ buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes) |
|
167 |
+ if err != nil { |
|
168 |
+ t.Fatal(err) |
|
169 |
+ } |
|
170 |
+ reader := &customReader{ |
|
171 |
+ correctCalls: 1, |
|
172 |
+ n: stdWriterPrefixLen + 1, |
|
173 |
+ err: io.EOF, |
|
174 |
+ src: buffer} |
|
175 |
+ written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader) |
|
176 |
+ if written != startingBufLen { |
|
177 |
+ t.Fatalf("Expected 0 bytes read, got %d", written) |
|
178 |
+ } |
|
179 |
+ if err != nil { |
|
180 |
+ t.Fatal("Didn't get nil error") |
|
181 |
+ } |
|
182 |
+} |
|
183 |
+ |
|
112 | 184 |
func TestStdCopyWithInvalidInputHeader(t *testing.T) { |
113 | 185 |
dstOut := NewStdWriter(ioutil.Discard, Stdout) |
114 | 186 |
dstErr := NewStdWriter(ioutil.Discard, Stderr) |
... | ... |
@@ -131,6 +209,44 @@ func TestStdCopyWithCorruptedPrefix(t *testing.T) { |
131 | 131 |
} |
132 | 132 |
} |
133 | 133 |
|
134 |
+func TestStdCopyReturnsWriteErrors(t *testing.T) { |
|
135 |
+ stdOutBytes := []byte(strings.Repeat("o", startingBufLen)) |
|
136 |
+ stdErrBytes := []byte(strings.Repeat("e", startingBufLen)) |
|
137 |
+ buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes) |
|
138 |
+ if err != nil { |
|
139 |
+ t.Fatal(err) |
|
140 |
+ } |
|
141 |
+ expectedError := errors.New("expected") |
|
142 |
+ |
|
143 |
+ dstOut := &errWriter{err: expectedError} |
|
144 |
+ |
|
145 |
+ written, err := StdCopy(dstOut, ioutil.Discard, buffer) |
|
146 |
+ if written != 0 { |
|
147 |
+ t.Fatalf("StdCopy should have written 0, but has written %d", written) |
|
148 |
+ } |
|
149 |
+ if err != expectedError { |
|
150 |
+ t.Fatalf("Didn't get expected error, got %v", err) |
|
151 |
+ } |
|
152 |
+} |
|
153 |
+ |
|
154 |
+func TestStdCopyDetectsNotFullyWrittenFrames(t *testing.T) { |
|
155 |
+ stdOutBytes := []byte(strings.Repeat("o", startingBufLen)) |
|
156 |
+ stdErrBytes := []byte(strings.Repeat("e", startingBufLen)) |
|
157 |
+ buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes) |
|
158 |
+ if err != nil { |
|
159 |
+ t.Fatal(err) |
|
160 |
+ } |
|
161 |
+ dstOut := &errWriter{n: startingBufLen - 10} |
|
162 |
+ |
|
163 |
+ written, err := StdCopy(dstOut, ioutil.Discard, buffer) |
|
164 |
+ if written != 0 { |
|
165 |
+ t.Fatalf("StdCopy should have return 0 written bytes, but returned %d", written) |
|
166 |
+ } |
|
167 |
+ if err != io.ErrShortWrite { |
|
168 |
+ t.Fatalf("Didn't get expected io.ErrShortWrite error") |
|
169 |
+ } |
|
170 |
+} |
|
171 |
+ |
|
134 | 172 |
func BenchmarkWrite(b *testing.B) { |
135 | 173 |
w := NewStdWriter(ioutil.Discard, Stdout) |
136 | 174 |
data := []byte("Test line for testing stdwriter performance\n") |