Browse code

Merge remote-tracking branch 'robryk/writebroadcaster-stuff'

Solomon Hykes authored on 2013/04/03 12:35:13
Showing 3 changed files
... ...
@@ -171,14 +171,14 @@ func (container *Container) startPty() error {
171 171
 
172 172
 	// Copy the PTYs to our broadcasters
173 173
 	go func() {
174
-		defer container.stdout.Close()
174
+		defer container.stdout.CloseWriters()
175 175
 		Debugf("[startPty] Begin of stdout pipe")
176 176
 		io.Copy(container.stdout, stdoutMaster)
177 177
 		Debugf("[startPty] End of stdout pipe")
178 178
 	}()
179 179
 
180 180
 	go func() {
181
-		defer container.stderr.Close()
181
+		defer container.stderr.CloseWriters()
182 182
 		Debugf("[startPty] Begin of stderr pipe")
183 183
 		io.Copy(container.stderr, stderrMaster)
184 184
 		Debugf("[startPty] End of stderr pipe")
... ...
@@ -391,10 +391,10 @@ func (container *Container) monitor() {
391 391
 			Debugf("%s: Error close stdin: %s", container.Id, err)
392 392
 		}
393 393
 	}
394
-	if err := container.stdout.Close(); err != nil {
394
+	if err := container.stdout.CloseWriters(); err != nil {
395 395
 		Debugf("%s: Error close stdout: %s", container.Id, err)
396 396
 	}
397
-	if err := container.stderr.Close(); err != nil {
397
+	if err := container.stderr.CloseWriters(); err != nil {
398 398
 		Debugf("%s: Error close stderr: %s", container.Id, err)
399 399
 	}
400 400
 
... ...
@@ -2,7 +2,6 @@ package docker
2 2
 
3 3
 import (
4 4
 	"bytes"
5
-	"container/list"
6 5
 	"errors"
7 6
 	"fmt"
8 7
 	"github.com/dotcloud/docker/rcli"
... ...
@@ -215,52 +214,48 @@ func (r *bufReader) Close() error {
215 215
 }
216 216
 
217 217
 type writeBroadcaster struct {
218
-	writers *list.List
218
+	mu      sync.Mutex
219
+	writers map[io.WriteCloser]struct{}
219 220
 }
220 221
 
221 222
 func (w *writeBroadcaster) AddWriter(writer io.WriteCloser) {
222
-	w.writers.PushBack(writer)
223
+	w.mu.Lock()
224
+	w.writers[writer] = struct{}{}
225
+	w.mu.Unlock()
223 226
 }
224 227
 
225 228
 // FIXME: Is that function used?
229
+// FIXME: This relies on the concrete writer type used having equality operator
226 230
 func (w *writeBroadcaster) RemoveWriter(writer io.WriteCloser) {
227
-	for e := w.writers.Front(); e != nil; e = e.Next() {
228
-		v := e.Value.(io.Writer)
229
-		if v == writer {
230
-			w.writers.Remove(e)
231
-			return
232
-		}
233
-	}
231
+	w.mu.Lock()
232
+	delete(w.writers, writer)
233
+	w.mu.Unlock()
234 234
 }
235 235
 
236 236
 func (w *writeBroadcaster) Write(p []byte) (n int, err error) {
237
-	failed := []*list.Element{}
238
-	for e := w.writers.Front(); e != nil; e = e.Next() {
239
-		writer := e.Value.(io.Writer)
237
+	w.mu.Lock()
238
+	defer w.mu.Unlock()
239
+	for writer := range w.writers {
240 240
 		if n, err := writer.Write(p); err != nil || n != len(p) {
241 241
 			// On error, evict the writer
242
-			failed = append(failed, e)
242
+			delete(w.writers, writer)
243 243
 		}
244 244
 	}
245
-	// We cannot remove while iterating, so it has to be done in
246
-	// a separate step
247
-	for _, e := range failed {
248
-		w.writers.Remove(e)
249
-	}
250 245
 	return len(p), nil
251 246
 }
252 247
 
253
-func (w *writeBroadcaster) Close() error {
254
-	for e := w.writers.Front(); e != nil; e = e.Next() {
255
-		writer := e.Value.(io.WriteCloser)
248
+func (w *writeBroadcaster) CloseWriters() error {
249
+	w.mu.Lock()
250
+	defer w.mu.Unlock()
251
+	for writer := range w.writers {
256 252
 		writer.Close()
257 253
 	}
258
-	w.writers.Init()
254
+	w.writers = make(map[io.WriteCloser]struct{})
259 255
 	return nil
260 256
 }
261 257
 
262 258
 func newWriteBroadcaster() *writeBroadcaster {
263
-	return &writeBroadcaster{list.New()}
259
+	return &writeBroadcaster{writers: make(map[io.WriteCloser]struct{})}
264 260
 }
265 261
 
266 262
 func getTotalUsedFds() int {
... ...
@@ -122,7 +122,29 @@ func TestWriteBroadcaster(t *testing.T) {
122 122
 		t.Errorf("Buffer contains %v", bufferC.String())
123 123
 	}
124 124
 
125
-	writer.Close()
125
+	writer.CloseWriters()
126
+}
127
+
128
+type devNullCloser int
129
+
130
+func (d devNullCloser) Close() error {
131
+	return nil
132
+}
133
+
134
+func (d devNullCloser) Write(buf []byte) (int, error) {
135
+	return len(buf), nil
136
+}
137
+
138
+// This test checks for races. It is only useful when run with the race detector.
139
+func TestRaceWriteBroadcaster(t *testing.T) {
140
+	writer := newWriteBroadcaster()
141
+	c := make(chan bool)
142
+	go func() {
143
+		writer.AddWriter(devNullCloser(0))
144
+		c <- true
145
+	}()
146
+	writer.Write([]byte("hello"))
147
+	<-c
126 148
 }
127 149
 
128 150
 // Test the behavior of TruncIndex, an index for querying IDs from a non-conflicting prefix.