Browse code

Fix goroutine/fd leak when client disconnects

In cases where the client disconnects and there is nothing to read from
a stdio stream after that disconnect, the copy goroutines and file
descriptors are leaked because `io.Copy` is just blocked waiting for
data from the container's I/O stream.

This fix only applies to Linux.
Windows will need a separate fix.

Signed-off-by: Brian Goff <cpuguy83@gmail.com>

Brian Goff authored on 2023/02/22 07:32:15
Showing 8 changed files
... ...
@@ -775,7 +775,7 @@ func (s *containerRouter) postContainersAttach(ctx context.Context, w http.Respo
775 775
 	}
776 776
 
777 777
 	contentType := types.MediaTypeRawStream
778
-	setupStreams := func(multiplexed bool) (io.ReadCloser, io.Writer, io.Writer, error) {
778
+	setupStreams := func(multiplexed bool, cancel func()) (io.ReadCloser, io.Writer, io.Writer, error) {
779 779
 		conn, _, err := hijacker.Hijack()
780 780
 		if err != nil {
781 781
 			return nil, nil, nil, err
... ...
@@ -793,6 +793,8 @@ func (s *containerRouter) postContainersAttach(ctx context.Context, w http.Respo
793 793
 			fmt.Fprintf(conn, "HTTP/1.1 200 OK\r\nContent-Type: application/vnd.docker.raw-stream\r\n\r\n")
794 794
 		}
795 795
 
796
+		go notifyClosed(ctx, conn, cancel)
797
+
796 798
 		closer := func() error {
797 799
 			httputils.CloseStreams(conn)
798 800
 			return nil
... ...
@@ -841,7 +843,7 @@ func (s *containerRouter) wsContainersAttach(ctx context.Context, w http.Respons
841 841
 
842 842
 	version := httputils.VersionFromContext(ctx)
843 843
 
844
-	setupStreams := func(multiplexed bool) (io.ReadCloser, io.Writer, io.Writer, error) {
844
+	setupStreams := func(multiplexed bool, cancel func()) (io.ReadCloser, io.Writer, io.Writer, error) {
845 845
 		wsChan := make(chan *websocket.Conn)
846 846
 		h := func(conn *websocket.Conn) {
847 847
 			wsChan <- conn
... ...
@@ -860,6 +862,8 @@ func (s *containerRouter) wsContainersAttach(ctx context.Context, w http.Respons
860 860
 		if versions.GreaterThanOrEqualTo(version, "1.28") {
861 861
 			conn.PayloadType = websocket.BinaryFrame
862 862
 		}
863
+
864
+		// TODO: Close notifications
863 865
 		return conn, conn, conn, nil
864 866
 	}
865 867
 
866 868
new file mode 100644
... ...
@@ -0,0 +1,54 @@
0
+package container
1
+
2
+import (
3
+	"context"
4
+	"net"
5
+	"syscall"
6
+
7
+	"github.com/containerd/log"
8
+	"github.com/docker/docker/internal/unix_noeintr"
9
+	"golang.org/x/sys/unix"
10
+)
11
+
12
+func notifyClosed(ctx context.Context, conn net.Conn, notify func()) {
13
+	sc, ok := conn.(syscall.Conn)
14
+	if !ok {
15
+		log.G(ctx).Debug("notifyClosed: conn does not support close notifications")
16
+		return
17
+	}
18
+
19
+	rc, err := sc.SyscallConn()
20
+	if err != nil {
21
+		log.G(ctx).WithError(err).Warn("notifyClosed: failed get raw conn for close notifications")
22
+		return
23
+	}
24
+
25
+	epFd, err := unix_noeintr.EpollCreate()
26
+	if err != nil {
27
+		log.G(ctx).WithError(err).Warn("notifyClosed: failed to create epoll fd")
28
+		return
29
+	}
30
+	defer unix.Close(epFd)
31
+
32
+	err = rc.Control(func(fd uintptr) {
33
+		err := unix_noeintr.EpollCtl(epFd, unix.EPOLL_CTL_ADD, int(fd), &unix.EpollEvent{
34
+			Events: unix.EPOLLHUP,
35
+			Fd:     int32(fd),
36
+		})
37
+		if err != nil {
38
+			log.G(ctx).WithError(err).Warn("notifyClosed: failed to register fd for close notifications")
39
+			return
40
+		}
41
+
42
+		events := make([]unix.EpollEvent, 1)
43
+		if _, err := unix_noeintr.EpollWait(epFd, events, -1); err != nil {
44
+			log.G(ctx).WithError(err).Warn("notifyClosed: failed to wait for close notifications")
45
+			return
46
+		}
47
+		notify()
48
+	})
49
+	if err != nil {
50
+		log.G(ctx).WithError(err).Warn("notifyClosed: failed to register for close notifications")
51
+		return
52
+	}
53
+}
0 54
new file mode 100644
... ...
@@ -0,0 +1,10 @@
0
+//go:build !linux
1
+
2
+package container
3
+
4
+import (
5
+	"context"
6
+	"net"
7
+)
8
+
9
+func notifyClosed(ctx context.Context, conn net.Conn, notify func()) {}
... ...
@@ -30,7 +30,7 @@ type ContainerRmConfig struct {
30 30
 
31 31
 // ContainerAttachConfig holds the streams to use when connecting to a container to view logs.
32 32
 type ContainerAttachConfig struct {
33
-	GetStreams func(multiplexed bool) (io.ReadCloser, io.Writer, io.Writer, error)
33
+	GetStreams func(multiplexed bool, cancel func()) (io.ReadCloser, io.Writer, io.Writer, error)
34 34
 	UseStdin   bool
35 35
 	UseStdout  bool
36 36
 	UseStderr  bool
... ...
@@ -52,10 +52,27 @@ func (daemon *Daemon) ContainerAttach(prefixOrName string, c *backend.ContainerA
52 52
 	ctr.StreamConfig.AttachStreams(&cfg)
53 53
 
54 54
 	multiplexed := !ctr.Config.Tty && c.MuxStreams
55
-	inStream, outStream, errStream, err := c.GetStreams(multiplexed)
55
+
56
+	clientCtx, closeNotify := context.WithCancel(context.Background())
57
+	defer closeNotify()
58
+	go func() {
59
+		<-clientCtx.Done()
60
+		// The client has disconnected
61
+		// In this case we need to close the container's output streams so that the goroutines used to copy
62
+		// to the client streams are unblocked and can exit.
63
+		if cfg.CStdout != nil {
64
+			cfg.CStdout.Close()
65
+		}
66
+		if cfg.CStderr != nil {
67
+			cfg.CStderr.Close()
68
+		}
69
+	}()
70
+
71
+	inStream, outStream, errStream, err := c.GetStreams(multiplexed, closeNotify)
56 72
 	if err != nil {
57 73
 		return err
58 74
 	}
75
+
59 76
 	defer inStream.Close()
60 77
 
61 78
 	if multiplexed {
... ...
@@ -2,13 +2,18 @@ package container // import "github.com/docker/docker/integration/container"
2 2
 
3 3
 import (
4 4
 	"testing"
5
+	"time"
5 6
 
6 7
 	"github.com/docker/docker/api/types"
7 8
 	"github.com/docker/docker/api/types/container"
8 9
 	"github.com/docker/docker/api/types/network"
10
+	systemutil "github.com/docker/docker/integration/internal/system"
9 11
 	"github.com/docker/docker/testutil"
12
+	"github.com/docker/docker/testutil/daemon"
10 13
 	"gotest.tools/v3/assert"
11 14
 	is "gotest.tools/v3/assert/cmp"
15
+	"gotest.tools/v3/poll"
16
+	"gotest.tools/v3/skip"
12 17
 )
13 18
 
14 19
 func TestAttach(t *testing.T) {
... ...
@@ -59,3 +64,58 @@ func TestAttach(t *testing.T) {
59 59
 		})
60 60
 	}
61 61
 }
62
+
63
+// Regression test for #37182
64
+func TestAttachDisconnectLeak(t *testing.T) {
65
+	skip.If(t, testEnv.DaemonInfo.OSType != "linux", "Bug still exists on Windows")
66
+	t.Parallel()
67
+
68
+	ctx := testutil.StartSpan(baseContext, t)
69
+
70
+	// Use a new daemon to make sure stuff from other tests isn't affecting the
71
+	// goroutine count.
72
+	d := daemon.New(t)
73
+	defer d.Cleanup(t)
74
+
75
+	d.StartWithBusybox(ctx, t, "--iptables=false")
76
+
77
+	client := d.NewClientT(t)
78
+
79
+	resp, err := client.ContainerCreate(ctx,
80
+		&container.Config{
81
+			Image: "busybox",
82
+			Cmd:   []string{"/bin/sh", "-c", "while true; usleep 100000; done"},
83
+		},
84
+		&container.HostConfig{},
85
+		&network.NetworkingConfig{},
86
+		nil,
87
+		"",
88
+	)
89
+	assert.NilError(t, err)
90
+	cID := resp.ID
91
+	defer client.ContainerRemove(ctx, cID, container.RemoveOptions{
92
+		Force: true,
93
+	})
94
+
95
+	nGoroutines := systemutil.WaitForStableGoroutineCount(ctx, t, client)
96
+
97
+	attach, err := client.ContainerAttach(ctx, cID, container.AttachOptions{
98
+		Stdout: true,
99
+	})
100
+	assert.NilError(t, err)
101
+	defer attach.Close()
102
+
103
+	poll.WaitOn(t, func(_ poll.LogT) poll.Result {
104
+		count := systemutil.WaitForStableGoroutineCount(ctx, t, client)
105
+		if count > nGoroutines {
106
+			return poll.Success()
107
+		}
108
+		return poll.Continue("waiting for goroutines to increase from %d, current: %d", nGoroutines, count)
109
+	},
110
+		poll.WithTimeout(time.Minute),
111
+	)
112
+
113
+	attach.Close()
114
+
115
+	poll.WaitOn(t, systemutil.CheckGoroutineCount(ctx, client, nGoroutines), poll.WithTimeout(time.Minute))
116
+}
62 117
new file mode 100644
... ...
@@ -0,0 +1,78 @@
0
+package system
1
+
2
+import (
3
+	"context"
4
+	"time"
5
+
6
+	"github.com/docker/docker/client"
7
+	"gotest.tools/v3/poll"
8
+)
9
+
10
+// WaitForStableGoroutineCount polls the daemon Info API and returns the reported goroutine count
11
+// after multiple calls return the same number.
12
+func WaitForStableGoroutineCount(ctx context.Context, t poll.TestingT, apiClient client.SystemAPIClient, opts ...poll.SettingOp) int {
13
+	var out int
14
+	// Use a longish delay to make sure the goroutine count is actually stable.
15
+	defaults := []poll.SettingOp{poll.WithTimeout(time.Minute), poll.WithDelay(time.Second)}
16
+	opts = append(defaults, opts...)
17
+
18
+	poll.WaitOn(t, StableGoroutineCount(ctx, apiClient, &out), opts...)
19
+	return out
20
+}
21
+
22
+// StableGoroutineCount is a [poll.Check] that polls the daemon info API until the goroutine count is the same for 3 iterations.
23
+func StableGoroutineCount(ctx context.Context, apiClient client.SystemAPIClient, count *int) poll.Check {
24
+	var (
25
+		numStable int
26
+		nRoutines int
27
+	)
28
+
29
+	return func(t poll.LogT) poll.Result {
30
+		n, err := getGoroutineNumber(ctx, apiClient)
31
+		if err != nil {
32
+			return poll.Error(err)
33
+		}
34
+
35
+		last := nRoutines
36
+
37
+		if nRoutines == n {
38
+			numStable++
39
+		} else {
40
+			numStable = 0
41
+			nRoutines = n
42
+		}
43
+
44
+		if numStable > 3 {
45
+			*count = n
46
+			return poll.Success()
47
+		}
48
+		return poll.Continue("goroutine count is not stable: last %d, current %d, stable iters: %d", last, n, numStable)
49
+	}
50
+}
51
+
52
+// CheckGoroutineCount returns a [poll.Check] that polls the daemon info API until the expected number of goroutines is hit.
53
+func CheckGoroutineCount(ctx context.Context, apiClient client.SystemAPIClient, expected int) poll.Check {
54
+	first := true
55
+	return func(t poll.LogT) poll.Result {
56
+		n, err := getGoroutineNumber(ctx, apiClient)
57
+		if err != nil {
58
+			return poll.Error(err)
59
+		}
60
+		if n > expected {
61
+			if first {
62
+				t.Log("Waiting for goroutines to stabilize")
63
+				first = false
64
+			}
65
+			return poll.Continue("exepcted %d goroutines, got %d", expected, n)
66
+		}
67
+		return poll.Success()
68
+	}
69
+}
70
+
71
+func getGoroutineNumber(ctx context.Context, apiClient client.SystemAPIClient) (int, error) {
72
+	info, err := apiClient.Info(ctx)
73
+	if err != nil {
74
+		return 0, err
75
+	}
76
+	return info.NGoroutines, nil
77
+}
0 78
new file mode 100644
... ...
@@ -0,0 +1,37 @@
0
+package unix_noeintr
1
+
2
+import (
3
+	"errors"
4
+
5
+	"golang.org/x/sys/unix"
6
+)
7
+
8
+func EpollCreate() (int, error) {
9
+	for {
10
+		fd, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC)
11
+		if errors.Is(err, unix.EINTR) {
12
+			continue
13
+		}
14
+		return fd, err
15
+	}
16
+}
17
+
18
+func EpollCtl(epFd int, op int, fd int, event *unix.EpollEvent) error {
19
+	for {
20
+		err := unix.EpollCtl(epFd, op, fd, event)
21
+		if errors.Is(err, unix.EINTR) {
22
+			continue
23
+		}
24
+		return err
25
+	}
26
+}
27
+
28
+func EpollWait(epFd int, events []unix.EpollEvent, msec int) (int, error) {
29
+	for {
30
+		n, err := unix.EpollWait(epFd, events, msec)
31
+		if errors.Is(err, unix.EINTR) {
32
+			continue
33
+		}
34
+		return n, err
35
+	}
36
+}