Browse code

integration: Extract stream demultiplexing from container.Exec

The original code in container.Exec was potentially leaking the copy
goroutine when the context was cancelled or timed out. The new
`demultiplexStreams()` function won't return until the goroutine has
finished its work, and to ensure that it takes care of closing the
hijacked connection.

Signed-off-by: Albin Kerouanton <albinker@gmail.com>

Albin Kerouanton authored on 2023/07/06 17:39:13
Showing 2 changed files
... ...
@@ -1,14 +1,17 @@
1 1
 package container
2 2
 
3 3
 import (
4
+	"bytes"
4 5
 	"context"
5 6
 	"runtime"
7
+	"sync"
6 8
 	"testing"
7 9
 
8 10
 	"github.com/docker/docker/api/types"
9 11
 	"github.com/docker/docker/api/types/container"
10 12
 	"github.com/docker/docker/api/types/network"
11 13
 	"github.com/docker/docker/client"
14
+	"github.com/docker/docker/pkg/stdcopy"
12 15
 	ocispec "github.com/opencontainers/image-spec/specs-go/v1"
13 16
 	"gotest.tools/v3/assert"
14 17
 )
... ...
@@ -71,3 +74,36 @@ func Run(ctx context.Context, t *testing.T, client client.APIClient, ops ...func
71 71
 
72 72
 	return id
73 73
 }
74
+
75
+type streams struct {
76
+	stdout, stderr bytes.Buffer
77
+}
78
+
79
+// demultiplexStreams starts a goroutine to demultiplex stdout and stderr from the types.HijackedResponse resp and
80
+// waits until either multiplexed stream reaches EOF or the context expires. It unconditionally closes resp and waits
81
+// until the demultiplexing goroutine has finished its work before returning.
82
+func demultiplexStreams(ctx context.Context, resp types.HijackedResponse) (streams, error) {
83
+	var s streams
84
+	outputDone := make(chan error, 1)
85
+
86
+	var wg sync.WaitGroup
87
+	wg.Add(1)
88
+	go func() {
89
+		_, err := stdcopy.StdCopy(&s.stdout, &s.stderr, resp.Reader)
90
+		outputDone <- err
91
+		wg.Done()
92
+	}()
93
+
94
+	var err error
95
+	select {
96
+	case copyErr := <-outputDone:
97
+		err = copyErr
98
+		break
99
+	case <-ctx.Done():
100
+		err = ctx.Err()
101
+	}
102
+
103
+	resp.Close()
104
+	wg.Wait()
105
+	return s, err
106
+}
... ...
@@ -6,7 +6,6 @@ import (
6 6
 
7 7
 	"github.com/docker/docker/api/types"
8 8
 	"github.com/docker/docker/client"
9
-	"github.com/docker/docker/pkg/stdcopy"
10 9
 )
11 10
 
12 11
 // ExecResult represents a result returned from Exec()
... ...
@@ -58,27 +57,11 @@ func Exec(ctx context.Context, cli client.APIClient, id string, cmd []string, op
58 58
 	if err != nil {
59 59
 		return ExecResult{}, err
60 60
 	}
61
-	defer aresp.Close()
62 61
 
63 62
 	// read the output
64
-	var outBuf, errBuf bytes.Buffer
65
-	outputDone := make(chan error, 1)
66
-
67
-	go func() {
68
-		// StdCopy demultiplexes the stream into two buffers
69
-		_, err = stdcopy.StdCopy(&outBuf, &errBuf, aresp.Reader)
70
-		outputDone <- err
71
-	}()
72
-
73
-	select {
74
-	case err := <-outputDone:
75
-		if err != nil {
76
-			return ExecResult{}, err
77
-		}
78
-		break
79
-
80
-	case <-ctx.Done():
81
-		return ExecResult{}, ctx.Err()
63
+	s, err := demultiplexStreams(ctx, aresp)
64
+	if err != nil {
65
+		return ExecResult{}, err
82 66
 	}
83 67
 
84 68
 	// get the exit code
... ...
@@ -87,5 +70,5 @@ func Exec(ctx context.Context, cli client.APIClient, id string, cmd []string, op
87 87
 		return ExecResult{}, err
88 88
 	}
89 89
 
90
-	return ExecResult{ExitCode: iresp.ExitCode, outBuffer: &outBuf, errBuffer: &errBuf}, nil
90
+	return ExecResult{ExitCode: iresp.ExitCode, outBuffer: &s.stdout, errBuffer: &s.stderr}, nil
91 91
 }